Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package dev.braintrust.eval;

import java.util.Map;
import javax.annotation.Nullable;

/**
* A single structured classification produced by a {@link Classifier}.
*
* <p>Unlike a {@link Score} (numeric 0-1), a Classification carries a stable id, an optional
* display label, and optional metadata. The {@code name} acts as the grouping key in the aggregated
* result map; when {@code name} is {@code null} or blank, the owning classifier's resolved name is
* used instead.
*
* @param name optional grouping key; defaults to the owning classifier's resolved name when null or
* blank
* @param id stable identifier for the classification (required)
* @param label optional display label
* @param metadata optional arbitrary metadata
*/
public record Classification(
@Nullable String name,
String id,
@Nullable String label,
@Nullable Map<String, Object> metadata) {

public static Classification of(String id) {
return new Classification(null, id, null, null);
}

public static Classification of(String id, String label) {
return new Classification(null, id, label, null);
}

public static Classification of(String name, String id, String label) {
return new Classification(name, id, label, null);
}
}
98 changes: 98 additions & 0 deletions braintrust-sdk/src/main/java/dev/braintrust/eval/Classifier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package dev.braintrust.eval;

import java.util.List;
import java.util.function.Function;

/**
* A classifier categorizes and labels eval outputs, producing zero or more structured {@link
* Classification} items.
*
* <p>Classifiers run independently from {@link Scorer}s. Each Classifier exposes a name (used as
* the span name and as the default grouping key for classifications whose own {@code name} is
* blank).
*
* @param <INPUT> type of the input data
* @param <OUTPUT> type of the output data
*/
public interface Classifier<INPUT, OUTPUT> {
String INVALID_CLASSIFICATION_MESSAGE =
"When returning structured classifier results, each classification must be a non-empty"
+ " object.";

String getName();

/**
* Classifies the result of a successful task execution.
*
* @param taskResult the task output and originating dataset case
* @return zero or more classifications. An empty list means "no classifications for this case".
*/
List<Classification> classify(TaskResult<INPUT, OUTPUT> taskResult);

/**
* Creates a classifier from a function that returns a (possibly empty or null) list of
* classifications.
*
* <p>A {@code null} return value is treated as no classifications. Each returned {@link
* Classification} must have a non-blank {@code id}; otherwise the classifier throws an
* exception (which the eval runner records but does not abort on).
*/
static <INPUT, OUTPUT> Classifier<INPUT, OUTPUT> of(
String classifierName,
Function<TaskResult<INPUT, OUTPUT>, List<Classification>> classifierFn) {
return new Classifier<>() {
@Override
public String getName() {
return classifierName;
}

@Override
public List<Classification> classify(TaskResult<INPUT, OUTPUT> taskResult) {
var result = classifierFn.apply(taskResult);
if (result == null) {
return List.of();
}
for (var item : result) {
validate(item);
}
return result;
}
};
}

/**
* Creates a classifier from a function that returns a single classification.
*
* <p>A {@code null} return value is treated as no classifications.
*/
static <INPUT, OUTPUT> Classifier<INPUT, OUTPUT> single(
String classifierName,
Function<TaskResult<INPUT, OUTPUT>, Classification> classifierFn) {
return new Classifier<>() {
@Override
public String getName() {
return classifierName;
}

@Override
public List<Classification> classify(TaskResult<INPUT, OUTPUT> taskResult) {
var item = classifierFn.apply(taskResult);
if (item == null) {
return List.of();
}
validate(item);
return List.of(item);
}
};
}

/**
* Validates a single classification: it must have a non-blank id. Throws with the spec-mandated
* wording on failure.
*/
private static void validate(Classification item) {
if (item == null || item.id() == null || item.id().isBlank()) {
throw new IllegalArgumentException(INVALID_CLASSIFICATION_MESSAGE + " Got: " + item);
}
}
}
127 changes: 125 additions & 2 deletions braintrust-sdk/src/main/java/dev/braintrust/eval/Eval.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public final class Eval<INPUT, OUTPUT> {
private final @Nonnull Dataset<INPUT, OUTPUT> dataset;
private final @Nonnull Task<INPUT, OUTPUT> task;
private final @Nonnull List<Scorer<INPUT, OUTPUT>> scorers;
private final @Nonnull List<Classifier<INPUT, OUTPUT>> classifiers;
private final @Nonnull List<String> tags;
private final @Nonnull Map<String, Object> metadata;
private final @Nonnull Parameters parameters;
Expand All @@ -58,6 +59,7 @@ private Eval(Builder<INPUT, OUTPUT> builder) {
this.dataset = builder.dataset;
this.task = Objects.requireNonNull(builder.task);
this.scorers = List.copyOf(builder.scorers);
this.classifiers = List.copyOf(builder.classifiers);
this.tags = List.copyOf(builder.tags);
this.metadata = Map.copyOf(builder.metadata);
this.parameters = builder.buildParameters();
Expand Down Expand Up @@ -172,6 +174,42 @@ private void evalOne(String experimentId, DatasetCase<INPUT, OUTPUT> datasetCase
for (var scorer : scorers) {
runScorer(experimentId, rootSpan, scorer, taskResult, trace);
}

// run classifiers - one span per classifier. Classifier exceptions are non-fatal:
// they are recorded on the classifier span and surfaced in the root span's metadata
// under `classifier_errors`, but do not abort the eval or affect other classifiers/
// scorers. Classifiers only run when the task succeeded (no scoreForTaskException
// analogue).
if (!classifiers.isEmpty()) {
Map<String, List<Map<String, Object>>> caseClassifications = new LinkedHashMap<>();
Map<String, String> classifierErrors = new LinkedHashMap<>();
for (int i = 0; i < classifiers.size(); i++) {
var classifier = classifiers.get(i);
var classifierName = classifier.getName();
if (classifierName == null || classifierName.isBlank()) {
classifierName = "classifier_" + i;
}
runClassifier(
experimentId,
classifier,
classifierName,
taskResult,
trace,
caseClassifications,
classifierErrors);
}
if (!caseClassifications.isEmpty()) {
rootSpan.setAttribute(
"braintrust.classifications", toJson(caseClassifications));
}
if (!classifierErrors.isEmpty()) {
Map<String, Object> mergedMetadata =
new LinkedHashMap<>(datasetCase.metadata());
mergedMetadata.put("classifier_errors", classifierErrors);
rootSpan.setAttribute(
AttributeKey.stringKey("braintrust.metadata"), toJson(mergedMetadata));
}
}
} finally {
rootSpan.end();
}
Expand Down Expand Up @@ -236,6 +274,84 @@ private void runScoreForTaskException(
}
}

/**
* Runs a classifier inside its own span. Exceptions are recorded on the classifier span and
* surfaced via {@code classifierErrors}; they do not propagate.
*/
private void runClassifier(
String experimentId,
Classifier<INPUT, OUTPUT> classifier,
String resolvedName,
TaskResult<INPUT, OUTPUT> taskResult,
BrainstoreTrace trace,
Map<String, List<Map<String, Object>>> caseClassifications,
Map<String, String> classifierErrors) {
var classifierSpan =
tracer.spanBuilder(resolvedName)
.setAttribute(PARENT, "experiment_id:" + experimentId)
.startSpan();
try (var unused =
BraintrustContext.ofExperiment(experimentId, classifierSpan).makeCurrent()) {
Map<String, Object> spanAttrs = new LinkedHashMap<>();
spanAttrs.put("type", "classifier");
spanAttrs.put("name", resolvedName);
spanAttrs.put("purpose", "scorer");
classifierSpan.setAttribute("braintrust.span_attributes", toJson(spanAttrs));

List<Classification> classifications;
try {
if (classifier instanceof TracedClassifier<INPUT, OUTPUT> tracedClassifier) {
classifications = tracedClassifier.classify(taskResult, trace);
} else {
classifications = classifier.classify(taskResult);
}
if (classifications == null) {
classifications = List.of();
}
} catch (Exception e) {
classifierSpan.setStatus(StatusCode.ERROR, e.getMessage());
classifierSpan.recordException(e);
log.debug("Classifier '{}' threw exception", resolvedName, e);
classifierErrors.put(
resolvedName, e.getMessage() == null ? e.toString() : e.getMessage());
return;
}

// Group results by resolved item name (item.name, falling back to the classifier
// name when blank). Same map is logged to the classifier span and merged into the
// per-case aggregate logged on the root span.
Map<String, List<Map<String, Object>>> outputByName = new LinkedHashMap<>();
for (var item : classifications) {
var itemName = item.name();
if (itemName == null || itemName.isBlank()) {
Comment on lines +324 to +326
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Keep classifier post-processing failures non-fatal

runClassifier only catches exceptions thrown by classifier.classify(...), but it processes returned items outside that try/catch. A custom Classifier implementation (which this commit explicitly supports) can return a list containing null or otherwise malformed items, and item.name() will throw here, escaping runClassifier and aborting the eval case. That breaks the intended contract that classifier failures are non-fatal and should be recorded under classifier_errors instead.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is valid. A classifier returning a list with null values seems like a contract breech. We could explicitly doc that you're not allowed to do this, but that seems so unlikely I wouldn't say it's necessary

itemName = resolvedName;
}
var itemMap = toClassificationItem(item);
outputByName.computeIfAbsent(itemName, k -> new ArrayList<>()).add(itemMap);
caseClassifications.computeIfAbsent(itemName, k -> new ArrayList<>()).add(itemMap);
}
classifierSpan.setAttribute("braintrust.output_json", toJson(outputByName));
} finally {
classifierSpan.end();
}
}

/**
* Converts a {@link Classification} to the wire-format {@code ClassificationItem}: drops {@code
* name}, includes {@code label} and {@code metadata} only when present.
*/
private static Map<String, Object> toClassificationItem(Classification c) {
Map<String, Object> m = new LinkedHashMap<>();
m.put("id", c.id());
if (c.label() != null) {
m.put("label", c.label());
}
if (c.metadata() != null) {
m.put("metadata", c.metadata());
}
return m;
}

/** Validates and records scores on the score span and root span. */
private void recordScores(
Span scoreSpan, Span rootSpan, Scorer<INPUT, OUTPUT> scorer, List<Score> scores) {
Expand Down Expand Up @@ -276,6 +392,7 @@ public static final class Builder<INPUT, OUTPUT> {
private @Nullable Tracer tracer = null;
private @Nullable Task<INPUT, OUTPUT> task;
private @Nonnull List<Scorer<INPUT, OUTPUT>> scorers = List.of();
private @Nonnull List<Classifier<INPUT, OUTPUT>> classifiers = List.of();
private @Nonnull List<ParameterDef<?>> parameterDefs = List.of();
private @Nonnull Map<String, Object> parameterValues = Map.of();
private @Nonnull List<String> tags = List.of();
Expand All @@ -291,8 +408,8 @@ public Eval<INPUT, OUTPUT> build() {
if (projectId == null) {
projectId = config.defaultProjectId().orElse(null);
}
if (scorers.isEmpty()) {
throw new RuntimeException("must provide at least one scorer");
if (scorers.isEmpty() && classifiers.isEmpty()) {
throw new RuntimeException("must provide at least one scorer or classifier");
}
if (null == apiClient) {
apiClient = BraintrustOpenApiClient.of(config);
Expand Down Expand Up @@ -380,6 +497,12 @@ public final Builder<INPUT, OUTPUT> scorers(Scorer<INPUT, OUTPUT>... scorers) {
return this;
}

@SafeVarargs
public final Builder<INPUT, OUTPUT> classifiers(Classifier<INPUT, OUTPUT>... classifiers) {
this.classifiers = List.of(classifiers);
return this;
}

/** Sets tags for the experiment. */
public Builder<INPUT, OUTPUT> tags(List<String> tags) {
this.tags = List.copyOf(tags);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package dev.braintrust.eval;

import dev.braintrust.trace.BrainstoreTrace;
import java.util.List;

/**
* A classifier that receives access to the full distributed trace of the task that was evaluated.
*
* <p>Implement this interface when your classifier needs to examine intermediate LLM calls, tool
* invocations, or other spans produced during task execution — not just the final {@link
* TaskResult}.
*
* @param <INPUT> type of the input data
* @param <OUTPUT> type of the output data
*/
public interface TracedClassifier<INPUT, OUTPUT> extends Classifier<INPUT, OUTPUT> {

/**
* Classifies the task result using the distributed trace for additional context. Called instead
* of {@link Classifier#classify(TaskResult)} when a {@link BrainstoreTrace} is available.
*
* @param taskResult the task output and originating dataset case
* @param trace lazy access to the distributed trace spans for this eval case
* @return zero or more classifications
*/
List<Classification> classify(TaskResult<INPUT, OUTPUT> taskResult, BrainstoreTrace trace);

/**
* {@inheritDoc}
*
* <p>When used inside an {@link Eval}, this overload is never called — {@link
* #classify(TaskResult, BrainstoreTrace)} is dispatched instead. This default implementation
* throws {@link UnsupportedOperationException} to surface any accidental direct calls.
*/
@Override
default List<Classification> classify(TaskResult<INPUT, OUTPUT> taskResult) {
throw new UnsupportedOperationException(
"traced classifier classify method directly called. This is likely an accident. If"
+ " you wish to support this, your implementation must override this method.");
}
}
Loading
Loading