diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index 5eb89b6f44..8f4b543238 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -8,8 +8,22 @@ from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak from pyrit.datasets.seed_datasets import local, remote # noqa: F401 from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) __all__ = [ + "SeedDatasetFilter", + "SeedDatasetMetadata", + "SeedDatasetLoadingRank", + "SeedDatasetModality", + "SeedDatasetSize", + "SeedDatasetSourceType", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 270fba1568..1ef745f628 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -3,10 +3,20 @@ import logging from collections.abc import Callable +from dataclasses import fields from pathlib import Path -from typing import Any +from typing import Any, Optional + +import yaml from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset logger = logging.getLogger(__name__) @@ -70,6 +80,68 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Extract metadata from a local YAML file and coerce raw values into typed schema fields. + + YAML produces raw Python primitives (str, list) that must be converted to the + enum and set types expected by SeedDatasetMetadata before _match_filter can work. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + + Raises: + Exception: If the dataset file cannot be read. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + try: + with open(self.file_path, encoding="utf-8") as f: + dataset = yaml.safe_load(f) + except Exception as e: + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") + raise + + if not isinstance(dataset, dict): + return None + + raw = {k: v for k, v in dataset.items() if k in valid_fields} + if not raw: + return None + + coerced = self._coerce_metadata_values(raw_metadata=raw) + return SeedDatasetMetadata(**coerced) + + @staticmethod + def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: + """ + Convert YAML primitive values into the enum/set types expected by SeedDatasetMetadata. + + Args: + raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. + + Returns: + dict[str, Any]: Dictionary with values coerced to the correct types. + """ + coerced: dict[str, Any] = {} + for key, value in raw_metadata.items(): + if key == "tags" and isinstance(value, list): + coerced[key] = set(value) + elif key == "size" and isinstance(value, str): + coerced[key] = SeedDatasetSize(value) + elif key == "source_type" and isinstance(value, str): + coerced[key] = SeedDatasetSourceType(value) + elif key == "rank" and isinstance(value, str): + coerced[key] = SeedDatasetLoadingRank(value) + elif key == "modalities" and isinstance(value, list): + coerced[key] = [SeedDatasetModality(v) for v in value] + elif key == "harm_categories" and isinstance(value, str): + coerced[key] = [value] + elif key == "tags" and isinstance(value, str): + coerced[key] = {value} + else: + coerced[key] = value + return coerced + def _register_local_datasets() -> None: """ diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index fc6d46e54d..5d30fac2c4 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -6,6 +6,10 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetModality, + SeedDatasetSize, +) from pyrit.models import SeedDataset, SeedObjective @@ -19,6 +23,13 @@ class _HarmBenchDataset(_RemoteDatasetLoader): Reference: https://github.com/centerforaisafety/HarmBench """ + # Metadata + harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT] + size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds + # "default" means included in curated set + tags: set[str] = {"default", "safety"} + def __init__( self, *, diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5cd9212846..9587a743f0 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from collections.abc import Callable +from dataclasses import fields from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast @@ -19,6 +20,7 @@ from pyrit.common.path import DB_DATA_PATH from pyrit.common.text_helper import read_txt, write_txt from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -285,3 +287,21 @@ def _load_dataset_sync() -> Any: except Exception as e: logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise + + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Extract metadata from class attributes and format into SeedDatasetMetadata schema. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + + provider_class = type(self) + self_metadata = { + key: getattr(provider_class, key) for key in valid_fields if getattr(provider_class, key, None) is not None + } + + if not self_metadata: + return None + return SeedDatasetMetadata(**self_metadata) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 56b61b3996..f5ef0d2736 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -9,6 +9,7 @@ from tqdm import tqdm +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadingRank, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset logger = logging.getLogger(__name__) @@ -25,9 +26,14 @@ class SeedDatasetProvider(ABC): Subclasses must implement: - fetch_dataset(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset + + All subclasses also have a _metadata property that is optional to make + dataset addition easier, but failing to complete it makes downstream + analysis more difficult. """ _registry: dict[str, type["SeedDatasetProvider"]] = {} + rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN def __init_subclass__(cls, **kwargs: Any) -> None: """ @@ -67,6 +73,19 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: Exception: If the dataset cannot be fetched or processed. """ + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Parse provider-specific metadata into the shared schema. + + Subclasses can override this to source metadata from class attributes, + prompt files, or any other backing format. The default implementation + returns None, which means metadata is not available for this provider. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata for this provider, or None. + """ + return None + @classmethod def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: """ @@ -78,10 +97,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls) -> list[str]: + def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: """ Get the names of all registered datasets. + Args: + filters (Optional[SeedDatasetFilter]): List of filters to apply. + Returns: List[str]: List of dataset names from all registered providers. @@ -97,11 +119,81 @@ def get_all_dataset_names(cls) -> list[str]: try: # Instantiate to get dataset name provider = provider_class() + + # Parser ensures a standard metadata format + metadata = provider._parse_metadata() + + # "all" bypasses metadata filtering and returns every dataset. + if filters and filters.tags and "all" in filters.tags: + dataset_names.add(provider.dataset_name) + continue + + if filters and not metadata: + # Datasets without metadata are skipped unless we want "all" + continue + + # Filters detected but no match -> don't add this dataset + if filters and metadata and not cls._match_filter(metadata=metadata, filters=filters): + continue + dataset_names.add(provider.dataset_name) except Exception as e: raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) + @classmethod + def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: + """ + + Match the filter(s) with the metadata provided by the SeedDatasetProvider subclass. + By default, filters across dimensions (e.g. size, harm categories) are treated as AND + requirements. Filters within a dimension (e.g. SeedDatasetSize.SMALL, + SeedDatasetSize.LARGE) are treated as OR requirements. + + Args: + metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider + subclass. + filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names. + + Returns: + bool: Whether or not the filters match or not. + """ + # Tags + if filters.tags and "all" in filters.tags: + return True + + # These lines all disable SIM103 because metadata and filters tags can be optional, so + # directly checking for membership breaks type checking. + + if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103 + return False + + # Size + if metadata.size and filters.sizes and metadata.size not in filters.sizes: # noqa: SIM103 + return False + + # Harm Categories + if ( + metadata.harm_categories + and filters.harm_categories + and not set(metadata.harm_categories) & set(filters.harm_categories) + ): # noqa: SIM103 + return False + + # Source Type + if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: # noqa: SIM103 + return False + + # Modalities + if metadata.modalities and filters.modalities and not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 + return False + + # Rank + if metadata.rank and filters.ranks and metadata.rank not in filters.ranks: # noqa: SIM103 + return False + + return True + @classmethod async def fetch_datasets_async( cls, diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py new file mode 100644 index 0000000000..d203d7e1cf --- /dev/null +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -0,0 +1,199 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass +from enum import Enum +from typing import Optional, TypedDict + +""" +Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). + +SeedDatasetMetadata is the internal schema used to normalize metadata fields +from different sources: +- Remote providers that declare metadata as class attributes +- Local prompt files that store metadata at the top level + +SeedDatasetFilter is the user-facing filter schema consumed by +SeedDatasetProvider.get_all_dataset_names(). +""" + + +class SeedDatasetSize(Enum): + """Ordinal size (by bucket) of the dataset.""" + + TINY = "tiny" # < 10 + SMALL = "small" # >= 10, < 100 + MEDIUM = "medium" # >= 100, < 500 + LARGE = "large" # >= 500, < 5000 + HUGE = "huge" # >= 5000 + + +class SeedDatasetLoadingRank(Enum): + """ + Represents the general difficulty of loading in a dataset. + """ + + # Default is equivalent to "fastest" in the sense that datasets marked + # with a default rank will always get loaded. + DEFAULT = "default" + + # These represent actual ranks. + PRIMARY = "primary" + SECONDARY = "secondary" + TERTIARY = "tertiary" + + # Unknown corresponds to an untested dataset that won't be loaded. It is the + # default provided in SeedDatasetProvider. + UNKNOWN = "unknown" + + +class SeedDatasetModality(Enum): + """ + Type of data contained in the dataset. + """ + + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + +class SeedDatasetSourceType(Enum): + """ + Where the dataset is pulled from. + """ + + REMOTE = "remote" + LOCAL = "local" + + +@dataclass +class SeedDatasetFilter: + """ + Filter object for datasets. Passed to `get_all_dataset_names` in + SeedDatasetProvider. + """ + + tags: Optional[set[str]] = None + sizes: Optional[list[SeedDatasetSize]] = None + modalities: Optional[list[SeedDatasetModality]] = None + source_types: Optional[list[SeedDatasetSourceType]] = None + ranks: Optional[list[SeedDatasetLoadingRank]] = None + harm_categories: Optional[list[str]] = None + + +@dataclass(frozen=True) +class SeedDatasetMetadata: + """ + Metadata object for datasets. Holds the same fields as the filter + object. + """ + + tags: Optional[set[str]] = None + size: Optional[SeedDatasetSize] = None + modalities: Optional[list[SeedDatasetModality]] = None + source_type: Optional[SeedDatasetSourceType] = None + rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN + harm_categories: Optional[list[str]] = None + + +class SeedDatasetMetadataUtilities: + """ + Utilities for deriving metadata for datasets. Currently, only static attributes + are supported. + + The default working location for datasets is the in-memory database. + """ + + class Metrics(TypedDict): + """ + Typed dictionary for easier retrieval and calculation of dataset metrics. + """ + + exact_size: int + loading_time_ms: float + modalities_found: set[str] + source_type: str + harm_categories_found: set[str] + tags: set[str] + + # Stores working dataset calculations. + # Maps name to metrics, which are later converted into SeedDatasetMetadata. + _cache: dict[str, Metrics] = {} + + @classmethod + def populate_datasets(cls) -> None: + """ + Populate metadata for all registered datasets. + + WARNING: Because metadata is stored as class attributes, this method can directly + change source files. Be extra careful when running it. + """ + # Get all dataset names + # Calling SeedDatasetProvider would create a circular import, so we do this explicitly + datasets: list[str] = [] + + # Populate cache with empty (name, metrics) pairs + for dataset in datasets: + metrics: SeedDatasetMetadataUtilities.Metrics = { + "exact_size": -1, + "loading_time_ms": -1.0, + "modalities_found": {"None"}, + "source_type": "None", + "harm_categories_found": {"None"}, + "tags": {"None"}, + } + cls._cache[dataset] = metrics + + # Using a list, for each dataset name, load it in depending on class type + # Invoke the appropriate helper to parse it + + # If local, local_helper + + # If remote, remote_helper + + # Get contents from the memory database + # Note that we have to load it into the memory_database to get timing + # We also want the helper to do no initialization, just extract the relevant + # types and get ready to call a timing library + + # Calculate metrics one by one + + # Once out of the loop, calculate metadata fields + + # Loading rank by comparing relative speeds + + # Size by comparing buckets + + # Convert all others to types + + # Update (if update = True) the datasets + + # If remote, write to the file using regex + # E.g. harm_categories: ... should appear in source + + # If local, make sure the .prompt is formatted nicely + + @classmethod + def _local_helper(cls) -> None: + """ + Load local datasets into the working cache. + """ + + @classmethod + def _remote_helper(cls) -> None: + """ + Load remote datasets into the working cache. + """ + + @classmethod + def _remote_writer(cls) -> None: + """ + Write updated metadata to a remote dataset source file. + """ + + @classmethod + def _local_writer(cls) -> None: + """ + Write updated metadata to a local .prompt file. + """ diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index a3ede4beab..f2da3a292a 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -2,12 +2,21 @@ # Licensed under the MIT license. import logging +import textwrap +from pathlib import Path +from unittest.mock import patch import pytest from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset -from pyrit.models import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetModality, + SeedDatasetSize, +) +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -55,3 +64,531 @@ async def test_fetch_dataset_integration(self, name, provider_cls): except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") + + +class TestRemoteFilteringIntegration: + """ + Integration test for remote dataset filtering. + + Uses a mocked remote provider with class-level metadata attributes to + validate the full flow: metadata population, filter matching, and + get_all_dataset_names output. + """ + + def _make_remote_provider_cls( + self, + *, + name: str, + tags: set, + size: SeedDatasetSize, + modalities: list, + harm_categories: list, + ) -> type: + """Build a minimal concrete SeedDatasetProvider with class-level metadata.""" + from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import _RemoteDatasetLoader + + captured_name = name + + async def _fetch_dataset(self, *, cache=True): + return SeedDataset( + seeds=[SeedPrompt(value="x", data_type="text")], + dataset_name=captured_name, + ) + + attrs = { + "tags": tags, + "size": size, + "modalities": modalities, + "harm_categories": harm_categories, + "should_register": False, + "__module__": __name__, + # Concrete implementations satisfy ABC requirements + "dataset_name": property(lambda self: captured_name), + "fetch_dataset": _fetch_dataset, + "_fetch_from_url": lambda self, **kw: [], + } + + return type(f"_Mock_{name}", (_RemoteDatasetLoader,), attrs) + + def test_filter_matches_correct_remote_provider(self): + """Filter by size returns only providers that match.""" + large_cls = self._make_remote_provider_cls( + name="large_ds", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["violence"], + ) + small_cls = self._make_remote_provider_cls( + name="small_ds", + tags={"default"}, + size=SeedDatasetSize.SMALL, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["cybercrime"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + assert names == ["large_ds"] + + def test_filter_all_tag_returns_everything(self): + """tags={'all'} bypasses filtering and returns every provider.""" + cls1 = self._make_remote_provider_cls( + name="ds_a", + tags={"safety"}, + size=SeedDatasetSize.TINY, + modalities=[SeedDatasetModality.TEXT], + harm_categories=[], + ) + cls2 = self._make_remote_provider_cls( + name="ds_b", + tags={"custom"}, + size=SeedDatasetSize.HUGE, + modalities=[SeedDatasetModality.IMAGE], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"A": cls1, "B": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert sorted(names) == ["ds_a", "ds_b"] + + def test_multi_axis_filter(self): + """Multiple filter axes are ANDed together.""" + cls1 = self._make_remote_provider_cls( + name="text_large", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["violence"], + ) + cls2 = self._make_remote_provider_cls( + name="image_large", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.IMAGE], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"TL": cls1, "IL": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter( + sizes=[SeedDatasetSize.LARGE], + modalities=[SeedDatasetModality.TEXT], + ), + ) + assert names == ["text_large"] + + +class TestLocalFilteringIntegration: + """ + Integration test for local dataset filtering. + + Creates real YAML prompt files on disk, registers them as local providers, + and validates the full flow through get_all_dataset_names with filters. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_local_filter_by_size(self, tmp_path): + """Local YAML with size metadata is correctly coerced and filtered.""" + large_yaml = tmp_path / "large_ds.prompt" + large_yaml.write_text( + textwrap.dedent("""\ + dataset_name: large_local + size: large + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + small_yaml = tmp_path / "small_ds.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_local + size: small + harm_categories: + - cybercrime + seeds: + - value: test + data_type: text + """) + ) + + large_cls = self._make_local_cls(large_yaml) + small_cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + # dataset_name falls back to file stem when SeedDataset.from_yaml_file + # rejects extra keys like "size" during __init__ pre-loading + assert names == ["large_ds"] + + def test_local_filter_by_tags(self, tmp_path): + """Local YAML tags (list) are coerced to set for intersection.""" + yaml_path = tmp_path / "tagged.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: tagged_local + tags: + - safety + - default + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Tagged": cls}, + clear=True, + ): + # dataset_name falls back to file stem ("tagged") when + # SeedDataset.from_yaml_file rejects extra keys like "tags" + matched = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert matched == ["tagged"] + + not_matched = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"unrelated"}), + ) + assert not_matched == [] + + def test_local_no_metadata_skipped(self, tmp_path): + """Local YAML without metadata fields is skipped when filters are provided.""" + yaml_path = tmp_path / "bare.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: bare_local + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Without filters, the dataset is included + all_names = SeedDatasetProvider.get_all_dataset_names() + assert "bare_local" in all_names + + # With filters, it's skipped (no metadata to match against) + filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] + + +class TestEndToEndLocalDatasetWorkflow: + """ + End-to-end integration test that simulates a user workflow: + 1. Write a temporary local dataset to disk + 2. Register it into the provider registry + 3. Discover it via filtered get_all_dataset_names + 4. Fetch the actual dataset by name + 5. Verify metadata is accessible and correct + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + @pytest.mark.asyncio + async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path): + """ + Simulate a user who wants small text datasets about cybercrime: + 1. Filter returns the matching dataset name + 2. Fetch by that name succeeds and returns correct seeds + 3. Metadata on the provider is accessible and consistent + """ + # --- Step 0: Write two datasets to disk --- + matching_yaml = tmp_path / "cyber_small.prompt" + matching_yaml.write_text( + textwrap.dedent("""\ + dataset_name: cyber_prompts + harm_categories: + - cybercrime + seeds: + - value: Explain phishing techniques + data_type: text + - value: Describe SQL injection + data_type: text + """) + ) + unrelated_yaml = tmp_path / "violence_large.prompt" + unrelated_yaml.write_text( + textwrap.dedent("""\ + dataset_name: violence_prompts + harm_categories: + - violence + seeds: + - value: Describe a violent scenario + data_type: text + """) + ) + + matching_cls = self._make_local_cls(matching_yaml) + unrelated_cls = self._make_local_cls(unrelated_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Cyber": matching_cls, "Violence": unrelated_cls}, + clear=True, + ): + # --- Step 1: User filters by harm_categories --- + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(harm_categories=["cybercrime"]), + ) + assert len(names) == 1 + dataset_name = names[0] + + # --- Step 2: User fetches the dataset by name --- + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=[dataset_name], + ) + assert len(datasets) == 1 + dataset = datasets[0] + assert len(dataset.seeds) == 2 + assert dataset.seeds[0].value == "Explain phishing techniques" + assert dataset.seeds[1].value == "Describe SQL injection" + + # --- Step 3: User inspects metadata --- + provider = matching_cls() + metadata = provider._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["cybercrime"] + + @pytest.mark.asyncio + async def test_user_fetches_unfiltered(self, tmp_path): + """ + Without filters, get_all_dataset_names returns everything, + and fetch_datasets_async retrieves all of them. + """ + ds1 = tmp_path / "ds_one.prompt" + ds1.write_text( + textwrap.dedent("""\ + dataset_name: dataset_one + seeds: + - value: prompt one + data_type: text + """) + ) + ds2 = tmp_path / "ds_two.prompt" + ds2.write_text( + textwrap.dedent("""\ + dataset_name: dataset_two + seeds: + - value: prompt two + data_type: text + """) + ) + + cls1 = self._make_local_cls(ds1) + cls2 = self._make_local_cls(ds2) + + with patch.dict( + SeedDatasetProvider._registry, + {"One": cls1, "Two": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names() + assert len(names) == 2 + + datasets = await SeedDatasetProvider.fetch_datasets_async() + assert len(datasets) == 2 + fetched_names = sorted(d.dataset_name for d in datasets) + assert fetched_names == ["dataset_one", "dataset_two"] + + +class TestAllTagBypassIntegration: + """ + Integration tests for the tags={'all'} bypass pattern. + + The 'all' tag is a special escape hatch that returns every registered + dataset regardless of metadata presence or other filter axes. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_all_tag_includes_datasets_without_metadata(self, tmp_path): + """ + A dataset whose YAML has no metadata fields at all is normally + skipped when filters are present. tags={'all'} overrides that. + """ + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Normal filter skips it + filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] + + # 'all' includes it + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert "bare_dataset" in all_names + + def test_all_tag_ignores_other_filter_axes(self, tmp_path): + """ + tags={'all'} returns everything even when other filter axes + would exclude datasets. + """ + small_yaml = tmp_path / "small.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_dataset + size: small + harm_categories: + - cybercrime + seeds: + - value: small prompt + data_type: text + """) + ) + cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Small": cls}, + clear=True, + ): + # Size filter alone would exclude it + size_filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + assert size_filtered == [] + + # 'all' tag overrides the size filter + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}, sizes=[SeedDatasetSize.LARGE]), + ) + assert "small" in all_names + + def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path): + """ + With a mix of metadata-rich and metadata-bare datasets, + tags={'all'} returns all of them. + """ + rich_yaml = tmp_path / "rich.prompt" + rich_yaml.write_text( + textwrap.dedent("""\ + dataset_name: rich_dataset + harm_categories: + - violence + tags: + - safety + seeds: + - value: rich prompt + data_type: text + """) + ) + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + + rich_cls = self._make_local_cls(rich_yaml) + bare_cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Rich": rich_cls, "Bare": bare_cls}, + clear=True, + ): + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert len(all_names) == 2 + assert "bare_dataset" in all_names diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py new file mode 100644 index 0000000000..73fa3dad3e --- /dev/null +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Tests for metadata components related to SeedDatasetProvider. +""" + +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) + + +class TestMetadataLifecycle: + """ + Test that the metadata object can be created with different + subsets of values. + """ + + def test_has_no_values(self): + metadata = SeedDatasetMetadata() + assert metadata.tags is None + assert metadata.size is None + assert metadata.modalities is None + assert metadata.source_type is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN + assert metadata.harm_categories is None + + def test_has_some_values(self): + metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + assert metadata.tags == {"safety"} + assert metadata.size == SeedDatasetSize.LARGE + assert metadata.modalities is None + assert metadata.source_type is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN + assert metadata.harm_categories is None + + def test_has_all_values(self): + metadata = SeedDatasetMetadata( + tags={"default", "safety"}, + size=SeedDatasetSize.MEDIUM, + modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE], + source_type=SeedDatasetSourceType.REMOTE, + rank=SeedDatasetLoadingRank.DEFAULT, + harm_categories=["violence", "illegal"], + ) + assert metadata.tags == {"default", "safety"} + assert metadata.size == SeedDatasetSize.MEDIUM + assert len(metadata.modalities) == 2 + assert metadata.source_type == SeedDatasetSourceType.REMOTE + assert metadata.rank == SeedDatasetLoadingRank.DEFAULT + assert metadata.harm_categories == ["violence", "illegal"] + + +class TestFilterLifecycle: + """ + Test that the filter object can be created with different + subsets of values. + """ + + def test_has_no_values(self): + f = SeedDatasetFilter() + assert f.tags is None + assert f.sizes is None + assert f.modalities is None + assert f.source_types is None + assert f.ranks is None + assert f.harm_categories is None + + def test_has_some_values(self): + f = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + assert f.sizes == [SeedDatasetSize.LARGE] + assert f.tags is None + assert f.modalities is None + + def test_has_all_values(self): + f = SeedDatasetFilter( + tags={"default"}, + sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.MEDIUM], + modalities=[SeedDatasetModality.TEXT], + source_types=[SeedDatasetSourceType.REMOTE], + ranks=[SeedDatasetLoadingRank.DEFAULT], + harm_categories=["violence"], + ) + assert f.tags == {"default"} + assert len(f.sizes) == 2 + assert f.modalities == [SeedDatasetModality.TEXT] + assert f.source_types == [SeedDatasetSourceType.REMOTE] + assert f.ranks == [SeedDatasetLoadingRank.DEFAULT] + assert f.harm_categories == ["violence"] + + +class TestMetadataProperties: + """ + Test that the metadata fields populate correctly. + """ + + def test_size_value(self): + for size in SeedDatasetSize: + metadata = SeedDatasetMetadata(size=size) + assert metadata.size == size + + def test_loading_rank_value(self): + for rank in SeedDatasetLoadingRank: + metadata = SeedDatasetMetadata(rank=rank) + assert metadata.rank == rank + + def test_source_value(self): + for source_type in SeedDatasetSourceType: + metadata = SeedDatasetMetadata(source_type=source_type) + assert metadata.source_type == source_type + + def test_modality_value(self): + for modality in SeedDatasetModality: + metadata = SeedDatasetMetadata(modalities=[modality]) + assert modality in metadata.modalities + + def test_tags_value(self): + metadata = SeedDatasetMetadata(tags={"safety", "default", "custom"}) + assert "safety" in metadata.tags + assert "default" in metadata.tags + assert "custom" in metadata.tags + + def test_harm_categories_value(self): + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + assert "violence" in metadata.harm_categories + assert "cybercrime" in metadata.harm_categories + + +class TestFilterProperties: + """ + Test that the filter fields populate correctly. + """ + + def test_sizes_values(self): + f = SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.LARGE]) + assert SeedDatasetSize.SMALL in f.sizes + assert SeedDatasetSize.LARGE in f.sizes + + def test_loading_ranks_values(self): + f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.TERTIARY]) + assert SeedDatasetLoadingRank.DEFAULT in f.ranks + assert SeedDatasetLoadingRank.TERTIARY in f.ranks + + def test_sources_values(self): + f = SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) + assert SeedDatasetSourceType.LOCAL in f.source_types + assert SeedDatasetSourceType.REMOTE in f.source_types + + def test_modalities_values(self): + f = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + assert SeedDatasetModality.TEXT in f.modalities + assert SeedDatasetModality.IMAGE in f.modalities + + def test_tags_values(self): + f = SeedDatasetFilter(tags={"safety", "default"}) + assert "safety" in f.tags + assert "default" in f.tags + + def test_harm_categories_values(self): + f = SeedDatasetFilter(harm_categories=["violence", "cybercrime"]) + assert "violence" in f.harm_categories + assert "cybercrime" in f.harm_categories diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index d61e2291a2..52029850b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -1,13 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import textwrap +from dataclasses import fields as dc_fields +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest +import yaml from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset, SeedObjective, SeedPrompt @@ -236,3 +249,401 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): assert call_kwargs["dataset_name"] == "custom/darkbench" assert call_kwargs["config"] == "custom_config" assert call_kwargs["split"] == "test" + + +class TestMetadataParsingRemote: + """Test metadata parsing and filter matching for remote providers.""" + + def test_parse_metadata_from_class_attrs(self): + """Test _parse_metadata correctly extracts class-level metadata attributes.""" + loader = _HarmBenchDataset() + metadata = loader._parse_metadata() + assert metadata is not None + assert metadata.tags == {"default", "safety"} + assert metadata.size == SeedDatasetSize.LARGE + assert metadata.modalities == [SeedDatasetModality.TEXT] + assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + # source_type is not declared as a class attribute on HarmBench; + # rank inherits the UNKNOWN default from SeedDatasetProvider base class + assert metadata.source_type is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN + + def test_all_tag(self): + """Filter with tags={'all'} matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_tags(self): + """Tag filter uses set intersection.""" + metadata = SeedDatasetMetadata(tags={"safety", "default"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"safety"})) + assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"})) + + def test_sizes(self): + """Size filter checks membership in the sizes list.""" + metadata = SeedDatasetMetadata(size=SeedDatasetSize.LARGE) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE, SeedDatasetSize.HUGE]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL]), + ) + + def test_modalities(self): + """Modality filter uses set intersection.""" + metadata = SeedDatasetMetadata(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(modalities=[SeedDatasetModality.AUDIO]), + ) + + def test_sources(self): + """Source filter checks membership.""" + metadata = SeedDatasetMetadata(source_type=SeedDatasetSourceType.REMOTE) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL]), + ) + + def test_ranks(self): + """Rank filter checks membership.""" + metadata = SeedDatasetMetadata(rank=SeedDatasetLoadingRank.DEFAULT) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.TERTIARY]), + ) + + def test_harm_categories(self): + """Harm category filter uses set intersection.""" + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(harm_categories=["violence"]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(harm_categories=["unrelated"]), + ) + + def test_empty_filter(self): + """Empty filter (all None) matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_no_metadata(self): + """Provider without metadata is skipped when filters are applied.""" + mock_provider_cls = MagicMock() + mock_provider_instance = mock_provider_cls.return_value + mock_provider_instance.dataset_name = "no_metadata" + mock_provider_instance._parse_metadata.return_value = None + + with patch.dict(SeedDatasetProvider._registry, {"NoProv": mock_provider_cls}, clear=True): + names = SeedDatasetProvider.get_all_dataset_names(filters=SeedDatasetFilter(tags={"safety"})) + assert names == [] + + +class TestMetadataParsingLocal: + """Test metadata parsing and filter matching for local YAML providers.""" + + def _make_loader(self, yaml_path): + """Create a _LocalDatasetLoader bypassing SeedDataset pre-loading.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = yaml_path + loader._dataset_name = yaml_path.stem + return loader + + def _write_yaml(self, tmp_path, name, content): + """Write a .prompt YAML file and return its path.""" + path = tmp_path / f"{name}.prompt" + path.write_text(content) + return path + + def test_parse_metadata_extracts_fields(self, tmp_path): + """Test _parse_metadata correctly extracts metadata fields from YAML.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["violence"] + + def test_all_tag(self, tmp_path): + """Filter with tags={'all'} matches regardless of metadata types.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_tags(self, tmp_path): + """YAML produces tags as list; set intersection in _match_filter expects a set.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + - default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"safety"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_sizes(self, tmp_path): + """YAML produces size as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + size: large + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_modalities(self, tmp_path): + """YAML produces modalities as list of strings; _match_filter uses enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + modalities: + - text + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_sources(self, tmp_path): + """YAML produces source_type as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + source_type: remote + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_ranks(self, tmp_path): + """YAML produces rank as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + rank: default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_harm_categories(self, tmp_path): + """Both YAML and filter use list[str], so intersection works correctly.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + - cybercrime + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(harm_categories=["violence"]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_empty_filter(self, tmp_path): + """Empty filter matches everything.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_no_metadata(self, tmp_path): + """YAML without any metadata fields returns None from _parse_metadata.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is None + + +class TestLocalDatasetMetadataCollisions: + """ + Regression tests that scan every real .prompt file under seed_datasets/local + to verify _parse_metadata does not crash from field-name collisions between + the YAML schema and SeedDatasetMetadata. + + The previous `source` field collision (URLs parsed as SeedDatasetSourceType) + is the motivating example. + """ + + @staticmethod + def _get_local_prompt_files() -> list: + """Collect all .prompt and .yaml files under the local datasets directory.""" + local_dir = Path(__file__).resolve().parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" + return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_parse_metadata_does_not_crash(self, prompt_file): + """_parse_metadata must not raise on any real local dataset file.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + # This must not raise — if a YAML key collides with a metadata field + # name but holds an incompatible value, the coercion layer should + # either handle it or skip it gracefully. + metadata = loader._parse_metadata() + # metadata can be None (no matching fields) or a valid SeedDatasetMetadata + if metadata is not None: + assert isinstance(metadata, SeedDatasetMetadata) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): + """ + If a YAML top-level key matches a SeedDatasetMetadata field name, the + coerced value must be the correct type (enum, set, list) — not a raw + string or other primitive that would silently break filtering. + """ + with open(prompt_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + return + + metadata_field_names = {fld.name for fld in dc_fields(SeedDatasetMetadata)} + overlapping_keys = metadata_field_names & data.keys() + + if not overlapping_keys: + return + + # Coerce and construct — must not raise + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + raw = {k: data[k] for k in overlapping_keys} + coerced = _LocalDatasetLoader._coerce_metadata_values(raw_metadata=raw) + metadata = SeedDatasetMetadata(**coerced) + + # Verify coerced types match expectations + expected_types = { + "tags": (set, type(None)), + "size": (SeedDatasetSize, type(None)), + "modalities": (list, type(None)), + "source_type": (SeedDatasetSourceType, type(None)), + "rank": (SeedDatasetLoadingRank, type(None)), + "harm_categories": (list, type(None)), + } + for key in overlapping_keys: + value = getattr(metadata, key) + valid_types = expected_types.get(key) + if valid_types: + assert isinstance(value, valid_types), ( + f"Field '{key}' in {prompt_file.name} has type {type(value).__name__}, " + f"expected one of {valid_types}" + )