Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 134 additions & 48 deletions CHANGELOG.md

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,36 @@ Scale Nucleus helps you:

Nucleus is a new way—the right way—to develop ML models, helping us move away from the concept of one dataset and towards a paradigm of collections of scenarios.

.. _evaluations-v2:

Evaluations V2
--------------

Evaluation V2 measures how well a **model run** matches ground-truth annotations.
Create a run with :meth:`NucleusClient.create_evaluation_v2`, wait with
:meth:`nucleus.evaluation_v2.EvaluationV2.wait_for_completion`, then read summary metrics with
:meth:`nucleus.evaluation_v2.EvaluationV2.charts` or individual matches with
:meth:`nucleus.evaluation_v2.EvaluationV2.examples`.

.. code-block:: python

import nucleus

client = nucleus.NucleusClient(api_key="YOUR_API_KEY")
evaluation = client.create_evaluation_v2(
model_run_id="run_xxx",
name="my-eval",
allowed_label_matches=[
nucleus.AllowedLabelMatch(
ground_truth_label="car",
model_prediction_label="vehicle",
),
],
)
evaluation.wait_for_completion()
charts = evaluation.charts(iou_threshold=0.5)
fps = evaluation.examples(match_type="FP", limit=20)

.. _installation:

Installation
Expand Down
86 changes: 85 additions & 1 deletion nucleus/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Nucleus Python SDK. """
"""Nucleus Python SDK."""

__all__ = [
"AsyncJob",
"AllowedLabelMatch",
"EmbeddingsExportJob",
"BoxAnnotation",
"DeduplicationJob",
Expand All @@ -18,6 +19,12 @@
"DatasetInfo",
"DatasetItem",
"DatasetItemRetrievalError",
"EvaluationV2",
"EvaluationV2Charts",
"EvaluationV2ExamplesPage",
"EvaluationV2FilterArgs",
"EvaluationV2MatchExample",
"EvaluationV2Status",
"Frame",
"Keypoint",
"KeypointsAnnotation",
Expand Down Expand Up @@ -131,6 +138,12 @@
)
from .data_transfer_object.dataset_details import DatasetDetails
from .data_transfer_object.dataset_info import DatasetInfo
from .data_transfer_object.evaluation_v2 import (
EvaluationV2Charts,
EvaluationV2ExamplesPage,
EvaluationV2FilterArgs,
EvaluationV2MatchExample,
)
from .data_transfer_object.job_status import JobInfoRequestPayload
from .dataset import Dataset
from .dataset_item import DatasetItem
Expand All @@ -148,6 +161,7 @@
NotFoundError,
NucleusAPIError,
)
from .evaluation_v2 import AllowedLabelMatch, EvaluationV2, EvaluationV2Status
from .job import CustomerJobTypes
from .local_deduplication import (
LocalDeduplicationResult,
Expand Down Expand Up @@ -881,6 +895,76 @@ def commit_model_run(
payload = {}
return self.make_request(payload, f"modelRun/{model_run_id}/commit")

def create_evaluation_v2(
self,
model_run_id: str,
*,
name: Optional[str] = None,
allowed_label_matches: Optional[List[AllowedLabelMatch]] = None,
allowed_label_matches_id: Optional[str] = None,
) -> EvaluationV2:
"""Create an evaluation for a model run.

The evaluation runs in the background. Call
:meth:`EvaluationV2.wait_for_completion`, then
:meth:`EvaluationV2.charts` or :meth:`EvaluationV2.examples` for results.

Parameters:
model_run_id: Model run id (``run_*``).
name: Optional display name.
allowed_label_matches: Optional label pairs to treat as matches.
allowed_label_matches_id: Optional id of a saved label-match configuration.

Returns:
:class:`EvaluationV2`: The created evaluation.
"""
payload: Dict[str, Any] = {}
if name is not None:
payload["name"] = name
if allowed_label_matches is not None:
payload["allowed_label_matches"] = [
m.to_api_dict() for m in allowed_label_matches
]
if allowed_label_matches_id is not None:
payload["allowed_label_matches_id"] = allowed_label_matches_id
result = self.make_request(
payload, f"modelRun/{model_run_id}/evaluationsV2"
)
eval_id = result.get("evaluation_id")
if not eval_id:
raise RuntimeError(
f"Unexpected create evaluation V2 response: {result}"
)
return self.get_evaluation_v2(str(eval_id))

def get_evaluation_v2(self, evaluation_id: str) -> EvaluationV2:
"""Get an evaluation by id.

Parameters:
evaluation_id: Evaluation id (``evalv2_*``).

Returns:
:class:`EvaluationV2`.
"""
data = self.get(f"evaluationsV2/{evaluation_id}")
return EvaluationV2.from_json(data, self)

def list_evaluations_v2(self, model_run_id: str) -> List[EvaluationV2]:
"""List evaluations for a model run (newest first).

Parameters:
model_run_id: Model run id (``run_*``).

