diff --git a/docs/examples/simbauq/README.md b/docs/examples/simbauq/README.md new file mode 100644 index 000000000..6d4e89f7a --- /dev/null +++ b/docs/examples/simbauq/README.md @@ -0,0 +1,270 @@ +# SIMBA-UQ Sampling Strategy + +Confidence-aware sample selection using the SIMBA-UQ framework +(Bhattacharjya et al., 2025). Generates multiple samples across a range of +temperatures and selects the one with the highest estimated confidence. + +**Paper:** [SIMBA UQ: Similarity-Based Aggregation for Uncertainty Quantification in Large Language Models](https://arxiv.org/abs/2510.13836) + +## Files + +### simbauq_example.py + +Complete example demonstrating all four confidence estimation variants with +Ollama and granite4:micro: + +1. **Aggregation** — data-free, no training data required. +2. **Classifier (synthetic)** — trained on hand-coded labeled groups. +3. **Classifier (HF data)** — training data generated live from a Hugging Face + dataset via Ollama. Calls `generate_training_data()` which streams items + from TriviaQA or SAMSum, generates `len(temperatures) * n_per_temp` responses + per item at the configured temperature schedule, and labels each response by + similarity to the ground-truth reference. Groups where all labels are + identical are discarded. Requires `pip install datasets`. +4. **Classifier (pre-trained)** — same HF-generated training data, but the + `RandomForestClassifier` is trained externally via `train_classifier()` and + passed directly to `SIMBAUQSamplingStrategy` via the `classifier=` argument. + Useful when you want to persist, inspect, or swap the classifier independently + of the sampling strategy. + +## Running the example + +``` +ollama serve +uv run python docs/examples/simbauq/simbauq_example.py +``` + +The script runs against the demo query +`"Which magazine was started first Arthur's Magazine or First for Women?"`. + +Which variant runs is controlled by the **CONFIG block** at the top of the +script (immediately below the imports). Edit the values in place: + +| Variable | Purpose | Allowed values | +|----------|---------|----------------| +| `EXAMPLE` | Which variant(s) to run | `"aggregation"`, `"synthetic"`, `"hf"`, `"pretrained"`, `"all"` | +| `DATASET` | HF dataset for the `hf` and `pretrained` variants | `"triviaqa"`, `"samsum"` | +| `METRIC` | Pairwise similarity metric (used by both the strategy and HF labelling) | `"rouge"`, `"jaccard"`, `"sbert"`, `"difflib"`, `"levenshtein"` | +| `AGGREGATION` | Aggregation function for the `aggregation` confidence method | `"mean"`, `"geometric_mean"`, `"harmonic_mean"`, `"median"`, `"max"`, `"min"` | +| `THRESHOLD` | Similarity score above which an HF-generated response is labelled correct (1) | float; tune per metric/dataset | + +The shipped defaults are `EXAMPLE="aggregation"`, `DATASET="triviaqa"`, +`METRIC="sbert"`, `AGGREGATION="mean"`, `THRESHOLD=0.2`. Set `EXAMPLE="all"` +to run all four variants sequentially. + +Reasonable starting points for `THRESHOLD`: + +- `sbert` + `triviaqa`: 0.5 - 0.7 +- `sbert` + `samsum`: 0.2 - 0.4 +- `rouge` + `triviaqa`: 0.3 - 0.5 + +Too strict drops every group; too loose makes every response "correct" and +the classifier sees no negatives. Groups where every response receives the +same label are discarded automatically. + +The number of HF training groups is controlled by the module-level constant +`N_TRAINING_GROUPS=5` — increase for stronger classifier signal at the cost +of more LLM calls. + +## Architecture + +``` +User Query + | + v +Generate N samples (across temperatures) + | + v +Compute pairwise similarity matrix (N x N) + | + +---> [Aggregation] Aggregate similarities per sample -> confidence + | + +---> [Classifier] Extract features per sample -> RF predicts P(correct) + | + v +Select sample with highest confidence + | + v +Result (with confidence metadata in mot.meta["simba_uq"]) +``` + +## Confidence Methods + +### 1. Aggregation (data-free) + +No training data required. For each sample, computes its similarity to every +other sample, then aggregates those values into a confidence score. Samples +that are more similar to the majority get higher confidence. + +```python +from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy + +strategy = SIMBAUQSamplingStrategy( + temperatures=[0.3, 0.5, 0.7, 1.0], + n_per_temp=3, + similarity_metric="sbert", + confidence_method="aggregation", + aggregation="mean", +) + +result = m.instruct("Your query here", strategy=strategy, return_sampling_results=True) +``` + +### 2. Classifier (trained) + +Uses a random forest classifier trained on labeled examples. The classifier +learns to predict P(correct) from pairwise similarity features. Provide +either training data or a pre-trained sklearn classifier. + +Each training group must have exactly `len(temperatures) * n_per_temp` samples +so the feature vectors match at inference time. + +**Option A — synthetic training data:** + +With `temperatures=[0.3, 0.5, 0.7, 1.0]` and `n_per_temp=3`, each group must +contain exactly 12 samples (4 temperatures × 3 per temp). See +`run_classifier_synthetic_example()` in `simbauq_example.py` for a complete, +hand-coded example. + +```python +strategy = SIMBAUQSamplingStrategy( + temperatures=[0.3, 0.5, 0.7, 1.0], + n_per_temp=3, + similarity_metric="rouge", + confidence_method="classifier", + training_samples=[ + ["correct answer 1", "correct answer 2", ..., "wrong answer"], # group 1 (12 samples) + ["correct answer 1", "correct answer 2", ..., "wrong answer"], # group 2 (12 samples) + ], + training_labels=[ + [1, 1, ..., 0], # labels for group 1 (12 entries) + [1, 1, ..., 0], # labels for group 2 (12 entries) + ], +) +``` + +**Option B — HF-generated training data (requires `pip install datasets`):** + +`generate_training_data()` in `simbauq_example.py` streams items from a HF +dataset, generates responses at each temperature, and labels them by similarity +to the ground-truth reference. Supported datasets: `"triviaqa"` (short QA) +and `"samsum"` (dialogue summarization). + +```python +from simbauq_example import generate_training_data, make_session + +m = make_session() +temperatures = [0.3, 0.5, 0.7, 1.0] +n_per_temp = 3 + +training_samples, training_labels = generate_training_data( + m, + temperatures, + n_per_temp, + dataset="triviaqa", # or "samsum" + similarity_metric="rouge", + threshold=0.5, # similarity >= threshold → label 1 +) + +strategy = SIMBAUQSamplingStrategy( + temperatures=temperatures, + n_per_temp=n_per_temp, + similarity_metric="rouge", + confidence_method="classifier", + training_samples=training_samples, + training_labels=training_labels, +) +``` + +Groups where all responses receive the same label are discarded automatically. +If no valid groups are collected, lower the `threshold` (scores are too low) +or raise it (all responses score above threshold). + +**Option C — pre-trained classifier:** + +Train the classifier externally with `train_classifier()`, then pass the fitted +object via `classifier=`. The feature extraction reuses +`SIMBAUQSamplingStrategy._compute_similarity_matrix` and `_extract_features` +internally, so the feature space is identical to what the strategy uses at +inference time. + +```python +from simbauq_example import generate_training_data, train_classifier, make_session + +m = make_session() +temperatures = [0.3, 0.5, 0.7, 1.0] +n_per_temp = 3 + +training_samples, training_labels = generate_training_data( + m, temperatures, n_per_temp, dataset="triviaqa", similarity_metric="rouge", threshold=0.5 +) + +clf = train_classifier(training_samples, training_labels, similarity_metric="rouge") + +strategy = SIMBAUQSamplingStrategy( + temperatures=temperatures, + n_per_temp=n_per_temp, + similarity_metric="rouge", + confidence_method="classifier", + classifier=clf, +) +``` + +## Constructor Parameters + +| Parameter | Type | Default | Description | +|-----------|------|---------|-------------| +| `temperatures` | `list[float]` | `[0.3, 0.5, 0.7, 1.0]` | Temperature values to sample at | +| `n_per_temp` | `int` | `4` | Number of samples per temperature | +| `similarity_metric` | `"rouge"`, `"jaccard"`, `"sbert"`, `"difflib"`, `"levenshtein"` | `"rouge"` | Pairwise similarity metric | +| `confidence_method` | `"aggregation"`, `"classifier"` | `"aggregation"` | Confidence estimation method | +| `aggregation` | `"mean"`, `"geometric_mean"`, `"harmonic_mean"`, `"median"`, `"max"`, `"min"` | `"mean"` | Aggregation function (for `aggregation` method) | +| `classifier` | sklearn classifier | `None` | Pre-trained classifier with `predict_proba` | +| `training_samples` | `list[list[str]]` | `None` | Training data for classifier | +| `training_labels` | `list[list[int]]` | `None` | Binary correctness labels (0/1) | +| `clf_max_depth` | `int` | `4` | Max tree depth for random forest | +| `rouge_type` | `str` | `"rougeL"` | Rouge variant | +| `sbert_model` | `str` | `"all-MiniLM-L6-v2"` | Sentence-BERT model name | +| `requirements` | `list[Requirement]` | `None` | Requirements to validate the selected sample | + +## Similarity Metrics + +- **rouge** (default): RougeL F-measure. Good general-purpose text similarity. + No extra dependencies beyond `rouge-score` (already in Mellea). +- **jaccard**: Word-level set overlap (intersection / union). Fast, no + external dependencies, works well for short structured answers. +- **sbert**: Cosine similarity of Sentence-BERT embeddings. Best semantic + similarity but requires `sentence-transformers`. +- **difflib**: `difflib.SequenceMatcher` ratio. Character-level similarity + from the Python standard library; no extra dependencies. +- **levenshtein**: Normalized Levenshtein edit distance (`1 - dist / max_len`). + Exact character-level metric; no extra dependencies. + +## Inspecting Results + +The selected sample's `ModelOutputThunk` stores confidence metadata: + +```python +result = m.instruct(..., strategy=strategy, return_sampling_results=True) + +# Best sample +best_mot = result.result +meta = best_mot._meta["simba_uq"] + +meta["confidence"] # float: confidence of the selected sample +meta["all_confidences"] # list[float]: confidence for every sample +meta["similarity_matrix"] # list[list[float]]: N x N pairwise similarity matrix +meta["temperatures_used"] # list[float]: temperature used for each sample +meta["confidence_method"] # "aggregation" or "classifier" +meta["similarity_metric"] # "rouge", "jaccard", "sbert", "difflib", or "levenshtein" +meta["aggregation"] # aggregation function name + +# All generated samples +for i, mot in enumerate(result.sample_generations): + print(f"Sample {i}: {mot.value}") +``` + +## Related Files + +- `mellea/stdlib/sampling/simbauq.py` -- Strategy implementation +- `test/stdlib/sampling/test_simbauq.py` -- Unit and integration tests diff --git a/docs/examples/simbauq/simbauq_example.py b/docs/examples/simbauq/simbauq_example.py new file mode 100644 index 000000000..0ae4d801e --- /dev/null +++ b/docs/examples/simbauq/simbauq_example.py @@ -0,0 +1,647 @@ +# pytest: ollama, llm, qualitative + +"""SIMBA-UQ Sampling Strategy Example. + +This example demonstrates the SIMBAUQSamplingStrategy using both confidence +estimation methods: + +1. **Aggregation** (data-free) - Computes pairwise similarity between all + generated samples and aggregates them into per-sample confidence scores. + The sample with the highest confidence is selected. + +2. **Classifier with synthetic data** - Uses a random forest classifier + trained on hand-coded labeled examples. + +3. **Classifier with HF data** - Same classifier method, but training data + is generated live via Ollama from one of two supported HF datasets: + ``"triviaqa"`` (short factoid QA) or ``"samsum"`` (dialogue summarization). + No other datasets are supported by ``generate_training_data()``. + +4. **Classifier with pre-trained classifier** - Trains a RandomForestClassifier + externally using HF-generated data, then passes the fitted object directly + to SIMBAUQSamplingStrategy via the ``classifier=`` argument. + +All variants generate multiple samples across different temperature values, +compute a pairwise similarity matrix, and select the most confident response. + +Available similarity metrics (set via ``METRIC`` in the CONFIG block below): +``"rouge"``, ``"jaccard"``, ``"sbert"``, ``"difflib"``, ``"levenshtein"``. + +To control which example runs, edit the CONFIG block immediately below the +imports — set ``EXAMPLE``, ``DATASET``, ``METRIC``, and ``THRESHOLD`` there. + +The example uses OllamaModelBackend with granite4:micro. To run: + + ollama serve + uv run python docs/examples/simbauq/simbauq_example.py +""" + +import logging +from typing import Literal + +import numpy as np +from sklearn.ensemble import RandomForestClassifier # type: ignore[import-not-found] +from tqdm import tqdm + +from mellea import MelleaSession +from mellea.backends import ModelOption +from mellea.backends.ollama import OllamaModelBackend +from mellea.core import FancyLogger, SamplingResult +from mellea.stdlib.context import ChatContext +from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy + +# ============================================================================ +# CONFIG — edit these to control which example(s) run. +# ============================================================================ + +# Which example(s) to run. +# "aggregation" — data-free similarity aggregation +# "synthetic" — classifier with hand-coded labeled groups +# "hf" — classifier trained on data generated from an HF dataset +# "pretrained" — classifier trained externally, passed in via `classifier=` +# "all" — run all four sequentially +EXAMPLE = "aggregation" + +# HF dataset for the `hf` and `pretrained` examples. Only two datasets are +# supported by `generate_training_data()`: +# "triviaqa" — short factoid QA (rc.nocontext split) +# "samsum" — dialogue summarization +DATASET = "triviaqa" + +# Pairwise similarity metric. Used by both the strategy and the labelling +# pass in `generate_training_data()`. Available metrics: +# "rouge" — RougeL F-measure (default; needs `rouge-score`) +# "jaccard" — word-level set overlap, fast, no extra deps +# "sbert" — Sentence-BERT cosine; needs `sentence-transformers` +# "difflib" — `difflib.SequenceMatcher` ratio; stdlib only +# "levenshtein" — normalized edit distance; stdlib only +METRIC: Literal["rouge", "jaccard", "sbert", "difflib", "levenshtein"] = "sbert" + +# Aggregation method for the "aggregation" confidence method. Options: +# "mean" — Arithmetic mean (default) +# "geometric_mean" — Geometric mean +# "harmonic_mean" — Harmonic mean +# "median" — Median +# "max" — Maximum +# "min" — Minimum +AGGREGATION: Literal[ + "mean", "geometric_mean", "harmonic_mean", "median", "max", "min" +] = "mean" + +# Similarity threshold for labelling generated responses against the HF +# reference: score >= THRESHOLD → label 1 (correct), else 0. Tune per +# dataset/metric/model — too strict drops every group, too loose makes every +# response "correct" and the classifier sees no negatives. Groups where every +# response receives the same label are discarded automatically. +# Reasonable starting points: +# sbert + triviaqa: 0.5 - 0.7 +# sbert + samsum: 0.2 - 0.4 +# rouge + triviaqa: 0.3 - 0.5 +THRESHOLD = 0.2 + +# ============================================================================ + +# Allowed values, used for CONFIG validation in `main()`. +_VALID_EXAMPLES = ("aggregation", "synthetic", "hf", "pretrained", "all") +_VALID_DATASETS = ("triviaqa", "samsum") +_VALID_METRICS = ("rouge", "jaccard", "sbert", "difflib", "levenshtein") +_VALID_AGGREGATIONS = ( + "mean", + "geometric_mean", + "harmonic_mean", + "median", + "max", + "min", +) + +# Number of training groups collected per dataset. +# Each group has len(temperatures) * n_per_temp samples. +# Increase for better classifier signal at the cost of more LLM calls. +N_TRAINING_GROUPS = 5 + + +def make_session() -> MelleaSession: + """Create a MelleaSession with OllamaModelBackend.""" + FancyLogger.get_logger().setLevel(logging.WARNING) + backend = OllamaModelBackend(model_options={ModelOption.MAX_NEW_TOKENS: 150}) + return MelleaSession(backend, ctx=ChatContext()) + + +def print_results(result: SamplingResult) -> None: + """Print detailed results from a SIMBA-UQ sampling run.""" + meta = result.result._meta["simba_uq"] + confidences = meta["all_confidences"] + temperatures = meta["temperatures_used"] + sim_matrix = np.array(meta["similarity_matrix"]) + + # --- Best response --- + print("=" * 70) + print("BEST RESPONSE") + print("=" * 70) + print(f" Index: {result.result_index}") + print(f" Confidence: {meta['confidence']:.4f}") + print(f" Method: {meta['confidence_method']}") + print(f" Metric: {meta['similarity_metric']}") + print(f" Aggregation: {meta['aggregation']}") + print(f" Text:\n {result.result!s}") + print() + + # --- All samples --- + print("=" * 70) + print("ALL SAMPLES") + print("=" * 70) + print(f"{'Idx':>4} {'Temp':>5} {'Conf':>8} {'Text'}") + print("-" * 70) + for i, mot in enumerate(result.sample_generations): + text = str(mot).replace("\n", " ") + truncated = (text[:100] + "...") if len(text) > 100 else text + marker = " <-- best" if i == result.result_index else "" + print( + f"{i:>4} {temperatures[i]:>5.2f} {confidences[i]:>8.4f} " + f"{truncated}{marker}" + ) + print() + + # --- Similarity matrix --- + n = sim_matrix.shape[0] + print("=" * 70) + print("SIMILARITY MATRIX") + print("=" * 70) + header = " " + "".join(f" [{i:>2}] " for i in range(n)) + print(header) + for i in range(n): + row = f"[{i:>2}] " + "".join(f" {sim_matrix[i, j]:.3f} " for j in range(n)) + print(row) + print() + + +def generate_training_data( + session: MelleaSession, + temperatures: list[float], + n_per_temp: int, + dataset: str = "triviaqa", + similarity_metric: Literal[ + "rouge", "jaccard", "sbert", "difflib", "levenshtein" + ] = "rouge", + threshold: float = 0.5, + n_groups: int = N_TRAINING_GROUPS, +) -> tuple[list[list[str]], list[list[int]]]: + """Generate classifier training data from a single HF dataset via Ollama. + + For each dataset item, generates one group of ``len(temperatures) * + n_per_temp`` responses at the configured temperature schedule. Each + response is labelled 1 if its similarity to the ground-truth reference + meets ``threshold``, 0 otherwise. Groups where all labels are identical + are discarded as they provide no classifier signal. + + Args: + session: Active MelleaSession to use for generation. + temperatures: Temperature schedule (must match inference-time schedule). + n_per_temp: Samples per temperature (must match inference-time value). + dataset: HF dataset to use. One of ``"triviaqa"`` (short QA) or + ``"samsum"`` (dialogue summarization). + similarity_metric: Metric used for labelling (should match the + ``similarity_metric`` passed to SIMBAUQSamplingStrategy). + threshold: Similarity score >= threshold → label 1. + n_groups: Target number of valid groups to collect. + + Returns: + Tuple of (training_samples, training_labels), each a list of groups + with exactly ``len(temperatures) * n_per_temp`` entries per group. + """ + if dataset not in _VALID_DATASETS: + raise ValueError( + f"Unknown dataset {dataset!r}. Supported: {list(_VALID_DATASETS)}." + ) + + try: + from datasets import load_dataset # type: ignore[import-not-found] + except ImportError: + raise ImportError( + "The 'datasets' package is required for HF training data generation. " + "Install it with: pip install datasets" + ) + + group_size = len(temperatures) * n_per_temp + print(f"Generating training data with {group_size} samples per group ") + + # Reused solely for _compute_similarity — no training data needed at init. + scorer = SIMBAUQSamplingStrategy( + similarity_metric=similarity_metric, confidence_method="aggregation" + ) + + def _load_triviaqa(n: int) -> list[dict]: + ds = load_dataset("trivia_qa", "rc.nocontext", split="train", streaming=True) + items = [] + for row in ds: + ref = row.get("answer", {}).get("value", "") + if not ref: + continue + items.append( + { + "prompt": f"Answer the following question briefly: {row['question']}", + "reference": ref, + } + ) + if len(items) >= n: + break + return items + + def _load_samsum(n: int) -> list[dict]: + ds = load_dataset("samsum", split="train", streaming=True) + items = [] + for row in ds: + dialogue = row.get("dialogue", "")[:1000] + ref = row.get("summary", "") + if not dialogue or not ref: + continue + items.append( + { + "prompt": f"Summarize the following dialogue in one sentence:\n\n{dialogue}", + "reference": ref, + } + ) + if len(items) >= n: + break + return items + + loaders = {"triviaqa": _load_triviaqa, "samsum": _load_samsum} + print(f" Collecting {n_groups} training groups from {dataset}...") + items = loaders[dataset](n_groups * 3) + + training_samples: list[list[str]] = [] + training_labels: list[list[int]] = [] + collected = 0 + + with tqdm(total=n_groups, desc=f"Generating [{dataset}]", unit="group") as pbar: + for item in items: + if collected >= n_groups: + break + + responses: list[str] = [] + for temp in temperatures: + for _ in range(n_per_temp): + try: + mot = session.instruct( + item["prompt"], + model_options={ + ModelOption.TEMPERATURE: temp, + ModelOption.MAX_NEW_TOKENS: 150, + }, + ) + responses.append(str(mot)) + except Exception: + responses.append("") + + scores = [ + scorer._compute_similarity(r, item["reference"]) for r in responses + ] + labels = [1 if s >= threshold else 0 for s in scores] + + print( + "Responses:\n " + + "\n ".join( + f"{i}. {r!r} (score={s:.4f}, label={label})" + for i, (r, s, label) in enumerate(zip(responses, scores, labels)) + ) + ) + if len(set(labels)) < 2: + pbar.set_postfix_str( + f"discarded (scores {min(scores):.2f}–{max(scores):.2f}, threshold={threshold})" + ) + continue + + training_samples.append(responses) + training_labels.append(labels) + collected += 1 + pbar.update(1) + + return training_samples, training_labels + + +def train_classifier( + training_samples: list[list[str]], + training_labels: list[list[int]], + similarity_metric: Literal[ + "rouge", "jaccard", "sbert", "difflib", "levenshtein" + ] = "rouge", + clf_max_depth: int = 4, +) -> RandomForestClassifier: + """Train a RandomForestClassifier on similarity features extracted from training data. + + Uses the same feature extraction as SIMBAUQSamplingStrategy internally + (_compute_similarity_matrix + _extract_features), ensuring the feature + space is identical at train and inference time. + + Args: + training_samples: List of groups, each with the same number of samples + as ``len(temperatures) * n_per_temp`` used at inference time. + training_labels: Binary correctness labels (0/1) matching + ``training_samples``. + similarity_metric: Similarity metric for feature extraction. + clf_max_depth: Maximum tree depth for the random forest. + + Returns: + Fitted RandomForestClassifier. + """ + extractor = SIMBAUQSamplingStrategy( + similarity_metric=similarity_metric, confidence_method="aggregation" + ) + + x_train: list[np.ndarray] = [] + y_train: list[int] = [] + for samples, labels in zip(training_samples, training_labels): + sim_matrix = extractor._compute_similarity_matrix(samples) + for i, label in enumerate(labels): + x_train.append(extractor._extract_features(sim_matrix, i)) + y_train.append(label) + + clf = RandomForestClassifier(max_depth=clf_max_depth, random_state=0) + clf.fit(x_train, y_train) + return clf + + +def run_aggregation_example( + session: MelleaSession, + similarity_metric: Literal["rouge", "jaccard", "sbert", "difflib", "levenshtein"], + aggregation: Literal[ + "mean", "geometric_mean", "harmonic_mean", "median", "max", "min" + ], +) -> None: + """Run SIMBA-UQ with data-free similarity aggregation.""" + print("\n>>> AGGREGATION CONFIDENCE METHOD <<<\n") + + strategy = SIMBAUQSamplingStrategy( + temperatures=[0.3, 0.5, 0.7, 1.0], + n_per_temp=3, + similarity_metric=similarity_metric, + confidence_method="aggregation", + aggregation=aggregation, + ) + + result: SamplingResult = session.instruct( + "Which magazine was started first Arthur's Magazine or First for Women?", + strategy=strategy, + return_sampling_results=True, + ) + + print(f"Total samples generated: {len(result.sample_generations)}") + print_results(result) + + +def run_classifier_synthetic_example( + session: MelleaSession, + similarity_metric: Literal["rouge", "jaccard", "sbert", "difflib", "levenshtein"], +) -> None: + """Run SIMBA-UQ classifier with hand-coded synthetic training data.""" + print("\n>>> CLASSIFIER CONFIDENCE METHOD (synthetic training data) <<<\n") + + temperatures = [0.3, 0.5, 0.7, 1.0] + n_per_temp = 3 + + # Synthetic training data: 3 groups of 12 samples (4 temps * 3 per temp). + # Each group has mostly "correct" similar answers and a few outliers. + training_samples = [ + [ + "Paris is the capital of France.", + "The capital of France is Paris.", + "France's capital city is Paris.", + "Paris, the capital of France.", + "The capital city of France is Paris.", + "France has Paris as its capital.", + "Paris serves as France's capital.", + "In France, Paris is the capital.", + "The French capital is Paris.", + "Bananas are a yellow fruit.", + "Dogs are loyal pets.", + "The ocean is very deep.", + ], + [ + "Water boils at 100 degrees Celsius.", + "At 100C water reaches boiling point.", + "The boiling point of water is 100 degrees.", + "Water boils when heated to 100C.", + "100 degrees Celsius is water's boiling point.", + "Boiling occurs at 100C for water.", + "Water starts boiling at one hundred degrees.", + "At 100 degrees water boils.", + "The temperature for boiling water is 100C.", + "Cats like to sleep a lot.", + "Mountains can be very high.", + "Stars shine in the night sky.", + ], + [ + "Python is a programming language.", + "Python is a popular programming language.", + "The Python programming language is widely used.", + "Python is used for programming.", + "Programming in Python is common.", + "Python is a well-known language for coding.", + "Many developers use Python.", + "Python is a general-purpose language.", + "The language Python is popular.", + "Pizza originated in Italy.", + "Rain falls from clouds.", + "Books contain many pages.", + ], + ] + training_labels = [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + ] + + strategy = SIMBAUQSamplingStrategy( + temperatures=temperatures, + n_per_temp=n_per_temp, + similarity_metric=similarity_metric, + confidence_method="classifier", + training_samples=training_samples, + training_labels=training_labels, + ) + + result: SamplingResult = session.instruct( + "Which magazine was started first Arthur's Magazine or First for Women?", + strategy=strategy, + return_sampling_results=True, + ) + + print(f"Total samples generated: {len(result.sample_generations)}") + print_results(result) + + +def run_classifier_hf_example( + session: MelleaSession, + dataset: str, + similarity_metric: Literal["rouge", "jaccard", "sbert", "difflib", "levenshtein"], + threshold: float, +) -> None: + """Run SIMBA-UQ classifier with training data generated from an HF dataset.""" + print(f"\n>>> CLASSIFIER CONFIDENCE METHOD (HF / {dataset}) <<<\n") + + temperatures = [0.3, 0.5, 0.7, 1.0] + n_per_temp = 3 + + print(f"Generating training data from {dataset}...") + training_samples, training_labels = generate_training_data( + session, + temperatures, + n_per_temp, + dataset=dataset, + similarity_metric=similarity_metric, + threshold=threshold, + ) + print( + f"Training data ready: {len(training_samples)} groups of {len(temperatures) * n_per_temp} samples each.\n" + ) + + if not training_samples: + print( + f" No valid training groups collected (threshold={threshold} may be " + "too strict or too loose for this model/dataset combination). " + "Try adjusting --threshold." + ) + return + + print("--- Training examples sample ---") + for group_idx, (samples, labels) in enumerate( + zip(training_samples, training_labels) + ): + correct = [s for s, lab in zip(samples, labels) if lab == 1] + incorrect = [s for s, lab in zip(samples, labels) if lab == 0] + print(f" Group {group_idx}:") + if correct: + print(f" [correct] {correct[0]!r}") + if incorrect: + print(f" [incorrect] {incorrect[0]!r}") + print() + + strategy = SIMBAUQSamplingStrategy( + temperatures=temperatures, + n_per_temp=n_per_temp, + similarity_metric=similarity_metric, + confidence_method="classifier", + training_samples=training_samples, + training_labels=training_labels, + ) + + result: SamplingResult = session.instruct( + "Which magazine was started first Arthur's Magazine or First for Women?", + strategy=strategy, + return_sampling_results=True, + ) + + print(f"Total samples generated: {len(result.sample_generations)}") + print_results(result) + + +def run_classifier_pretrained_example( + session: MelleaSession, + dataset: str, + similarity_metric: Literal["rouge", "jaccard", "sbert", "difflib", "levenshtein"], + threshold: float, +) -> None: + """Run SIMBA-UQ classifier with a pre-trained RandomForestClassifier.""" + print(f"\n>>> CLASSIFIER CONFIDENCE METHOD (pre-trained / {dataset}) <<<\n") + + temperatures = [0.3, 0.5, 0.7, 1.0] + n_per_temp = 3 + + print(f"Generating training data from {dataset}...") + training_samples, training_labels = generate_training_data( + session, + temperatures, + n_per_temp, + dataset=dataset, + similarity_metric=similarity_metric, + threshold=threshold, + ) + print( + f"Training data ready: {len(training_samples)} groups of {len(temperatures) * n_per_temp} samples each.\n" + ) + + if not training_samples: + print( + f" No valid training groups collected (threshold={threshold} may be " + "too strict or too loose for this model/dataset combination). " + "Try adjusting the threshold." + ) + return + + print("--- Training examples sample ---") + for group_idx, (samples, labels) in enumerate( + zip(training_samples, training_labels) + ): + correct = [s for s, lab in zip(samples, labels) if lab == 1] + incorrect = [s for s, lab in zip(samples, labels) if lab == 0] + print(f" Group {group_idx}:") + if correct: + print(f" [correct] {correct[0]!r}") + if incorrect: + print(f" [incorrect] {incorrect[0]!r}") + print() + + clf = train_classifier( + training_samples, training_labels, similarity_metric=similarity_metric + ) + print(f"Classifier trained: {clf}\n") + + strategy = SIMBAUQSamplingStrategy( + temperatures=temperatures, + n_per_temp=n_per_temp, + similarity_metric=similarity_metric, + confidence_method="classifier", + classifier=clf, + ) + + result: SamplingResult = session.instruct( + "Which magazine was started first Arthur's Magazine or First for Women?", + strategy=strategy, + return_sampling_results=True, + ) + + print(f"Total samples generated: {len(result.sample_generations)}") + print_results(result) + + +def main() -> None: + """Run the SIMBA-UQ example(s) selected via the CONFIG block at the top of this file.""" + if EXAMPLE not in _VALID_EXAMPLES: + raise ValueError( + f"Unknown EXAMPLE={EXAMPLE!r}. Choose one of: {list(_VALID_EXAMPLES)}." + ) + if DATASET not in _VALID_DATASETS: + raise ValueError( + f"Unknown DATASET={DATASET!r}. Choose one of: {list(_VALID_DATASETS)}." + ) + if METRIC not in _VALID_METRICS: + raise ValueError( + f"Unknown METRIC={METRIC!r}. Choose one of: {list(_VALID_METRICS)}." + ) + if AGGREGATION not in _VALID_AGGREGATIONS: + raise ValueError( + f"Unknown AGGREGATION={AGGREGATION!r}. Choose one of: {list(_VALID_AGGREGATIONS)}." + ) + + # Start a Mellea session with OllamaModelBackend. + m = make_session() + + runners = { + "aggregation": lambda: run_aggregation_example(m, METRIC, AGGREGATION), + "synthetic": lambda: run_classifier_synthetic_example(m, METRIC), + "hf": lambda: run_classifier_hf_example(m, DATASET, METRIC, THRESHOLD), + "pretrained": lambda: run_classifier_pretrained_example( + m, DATASET, METRIC, THRESHOLD + ), + } + + to_run = list(runners.values()) if EXAMPLE == "all" else [runners[EXAMPLE]] + + for i, run in enumerate(to_run): + if i > 0: + print("\n" + "=" * 70 + "\n") + run() + + +if __name__ == "__main__": + main() diff --git a/mellea/stdlib/sampling/__init__.py b/mellea/stdlib/sampling/__init__.py index d89d45c98..9ec8d4762 100644 --- a/mellea/stdlib/sampling/__init__.py +++ b/mellea/stdlib/sampling/__init__.py @@ -8,6 +8,7 @@ RejectionSamplingStrategy, RepairTemplateStrategy, ) +from .simbauq import SIMBAUQSamplingStrategy from .sofai import SOFAISamplingStrategy __all__ = [ @@ -15,6 +16,7 @@ "MultiTurnStrategy", "RejectionSamplingStrategy", "RepairTemplateStrategy", + "SIMBAUQSamplingStrategy", "SamplingResult", "SamplingStrategy", ] diff --git a/mellea/stdlib/sampling/simbauq.py b/mellea/stdlib/sampling/simbauq.py new file mode 100644 index 000000000..e5395b2e6 --- /dev/null +++ b/mellea/stdlib/sampling/simbauq.py @@ -0,0 +1,602 @@ +"""SIMBA-UQ Sampling Strategy. + +Implements confidence-aware sample selection using the SIMBA-UQ framework +(Bhattacharjya et al., 2025). Generates multiple samples across a range of +temperatures and selects the most confident one. + +Two confidence estimation methods are supported: + +* **aggregation** (data-free) — computes pairwise similarity between all + samples, then aggregates per-sample similarities into a confidence score. +* **classifier** — extracts pairwise similarity features and feeds them into + a trained probabilistic classifier (e.g. random forest) that predicts + P(correct) for each sample. + +Reference: + Bhattacharjya et al. (2025), "SIMBA UQ: Similarity-Based Aggregation for + Uncertainty Quantification in Large Language Models", https://arxiv.org/abs/2510.13836 +""" + +import asyncio +from copy import deepcopy +from difflib import SequenceMatcher +from typing import Literal, Protocol, runtime_checkable + +import numpy as np +from rouge_score.rouge_scorer import RougeScorer + +from mellea.stdlib import context # codespell:ignore + +from ...core import ( + Backend, + BaseModelSubclass, + Component, + Context, + Requirement, + S, + SamplingResult, + SamplingStrategy, + ValidationResult, +) +from .. import functional as mfuncs + + +@runtime_checkable +class ProbabilisticClassifier(Protocol): + """Protocol for sklearn-compatible probabilistic classifiers.""" + + def predict_proba(self, X: list[np.ndarray]) -> np.ndarray: + """Return class probability estimates for the given samples.""" + ... + + +class SIMBAUQSamplingStrategy(SamplingStrategy): + """Sampling strategy that selects the most confident sample using SIMBA-UQ. + + Generates ``len(temperatures) * n_per_temp`` samples across a range of + temperature values, computes pairwise similarity between all samples, and + uses either similarity aggregation or a trained classifier to estimate + per-sample confidence. The sample with the highest confidence is returned. + + Confidence metadata is stored on the selected ``ModelOutputThunk`` in + ``mot.meta['simba_uq']``. + + Args: + temperatures (list[float]): Temperature values to sample at. + n_per_temp (int): Number of samples to generate per temperature value. + similarity_metric (Literal['rouge', 'jaccard', 'sbert', 'difflib', + 'levenshtein']): Pairwise similarity metric. ``'rouge'`` uses + RougeL F-measure; ``'jaccard'`` uses word-level Jaccard index; + ``'sbert'`` uses cosine similarity of Sentence-BERT embeddings + (requires ``sentence-transformers``); ``'difflib'`` uses + ``difflib.SequenceMatcher`` ratio; ``'levenshtein'`` uses + normalized Levenshtein edit distance. + confidence_method (Literal['aggregation', 'classifier']): How to + compute confidence from the similarity matrix. ``'aggregation'`` + uses a data-free aggregation function; ``'classifier'`` uses a + trained probabilistic classifier. + aggregation (Literal['mean', 'geometric_mean', 'harmonic_mean', + 'median', 'max', 'min']): Aggregation function used when + ``confidence_method='aggregation'``. + classifier (ProbabilisticClassifier | None): Pre-trained + sklearn-compatible probabilistic classifier (any estimator with a + ``predict_proba`` method). Used when + ``confidence_method='classifier'``. If not provided, a random + forest is trained from ``training_samples`` and + ``training_labels``. + training_samples (list[list[str]] | None): Training data for the + classifier — a list of query groups, each containing sample + strings. Each group must have the same number of samples as + ``len(temperatures) * n_per_temp``. + training_labels (list[list[int]] | None): Binary correctness labels + (0/1) matching ``training_samples``. + clf_max_depth (int): Maximum tree depth for the random forest when + training from data. + rouge_type (str): Rouge variant when ``similarity_metric='rouge'``. + sbert_model (str): Sentence-BERT model name when + ``similarity_metric='sbert'``. + requirements (list[Requirement] | None): Optional global requirements + to validate the selected sample against. + """ + + _CLF_EPS = 1e-6 + + def __init__( + self, + *, + temperatures: list[float] | None = None, + n_per_temp: int = 4, + similarity_metric: Literal[ + "rouge", "jaccard", "sbert", "difflib", "levenshtein" + ] = "rouge", + confidence_method: Literal["aggregation", "classifier"] = "aggregation", + aggregation: Literal[ + "mean", "geometric_mean", "harmonic_mean", "median", "max", "min" + ] = "mean", + classifier: ProbabilisticClassifier | None = None, + training_samples: list[list[str]] | None = None, + training_labels: list[list[int]] | None = None, + clf_max_depth: int = 4, + rouge_type: str = "rougeL", + sbert_model: str = "all-MiniLM-L6-v2", + requirements: list[Requirement] | None = None, + ) -> None: + """Initialize SIMBAUQSamplingStrategy with temperature schedule and confidence parameters.""" + if temperatures is None: + temperatures = [0.3, 0.5, 0.7, 1.0] + + if len(temperatures) == 0: + raise ValueError("Temperatures must be a non-empty list") + if n_per_temp <= 0: + raise ValueError("n_per_temp must be > 0") + + self.temperatures = temperatures + self.n_per_temp = n_per_temp + self.similarity_metric = similarity_metric + self.confidence_method = confidence_method + self.aggregation = aggregation + self.clf_max_depth = clf_max_depth + self.rouge_type = rouge_type + self.sbert_model = sbert_model + self.requirements = requirements + + # --- Similarity metric initialization --- + if similarity_metric == "rouge": + self._rouge_scorer = RougeScorer([rouge_type], use_stemmer=True) + elif similarity_metric == "sbert": + try: + import sentence_transformers # type: ignore[import-not-found] + except ImportError: + msg = ( + "sentence-transformers is required for sbert similarity. " + "Please install with `pip install sentence-transformers`." + ) + raise ImportError(msg) + self._sbert_model_obj = sentence_transformers.SentenceTransformer( + sbert_model + ) + + # --- Classifier initialization --- + self._classifier: ProbabilisticClassifier | None = None + if confidence_method == "classifier": + if classifier is not None: + self._classifier = classifier + + # If a classifier is provided, do a sanity check to ensure the feature + # dimensionality matches the expected number of samples. + expected = len(temperatures) * n_per_temp - 1 + n_features = getattr(classifier, "n_features_in_", None) + if n_features is not None and n_features != expected: + raise ValueError( + f"Classifier expects {n_features} features but this configuration " + f"produces {expected} (len(temperatures) * n_per_temp - 1)." + ) + elif training_samples is not None and training_labels is not None: + n_samples = len(temperatures) * n_per_temp + for i, group in enumerate(training_samples): + msg = ( + f"Training group {i} has {len(group)} samples, " + f"expected {n_samples} " + f"(len(temperatures) * n_per_temp)" + ) + if len(group) != n_samples: + raise ValueError(msg) + + msg = ( + f"Training labels group {i} has " + f"{len(training_labels[i])} labels, " + f"expected {n_samples}" + ) + if len(training_labels[i]) != n_samples: + raise ValueError(msg) + + self._classifier = self._train_classifier( + training_samples, training_labels + ) + else: + msg = ( + "confidence_method='classifier' requires either a " + "'classifier' or both 'training_samples' and " + "'training_labels'" + ) + raise ValueError(msg) + + async def sample( + self, + action: Component[S], + context: Context, + backend: Backend, + requirements: list[Requirement] | None, + *, + validation_ctx: Context | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult[S]: + """Sample across temperatures and select the most confident result. + + Args: + action: The action object to be sampled. + context: The context to be passed to the sampling strategy. + backend: The backend used for generating samples. + requirements: List of requirements to test against (merged with + global requirements). + validation_ctx: Optional context to use for validation. + format: Output format for structured outputs. + model_options: Model options to pass to the backend during + generation. + tool_calls: True if tool calls should be used during this sampling + strategy. + + Returns: + SamplingResult with the most confident sample selected. + """ + if model_options is None: + model_options = {} + + # Merge requirements: global requirements override local. + reqs = self._merge_requirements(requirements) + + # --- Phase 1: Generate samples across temperatures --- + generation_tasks: list[asyncio.Task] = [] + task_actions: list[Component[S]] = [] + task_temps: list[float] = [] + + for temp in self.temperatures: + for _ in range(self.n_per_temp): + opts = {**model_options, "temperature": temp} + task_action = deepcopy(action) + task = asyncio.create_task( + backend.generate_from_context( + task_action, + ctx=context, + format=format, + model_options=opts, + tool_calls=tool_calls, + ) + ) + generation_tasks.append(task) + task_actions.append(task_action) + task_temps.append(temp) + + generation_results = await asyncio.gather( + *generation_tasks, return_exceptions=True + ) + + # Resolve all thunks and parse. Skip failed tasks but keep + # all_mots / all_contexts / all_actions / temp_assignments aligned positionally. + all_mots = [] + all_contexts = [] + all_actions: list[Component[S]] = [] + temp_assignments: list[float] = [] + for gen_result, task_action, task_temp in zip( + generation_results, task_actions, task_temps + ): + if isinstance(gen_result, BaseException): + continue # Skip failed generations. + result_mot, result_ctx = gen_result + await result_mot.avalue() + result_mot.parsed_repr = task_action.parse(result_mot) + all_mots.append(result_mot) + all_contexts.append(result_ctx) + all_actions.append(task_action) + temp_assignments.append(task_temp) + + # --- Phase 2: Compute SIMBA-UQ confidence scores --- + sample_strings = [str(mot) for mot in all_mots] + n = len(sample_strings) + if n == 0: + raise RuntimeError("No successful samples were generated.") + elif n == 1: + sim_matrix = np.ones((1, 1)) + confidences = np.array([0.5]) + else: + sim_matrix = self._compute_similarity_matrix(sample_strings) + if self.confidence_method == "classifier": + confidences = self._compute_confidences_classifier(sim_matrix, n) + else: + confidences = self._compute_confidences(sim_matrix) + + # Select the sample with the highest confidence. + best_index = int(np.argmax(confidences)) + best_confidence = float(confidences[best_index]) + + # Store confidence metadata in the selected MOT's meta dict. + best_mot = all_mots[best_index] + if best_mot._meta is None: + best_mot._meta = {} + best_mot._meta["simba_uq"] = { + "confidence": best_confidence, + "all_confidences": confidences.tolist(), + "similarity_matrix": sim_matrix.tolist(), + "temperatures_used": temp_assignments, + "confidence_method": self.confidence_method, + "similarity_metric": self.similarity_metric, + "aggregation": self.aggregation, + } + + # Mark as final result. + if best_mot._generate_log is not None: + best_mot._generate_log.is_final_result = True + + # --- Phase 3: Validate selected sample (if requirements exist) --- + success = True + all_validations: list[list[tuple[Requirement, ValidationResult]]] = [ + [] for _ in all_mots + ] + + validation_ctx = ( + validation_ctx if validation_ctx is not None else all_contexts[best_index] + ) + + if reqs: + val_results = await mfuncs.avalidate( + reqs=reqs, + context=validation_ctx, + backend=backend, + output=best_mot, + format=None, + model_options=model_options, + ) + scored = list(zip(reqs, val_results)) + all_validations[best_index] = scored + success = all(vr.as_bool() for vr in val_results) + + return SamplingResult( + result_index=best_index, + success=success, + sample_generations=all_mots, + sample_validations=all_validations, + sample_actions=all_actions, + sample_contexts=all_contexts, + ) + + def _merge_requirements(self, local: list[Requirement] | None) -> list[Requirement]: + """Merge global and local requirements, deduplicating by identity.""" + combined: list[Requirement] = [] + seen: set[int] = set() + for req_list in (self.requirements, local): + if req_list is None: + continue + for req in req_list: + if id(req) not in seen: + combined.append(req) + seen.add(id(req)) + return combined + + @staticmethod + def _levenshtein_distance(s1: str, s2: str) -> int: + """Compute the Levenshtein edit distance between two strings.""" + m, n = len(s1), len(s2) + dp = list(range(n + 1)) + for i in range(1, m + 1): + prev = dp[0] + dp[0] = i + for j in range(1, n + 1): + temp = dp[j] + if s1[i - 1] == s2[j - 1]: + dp[j] = prev + else: + dp[j] = 1 + min(prev, dp[j], dp[j - 1]) + prev = temp + return dp[n] + + def _compute_similarity(self, text1: str, text2: str) -> float: + """Compute pairwise similarity between two text strings. + + Args: + text1: First text. + text2: Second text. + + Returns: + Similarity score in [0.0, 1.0]. + """ + if self.similarity_metric == "rouge": + scores = self._rouge_scorer.score(text1, text2) + return scores[self.rouge_type].fmeasure + + if self.similarity_metric == "sbert": + try: + from sklearn.metrics.pairwise import ( + cosine_similarity, # type: ignore[import-not-found] + ) + except ImportError: + msg = ( + "sklearn.metrics.pairwise.cosine_similarity is required for sbert similarity. " + "Please install with extra dependencies: `pip install mellea[simbauq]`." + ) + raise ImportError(msg) + + embs = self._sbert_model_obj.encode([text1, text2]) + return float(cosine_similarity([embs[0]], [embs[1]])[0, 0]) + + if self.similarity_metric == "difflib": + return SequenceMatcher(None, text1, text2).ratio() + + if self.similarity_metric == "levenshtein": + dist = self._levenshtein_distance(text1, text2) + max_len = max(len(text1), len(text2)) + return 1.0 - dist / max_len if max_len > 0 else 1.0 + + if self.similarity_metric == "jaccard": + # Jaccard: word-level set overlap. + words1 = set(text1.lower().split()) + words2 = set(text2.lower().split()) + if len(words1) == 0 and len(words2) == 0: + return 1.0 + union = len(words1 | words2) + return len(words1 & words2) / union if union > 0 else 0.0 + + msg = f"Unknown similarity metric: {self.similarity_metric!r}" + raise ValueError(msg) + + def _compute_similarity_matrix(self, samples: list[str]) -> np.ndarray: + """Build a symmetric pairwise similarity matrix. + + For ``sbert``, batch-encodes all samples once and computes cosine + similarity in a single matrix operation. For ``rouge`` and ``jaccard``, + computes pairwise similarities individually (upper triangle, mirrored). + + Args: + samples: List of sample strings. + + Returns: + Symmetric (N, N) matrix with self-similarity = 1.0. + """ + if self.similarity_metric == "sbert": + try: + from sklearn.metrics.pairwise import ( + cosine_similarity, # type: ignore[import-not-found] + ) + except ImportError: + msg = ( + "sklearn.metrics.pairwise.cosine_similarity is required for sbert similarity. " + "Please install with extra dependencies: `pip install mellea[simbauq]`." + ) + raise ImportError(msg) + + embeddings = self._sbert_model_obj.encode(samples) + matrix = cosine_similarity(embeddings) + np.fill_diagonal(matrix, 1.0) + return matrix + + n = len(samples) + matrix = np.eye(n) + for i in range(n): + for j in range(i + 1, n): + sim = self._compute_similarity(samples[i], samples[j]) + matrix[i, j] = sim + matrix[j, i] = sim + return matrix + + def _aggregate(self, similarities: np.ndarray) -> float: + """Aggregate a vector of similarity scores into a single confidence value. + + Args: + similarities: 1-D array of similarity scores. + + Returns: + Aggregated confidence score. + """ + epsilon = 1e-10 + if len(similarities) == 0: + return 0.0 + + if self.aggregation == "mean": + return float(np.mean(similarities)) + + if self.aggregation == "geometric_mean": + log_sims = np.log(similarities + epsilon) + return float(np.exp(np.mean(log_sims))) + + if self.aggregation == "harmonic_mean": + return float(len(similarities) / np.sum(1.0 / (similarities + epsilon))) + + if self.aggregation == "median": + return float(np.median(similarities)) + + if self.aggregation == "max": + return float(np.max(similarities)) + + if self.aggregation == "min": + return float(np.min(similarities)) + + msg = f"Unknown aggregation method: {self.aggregation}" + raise ValueError(msg) + + def _extract_features( + self, sim_matrix: np.ndarray, sample_index: int + ) -> np.ndarray: + """Extract pairwise similarity features for a single sample. + + Returns the similarity row with self-similarity removed and values + clipped to ``(eps, 1 - eps)`` for numerical stability. + + Args: + sim_matrix: Symmetric (N, N) similarity matrix. + sample_index: Index of the sample to extract features for. + + Returns: + 1-D feature array of length ``N - 1``. + """ + row = np.delete(sim_matrix[sample_index, :], sample_index) + return np.clip(row, self._CLF_EPS, 1.0 - self._CLF_EPS) + + def _train_classifier( + self, training_samples: list[list[str]], training_labels: list[list[int]] + ) -> ProbabilisticClassifier: + """Train a random forest classifier on similarity features. + + Args: + training_samples: List of query groups, each a list of sample + strings with the same length as the inference-time sample + count. + training_labels: Binary correctness labels (0/1) matching + ``training_samples``. + + Returns: + Trained ``RandomForestClassifier``. + """ + try: + from sklearn.ensemble import ( + RandomForestClassifier, # type: ignore[import-not-found] + ) + except ImportError: + msg = ( + "sklearn is required for training a Random Forest classifier. " + "Please install with extra dependencies: `pip install mellea[simbauq]`." + ) + raise ImportError(msg) + + x_train: list[np.ndarray] = [] + y_train: list[int] = [] + for samples, labels in zip(training_samples, training_labels): + sim_matrix = self._compute_similarity_matrix(samples) + for i, label in enumerate(labels): + x_train.append(self._extract_features(sim_matrix, i)) + y_train.append(label) + clf = RandomForestClassifier(max_depth=self.clf_max_depth, random_state=0) + clf.fit(x_train, y_train) + return clf + + def _compute_confidences_classifier( + self, sim_matrix: np.ndarray, n: int + ) -> np.ndarray: + """Compute per-sample confidence using the trained classifier. + + Args: + sim_matrix: Pre-computed (N, N) similarity matrix. + n: Number of samples. + + Returns: + Array of P(correct) confidence scores with shape ``(n,)``. + """ + x_test = [self._extract_features(sim_matrix, i) for i in range(n)] + if self._classifier is None: + raise RuntimeError( + "Classifier is not initialised — this is a bug in SIMBAUQSamplingStrategy." + ) + probs = self._classifier.predict_proba(x_test) + return probs[:, 1] + + def _compute_confidences(self, sim_matrix: np.ndarray) -> np.ndarray: + """Compute per-sample confidence using similarity-based aggregation. + + For each sample, aggregates its similarities to every other sample + into a single confidence score. + + Args: + sim_matrix: Symmetric (N, N) pairwise similarity matrix. + + Returns: + Array of confidence scores with shape ``(N,)``. + """ + n = sim_matrix.shape[0] + if n == 1: + return np.array([0.5]) + + confidences = np.zeros(n) + for i in range(n): + others = np.concatenate([sim_matrix[i, :i], sim_matrix[i, i + 1 :]]) + confidences[i] = self._aggregate(others) + return confidences diff --git a/pyproject.toml b/pyproject.toml index f64a15d2a..8e09e9e4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -108,7 +108,9 @@ hooks = [ "grpcio>=1.78.0", ] -all = ["mellea[backends,docling,tools,telemetry,server,sandbox,granite_retriever,hooks]"] +simbauq = ["scikit-learn", "numpy<=2.2"] + +all = ["mellea[backends,docling,tools,telemetry,server,sandbox,granite_retriever,hooks,simbauq]"] [dependency-groups] # Development groups: uv sync --all-groups diff --git a/test/stdlib/sampling/test_simbauq.py b/test/stdlib/sampling/test_simbauq.py new file mode 100644 index 000000000..b71df29e6 --- /dev/null +++ b/test/stdlib/sampling/test_simbauq.py @@ -0,0 +1,428 @@ +"""Tests for SIMBAUQSamplingStrategy.""" + +import numpy as np +import pytest + +from mellea.stdlib.sampling.simbauq import SIMBAUQSamplingStrategy + +# --- Unit tests (no LLM required) --- + + +class TestComputeSimilarity: + def test_rouge_identical(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="rouge") + assert strategy._compute_similarity( + "hello world", "hello world" + ) == pytest.approx(1.0) + + def test_rouge_different(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="rouge") + score = strategy._compute_similarity("the cat sat on the mat", "dogs run fast") + assert 0.0 <= score < 0.5 + + def test_jaccard_identical(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + assert strategy._compute_similarity( + "hello world", "hello world" + ) == pytest.approx(1.0) + + def test_jaccard_partial_overlap(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + score = strategy._compute_similarity("hello world foo", "hello world bar") + # intersection = {"hello", "world"}, union = {"hello", "world", "foo", "bar"} + assert score == pytest.approx(2.0 / 4.0) + + def test_jaccard_no_overlap(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + score = strategy._compute_similarity("alpha beta", "gamma delta") + assert score == pytest.approx(0.0) + + def test_jaccard_empty_strings(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + assert strategy._compute_similarity("", "") == pytest.approx(1.0) + + def test_difflib_identical(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="difflib") + assert strategy._compute_similarity( + "hello world", "hello world" + ) == pytest.approx(1.0) + + def test_difflib_different(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="difflib") + score = strategy._compute_similarity("the cat sat on the mat", "dogs run fast") + assert 0.0 <= score < 0.5 + + def test_difflib_partial(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="difflib") + score = strategy._compute_similarity("hello world foo", "hello world bar") + assert 0.5 < score < 1.0 + + def test_difflib_empty(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="difflib") + assert strategy._compute_similarity("", "") == pytest.approx(1.0) + + def test_levenshtein_identical(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="levenshtein") + assert strategy._compute_similarity( + "hello world", "hello world" + ) == pytest.approx(1.0) + + def test_levenshtein_different(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="levenshtein") + score = strategy._compute_similarity("abc", "xyz") + assert score == pytest.approx(0.0) + + def test_levenshtein_partial(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="levenshtein") + score = strategy._compute_similarity("kitten", "sitting") + assert 0.0 < score < 1.0 + + def test_levenshtein_empty(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="levenshtein") + assert strategy._compute_similarity("", "") == pytest.approx(1.0) + + def test_levenshtein_single_edit(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="levenshtein") + score = strategy._compute_similarity("abcd", "abce") + assert score == pytest.approx(0.75) + + +class TestLevenshteinDistance: + def test_empty_strings(self): + assert SIMBAUQSamplingStrategy._levenshtein_distance("", "") == 0 + + def test_one_empty(self): + assert SIMBAUQSamplingStrategy._levenshtein_distance("a", "") == 1 + + def test_classic_example(self): + assert SIMBAUQSamplingStrategy._levenshtein_distance("kitten", "sitting") == 3 + + def test_identical(self): + assert SIMBAUQSamplingStrategy._levenshtein_distance("abc", "abc") == 0 + + +class TestAggregate: + def setup_method(self): + self.sims = np.array([0.8, 0.6, 0.4]) + + def test_mean(self): + strategy = SIMBAUQSamplingStrategy(aggregation="mean") + assert strategy._aggregate(self.sims) == pytest.approx(0.6) + + def test_median(self): + strategy = SIMBAUQSamplingStrategy(aggregation="median") + assert strategy._aggregate(self.sims) == pytest.approx(0.6) + + def test_max(self): + strategy = SIMBAUQSamplingStrategy(aggregation="max") + assert strategy._aggregate(self.sims) == pytest.approx(0.8) + + def test_min(self): + strategy = SIMBAUQSamplingStrategy(aggregation="min") + assert strategy._aggregate(self.sims) == pytest.approx(0.4) + + def test_geometric_mean(self): + strategy = SIMBAUQSamplingStrategy(aggregation="geometric_mean") + expected = (0.8 * 0.6 * 0.4) ** (1.0 / 3.0) + assert strategy._aggregate(self.sims) == pytest.approx(expected, abs=1e-3) + + def test_harmonic_mean(self): + strategy = SIMBAUQSamplingStrategy(aggregation="harmonic_mean") + expected = 3.0 / (1.0 / 0.8 + 1.0 / 0.6 + 1.0 / 0.4) + assert strategy._aggregate(self.sims) == pytest.approx(expected, abs=1e-3) + + def test_empty(self): + strategy = SIMBAUQSamplingStrategy(aggregation="mean") + assert strategy._aggregate(np.array([])) == 0.0 + + +class TestComputeConfidences: + def test_single_sample(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + sim_matrix = strategy._compute_similarity_matrix(["hello world"]) + confs = strategy._compute_confidences(sim_matrix) + assert len(confs) == 1 + assert confs[0] == pytest.approx(0.5) + + def test_identical_samples_high_confidence(self): + strategy = SIMBAUQSamplingStrategy( + similarity_metric="jaccard", aggregation="mean" + ) + samples = ["the cat sat on the mat"] * 5 + sim_matrix = strategy._compute_similarity_matrix(samples) + confs = strategy._compute_confidences(sim_matrix) + assert len(confs) == 5 + for c in confs: + assert c == pytest.approx(1.0) + + def test_outlier_has_lower_confidence(self): + strategy = SIMBAUQSamplingStrategy( + similarity_metric="jaccard", aggregation="mean" + ) + samples = [ + "the capital of france is paris", + "paris is the capital of france", + "france capital is paris", + "bananas are yellow fruit", # outlier + ] + sim_matrix = strategy._compute_similarity_matrix(samples) + confs = strategy._compute_confidences(sim_matrix) + assert len(confs) == 4 + # The outlier (index 3) should have the lowest confidence. + assert confs[3] < confs[0] + assert confs[3] < confs[1] + assert confs[3] < confs[2] + + def test_similarity_matrix_symmetric(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="rouge") + samples = ["hello world", "world hello", "foo bar"] + matrix = strategy._compute_similarity_matrix(samples) + assert matrix.shape == (3, 3) + np.testing.assert_array_almost_equal(matrix, matrix.T) + np.testing.assert_array_equal(np.diag(matrix), [1.0, 1.0, 1.0]) + + +class TestSBERTSimilarity: + @pytest.fixture(autouse=True) + def _require_sbert(self): + pytest.importorskip("sentence_transformers") + + def test_sbert_identical(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="sbert") + score = strategy._compute_similarity("hello world", "hello world") + assert score == pytest.approx(1.0, abs=0.01) + + def test_sbert_similar(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="sbert") + score = strategy._compute_similarity( + "The capital of France is Paris.", "Paris is the capital city of France." + ) + assert score > 0.7 + + def test_sbert_different(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="sbert") + score = strategy._compute_similarity( + "The capital of France is Paris.", "Bananas are a yellow tropical fruit." + ) + assert score < 0.4 + + def test_sbert_matrix_symmetric(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="sbert") + samples = ["hello world", "world hello", "foo bar baz"] + matrix = strategy._compute_similarity_matrix(samples) + assert matrix.shape == (3, 3) + np.testing.assert_array_almost_equal(matrix, matrix.T) + np.testing.assert_array_almost_equal( + np.diag(matrix), [1.0, 1.0, 1.0], decimal=2 + ) + + def test_sbert_outlier_confidence(self): + strategy = SIMBAUQSamplingStrategy( + similarity_metric="sbert", aggregation="mean" + ) + samples = [ + "The capital of France is Paris.", + "Paris is the capital of France.", + "France has Paris as its capital.", + "Bananas are a yellow tropical fruit.", # outlier + ] + sim_matrix = strategy._compute_similarity_matrix(samples) + confs = strategy._compute_confidences(sim_matrix) + assert len(confs) == 4 + assert confs[3] < confs[0] + assert confs[3] < confs[1] + assert confs[3] < confs[2] + + +class TestClassifierConfidence: + """Tests for the classifier-based confidence estimation method.""" + + # Synthetic training data: 3 groups of 4 samples each. + # Groups have 3 "correct" similar answers and 1 "incorrect" outlier. + TRAINING_SAMPLES = [ + [ + "The capital of France is Paris.", + "Paris is the capital of France.", + "France's capital city is Paris.", + "Bananas are a yellow tropical fruit.", + ], + [ + "Water boils at 100 degrees Celsius.", + "At 100 degrees Celsius water boils.", + "The boiling point of water is 100C.", + "The sky is often blue on clear days.", + ], + [ + "Python is a programming language.", + "Python is a popular programming language.", + "The Python programming language is widely used.", + "Mount Everest is very tall.", + ], + ] + TRAINING_LABELS = [[1, 1, 1, 0], [1, 1, 1, 0], [1, 1, 1, 0]] + + def test_extract_features(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + sim_matrix = np.array([[1.0, 0.8, 0.2], [0.8, 1.0, 0.3], [0.2, 0.3, 1.0]]) + features = strategy._extract_features(sim_matrix, 0) + assert len(features) == 2 + assert features[0] == pytest.approx(0.8) + assert features[1] == pytest.approx(0.2) + + def test_extract_features_clipping(self): + strategy = SIMBAUQSamplingStrategy(similarity_metric="jaccard") + sim_matrix = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 0.5], [1.0, 0.5, 1.0]]) + features = strategy._extract_features(sim_matrix, 0) + # 0.0 clipped to eps, 1.0 clipped to 1-eps + assert features[0] > 0.0 + assert features[1] < 1.0 + + def test_train_classifier(self): + strategy = SIMBAUQSamplingStrategy( + temperatures=[0.5], + n_per_temp=4, + similarity_metric="jaccard", + confidence_method="classifier", + training_samples=self.TRAINING_SAMPLES, + training_labels=self.TRAINING_LABELS, + ) + assert strategy._classifier is not None + assert hasattr(strategy._classifier, "predict_proba") + + def test_classifier_confidence_produces_valid_scores(self): + strategy = SIMBAUQSamplingStrategy( + temperatures=[0.5], + n_per_temp=4, + similarity_metric="jaccard", + confidence_method="classifier", + training_samples=self.TRAINING_SAMPLES, + training_labels=self.TRAINING_LABELS, + ) + # Test samples: 3 similar + 1 outlier (same structure as training) + test_samples = [ + "The capital of Germany is Berlin.", + "Berlin is the capital of Germany.", + "Germany's capital is Berlin.", + "Cats like to chase mice around.", + ] + sim_matrix = strategy._compute_similarity_matrix(test_samples) + confs = strategy._compute_confidences_classifier(sim_matrix, len(test_samples)) + assert len(confs) == 4 + for c in confs: + assert 0.0 <= c <= 1.0 + + def test_classifier_outlier_lower_confidence(self): + strategy = SIMBAUQSamplingStrategy( + temperatures=[0.5], + n_per_temp=4, + similarity_metric="jaccard", + confidence_method="classifier", + training_samples=self.TRAINING_SAMPLES, + training_labels=self.TRAINING_LABELS, + ) + test_samples = [ + "The capital of Italy is Rome.", + "Rome is the capital of Italy.", + "Italy has Rome as its capital.", + "Elephants are the largest land animals.", + ] + sim_matrix = strategy._compute_similarity_matrix(test_samples) + confs = strategy._compute_confidences_classifier(sim_matrix, len(test_samples)) + # Outlier (index 3) should have lower confidence than the similar ones. + assert confs[3] < confs[0] + + def test_pretrained_classifier(self): + from sklearn.ensemble import RandomForestClassifier + + # Train a classifier manually. + strategy_train = SIMBAUQSamplingStrategy( + temperatures=[0.5], + n_per_temp=4, + similarity_metric="jaccard", + confidence_method="classifier", + training_samples=self.TRAINING_SAMPLES, + training_labels=self.TRAINING_LABELS, + ) + trained_clf = strategy_train._classifier + + # Use pre-trained classifier in a new strategy. + strategy = SIMBAUQSamplingStrategy( + temperatures=[0.5], + n_per_temp=4, + similarity_metric="jaccard", + confidence_method="classifier", + classifier=trained_clf, + ) + assert strategy._classifier is trained_clf + + def test_classifier_requires_data_or_clf(self): + with pytest.raises(ValueError, match="requires either"): + SIMBAUQSamplingStrategy(confidence_method="classifier") + + +class TestInit: + def test_default_temperatures(self): + strategy = SIMBAUQSamplingStrategy() + assert strategy.temperatures == [0.3, 0.5, 0.7, 1.0] + + def test_custom_temperatures(self): + strategy = SIMBAUQSamplingStrategy(temperatures=[0.1, 0.9]) + assert strategy.temperatures == [0.1, 0.9] + + def test_empty_temperatures_raises(self): + with pytest.raises(ValueError): + SIMBAUQSamplingStrategy(temperatures=[]) + + def test_zero_n_per_temp_raises(self): + with pytest.raises(ValueError): + SIMBAUQSamplingStrategy(n_per_temp=0) + + +# --- Integration test (requires Ollama) --- + + +@pytest.mark.ollama +@pytest.mark.e2e +@pytest.mark.qualitative +class TestSIMBAUQIntegration: + def test_simbauq_sampling(self): + from mellea import MelleaSession, start_session + from mellea.backends import ModelOption + from mellea.core import SamplingResult + + m: MelleaSession = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 30}) + + result: SamplingResult = m.instruct( + "What is the capital of France?", + strategy=SIMBAUQSamplingStrategy( + temperatures=[0.3, 0.7], + n_per_temp=2, + similarity_metric="rouge", + aggregation="mean", + ), + return_sampling_results=True, + ) + + assert isinstance(result, SamplingResult) + assert result.success is True + assert len(result.sample_generations) == 4 # 2 temps * 2 per temp + + # Check that the selected MOT has confidence metadata. + best_mot = result.result + assert best_mot._meta is not None + simba_meta = best_mot._meta["simba_uq"] + assert "confidence" in simba_meta + assert 0.0 <= simba_meta["confidence"] <= 1.0 + assert len(simba_meta["all_confidences"]) == 4 + assert simba_meta["similarity_metric"] == "rouge" + assert simba_meta["aggregation"] == "mean" + + output = str(best_mot) + print(f"Best output (confidence={simba_meta['confidence']:.3f}): {output}") + assert output + + del m + + +if __name__ == "__main__": + pytest.main(["-s", __file__]) diff --git a/uv.lock b/uv.lock index 8c39705b8..1277ff6f5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.11" resolution-markers = [ "python_full_version >= '3.14' and python_full_version < '4' and sys_platform == 'darwin'", @@ -2009,7 +2009,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/e8/2e1462c8fdbe0f210feb5ac7ad2d9029af8be3bf45bd9fa39765f821642f/greenlet-3.3.1-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:5fd23b9bc6d37b563211c6abbb1b3cab27db385a4449af5c32e932f93017080c", size = 274974, upload-time = "2026-01-23T15:31:02.891Z" }, { url = "https://files.pythonhosted.org/packages/7e/a8/530a401419a6b302af59f67aaf0b9ba1015855ea7e56c036b5928793c5bd/greenlet-3.3.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:09f51496a0bfbaa9d74d36a52d2580d1ef5ed4fdfcff0a73730abfbbbe1403dd", size = 577175, upload-time = "2026-01-23T16:00:56.213Z" }, { url = "https://files.pythonhosted.org/packages/8e/89/7e812bb9c05e1aaef9b597ac1d0962b9021d2c6269354966451e885c4e6b/greenlet-3.3.1-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cb0feb07fe6e6a74615ee62a880007d976cf739b6669cce95daa7373d4fc69c5", size = 590401, upload-time = "2026-01-23T16:05:26.365Z" }, - { url = "https://files.pythonhosted.org/packages/70/ae/e2d5f0e59b94a2269b68a629173263fa40b63da32f5c231307c349315871/greenlet-3.3.1-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:67ea3fc73c8cd92f42467a72b75e8f05ed51a0e9b1d15398c913416f2dafd49f", size = 601161, upload-time = "2026-01-23T16:15:53.456Z" }, { url = "https://files.pythonhosted.org/packages/5c/ae/8d472e1f5ac5efe55c563f3eabb38c98a44b832602e12910750a7c025802/greenlet-3.3.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:39eda9ba259cc9801da05351eaa8576e9aa83eb9411e8f0c299e05d712a210f2", size = 590272, upload-time = "2026-01-23T15:32:49.411Z" }, { url = "https://files.pythonhosted.org/packages/a8/51/0fde34bebfcadc833550717eade64e35ec8738e6b097d5d248274a01258b/greenlet-3.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:e2e7e882f83149f0a71ac822ebf156d902e7a5d22c9045e3e0d1daf59cee2cc9", size = 1550729, upload-time = "2026-01-23T16:04:20.867Z" }, { url = "https://files.pythonhosted.org/packages/16/c9/2fb47bee83b25b119d5a35d580807bb8b92480a54b68fef009a02945629f/greenlet-3.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:80aa4d79eb5564f2e0a6144fcc744b5a37c56c4a92d60920720e99210d88db0f", size = 1615552, upload-time = "2026-01-23T15:33:45.743Z" }, @@ -2018,7 +2017,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/c8/9d76a66421d1ae24340dfae7e79c313957f6e3195c144d2c73333b5bfe34/greenlet-3.3.1-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:7e806ca53acf6d15a888405880766ec84721aa4181261cd11a457dfe9a7a4975", size = 276443, upload-time = "2026-01-23T15:30:10.066Z" }, { url = "https://files.pythonhosted.org/packages/81/99/401ff34bb3c032d1f10477d199724f5e5f6fbfb59816ad1455c79c1eb8e7/greenlet-3.3.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d842c94b9155f1c9b3058036c24ffb8ff78b428414a19792b2380be9cecf4f36", size = 597359, upload-time = "2026-01-23T16:00:57.394Z" }, { url = "https://files.pythonhosted.org/packages/2b/bc/4dcc0871ed557792d304f50be0f7487a14e017952ec689effe2180a6ff35/greenlet-3.3.1-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:20fedaadd422fa02695f82093f9a98bad3dab5fcda793c658b945fcde2ab27ba", size = 607805, upload-time = "2026-01-23T16:05:28.068Z" }, - { url = "https://files.pythonhosted.org/packages/3b/cd/7a7ca57588dac3389e97f7c9521cb6641fd8b6602faf1eaa4188384757df/greenlet-3.3.1-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:c620051669fd04ac6b60ebc70478210119c56e2d5d5df848baec4312e260e4ca", size = 622363, upload-time = "2026-01-23T16:15:54.754Z" }, { url = "https://files.pythonhosted.org/packages/cf/05/821587cf19e2ce1f2b24945d890b164401e5085f9d09cbd969b0c193cd20/greenlet-3.3.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:14194f5f4305800ff329cbf02c5fcc88f01886cadd29941b807668a45f0d2336", size = 609947, upload-time = "2026-01-23T15:32:51.004Z" }, { url = "https://files.pythonhosted.org/packages/a4/52/ee8c46ed9f8babaa93a19e577f26e3d28a519feac6350ed6f25f1afee7e9/greenlet-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7b2fe4150a0cf59f847a67db8c155ac36aed89080a6a639e9f16df5d6c6096f1", size = 1567487, upload-time = "2026-01-23T16:04:22.125Z" }, { url = "https://files.pythonhosted.org/packages/8f/7c/456a74f07029597626f3a6db71b273a3632aecb9afafeeca452cfa633197/greenlet-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:49f4ad195d45f4a66a0eb9c1ba4832bb380570d361912fa3554746830d332149", size = 1636087, upload-time = "2026-01-23T15:33:47.486Z" }, @@ -2027,7 +2025,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/ab/d26750f2b7242c2b90ea2ad71de70cfcd73a948a49513188a0fc0d6fc15a/greenlet-3.3.1-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:7ab327905cabb0622adca5971e488064e35115430cec2c35a50fd36e72a315b3", size = 275205, upload-time = "2026-01-23T15:30:24.556Z" }, { url = "https://files.pythonhosted.org/packages/10/d3/be7d19e8fad7c5a78eeefb2d896a08cd4643e1e90c605c4be3b46264998f/greenlet-3.3.1-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:65be2f026ca6a176f88fb935ee23c18333ccea97048076aef4db1ef5bc0713ac", size = 599284, upload-time = "2026-01-23T16:00:58.584Z" }, { url = "https://files.pythonhosted.org/packages/ae/21/fe703aaa056fdb0f17e5afd4b5c80195bbdab701208918938bd15b00d39b/greenlet-3.3.1-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7a3ae05b3d225b4155bda56b072ceb09d05e974bc74be6c3fc15463cf69f33fd", size = 610274, upload-time = "2026-01-23T16:05:29.312Z" }, - { url = "https://files.pythonhosted.org/packages/06/00/95df0b6a935103c0452dad2203f5be8377e551b8466a29650c4c5a5af6cc/greenlet-3.3.1-cp313-cp313-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:12184c61e5d64268a160226fb4818af4df02cfead8379d7f8b99a56c3a54ff3e", size = 624375, upload-time = "2026-01-23T16:15:55.915Z" }, { url = "https://files.pythonhosted.org/packages/cb/86/5c6ab23bb3c28c21ed6bebad006515cfe08b04613eb105ca0041fecca852/greenlet-3.3.1-cp313-cp313-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6423481193bbbe871313de5fd06a082f2649e7ce6e08015d2a76c1e9186ca5b3", size = 612904, upload-time = "2026-01-23T15:32:52.317Z" }, { url = "https://files.pythonhosted.org/packages/c2/f3/7949994264e22639e40718c2daf6f6df5169bf48fb038c008a489ec53a50/greenlet-3.3.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:33a956fe78bbbda82bfc95e128d61129b32d66bcf0a20a1f0c08aa4839ffa951", size = 1567316, upload-time = "2026-01-23T16:04:23.316Z" }, { url = "https://files.pythonhosted.org/packages/8d/6e/d73c94d13b6465e9f7cd6231c68abde838bb22408596c05d9059830b7872/greenlet-3.3.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b065d3284be43728dd280f6f9a13990b56470b81be20375a207cdc814a983f2", size = 1636549, upload-time = "2026-01-23T15:33:48.643Z" }, @@ -2036,7 +2033,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ae/fb/011c7c717213182caf78084a9bea51c8590b0afda98001f69d9f853a495b/greenlet-3.3.1-cp314-cp314-macosx_11_0_universal2.whl", hash = "sha256:bd59acd8529b372775cd0fcbc5f420ae20681c5b045ce25bd453ed8455ab99b5", size = 275737, upload-time = "2026-01-23T15:32:16.889Z" }, { url = "https://files.pythonhosted.org/packages/41/2e/a3a417d620363fdbb08a48b1dd582956a46a61bf8fd27ee8164f9dfe87c2/greenlet-3.3.1-cp314-cp314-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b31c05dd84ef6871dd47120386aed35323c944d86c3d91a17c4b8d23df62f15b", size = 646422, upload-time = "2026-01-23T16:01:00.354Z" }, { url = "https://files.pythonhosted.org/packages/b4/09/c6c4a0db47defafd2d6bab8ddfe47ad19963b4e30f5bed84d75328059f8c/greenlet-3.3.1-cp314-cp314-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:02925a0bfffc41e542c70aa14c7eda3593e4d7e274bfcccca1827e6c0875902e", size = 658219, upload-time = "2026-01-23T16:05:30.956Z" }, - { url = "https://files.pythonhosted.org/packages/e2/89/b95f2ddcc5f3c2bc09c8ee8d77be312df7f9e7175703ab780f2014a0e781/greenlet-3.3.1-cp314-cp314-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:3e0f3878ca3a3ff63ab4ea478585942b53df66ddde327b59ecb191b19dbbd62d", size = 671455, upload-time = "2026-01-23T16:15:57.232Z" }, { url = "https://files.pythonhosted.org/packages/80/38/9d42d60dffb04b45f03dbab9430898352dba277758640751dc5cc316c521/greenlet-3.3.1-cp314-cp314-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34a729e2e4e4ffe9ae2408d5ecaf12f944853f40ad724929b7585bca808a9d6f", size = 660237, upload-time = "2026-01-23T15:32:53.967Z" }, { url = "https://files.pythonhosted.org/packages/96/61/373c30b7197f9e756e4c81ae90a8d55dc3598c17673f91f4d31c3c689c3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:aec9ab04e82918e623415947921dea15851b152b822661cce3f8e4393c3df683", size = 1615261, upload-time = "2026-01-23T16:04:25.066Z" }, { url = "https://files.pythonhosted.org/packages/fd/d3/ca534310343f5945316f9451e953dcd89b36fe7a19de652a1dc5a0eeef3f/greenlet-3.3.1-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:71c767cf281a80d02b6c1bdc41c9468e1f5a494fb11bc8688c360524e273d7b1", size = 1683719, upload-time = "2026-01-23T15:33:50.61Z" }, @@ -2045,7 +2041,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/28/24/cbbec49bacdcc9ec652a81d3efef7b59f326697e7edf6ed775a5e08e54c2/greenlet-3.3.1-cp314-cp314t-macosx_11_0_universal2.whl", hash = "sha256:3e63252943c921b90abb035ebe9de832c436401d9c45f262d80e2d06cc659242", size = 282706, upload-time = "2026-01-23T15:33:05.525Z" }, { url = "https://files.pythonhosted.org/packages/86/2e/4f2b9323c144c4fe8842a4e0d92121465485c3c2c5b9e9b30a52e80f523f/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76e39058e68eb125de10c92524573924e827927df5d3891fbc97bd55764a8774", size = 651209, upload-time = "2026-01-23T16:01:01.517Z" }, { url = "https://files.pythonhosted.org/packages/d9/87/50ca60e515f5bb55a2fbc5f0c9b5b156de7d2fc51a0a69abc9d23914a237/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c9f9d5e7a9310b7a2f416dd13d2e3fd8b42d803968ea580b7c0f322ccb389b97", size = 654300, upload-time = "2026-01-23T16:05:32.199Z" }, - { url = "https://files.pythonhosted.org/packages/7c/25/c51a63f3f463171e09cb586eb64db0861eb06667ab01a7968371a24c4f3b/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4b9721549a95db96689458a1e0ae32412ca18776ed004463df3a9299c1b257ab", size = 662574, upload-time = "2026-01-23T16:15:58.364Z" }, { url = "https://files.pythonhosted.org/packages/1d/94/74310866dfa2b73dd08659a3d18762f83985ad3281901ba0ee9a815194fb/greenlet-3.3.1-cp314-cp314t-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:92497c78adf3ac703b57f1e3813c2d874f27f71a178f9ea5887855da413cd6d2", size = 653842, upload-time = "2026-01-23T15:32:55.671Z" }, { url = "https://files.pythonhosted.org/packages/97/43/8bf0ffa3d498eeee4c58c212a3905dd6146c01c8dc0b0a046481ca29b18c/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:ed6b402bc74d6557a705e197d47f9063733091ed6357b3de33619d8a8d93ac53", size = 1614917, upload-time = "2026-01-23T16:04:26.276Z" }, { url = "https://files.pythonhosted.org/packages/89/90/a3be7a5f378fc6e84abe4dcfb2ba32b07786861172e502388b4c90000d1b/greenlet-3.3.1-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:59913f1e5ada20fde795ba906916aea25d442abcc0593fba7e26c92b7ad76249", size = 1676092, upload-time = "2026-01-23T15:33:52.176Z" }, @@ -3773,6 +3768,7 @@ all = [ { name = "opentelemetry-sdk" }, { name = "peft" }, { name = "pyarrow" }, + { name = "scikit-learn" }, { name = "sentence-transformers" }, { name = "smolagents" }, { name = "transformers" }, @@ -3830,6 +3826,10 @@ server = [ { name = "fastapi" }, { name = "uvicorn" }, ] +simbauq = [ + { name = "numpy" }, + { name = "scikit-learn" }, +] telemetry = [ { name = "cpex" }, { name = "grpcio" }, @@ -3941,10 +3941,11 @@ requires-dist = [ { name = "llguidance", marker = "extra == 'hf'" }, { name = "llm-sandbox", extras = ["docker"], marker = "extra == 'sandbox'", specifier = ">=0.3.23" }, { name = "math-verify" }, - { name = "mellea", extras = ["backends", "docling", "tools", "telemetry", "server", "sandbox", "granite-retriever", "hooks"], marker = "extra == 'all'" }, + { name = "mellea", extras = ["backends", "docling", "tools", "telemetry", "server", "sandbox", "granite-retriever", "hooks", "simbauq"], marker = "extra == 'all'" }, { name = "mellea", extras = ["hooks"], marker = "extra == 'telemetry'" }, { name = "mellea", extras = ["watsonx", "hf", "vllm", "litellm"], marker = "extra == 'backends'" }, { name = "mistletoe", specifier = ">=1.4.0" }, + { name = "numpy", marker = "extra == 'simbauq'", specifier = "<=2.2" }, { name = "numpy", marker = "extra == 'vllm'", specifier = "<=2.2" }, { name = "ollama", specifier = ">=0.5.1" }, { name = "openai" }, @@ -3962,6 +3963,7 @@ requires-dist = [ { name = "pyyaml" }, { name = "requests", specifier = ">=2.32.3" }, { name = "rouge-score" }, + { name = "scikit-learn", marker = "extra == 'simbauq'" }, { name = "sentence-transformers", marker = "extra == 'granite-retriever'" }, { name = "smolagents", marker = "extra == 'tools'", specifier = ">=1.0.0" }, { name = "transformers", marker = "extra == 'hf'", specifier = ">=4.53.2,<5" }, @@ -3971,7 +3973,7 @@ requires-dist = [ { name = "vllm", marker = "sys_platform != 'darwin' and extra == 'vllm'", specifier = ">=0.13.0" }, { name = "xgrammar", marker = "extra == 'hf'" }, ] -provides-extras = ["hf", "vllm", "litellm", "watsonx", "tools", "telemetry", "docling", "granite-retriever", "server", "sandbox", "backends", "hooks", "all"] +provides-extras = ["hf", "vllm", "litellm", "watsonx", "tools", "telemetry", "docling", "granite-retriever", "server", "sandbox", "backends", "hooks", "simbauq", "all"] [package.metadata.requires-dev] build = [{ name = "pdm", specifier = ">=2.24.0" }] @@ -5908,7 +5910,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/75/2e/a9e28941c6dab6f06e6d3f6783d3373044be9b0f9a9d3492c3d8d2260ac0/pybase64-1.4.3-cp312-cp312-win32.whl", hash = "sha256:7bca1ed3a5df53305c629ca94276966272eda33c0d71f862d2d3d043f1e1b91a", size = 33686, upload-time = "2025-12-06T13:23:37.848Z" }, { url = "https://files.pythonhosted.org/packages/83/e3/507ab649d8c3512c258819c51d25c45d6e29d9ca33992593059e7b646a33/pybase64-1.4.3-cp312-cp312-win_amd64.whl", hash = "sha256:9f2da8f56d9b891b18b4daf463a0640eae45a80af548ce435be86aa6eff3603b", size = 35833, upload-time = "2025-12-06T13:23:38.877Z" }, { url = "https://files.pythonhosted.org/packages/bc/8a/6eba66cd549a2fc74bb4425fd61b839ba0ab3022d3c401b8a8dc2cc00c7a/pybase64-1.4.3-cp312-cp312-win_arm64.whl", hash = "sha256:0631d8a2d035de03aa9bded029b9513e1fee8ed80b7ddef6b8e9389ffc445da0", size = 31185, upload-time = "2025-12-06T13:23:39.908Z" }, + { url = "https://files.pythonhosted.org/packages/3a/50/b7170cb2c631944388fe2519507fe3835a4054a6a12a43f43781dae82be1/pybase64-1.4.3-cp313-cp313-android_21_arm64_v8a.whl", hash = "sha256:ea4b785b0607d11950b66ce7c328f452614aefc9c6d3c9c28bae795dc7f072e1", size = 33901, upload-time = "2025-12-06T13:23:40.951Z" }, { url = "https://files.pythonhosted.org/packages/48/8b/69f50578e49c25e0a26e3ee72c39884ff56363344b79fc3967f5af420ed6/pybase64-1.4.3-cp313-cp313-android_21_x86_64.whl", hash = "sha256:6a10b6330188c3026a8b9c10e6b9b3f2e445779cf16a4c453d51a072241c65a2", size = 40807, upload-time = "2025-12-06T13:23:42.006Z" }, + { url = "https://files.pythonhosted.org/packages/5c/8d/20b68f11adfc4c22230e034b65c71392e3e338b413bf713c8945bd2ccfb3/pybase64-1.4.3-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:27fdff227a0c0e182e0ba37a99109645188978b920dfb20d8b9c17eeee370d0d", size = 30932, upload-time = "2025-12-06T13:23:43.348Z" }, + { url = "https://files.pythonhosted.org/packages/f7/79/b1b550ac6bff51a4880bf6e089008b2e1ca16f2c98db5e039a08ac3ad157/pybase64-1.4.3-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:2a8204f1fdfec5aa4184249b51296c0de95445869920c88123978304aad42df1", size = 31394, upload-time = "2025-12-06T13:23:44.317Z" }, + { url = "https://files.pythonhosted.org/packages/82/70/b5d7c5932bf64ee1ec5da859fbac981930b6a55d432a603986c7f509c838/pybase64-1.4.3-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:874fc2a3777de6baf6aa921a7aa73b3be98295794bea31bd80568a963be30767", size = 38078, upload-time = "2025-12-06T13:23:45.348Z" }, { url = "https://files.pythonhosted.org/packages/1c/c9/24b3b905cf75e23a9a4deaf203b35ffcb9f473ac0e6d8257f91a05dfce62/pybase64-1.4.3-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:1d45c8fe8fe82b65c36b227bb4a2cf623d9ada16bed602ce2d3e18c35285b72a", size = 68244, upload-time = "2025-12-06T13:23:49.026Z" }, { url = "https://files.pythonhosted.org/packages/f8/cd/d15b0c3e25e5859fab0416dc5b96d34d6bd2603c1c96a07bb2202b68ab92/pybase64-1.4.3-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ad70c26ba091d8f5167e9d4e1e86a0483a5414805cdb598a813db635bd3be8b8", size = 71620, upload-time = "2025-12-06T13:23:50.081Z" }, { url = "https://files.pythonhosted.org/packages/0d/31/4ca953cc3dcde2b3711d6bfd70a6f4ad2ca95a483c9698076ba605f1520f/pybase64-1.4.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e98310b7c43145221e7194ac9fa7fffc84763c87bfc5e2f59f9f92363475bdc1", size = 59930, upload-time = "2025-12-06T13:23:51.68Z" }, @@ -5943,7 +5949,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/42/10/abb7757c330bb869ebb95dab0c57edf5961ffbd6c095c8209cbbf75d117d/pybase64-1.4.3-cp313-cp313t-win32.whl", hash = "sha256:46d75c9387f354c5172582a9eaae153b53a53afeb9c19fcf764ea7038be3bd8b", size = 33965, upload-time = "2025-12-06T13:24:28.548Z" }, { url = "https://files.pythonhosted.org/packages/63/a0/2d4e5a59188e9e6aed0903d580541aaea72dcbbab7bf50fb8b83b490b6c3/pybase64-1.4.3-cp313-cp313t-win_amd64.whl", hash = "sha256:d7344625591d281bec54e85cbfdab9e970f6219cac1570f2aa140b8c942ccb81", size = 36207, upload-time = "2025-12-06T13:24:29.646Z" }, { url = "https://files.pythonhosted.org/packages/1f/05/95b902e8f567b4d4b41df768ccc438af618f8d111e54deaf57d2df46bd76/pybase64-1.4.3-cp313-cp313t-win_arm64.whl", hash = "sha256:28a3c60c55138e0028313f2eccd321fec3c4a0be75e57a8d3eb883730b1b0880", size = 31505, upload-time = "2025-12-06T13:24:30.687Z" }, + { url = "https://files.pythonhosted.org/packages/e4/80/4bd3dff423e5a91f667ca41982dc0b79495b90ec0c0f5d59aca513e50f8c/pybase64-1.4.3-cp314-cp314-android_24_arm64_v8a.whl", hash = "sha256:015bb586a1ea1467f69d57427abe587469392215f59db14f1f5c39b52fdafaf5", size = 33835, upload-time = "2025-12-06T13:24:31.767Z" }, { url = "https://files.pythonhosted.org/packages/45/60/a94d94cc1e3057f602e0b483c9ebdaef40911d84a232647a2fe593ab77bb/pybase64-1.4.3-cp314-cp314-android_24_x86_64.whl", hash = "sha256:d101e3a516f837c3dcc0e5a0b7db09582ebf99ed670865223123fb2e5839c6c0", size = 40673, upload-time = "2025-12-06T13:24:32.82Z" }, + { url = "https://files.pythonhosted.org/packages/e3/71/cf62b261d431857e8e054537a5c3c24caafa331de30daede7b2c6c558501/pybase64-1.4.3-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:8f183ac925a48046abe047360fe3a1b28327afb35309892132fe1915d62fb282", size = 30939, upload-time = "2025-12-06T13:24:34.001Z" }, + { url = "https://files.pythonhosted.org/packages/24/3e/d12f92a3c1f7c6ab5d53c155bff9f1084ba997a37a39a4f781ccba9455f3/pybase64-1.4.3-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:30bf3558e24dcce4da5248dcf6d73792adfcf4f504246967e9db155be4c439ad", size = 31401, upload-time = "2025-12-06T13:24:35.11Z" }, + { url = "https://files.pythonhosted.org/packages/9b/3d/9c27440031fea0d05146f8b70a460feb95d8b4e3d9ca8f45c972efb4c3d3/pybase64-1.4.3-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:a674b419de318d2ce54387dd62646731efa32b4b590907800f0bd40675c1771d", size = 38075, upload-time = "2025-12-06T13:24:36.53Z" }, { url = "https://files.pythonhosted.org/packages/db/26/b136a4b65e5c94ff06217f7726478df3f31ab1c777c2c02cf698e748183f/pybase64-1.4.3-cp314-cp314-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:b51204d349a4b208287a8aa5b5422be3baa88abf6cc8ff97ccbda34919bbc857", size = 68460, upload-time = "2025-12-06T13:24:41.735Z" }, { url = "https://files.pythonhosted.org/packages/68/6d/84ce50e7ee1ae79984d689e05a9937b2460d4efa1e5b202b46762fb9036c/pybase64-1.4.3-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:30f2fd53efecbdde4bdca73a872a68dcb0d1bf8a4560c70a3e7746df973e1ef3", size = 71688, upload-time = "2025-12-06T13:24:42.908Z" }, { url = "https://files.pythonhosted.org/packages/e3/57/6743e420416c3ff1b004041c85eb0ebd9c50e9cf05624664bfa1dc8b5625/pybase64-1.4.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0932b0c5cfa617091fd74f17d24549ce5de3628791998c94ba57be808078eeaf", size = 60040, upload-time = "2025-12-06T13:24:44.37Z" },