Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -270,9 +270,9 @@ public Evaluation evaluate(final List<Message> messages, final Options options)
WafMetricCollector.get().aiGuardRequest(action, shouldBlock);
if (shouldBlock) {
span.setTag(BLOCKED_TAG, true);
throw new AIGuardAbortError(action, reason, tags);
throw new AIGuardAbortError(action, reason, tags, sdsFindings);
}
return new Evaluation(action, reason, tags);
return new Evaluation(action, reason, tags, sdsFindings);
}
} catch (AIGuardAbortError e) {
span.addThrowable(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,11 +210,13 @@ class AIGuardInternalTests extends DDSpecification {
error.action == suite.action
error.reason == suite.reason
error.tags == suite.tags
error.sds == []
} else {
error == null
eval.action == suite.action
eval.reason == suite.reason
eval.tags == suite.tags
eval.sds == []
}
assertTelemetry('ai_guard.requests', "action:$suite.action", "block:$throwAbortError", 'error:false')

Expand Down Expand Up @@ -366,14 +368,15 @@ class AIGuardInternalTests extends DDSpecification {
Map<String, Object> receivedMeta

when:
aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT)
final result = aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT)

then:
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> {
receivedMeta = it[1] as Map<String, Object>
return span
}
receivedMeta.sds == sdsFindings
result.sds == sdsFindings
}

void 'test evaluate with empty sds findings'() {
Expand All @@ -382,19 +385,41 @@ class AIGuardInternalTests extends DDSpecification {
Map<String, Object> receivedMeta

when:
aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT)
final result = aiguard.evaluate(PROMPT, AIGuard.Options.DEFAULT)

then:
1 * span.setMetaStruct(AIGuardInternal.META_STRUCT_TAG, _) >> {
receivedMeta = it[1] as Map<String, Object>
return span
}
!receivedMeta.containsKey('sds')
result.sds == (sdsFindings ?: [])

where:
sdsFindings << [null, []]
}

void 'test evaluate with sds findings in abort error'() {
given:
final sdsFindings = [
[
rule_display_name: 'Credit Card Number',
rule_tag: 'credit_card',
category: 'pii',
matched_text: '4111111111111111',
location: [start_index: 10, end_index_exclusive: 26, path: 'messages[0].content[0].text']
]
]
final aiguard = mockClient(200, [data: [attributes: [action: 'ABORT', reason: 'PII detected', tags: ['pii'], sds_findings: sdsFindings, is_blocking_enabled: true]]])

when:
aiguard.evaluate(PROMPT, new AIGuard.Options().block(true))

then:
final error = thrown(AIGuard.AIGuardAbortError)
error.sds == sdsFindings
}

void 'test missing tool name'() {
given:
final aiguard = mockClient(200, [data: [attributes: [action: 'ALLOW', reason: 'Just do it']]])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,15 @@ public static class AIGuardAbortError extends RuntimeException {
private final Action action;
private final String reason;
private final List<String> tags;
private final List<?> sds;

public AIGuardAbortError(final Action action, final String reason, final List<String> tags) {
public AIGuardAbortError(
final Action action, final String reason, final List<String> tags, final List<?> sds) {
super(reason);
this.action = action;
this.reason = reason;
this.tags = tags;
this.sds = sds != null ? sds : Collections.emptyList();
}

public Action getAction() {
Expand All @@ -88,6 +91,10 @@ public String getReason() {
public List<String> getTags() {
return tags;
}

public List<?> getSds() {
return sds;
}
}

/**
Expand Down Expand Up @@ -149,18 +156,22 @@ public static class Evaluation {
final Action action;
final String reason;
final List<String> tags;
final List<?> sds;

/**
* Creates a new evaluation result.
*
* @param action the recommended action for the evaluated content
* @param reason human-readable explanation for the decision
* @param tags list of tags associated with the evaluation (e.g. indirect-prompt-injection)
* @param sds list of Sensitive Data Scanner findings
*/
public Evaluation(final Action action, final String reason, final List<String> tags) {
public Evaluation(
final Action action, final String reason, final List<String> tags, final List<?> sds) {
this.action = action;
this.reason = reason;
this.tags = tags;
this.sds = sds != null ? sds : Collections.emptyList();
}

/**
Expand Down Expand Up @@ -189,6 +200,15 @@ public String getReason() {
public List<String> getTags() {
return tags;
}

/**
* Returns the list of Sensitive Data Scanner findings.
*
* @return list of SDS findings.
*/
public List<?> getSds() {
return sds;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ public final class NoOpEvaluator implements Evaluator {

@Override
public Evaluation evaluate(final List<Message> messages, final Options options) {
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList());
return new Evaluation(ALLOW, "AI Guard is not enabled", emptyList(), emptyList());
}
}