Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
243 changes: 242 additions & 1 deletion crates/cli/tests/coverage/plugins_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use nemo_relay::config_editor::{EditorConfig, EditorSchema};
use nemo_relay::observability::plugin_component::{OBSERVABILITY_PLUGIN_KIND, ObservabilityConfig};
use nemo_relay::plugin::{ConfigPolicy, PluginComponentSpec, PluginConfig};
use nemo_relay::plugins::nemo_guardrails::component::{
NEMO_GUARDRAILS_PLUGIN_KIND, NeMoGuardrailsConfig, RemoteBackendConfig,
LocalBackendConfig, NEMO_GUARDRAILS_PLUGIN_KIND, NeMoGuardrailsConfig, RemoteBackendConfig,
};
use nemo_relay_adaptive::AdaptiveConfig;
use nemo_relay_adaptive::plugin_component::ADAPTIVE_PLUGIN_KIND;
Expand Down Expand Up @@ -50,6 +50,40 @@ fn guardrails_component_config(config_id: &str) -> serde_json::Map<String, Value
.clone()
}

fn local_guardrails_component_config(config_path: &str) -> serde_json::Map<String, Value> {
json!({
"mode": "local",
"input": false,
"output": false,
"config_path": config_path,
"tool_input": true,
"tool_output": true,
"local": {
"python_module": "custom_guardrails"
}
})
.as_object()
.unwrap()
.clone()
}

fn local_llm_guardrails_component_config(config_yaml: &str) -> serde_json::Map<String, Value> {
json!({
"mode": "local",
"codec": "openai_chat",
"input": true,
"output": true,
"config_yaml": config_yaml,
"colang_content": "define flow noop\n pass",
"local": {
"python_module": "custom_guardrails"
}
})
.as_object()
.unwrap()
.clone()
}

