Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions tests/unit/vertexai/genai/replays/test_evaluate_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ def test_rouge_metric(client):
assert len(response.rouge_results.rouge_metric_values) == 1


def test_evaluate_with_metric_resource_name(client):
metric_res = "projects/my-project/locations/us-central1/evaluationMetrics/my-metric-id"
dataset = types.EvaluationDataset(
eval_dataset_df=pd.DataFrame({
"prompt": ["What is 1+1?"],
"response": ["2"],
})
)
result = client.evals.evaluate(
dataset=dataset,
metrics=[types.Metric(name="my_metric", metric_resource_name=metric_res)],
)
assert result is not None


def test_pointwise_metric(client):
"""Tests the _evaluate_instances method with PointwiseMetricInput."""
instance_dict = {
Expand Down
13 changes: 12 additions & 1 deletion vertexai/_genai/_evals_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from . import _gcs_utils
from . import evals
from . import types
from . import _transformers as t

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1328,7 +1329,7 @@ def _resolve_dataset_inputs(


def _resolve_evaluation_run_metrics(
metrics: list[types.EvaluationRunMetric], api_client: Any
metrics: list[types.EvaluationRunMetric] | list[types.Metric], api_client: Any
) -> list[types.EvaluationRunMetric]:
"""Resolves a list of evaluation run metric instances, loading RubricMetric if necessary."""
if not metrics:
Expand Down Expand Up @@ -1361,6 +1362,16 @@ def _resolve_evaluation_run_metrics(
e,
)
raise
elif isinstance(metric_instance, types.Metric):
config_dict = t.t_metrics([metric_instance])[0]
res_name = config_dict.pop("metric_resource_name", None)
resolved_metrics_list.append(
types.EvaluationRunMetric(
metric=metric_instance.name,
metric_config=config_dict if config_dict else None,
metric_resource_name=res_name,
)
)
else:
try:
metric_name_str = str(metric_instance)
Expand Down
5 changes: 5 additions & 0 deletions vertexai/_genai/_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def t_metrics(

for metric in metrics:
metric_payload_item: dict[str, Any] = {}
if hasattr(metric, "metric_resource_name") and metric.metric_resource_name:
metric_payload_item["metric_resource_name"] = metric.metric_resource_name

metric_name = getv(metric, ["name"]).lower()

Expand Down Expand Up @@ -79,6 +81,9 @@ def t_metrics(
"return_raw_output": return_raw_output
}
metric_payload_item["pointwise_metric_spec"] = pointwise_spec
elif "metric_resource_name" in metric_payload_item:
# Valid case: Metric is identified by resource name; no inline spec required.
pass
else:
raise ValueError(
f"Unsupported metric type or invalid metric name: {metric_name}"
Expand Down
14 changes: 14 additions & 0 deletions vertexai/_genai/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,13 @@ def _EvaluationRunMetric_from_vertex(
if getv(from_object, ["metric"]) is not None:
setv(to_object, ["metric"], getv(from_object, ["metric"]))

if getv(from_object, ["metricResourceName"]) is not None:
setv(
to_object,
["metric_resource_name"],
getv(from_object, ["metricResourceName"]),
)

if getv(from_object, ["metricConfig"]) is not None:
setv(
to_object,
Expand All @@ -410,6 +417,13 @@ def _EvaluationRunMetric_to_vertex(
if getv(from_object, ["metric"]) is not None:
setv(to_object, ["metric"], getv(from_object, ["metric"]))

if getv(from_object, ["metric_resource_name"]) is not None:
setv(
to_object,
["metricResourceName"],
getv(from_object, ["metric_resource_name"]),
)

if getv(from_object, ["metric_config"]) is not None:
setv(
to_object,
Expand Down
14 changes: 14 additions & 0 deletions vertexai/_genai/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,10 @@ class EvaluationRunMetric(_common.BaseModel):
metric: Optional[str] = Field(
default=None, description="""The name of the metric."""
)
metric_resource_name: Optional[str] = Field(
default=None,
description="""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""",
)
metric_config: Optional[UnifiedMetric] = Field(
default=None, description="""The unified metric used for evaluation run."""
)
Expand All @@ -2490,6 +2494,9 @@ class EvaluationRunMetricDict(TypedDict, total=False):
metric: Optional[str]
"""The name of the metric."""

metric_resource_name: Optional[str]
"""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}"""

metric_config: Optional[UnifiedMetricDict]
"""The unified metric used for evaluation run."""

Expand Down Expand Up @@ -4439,6 +4446,10 @@ class Metric(_common.BaseModel):
default=None,
description="""Optional steering instruction parameters for the automated predefined metric.""",
)
metric_resource_name: Optional[str] = Field(
default=None,
description="""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}""",
)

# Allow extra fields to support metric-specific config fields.
model_config = ConfigDict(extra="allow")
Expand Down Expand Up @@ -4643,6 +4654,9 @@ class MetricDict(TypedDict, total=False):
metric_spec_parameters: Optional[dict[str, Any]]
"""Optional steering instruction parameters for the automated predefined metric."""

metric_resource_name: Optional[str]
"""The resource name of the metric definition. Example: projects/{project}/locations/{location}/evaluationMetrics/{evaluation_metric_id}"""


MetricOrDict = Union[Metric, MetricDict]

Expand Down
Loading