Returns:
List of :class:`EvaluationV2`.
"""
rows = self.get(f"modelRun/{model_run_id}/evaluationsV2")
if not isinstance(rows, list):
raise RuntimeError(
f"Unexpected list evaluations V2 response: {rows!r}"
)
return [EvaluationV2.from_json(r, self) for r in rows]

@deprecated(msg="Prefer calling Dataset.info() directly.")
def dataset_info(self, dataset_id: str):
dataset = self.get_dataset(dataset_id)
Expand Down
162 changes: 162 additions & 0 deletions nucleus/data_transfer_object/evaluation_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""Response and filter models for Evaluation V2."""

from typing import Any, Dict, List, Literal, Optional

from nucleus.pydantic_base import DictCompatibleModel


def _snake_to_camel(name: str) -> str:
parts = name.split("_")
if len(parts) == 1:
return name
return parts[0] + "".join(part.capitalize() for part in parts[1:])


def _camelize_filter_value(value: Any) -> Any:
if isinstance(value, dict):
return {
_snake_to_camel(key): (
val if key == "value" else _camelize_filter_value(val)
)
for key, val in value.items()
}
if isinstance(value, list):
return [_camelize_filter_value(item) for item in value]
return value


class RangeNum(DictCompatibleModel):
min: Optional[float] = None
max: Optional[float] = None


class MetadataPredicate(DictCompatibleModel):
key: str
op: Literal["EQ", "IN", "GT", "LT"]
value: Optional[Any] = None


_FILTER_API_KEYS = {
"confidence_range": "confidenceRange",
"iou_range": "iouRange",
"pred_labels": "predLabels",
"gt_labels": "gtLabels",
"item_metadata": "itemMetadata",
"prediction_metadata": "predictionMetadata",
"label_equality": "labelEquality",
"has_ground_truth": "hasGroundTruth",
"tide_background": "tideBackground",
}


class EvaluationV2FilterArgs(DictCompatibleModel):
"""Optional filters for :meth:`nucleus.evaluation_v2.EvaluationV2.charts` and :meth:`nucleus.evaluation_v2.EvaluationV2.examples`."""

confidence_range: Optional[RangeNum] = None
iou_range: Optional[RangeNum] = None
pred_labels: Optional[List[str]] = None
gt_labels: Optional[List[str]] = None
item_metadata: Optional[List[MetadataPredicate]] = None
prediction_metadata: Optional[List[MetadataPredicate]] = None
label_equality: Optional[Literal["EQ", "NEQ"]] = None
has_ground_truth: Optional[bool] = None
tide_background: Optional[bool] = None

def to_api_filters(self) -> Dict[str, Any]:
Comment thread
luke-e-schaefer marked this conversation as resolved.
"""Return filters as a dict ready for API requests."""
d = self.dict(exclude_none=True)
return {
api_key: _camelize_filter_value(d[snake_key])
for snake_key, api_key in _FILTER_API_KEYS.items()
if snake_key in d
}


class MapSummary(DictCompatibleModel):
mapAt50: Optional[float] = None
mapAt75: Optional[float] = None
mapAt5095: Optional[float] = None


class PerClassAp(DictCompatibleModel):
classLabel: str
ap: float


class ConfusionEntry(DictCompatibleModel):
gtLabel: str
predLabel: str
count: int


class ScoreHistogramBucket(DictCompatibleModel):
bucketMin: float
bucketMax: float
count: int


class TotalCounts(DictCompatibleModel):
tp: int
fp: int
fn: int
predsWithConfidence: int


class ApBySize(DictCompatibleModel):
small: Optional[float] = None
medium: Optional[float] = None
large: Optional[float] = None


class PrCurvePoint(DictCompatibleModel):
classLabel: str
recall: float
precision: float


class TideAttribution(DictCompatibleModel):
truePositive: int
localization: int
classification: int
both: int
duplicate: int
background: int
missed: int


class EvaluationV2Charts(DictCompatibleModel):
mapSummary: MapSummary
perClassAp: List[PerClassAp]
confusionMatrix: List[ConfusionEntry]
scoreHistogram: List[ScoreHistogramBucket]
computedIouRanges: List[float]
totalCounts: TotalCounts
apBySize: ApBySize
prCurve: List[PrCurvePoint]
tideAttribution: TideAttribution


class EvaluationV2MatchExample(DictCompatibleModel):
Comment thread
luke-e-schaefer marked this conversation as resolved.
id: str
evaluation_id: str
dataset_item_id: str
model_prediction_id: Optional[str] = None
ground_truth_annotation_id: Optional[str] = None
pred_canonical_label: Optional[str] = None
gt_canonical_label: Optional[str] = None
pred_raw_label: Optional[str] = None
gt_raw_label: Optional[str] = None
iou: Optional[float] = None
confidence: Optional[float] = None
true_positive: bool
match_type: str
gt_area: Optional[float] = None
item_metadata: Optional[Dict[str, Any]] = None
prediction_metadata: Optional[Dict[str, Any]] = None
Comment thread
luke-e-schaefer marked this conversation as resolved.
prediction_row: Optional[Dict[str, Any]] = None
annotation_row: Optional[Dict[str, Any]] = None


class EvaluationV2ExamplesPage(DictCompatibleModel):
rows: List[EvaluationV2MatchExample]
total: int
Loading