#[test]
fn target_scope_defaults_to_user_and_rejects_conflicts() {
assert_eq!(
Expand Down Expand Up @@ -160,6 +194,24 @@ fn typed_editor_model_contains_nemo_guardrails_options() {
EditorFieldKind::StringMap
);

let local = schema.field("local").unwrap().schema().unwrap();
assert_eq!(
local.field("python_module").unwrap().kind,
EditorFieldKind::String
);
assert_eq!(
schema.field("config_path").unwrap().kind,
EditorFieldKind::String
);
assert_eq!(
schema.field("config_yaml").unwrap().kind,
EditorFieldKind::String
);
assert_eq!(
schema.field("colang_content").unwrap().kind,
EditorFieldKind::String
);

let request_defaults = schema.field("request_defaults").unwrap().schema().unwrap();
let rails = request_defaults.field("rails").unwrap().schema().unwrap();
assert_eq!(
Expand Down Expand Up @@ -1137,6 +1189,98 @@ fn validate_config_accepts_nemo_guardrails_component() {
validate_config(&config).unwrap();
}

#[test]
fn validate_config_accepts_local_tool_only_nemo_guardrails_component() {
let config = PluginConfig {
components: vec![PluginComponentSpec {
kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(),
enabled: true,
config: local_guardrails_component_config("./rails"),
}],
..PluginConfig::default()
};

validate_config(&config).unwrap();
}

#[test]
fn validate_config_rejects_local_nemo_guardrails_request_defaults() {
let config = PluginConfig {
components: vec![PluginComponentSpec {
kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(),
enabled: true,
config: json!({
"mode": "local",
"codec": "openai_chat",
"input": true,
"output": true,
"config_yaml": "models: []",
"request_defaults": {
"context": {"tenant": "demo"}
}
})
.as_object()
.unwrap()
.clone(),
}],
..PluginConfig::default()
};

let error = validate_config(&config).unwrap_err().to_string();
assert!(error.contains("request_defaults"), "error was: {error}");
assert!(error.contains("local mode"), "error was: {error}");
}

#[test]
fn validate_config_rejects_local_nemo_guardrails_multiple_config_sources() {
let config = PluginConfig {
components: vec![PluginComponentSpec {
kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(),
enabled: true,
config: json!({
"mode": "local",
"config_path": "./rails",
"config_yaml": "models: []"
})
.as_object()
.unwrap()
.clone(),
}],
..PluginConfig::default()
};

let error = validate_config(&config).unwrap_err().to_string();
assert!(
error.contains("exactly one of config_path or config_yaml"),
"error was: {error}"
);
}

#[test]
fn validate_config_rejects_local_nemo_guardrails_colang_without_yaml() {
let config = PluginConfig {
components: vec![PluginComponentSpec {
kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(),
enabled: true,
config: json!({
"mode": "local",
"config_path": "./rails",
"colang_content": "define flow noop\n pass"
})
.as_object()
.unwrap()
.clone(),
}],
..PluginConfig::default()
};

let error = validate_config(&config).unwrap_err().to_string();
assert!(
error.contains("colang_content can only be used with config_yaml"),
"error was: {error}"
);
}

#[test]
fn nemo_guardrails_config_map_prunes_default_version() {
let map = nemo_guardrails_config_map(&NeMoGuardrailsConfig {
Expand All @@ -1155,6 +1299,103 @@ fn nemo_guardrails_config_map_prunes_default_version() {
assert_eq!(map["remote"]["config_id"], json!("default"));
}

#[test]
fn write_plugin_config_round_trips_local_nemo_guardrails_component() {
let temp = tempfile::tempdir().unwrap();
let path = temp.path().join("plugins.toml");
let config = PluginConfig {
components: vec![PluginComponentSpec {
kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(),
enabled: true,
config: local_guardrails_component_config("./rails"),
}],
..PluginConfig::default()
};

write_plugin_config(&path, &config).unwrap();

let rendered = std::fs::read_to_string(&path).unwrap();
assert!(rendered.contains("mode = \"local\""));
assert!(rendered.contains("config_path = \"./rails\""));
assert!(rendered.contains("tool_input = true"));
assert!(rendered.contains("python_module = \"custom_guardrails\""));

let round_tripped = read_plugin_config(&path).unwrap();
let guardrails = round_tripped
.components
.iter()
.find(|component| component.kind == NEMO_GUARDRAILS_PLUGIN_KIND)
.unwrap();
assert!(guardrails.enabled);
assert_eq!(guardrails.config["mode"], json!("local"));
assert_eq!(guardrails.config["config_path"], json!("./rails"));
assert_eq!(guardrails.config["tool_input"], json!(true));
assert_eq!(
guardrails.config["local"]["python_module"],
json!("custom_guardrails")
);
}

#[test]
fn nemo_guardrails_config_map_serializes_local_mode_fields() {
let map = nemo_guardrails_config_map(&NeMoGuardrailsConfig {
mode: "local".into(),
config_path: Some("./rails".into()),
tool_input: true,
tool_output: true,
local: Some(LocalBackendConfig {
python_module: Some("custom_guardrails".into()),
}),
..NeMoGuardrailsConfig::default()
})
.unwrap();

assert!(!map.contains_key("version"));
assert_eq!(map.get("mode"), Some(&json!("local")));
assert_eq!(map.get("config_path"), Some(&json!("./rails")));
assert_eq!(map.get("tool_input"), Some(&json!(true)));
assert_eq!(map["local"]["python_module"], json!("custom_guardrails"));
}

#[test]
fn write_plugin_config_round_trips_local_llm_nemo_guardrails_component() {
let temp = tempfile::tempdir().unwrap();
let path = temp.path().join("plugins.toml");
let config = PluginConfig {
components: vec![PluginComponentSpec {
kind: NEMO_GUARDRAILS_PLUGIN_KIND.to_string(),
enabled: true,
config: local_llm_guardrails_component_config("models: []"),
}],
..PluginConfig::default()
};

write_plugin_config(&path, &config).unwrap();

let rendered = std::fs::read_to_string(&path).unwrap();
assert!(rendered.contains("mode = \"local\""));
assert!(rendered.contains("codec = \"openai_chat\""));
assert!(rendered.contains("input = true"));
assert!(rendered.contains("output = true"));
assert!(rendered.contains("config_yaml = \"models: []\""));

let round_tripped = read_plugin_config(&path).unwrap();
let guardrails = round_tripped
.components
.iter()
.find(|component| component.kind == NEMO_GUARDRAILS_PLUGIN_KIND)
.unwrap();
assert_eq!(guardrails.config["mode"], json!("local"));
assert_eq!(guardrails.config["codec"], json!("openai_chat"));
assert_eq!(guardrails.config["input"], json!(true));
assert_eq!(guardrails.config["output"], json!(true));
assert_eq!(guardrails.config["config_yaml"], json!("models: []"));
assert_eq!(
guardrails.config["colang_content"],
json!("define flow noop\n pass")
);
}

#[test]
fn display_helpers_render_scalars_json_and_defaults() {
assert_eq!(display_value(&json!("logs")), "logs");
Expand Down
20 changes: 17 additions & 3 deletions crates/core/src/plugins/nemo_guardrails/component.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,13 @@ use crate::plugin::{
register_plugin,
};

#[path = "local.rs"]
mod local;
#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))]
#[path = "remote.rs"]
mod remote;
use local::register_local_backend;
pub use local::{clear_local_backend_provider, register_local_backend_provider};
#[cfg(all(feature = "guardrails-remote", not(target_arch = "wasm32")))]
use remote::register_remote_backend;

Expand Down Expand Up @@ -447,9 +451,7 @@ fn register_nemo_guardrails_backend(
) -> PluginResult<()> {
match config.mode.as_str() {
"remote" => register_remote_backend(config, ctx),
"local" => Err(PluginError::RegistrationFailed(
"built-in NeMo Guardrails local backend is not implemented yet".to_string(),
)),
"local" => register_local_backend(config, ctx),
other => Err(PluginError::InvalidConfig(format!(
"unsupported NeMo Guardrails mode '{other}'"
))),
Expand Down Expand Up @@ -955,6 +957,18 @@ fn validate_request_defaults(
return;
};

if config.mode == "local" {
push_policy_diag(
diagnostics,
policy.unsupported_value,
"nemo_guardrails.unsupported_value",
Some(NEMO_GUARDRAILS_PLUGIN_KIND.to_string()),
Some("request_defaults".to_string()),
"local mode does not currently support request_defaults".to_string(),
);
return;
}

validate_json_object_field(
diagnostics,
policy,
Expand Down
52 changes: 52 additions & 0 deletions crates/core/src/plugins/nemo_guardrails/local.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::sync::{Arc, LazyLock, Mutex, MutexGuard};

use crate::plugin::{PluginError, PluginRegistrationContext, Result as PluginResult};

use super::NeMoGuardrailsConfig;

type LocalBackendProvider = Arc<
dyn Fn(NeMoGuardrailsConfig, &mut PluginRegistrationContext) -> PluginResult<()> + Send + Sync,
>;

static LOCAL_BACKEND_PROVIDER: LazyLock<Mutex<Option<LocalBackendProvider>>> =
LazyLock::new(|| Mutex::new(None));

fn local_backend_provider_guard() -> PluginResult<MutexGuard<'static, Option<LocalBackendProvider>>>
{
LOCAL_BACKEND_PROVIDER.lock().map_err(|e| {
PluginError::Internal(format!(
"NeMo Guardrails local backend provider lock poisoned: {e}"
))
})
}

#[doc(hidden)]
pub fn register_local_backend_provider(provider: LocalBackendProvider) -> PluginResult<()> {
let mut guard = local_backend_provider_guard()?;
*guard = Some(provider);
Ok(())
}

#[doc(hidden)]
pub fn clear_local_backend_provider() -> PluginResult<()> {
let mut guard = local_backend_provider_guard()?;
*guard = None;
Ok(())
}

pub(super) fn register_local_backend(
config: NeMoGuardrailsConfig,
ctx: &mut PluginRegistrationContext,
) -> PluginResult<()> {
let provider = local_backend_provider_guard()?.clone();

match provider {
Some(provider) => provider(config, ctx),
None => Err(PluginError::RegistrationFailed(
"built-in NeMo Guardrails local backend is unavailable in this runtime".to_string(),
)),
}
}
Loading
Loading