Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _plot_actor_bounding_box(
color: Desired color for the bounding box.
bbox_size: Desired size for the bounding box (length, width).
"""
(bbox_length, bbox_width) = bbox_size
bbox_length, bbox_width = bbox_size

# Compute coordinate for pivot point of bounding box
d = np.hypot(bbox_length, bbox_width)
Expand Down
2 changes: 1 addition & 1 deletion src/av2/evaluation/scenario_mining/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ class ScenarioMiningCategories(str, Enum):
"HOTA": "HOTA",
}

AV2_CATEGORIES: Final = tuple(x.value for x in ScenarioMiningCategories)
SCENARIO_MINING_CATEGORIES: Final = tuple(x.value for x in ScenarioMiningCategories)
186 changes: 111 additions & 75 deletions src/av2/evaluation/scenario_mining/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
setattr(np, "int", numpy_int)
setattr(np, "bool", numpy_bool)

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt

from av2.map.map_api import ArgoverseStaticMap, RasterLayerType
Expand All @@ -38,7 +41,7 @@
_tune_score_thresholds,
evaluate_tracking,
)
from av2.evaluation.scenario_mining import AV2_CATEGORIES
from av2.evaluation.scenario_mining import SCENARIO_MINING_CATEGORIES
from av2.evaluation.tracking import utils as sm_utils
from av2.utils.typing import NDArrayFloat
from av2.evaluation.typing import Sequences
Expand Down Expand Up @@ -271,12 +274,18 @@ def compute_temporal_metrics(
timestamp_pred = np.zeros(len(labels[description]), dtype=bool)

for j, frame in enumerate(labels[description]):
if len(frame["label"]) > 0 and 0 in frame["label"]:
if "is_positive" in frame and frame["is_positive"]:
timestamp_gt[j] = True
scenario_gt[i] = True
elif "label" in frame and len(frame["label"]) > 0 and 0 in frame["label"]:
timestamp_gt[j] = True
scenario_gt[i] = True

for j, frame in enumerate(track_predictions[description]):
if len(frame["label"]) > 0 and 0 in frame["label"]:
if "is_positive" in frame and frame["is_positive"]:
timestamp_pred[j] = True
scenario_pred[i] = True
elif "label" in frame and len(frame["label"]) > 0 and 0 in frame["label"]:
timestamp_pred[j] = True
scenario_pred[i] = True

Expand All @@ -292,39 +301,40 @@ def compute_temporal_metrics(

# Balanced Accuracy: https://en.wikipedia.org/wiki/Evaluation_of_binary_classifiers
# (TPR + TNR) / 2
scenario_tpr = scenario_tp / (scenario_tp + scenario_fp)
scenario_tnr = scenario_tn / (scenario_tn + scenario_fn)
timestamp_tpr = timestamp_tp / (timestamp_tp + timestamp_fp)
timestamp_tnr = timestamp_tn / (timestamp_tn + timestamp_fn)
scenario_tpr = scenario_tp / (scenario_tp + scenario_fn)
scenario_tnr = scenario_tn / (scenario_tn + scenario_fp)
timestamp_tpr = timestamp_tp / (timestamp_tp + timestamp_fn)
timestamp_tnr = timestamp_tn / (timestamp_tn + timestamp_fp)

if scenario_tp + scenario_fp == 0:
if scenario_tp + scenario_fn == 0:
scenario_tpr = 1.0
if scenario_tn + scenario_fn == 0:
if scenario_tn + scenario_fp == 0:
scenario_tnr = 1.0
if timestamp_tp + timestamp_fp == 0:
if timestamp_tp + timestamp_fn == 0:
timestamp_tpr = 1.0
if timestamp_tn + timestamp_fn == 0:
if timestamp_tn + timestamp_fp == 0:
timestamp_tnr = 1.0

scenario_ba = float((scenario_tpr + scenario_tnr) / 2)
timestamp_ba = float((timestamp_tpr + timestamp_tnr) / 2)

_plot_confusion_matrix(
scenario_tp,
scenario_fn,
scenario_fp,
scenario_tn,
title="scenario",
output_dir=output_dir,
)
_plot_confusion_matrix(
timestamp_tp,
timestamp_fn,
timestamp_fp,
timestamp_tn,
title="timestamp",
output_dir=output_dir,
)
if output_dir:
_plot_confusion_matrix(
scenario_tp,
scenario_fn,
scenario_fp,
scenario_tn,
title="scenario",
output_dir=output_dir,
)
_plot_confusion_matrix(
timestamp_tp,
timestamp_fn,
timestamp_fp,
timestamp_tn,
title="timestamp",
output_dir=output_dir,
)

return scenario_ba, timestamp_ba

Expand Down Expand Up @@ -356,17 +366,19 @@ def _relabel_seq_ids(sequences: Sequences) -> Sequences:


def evaluate(
track_predictions: Sequences,
scenario_predictions: Sequences,
labels: Sequences,
objective_metric: str,
max_range_m: int,
dataset_dir: Any,
dataset_dir: str,
out: str,
) -> tuple[float, float, float, float]:
"""Run scenario mining evaluation on the supplied prediction and label pkl files.

If tracks are not submitted within the scenario_predictions dictionary, only temporal metrics will be computed.

Args:
track_predictions: Prediction sequences.
scenario_predictions: Prediction sequences.
labels: Ground truth sequences.
objective_metric: Metric to optimize.
max_range_m: Maximum evaluation range.
Expand All @@ -379,33 +391,67 @@ def evaluate(
timestamp_ba: A retrieval/classification metric for determining if each timestamp contains any instance of the prompt.
scenario_ba: A retrieval/classification metric for determining if each data log contains any instance of the prompt.
"""
output_dir = out + "/partial_tracks"
Path(output_dir).mkdir(parents=True, exist_ok=True)
contains_tracking = False
for frames in scenario_predictions.values():
for frame in frames:
if "is_positive" not in frame:
contains_tracking = True
break
elif (
"track_id" in frame
and isinstance(frame["track_id"], np.ndarray)
and len(frame["track_id"] > 0)
):
contains_tracking = True
break

if contains_tracking:
break

if not contains_tracking:

partial_track_hota = 0.0
full_track_hota = 0.0
scenario_ba, timestamp_ba = compute_temporal_metrics(
scenario_predictions, labels, out
)

partial_track_hota, scenario_ba, timestamp_ba = evaluate_scenario_mining(
track_predictions,
labels,
objective_metric=objective_metric,
max_range_m=max_range_m,
dataset_dir=dataset_dir,
out=output_dir,
)
else:
labels = filter_max_dist(labels, max_range_m)
scenario_predictions = filter_max_dist(scenario_predictions, max_range_m)

full_track_preds = referred_full_tracks(track_predictions)
full_track_labels = referred_full_tracks(labels)
if dataset_dir is not None:
labels = filter_drivable_area(labels, dataset_dir)
scenario_predictions = filter_drivable_area(
scenario_predictions, dataset_dir
)

output_dir = out + "/full_tracks"
Path(output_dir).mkdir(parents=True, exist_ok=True)
scenario_predictions = _relabel_seq_ids(scenario_predictions)
labels = _relabel_seq_ids(labels)

full_track_hota, _, _ = evaluate_scenario_mining(
full_track_preds,
full_track_labels,
objective_metric=objective_metric,
max_range_m=max_range_m,
dataset_dir=dataset_dir,
out=output_dir,
full_tracks=True,
)
output_dir = out + "/partial_tracks"
Path(output_dir).mkdir(parents=True, exist_ok=True)

partial_track_hota, scenario_ba, timestamp_ba = evaluate_scenario_mining(
scenario_predictions,
labels,
objective_metric=objective_metric,
out=output_dir,
)

full_track_preds = referred_full_tracks(scenario_predictions)
full_track_labels = referred_full_tracks(labels)

output_dir = out + "/full_tracks"
Path(output_dir).mkdir(parents=True, exist_ok=True)

full_track_hota, _, _ = evaluate_scenario_mining(
full_track_preds,
full_track_labels,
objective_metric=objective_metric,
out=output_dir,
full_tracks=True,
)

return (
partial_track_hota,
Expand All @@ -419,8 +465,6 @@ def evaluate_scenario_mining(
track_predictions: Sequences,
labels: Sequences,
objective_metric: str,
max_range_m: int,
dataset_dir: Any,
out: str,
full_tracks: bool = False,
) -> Tuple[float, float, float]:
Expand All @@ -442,17 +486,7 @@ def evaluate_scenario_mining(
scenario_ba: A retrieval/classification metric for determining if each data log contains any instance of the prompt.
timestamp_ba: A retrieval/classification metric for determining if each timestamp contains any instance of the prompt.
"""
classes = list(AV2_CATEGORIES)

labels = filter_max_dist(labels, max_range_m)
track_predictions = filter_max_dist(track_predictions, max_range_m)

if dataset_dir is not None:
labels = filter_drivable_area(labels, dataset_dir)
track_predictions = filter_drivable_area(track_predictions, dataset_dir)

track_predictions = _relabel_seq_ids(track_predictions)
labels = _relabel_seq_ids(labels)
classes = list(SCENARIO_MINING_CATEGORIES)

score_thresholds, tuned_metric_values, _ = _tune_score_thresholds(
labels,
Expand All @@ -466,13 +500,15 @@ def evaluate_scenario_mining(
track_predictions, score_thresholds
)

res = evaluate_tracking(
labels,
filtered_track_predictions,
classes,
tracker_name="TRACKER",
output_dir=out,
)
if out is not None:
# Calculates in depth metrics and generates HOTA plots by recall.
evaluate_tracking(
labels,
filtered_track_predictions,
classes,
tracker_name="TRACKER",
output_dir=out,
)

referrred_hota = tuned_metric_values["REFERRED_OBJECT"]

Expand All @@ -481,8 +517,8 @@ def evaluate_scenario_mining(
filtered_track_predictions, labels, out
)
else:
scenario_ba = 0
timestamp_ba = 0
scenario_ba = 0.0
timestamp_ba = 0.0

return referrred_hota, scenario_ba, timestamp_ba

Expand Down
Loading
Loading