From 857f596e178b99308c366f39ead31f0810edc077 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Tue, 10 Mar 2026 23:47:44 +0000 Subject: [PATCH 1/9] scaffolding --- .../remote/aegis_ai_content_safety_dataset.py | 13 +++++- .../seed_datasets/seed_dataset_provider.py | 41 +++++++++++++++---- pyrit/datasets/seed_datasets/seed_metadata.py | 33 +++++++++++++++ 3 files changed, 78 insertions(+), 9 deletions(-) create mode 100644 pyrit/datasets/seed_datasets/seed_metadata.py diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index 4b9004f772..ecc952f35a 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -11,6 +11,8 @@ ) from pyrit.models import SeedDataset, SeedPrompt +from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata + logger = logging.getLogger(__name__) @@ -107,7 +109,8 @@ def __init__( # Validate harm categories if provided if harm_categories: - invalid_categories = {cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} + invalid_categories = { + cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} if invalid_categories: raise ValueError( f"Invalid harm categories: {invalid_categories}. Valid categories are: {self.HARM_CATEGORIES}" @@ -157,7 +160,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: prompt_harm_categories = [] if violated_categories: # The violated_categories field contains comma-separated category names - categories = [cat.strip() for cat in violated_categories.split(",") if cat.strip()] + categories = [ + cat.strip() for cat in violated_categories.split(",") if cat.strip()] prompt_harm_categories = categories # Filter by harm_categories if specified @@ -186,3 +190,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: ) return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) + + def metadata_factory(self) -> SeedMetadata: + return SeedMetadata( + size= + ) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 56b61b3996..cb0e7ed11a 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -10,6 +10,7 @@ from tqdm import tqdm from pyrit.models.seeds import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata logger = logging.getLogger(__name__) @@ -51,6 +52,12 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ + @abstractmethod + def metadata_factory(self) -> SeedMetadata: + """ + Build metadata from tags and derived fields (e.g. dataset size). + """ + @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ @@ -78,10 +85,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls) -> list[str]: + def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: """ Get the names of all registered datasets. + Args: + filters (Optional[Dict[str, str]]): List of filters to apply. + Returns: List[str]: List of dataset names from all registered providers. @@ -97,9 +107,21 @@ def get_all_dataset_names(cls) -> list[str]: try: # Instantiate to get dataset name provider = provider_class() + + # Injection point for filtering. TODO + + # 1 Remove invalid filters by checking ground truth in seed_metadata + + # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) + + # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + + # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets + # since we can't check them statically dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -142,9 +164,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -170,7 +194,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -208,10 +233,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py new file mode 100644 index 0000000000..8ac0c99fd5 --- /dev/null +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from enum import Enum +from dataclasses import dataclass + + +class DatasetLoadingRank(Enum): + """Represents the general difficulty of loading in a dataset.""" + DEFAULT = "default" + EXTENDED = "extended" + SLOW = "slow" + + +class DatasetModalities(Enum): + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + +class DatasetSourceType(Enum): + GENERIC_URL = "generic_url" + LOCAL = "local" + HUGGING_FACE = "hugging_face" + + +@dataclass +class DatasetMetadata: + size: int + modalities: list[DatasetModalities] + source: DatasetSourceType + loading_rank: DatasetLoadingRank From 15b58e8a47fd9d673b6e58c0e7f5b01e24d52e9d Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Wed, 11 Mar 2026 19:52:05 +0000 Subject: [PATCH 2/9] more scaffolding --- pyrit/datasets/__init__.py | 3 ++ .../seed_datasets/seed_dataset_provider.py | 43 +++++++++++++------ pyrit/datasets/seed_datasets/seed_metadata.py | 32 +++++++++++++- .../test_seed_dataset_provider_integration.py | 15 +++++-- .../datasets/test_seed_dataset_metadata.py | 32 ++++++++++++++ 5 files changed, 109 insertions(+), 16 deletions(-) create mode 100644 tests/unit/datasets/test_seed_dataset_metadata.py diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index 5eb89b6f44..c8d8592625 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -8,8 +8,11 @@ from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak from pyrit.datasets.seed_datasets import local, remote # noqa: F401 from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import DatasetMetadata, DatasetFilters __all__ = [ + "DatasetMetadata", + "DatasetFilters", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index cb0e7ed11a..7e11bf5f4c 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -26,6 +26,10 @@ class SeedDatasetProvider(ABC): Subclasses must implement: - fetch_dataset(): Fetch and return the dataset as a SeedDataset - dataset_name property: Human-readable name for the dataset + + All subclasses also have a _metadata property that is optional to make + dataset addition easier, but failing to complete it makes downstream + analysis more difficult. """ _registry: dict[str, type["SeedDatasetProvider"]] = {} @@ -41,6 +45,10 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") + # Providing metadata is optional + if getattr(cls, "_metadata", False): + logger.debug( + f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -52,12 +60,6 @@ def dataset_name(self) -> str: str: The dataset name (e.g., "HarmBench", "JailbreakBench JBB-Behaviors") """ - @abstractmethod - def metadata_factory(self) -> SeedMetadata: - """ - Build metadata from tags and derived fields (e.g. dataset size). - """ - @abstractmethod async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: """ @@ -103,21 +105,38 @@ def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() + # 1 Remove invalid filters by checking ground truth in seed_metadata + if filters: + valid_filters = [f.value for f in SeedMetadata.DatasetFilters] + # Prefer doing this to a list or set comprehension so we can raise ValueError on + # specific unsupported filters + for filter, _ in filters.items(): + if filter not in valid_filters: + raise ValueError( + f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!") + for provider_class in cls._registry.values(): try: # Instantiate to get dataset name provider = provider_class() - # Injection point for filtering. TODO + if filters: + # 1 Check if it has metadata + # should this be none or false + if getattr(provider, "_metadata", False): + # Skip a dataset without metadata if we have filters enabled + continue + + # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) - # 1 Remove invalid filters by checking ground truth in seed_metadata + # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition - # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) + # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets + # since we can't check them statically - # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + # Solution: If filter is dynamic, then just download or load into central memory early to retrieve it + # and present a warning to the user that this is occuring - # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets - # since we can't check them statically dataset_names.add(provider.dataset_name) except Exception as e: raise ValueError( diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 8ac0c99fd5..8ad37940d4 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -4,6 +4,20 @@ from enum import Enum from dataclasses import dataclass +""" +TODO Finish docstring + +Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). + +We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets +using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to +a private variable there. + +Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as +NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in +CentralMemory. +""" + class DatasetLoadingRank(Enum): """Represents the general difficulty of loading in a dataset.""" @@ -27,7 +41,23 @@ class DatasetSourceType(Enum): @dataclass class DatasetMetadata: + # TODO: separate dynamic fields from static fields and mark dynamic fields as None size: int modalities: list[DatasetModalities] source: DatasetSourceType - loading_rank: DatasetLoadingRank + rank: DatasetLoadingRank + + +class DatasetFilters(Enum): + # TODO: This is a bad way of extracting the fields from DatasetMetadata. + # A metaclass or even just calling getattr might be better. + SIZE = "size" + MODALITIES = "modalities" + SOURCE = "source" + RANK = "rank" + +# TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass? + + +def _validate_filter_value(v): + """Check if the filter value given is valid.""" diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index a3ede4beab..ceacc2a860 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -37,10 +37,12 @@ async def test_fetch_dataset_integration(self, name, provider_cls): try: # Use max_examples for slow providers that fetch many remote images - provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + provider = provider_cls( + max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() dataset = await provider.fetch_dataset(cache=False) - assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert isinstance( + dataset, SeedDataset), f"{name} did not return a SeedDataset" assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" assert dataset.dataset_name, f"{name} has no dataset_name" @@ -51,7 +53,14 @@ async def test_fetch_dataset_integration(self, name, provider_cls): f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" ) - logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") + logger.info( + f"Successfully verified {name} with {len(dataset.seeds)} seeds") except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") + + @pytest.mark.asyncio + @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) + async def test_fetch_dataset_with_filtering(self, name, provider_cls): + # TODO + pass diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py new file mode 100644 index 0000000000..7f38572311 --- /dev/null +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -0,0 +1,32 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +TODO + +Tests for SeedDatasetMetadata +""" + + +class TestMetadataParsing: + def test_invalid_filter_key(self): + pass + + def test_invalid_filter_value(self): + pass + + +class TestMetadataLifecycle: + def test_static_values_populated(self): + pass + + def test_dynamic_values_populated(self): + pass + + +class TestMetadataPerformance: + def test_quick_retrieval_for_static_values(self): + pass + + def test_acceptable_retrieval_for_dynamic_values(self): + pass From fc43c8c6f7e198ce9669f7cd8dd7047795791977 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 12 Mar 2026 00:32:21 +0000 Subject: [PATCH 3/9] . --- pyrit/datasets/seed_datasets/seed_metadata.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 8ad37940d4..b5a4070f89 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -61,3 +61,13 @@ class DatasetFilters(Enum): def _validate_filter_value(v): """Check if the filter value given is valid.""" + + +def _metadata_builder(): + """ + Force build metadata for all datasets. + Download/load into local memory. + Add a timestamp. + Add all derived attributes. + Make sure every dataset subclass has it. + """ From 9f357e64178f3311332136fceb93b3c568d43296 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Thu, 12 Mar 2026 20:48:39 +0000 Subject: [PATCH 4/9] data types --- pyrit/datasets/__init__.py | 17 ++- .../remote/aegis_ai_content_safety_dataset.py | 13 +-- .../seed_datasets/remote/harmbench_dataset.py | 16 +++ .../seed_datasets/seed_dataset_provider.py | 108 +++++++++++------- pyrit/datasets/seed_datasets/seed_metadata.py | 83 +++++++++----- 5 files changed, 153 insertions(+), 84 deletions(-) diff --git a/pyrit/datasets/__init__.py b/pyrit/datasets/__init__.py index c8d8592625..8f4b543238 100644 --- a/pyrit/datasets/__init__.py +++ b/pyrit/datasets/__init__.py @@ -8,11 +8,22 @@ from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak from pyrit.datasets.seed_datasets import local, remote # noqa: F401 from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider -from pyrit.datasets.seed_datasets.seed_metadata import DatasetMetadata, DatasetFilters +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) __all__ = [ - "DatasetMetadata", - "DatasetFilters", + "SeedDatasetFilter", + "SeedDatasetMetadata", + "SeedDatasetLoadingRank", + "SeedDatasetModality", + "SeedDatasetSize", + "SeedDatasetSourceType", "SeedDatasetProvider", "TextJailBreak", ] diff --git a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py index ecc952f35a..4b9004f772 100644 --- a/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/aegis_ai_content_safety_dataset.py @@ -11,8 +11,6 @@ ) from pyrit.models import SeedDataset, SeedPrompt -from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata - logger = logging.getLogger(__name__) @@ -109,8 +107,7 @@ def __init__( # Validate harm categories if provided if harm_categories: - invalid_categories = { - cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} + invalid_categories = {cat for cat in harm_categories if cat not in self.HARM_CATEGORIES} if invalid_categories: raise ValueError( f"Invalid harm categories: {invalid_categories}. Valid categories are: {self.HARM_CATEGORIES}" @@ -160,8 +157,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: prompt_harm_categories = [] if violated_categories: # The violated_categories field contains comma-separated category names - categories = [ - cat.strip() for cat in violated_categories.split(",") if cat.strip()] + categories = [cat.strip() for cat in violated_categories.split(",") if cat.strip()] prompt_harm_categories = categories # Filter by harm_categories if specified @@ -190,8 +186,3 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: ) return SeedDataset(seeds=seed_prompts, dataset_name=self.dataset_name) - - def metadata_factory(self) -> SeedMetadata: - return SeedMetadata( - size= - ) diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index fc6d46e54d..a31ca1cf58 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -6,6 +6,13 @@ from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( _RemoteDatasetLoader, ) +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset, SeedObjective @@ -19,6 +26,15 @@ class _HarmBenchDataset(_RemoteDatasetLoader): Reference: https://github.com/centerforaisafety/HarmBench """ + _metadata = SeedDatasetMetadata( + tags={"default, safety"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + source=SeedDatasetSourceType.GENERIC_URL, + rank=SeedDatasetLoadingRank.DEFAULT, + harm_categories=["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"], + ) + def __init__( self, *, diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 7e11bf5f4c..cb7ea6fed7 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -9,8 +9,8 @@ from tqdm import tqdm +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset -from pyrit.datasets.seed_datasets.seed_metadata import SeedMetadata logger = logging.getLogger(__name__) @@ -45,10 +45,9 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") - # Providing metadata is optional - if getattr(cls, "_metadata", False): - logger.debug( - f"Dataset provider {cls.__name__} provided metadata.") + # Providing metadata is optional. + if getattr(cls, "_metadata", True): + logger.debug(f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -87,12 +86,12 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: return cls._registry.copy() @classmethod - def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list[str]: + def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]: """ Get the names of all registered datasets. Args: - filters (Optional[Dict[str, str]]): List of filters to apply. + filters (Optional[SeedDatasetFilter]): List of filters to apply. Returns: List[str]: List of dataset names from all registered providers. @@ -105,44 +104,72 @@ def get_all_dataset_names(cls, filters: Optional[dict[str, str]] = None) -> list >>> print(f"Available datasets: {', '.join(names)}") """ dataset_names = set() - # 1 Remove invalid filters by checking ground truth in seed_metadata - if filters: - valid_filters = [f.value for f in SeedMetadata.DatasetFilters] - # Prefer doing this to a list or set comprehension so we can raise ValueError on - # specific unsupported filters - for filter, _ in filters.items(): - if filter not in valid_filters: - raise ValueError( - f"Tried to pass invalid filter `{filter}` to SeedDatasetProvider.get_all_dataset_names!") - for provider_class in cls._registry.values(): try: # Instantiate to get dataset name provider = provider_class() - if filters: - # 1 Check if it has metadata - # should this be none or false - if getattr(provider, "_metadata", False): - # Skip a dataset without metadata if we have filters enabled - continue - - # 2 Remove invalid filter values by invoking helpers (e.g. size: <100 is fine, size: foobar is not) - - # 3 Only execute the following line if the filter key is valid and so is the value, AND the dataset meets the condition + # Extract metadata, default to False if not found + metadata = getattr(provider, "_metadata", False) + if filters and not metadata: + continue - # Problem: We don't know size at this point because we're just collecting the name. Size and source are tricky for remote datasets - # since we can't check them statically + # Type safety for metadata object given getattr return type + if isinstance(metadata, bool): + raise ValueError - # Solution: If filter is dynamic, then just download or load into central memory early to retrieve it - # and present a warning to the user that this is occuring + # Filters detected but no match -> don't add this dataset + if filters and not cls._match_filter(metadata=metadata, filters=filters): + continue + # This triggers when filters match (and filters exist) dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) + @classmethod + def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool: + """ + + Match the filter(s) with the metadata provided by the SeedDatasetProvider subclass. + By default, filters across dimensions (e.g. size, harm categories) are treated as AND + requirements. Filters within a dimension (e.g. SeedDatasetSize.SMALL, + SeedDatasetSize.LARGE) are treated as OR requirements. + + Args: + metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider + subclass. + filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names. + + Returns: + bool: Whether or not the filters match or not. + """ + # Tags + if metadata.tags and "all" in metadata.tags: + # This is the only condition that returns true, because we want the "all" + # tag to override everything else in the filter. + return True + + # These lines all disable SIM103 because metadata and filters tags can be optional, so + # directly checking for membership breaks type checking. + if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103 + return False + + # Size + if metadata.size and filters.sizes and metadata.size not in filters.sizes: # noqa: SIM103 + return False + + # Harm Categories + + # Source Type + + # Modalities + + # Rank + + return True + @classmethod async def fetch_datasets_async( cls, @@ -183,11 +210,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -213,8 +238,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -252,12 +276,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index b5a4070f89..6b87f2fceb 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from enum import Enum from dataclasses import dataclass +from enum import Enum +from typing import Optional """ TODO Finish docstring @@ -19,55 +20,83 @@ """ -class DatasetLoadingRank(Enum): +class SeedDatasetSize(Enum): + """Ordinal size (by bucket) of the dataset.""" + + TINY = "tiny" # < 10 + SMALL = "small" # >= 10, < 100 + MEDIUM = "medium" # >= 100, < 500 + LARGE = "large" # >= 500, < 5000 + HUGE = "huge" # >= 5000 + + +class SeedDatasetLoadingRank(Enum): """Represents the general difficulty of loading in a dataset.""" + DEFAULT = "default" EXTENDED = "extended" SLOW = "slow" -class DatasetModalities(Enum): +class SeedDatasetModality(Enum): + """ + ... + """ + TEXT = "text" IMAGE = "image" VIDEO = "video" AUDIO = "audio" -class DatasetSourceType(Enum): +class SeedDatasetSourceType(Enum): + """ + ... + """ + GENERIC_URL = "generic_url" LOCAL = "local" HUGGING_FACE = "hugging_face" @dataclass -class DatasetMetadata: - # TODO: separate dynamic fields from static fields and mark dynamic fields as None - size: int - modalities: list[DatasetModalities] - source: DatasetSourceType - rank: DatasetLoadingRank - +class SeedDatasetFilter: + """ + ... + """ -class DatasetFilters(Enum): - # TODO: This is a bad way of extracting the fields from DatasetMetadata. - # A metaclass or even just calling getattr might be better. - SIZE = "size" - MODALITIES = "modalities" - SOURCE = "source" - RANK = "rank" + tags: Optional[set[str]] + sizes: Optional[list[SeedDatasetSize]] + modalities: Optional[list[SeedDatasetModality]] + sources: Optional[list[SeedDatasetSourceType]] + ranks: Optional[list[SeedDatasetLoadingRank]] + harm_categories: Optional[list[str]] -# TODO These stubs should be moved somewhere, maybe as static methods to the metadata dataclass? +@dataclass(frozen=True) +class SeedDatasetMetadata: + """ + ... + """ -def _validate_filter_value(v): - """Check if the filter value given is valid.""" + tags: Optional[set[str]] + size: Optional[SeedDatasetSize] + modalities: Optional[list[SeedDatasetModality]] + source: Optional[SeedDatasetSourceType] + rank: Optional[SeedDatasetLoadingRank] + harm_categories: Optional[list[str]] -def _metadata_builder(): +class SeedDatasetMetadataUtilities: """ - Force build metadata for all datasets. - Download/load into local memory. - Add a timestamp. - Add all derived attributes. - Make sure every dataset subclass has it. + Collected utilities for managing and updating SeedDatasetMetadata. """ + + @staticmethod + def populate_metadata() -> None: + """ + WARNING: Because this function updates the metadata for each SeedDatasetProvider, + it changes the provider's corresopnding source file. Run with caution! + + Update the metadata per SeedDatasetProvider. + """ From 34f8953f14f5ae9ef0e24faea0ae1db6d5e28ece Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 19:30:35 +0000 Subject: [PATCH 5/9] redesign --- .../local/local_dataset_loader.py | 45 ++++++++-- .../seed_datasets/remote/harmbench_dataset.py | 21 ++--- .../remote/jbb_behaviors_dataset.py | 2 +- .../remote/remote_dataset_loader.py | 48 ++++++++--- .../seed_datasets/seed_dataset_provider.py | 44 ++++++---- pyrit/datasets/seed_datasets/seed_metadata.py | 47 +++++++---- .../test_seed_dataset_provider_integration.py | 2 +- .../datasets/test_seed_dataset_metadata.py | 83 ++++++++++++++++--- .../datasets/test_seed_dataset_provider.py | 75 +++++++++++++++-- 9 files changed, 283 insertions(+), 84 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 270fba1568..a54062c779 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -2,12 +2,14 @@ # Licensed under the MIT license. import logging +import yaml from collections.abc import Callable from pathlib import Path from typing import Any from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider from pyrit.models import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -36,7 +38,8 @@ def __init__(self, *, file_path: Path): dataset = SeedDataset.from_yaml_file(file_path) # Use the dataset_name from the YAML if available, otherwise use filename self._dataset_name = ( - getattr(dataset, "dataset_name", None) or getattr(dataset, "name", None) or file_path.stem + getattr(dataset, "dataset_name", None) or getattr( + dataset, "name", None) or file_path.stem ) except Exception as e: logger.warning(f"Could not pre-load dataset from {file_path}: {e}") @@ -67,9 +70,32 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: dataset.dataset_name = self.dataset_name return dataset except Exception as e: - logger.error(f"Failed to load local dataset from {self.file_path}: {e}") + logger.error( + f"Failed to load local dataset from {self.file_path}: {e}") raise + def _parse_metadata(self) -> SeedDatasetMetadata | None: + """ + Extract metadata from class attributes and format into SeedDatasetMetadata schema. + + Raises: + Exception: If the dataset cannot be loaded. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + try: + with open(self.file_path, 'r') as f: + dataset = yaml.safe_load(f) + except Exception as e: + logger.error( + f"Failed to load local datset from {self.file_path}: {e}" + ) + raise + self_metadata = {k: v for k, v in dataset if k in valid_fields} + if not self_metadata: + return None + return SeedDatasetMetadata(**self_metadata) + + def _register_local_datasets() -> None: """ @@ -93,21 +119,26 @@ def _register_local_datasets() -> None: def make_init(path: Path) -> Callable[[Any], None]: def __init__(self: Any) -> None: # noqa: N807 - super(self.__class__, self).__init__(file_path=path) + super(self.__class__, self).__init__( + file_path=path) return __init__ type( class_name, (_LocalDatasetLoader,), - {"__init__": make_init(yaml_file), "should_register": True, "__module__": __name__}, + {"__init__": make_init( + yaml_file), "should_register": True, "__module__": __name__}, ) - logger.debug(f"Registered local dataset loader: {class_name} for {yaml_file.name}") + logger.debug( + f"Registered local dataset loader: {class_name} for {yaml_file.name}") except Exception as e: - logger.warning(f"Failed to register local dataset {yaml_file}: {e}") + logger.warning( + f"Failed to register local dataset {yaml_file}: {e}") else: - logger.warning(f"Seed datasets directory not found: {seed_datasets_path}") + logger.warning( + f"Seed datasets directory not found: {seed_datasets_path}") # Execute registration diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index a31ca1cf58..4759d4b7ee 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -26,14 +26,13 @@ class _HarmBenchDataset(_RemoteDatasetLoader): Reference: https://github.com/centerforaisafety/HarmBench """ - _metadata = SeedDatasetMetadata( - tags={"default, safety"}, - size=SeedDatasetSize.LARGE, - modalities=[SeedDatasetModality.TEXT], - source=SeedDatasetSourceType.GENERIC_URL, - rank=SeedDatasetLoadingRank.DEFAULT, - harm_categories=["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"], - ) + # Metadata + harm_categories: list[str] = ["cybercrime", "illegal", + "harmful", "chemical_biological", "harassment"] + modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT] + size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds + # "default" means included in curated set + tags: set[str] = {"default", "safety"} def __init__( self, @@ -88,7 +87,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: # Check for missing keys in the example missing_keys = required_keys - example.keys() if missing_keys: - raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}") + raise ValueError( + f"Missing keys in example: {', '.join(missing_keys)}") # Extract data category = example["SemanticCategory"] @@ -104,7 +104,8 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: "biological, illegal activities, etc." ), source="https://github.com/centerforaisafety/HarmBench", - authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", "Andy Zou", "Zifan Wang", "Norman Mu"], + authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", + "Andy Zou", "Zifan Wang", "Norman Mu"], ) seeds.append(seed_prompt) diff --git a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py index a622a4a018..b2b45c2a33 100644 --- a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py @@ -23,7 +23,7 @@ class _JBBBehaviorsDataset(_RemoteDatasetLoader): and may contain offensive content. Users should check with their legal department before using these prompts against production LLMs. """ - + def __init__( self, *, diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 5cd9212846..2f234f451b 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -8,6 +8,7 @@ import tempfile from abc import ABC from collections.abc import Callable +from dataclasses import fields from pathlib import Path from typing import Any, Literal, Optional, TextIO, cast @@ -19,6 +20,7 @@ from pyrit.common.path import DB_DATA_PATH from pyrit.common.text_helper import read_txt, write_txt from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -74,7 +76,8 @@ def _validate_file_type(self, file_type: str) -> None: """ if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ @@ -131,15 +134,19 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st if file_type in FILE_TYPE_HANDLERS: if file_type == "json": return cast( - "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) + "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"]( + io.StringIO(response.text)) ) return cast( "list[dict[str, str]]", - FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), + FILE_TYPE_HANDLERS[file_type]["read"]( + io.StringIO("\n".join(response.text.splitlines()))), ) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") - raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") + raise Exception( + f"Failed to fetch examples from public URL. Status code: {response.status_code}") def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str]]: """ @@ -159,7 +166,8 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str if file_type in FILE_TYPE_HANDLERS: return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file)) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") def _fetch_from_url( self, @@ -191,21 +199,26 @@ def _fetch_from_url( file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError( + f"Invalid file_type. Expected one of: {valid_types}.") data_home = DB_DATA_PATH / "seed-prompt-entries" - cache_file = data_home / self._get_cache_file_name(source=source, file_type=file_type) + cache_file = data_home / \ + self._get_cache_file_name(source=source, file_type=file_type) if cache and cache_file.exists(): return self._read_cache(cache_file=cache_file, file_type=file_type) if source_type == "public_url": - examples = self._fetch_from_public_url(source=source, file_type=file_type) + examples = self._fetch_from_public_url( + source=source, file_type=file_type) elif source_type == "file": - examples = self._fetch_from_file(source=source, file_type=file_type) + examples = self._fetch_from_file( + source=source, file_type=file_type) if cache: - self._write_cache(cache_file=cache_file, examples=examples, file_type=file_type) + self._write_cache(cache_file=cache_file, + examples=examples, file_type=file_type) else: with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=f".{file_type}", encoding="utf-8" @@ -283,5 +296,16 @@ def _load_dataset_sync() -> Any: # Run the synchronous load_dataset in a thread pool to avoid blocking the event loop return await asyncio.to_thread(_load_dataset_sync) except Exception as e: - logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") + logger.error( + f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise + + def _parse_metadata(self) -> SeedDatasetMetadata | None: + """ + Extract metadata from class attributes and format into SeedDatasetMetadata schema. + """ + valid_fields = [f.name for f in fields(SeedDatasetMetadata)] + self_metadata = {k: v for k, v in self.__dict__.items() if k in valid_fields} + if not self_metadata: + return None + return SeedDatasetMetadata(**self_metadata) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index cb7ea6fed7..ae4e33eb31 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -45,9 +45,6 @@ def __init_subclass__(cls, **kwargs: Any) -> None: if not inspect.isabstract(cls) and getattr(cls, "should_register", True): SeedDatasetProvider._registry[cls.__name__] = cls logger.debug(f"Registered dataset provider: {cls.__name__}") - # Providing metadata is optional. - if getattr(cls, "_metadata", True): - logger.debug(f"Dataset provider {cls.__name__} provided metadata.") @property @abstractmethod @@ -109,23 +106,20 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l # Instantiate to get dataset name provider = provider_class() - # Extract metadata, default to False if not found - metadata = getattr(provider, "_metadata", False) - if filters and not metadata: + # Parser ensures a standard metadata format + metadata: SeedDatasetMetadata = cls._parse_metadata() + if filters and not metadata and "all" not in filters.tags: + # Datasets without metadata are skipped unless we want "all" continue - # Type safety for metadata object given getattr return type - if isinstance(metadata, bool): - raise ValueError - # Filters detected but no match -> don't add this dataset if filters and not cls._match_filter(metadata=metadata, filters=filters): continue - # This triggers when filters match (and filters exist) dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError( + f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -153,6 +147,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter # These lines all disable SIM103 because metadata and filters tags can be optional, so # directly checking for membership breaks type checking. + if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103 return False @@ -161,12 +156,22 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Harm Categories + if metadata.harm_categories and filters.harm_categories and \ + not set(metadata.harm_categories) & set(filters.harm_categories): # noqa: SIM103 + return False # Source Type + if metadata.source and filters.sources and metadata.source not in filters.sources: # noqa: SIM103 + return False # Modalities + if metadata.modalities and filters.modalities and \ + not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 + return False # Rank + if metadata.rank and filters.ranks and metadata.rank not in filters.ranks: # noqa: SIM103 + return False return True @@ -210,9 +215,11 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [name for name in dataset_names if name not in available_names] + invalid_names = [ + name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError( + f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -238,7 +245,8 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, + desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -276,10 +284,12 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list( + existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info( + f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 6b87f2fceb..6037d811b0 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -5,18 +5,12 @@ from enum import Enum from typing import Optional -""" -TODO Finish docstring +from pyrit.common.path import DATASETS_PATH +""" Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). -We have one DatasetMetadata dataclass that is our ground truth. As we instantiate datasets -using the subclass call in SeedDatasetProvider, we create DatasetMetadata and assign it to -a private variable there. - -Some fields are dynamic (e.g. loading statistics, timestamp, dataset size) and are left as -NoneType until the SeedDatasetProvider actually downloads/parses the dataset and puts it in -CentralMemory. +The ground truth is SeedDatasetMetadata. This is """ @@ -31,7 +25,9 @@ class SeedDatasetSize(Enum): class SeedDatasetLoadingRank(Enum): - """Represents the general difficulty of loading in a dataset.""" + """ + Represents the general difficulty of loading in a dataset. + """ DEFAULT = "default" EXTENDED = "extended" @@ -40,7 +36,7 @@ class SeedDatasetLoadingRank(Enum): class SeedDatasetModality(Enum): """ - ... + Type of data contained in the dataset. """ TEXT = "text" @@ -51,18 +47,18 @@ class SeedDatasetModality(Enum): class SeedDatasetSourceType(Enum): """ - ... + Where the dataset is pulled from. """ - GENERIC_URL = "generic_url" + REMOTE = "remote" LOCAL = "local" - HUGGING_FACE = "hugging_face" @dataclass class SeedDatasetFilter: """ - ... + Filter object for datasets. Passed to `get_all_dataset_names` in + SeedDatasetProvider. """ tags: Optional[set[str]] @@ -76,7 +72,8 @@ class SeedDatasetFilter: @dataclass(frozen=True) class SeedDatasetMetadata: """ - ... + Metadata object for datasets. Holds the same fields as the filter + object. """ tags: Optional[set[str]] @@ -89,7 +86,7 @@ class SeedDatasetMetadata: class SeedDatasetMetadataUtilities: """ - Collected utilities for managing and updating SeedDatasetMetadata. + Collected utilities for managing and updating metadata. """ @staticmethod @@ -98,5 +95,19 @@ def populate_metadata() -> None: WARNING: Because this function updates the metadata for each SeedDatasetProvider, it changes the provider's corresopnding source file. Run with caution! - Update the metadata per SeedDatasetProvider. + Updates the metadata per SeedDatasetProvider. """ + + # 1 Gather all dataset files + + # 2 For each file, download and store in the database (in-memory) + + # 3 Count the number of entries exactly and identify its threshold + + # 4 If harm categories are found in source, add them + + # 5 Inspect type of prompts to identify modalities present + + # 6 Inspect source file to find where it pulled from + + # 7 Leave rank optional for now diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index ceacc2a860..491f97e92a 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -61,6 +61,6 @@ async def test_fetch_dataset_integration(self, name, provider_cls): @pytest.mark.asyncio @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) - async def test_fetch_dataset_with_filtering(self, name, provider_cls): + async def test_fetch_dataset_integration_with_filtering(self, name, provider_cls): # TODO pass diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 7f38572311..5487a1c848 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -2,31 +2,90 @@ # Licensed under the MIT license. """ -TODO - -Tests for SeedDatasetMetadata +Tests for metadata components related to SeedDatasetProvider. """ -class TestMetadataParsing: - def test_invalid_filter_key(self): +class TestMetadataLifecycle: + """ + Test that the metadata object can be created with different + subsets of values. + """ + + def test_has_no_values(self): pass - def test_invalid_filter_value(self): + def test_has_some_values(self): pass + def test_has_all_values(self): + pass -class TestMetadataLifecycle: - def test_static_values_populated(self): + +class TestFilterLifecycle: + """ + Test that the metadata object can be created with different + subsets of values. + """ + + def test_has_no_values(self): + pass + + def test_has_some_values(self): pass - def test_dynamic_values_populated(self): + def test_has_all_values(self): pass -class TestMetadataPerformance: - def test_quick_retrieval_for_static_values(self): +class TestMetadataProperties: + """ + Test that the metadata fields populate correctly. + """ + + def test_size_value(self): + pass + + def test_loading_rank_value(self): + pass + + def test_source_value(self): + pass + + def test_modality_value(self): + pass + + def test_tags_value(self): pass - def test_acceptable_retrieval_for_dynamic_values(self): + def test_harm_categories_value(self): + pass + + +class TestFilterProperties: + """ + Test that the filter fields popualte correctly. + """ + + def test_sizes_values(self): + pass + + def test_loading_ranks_values(self): + pass + + def test_sources_values(self): + pass + + def test_modalities_values(self): + pass + + def test_tags_values(self): + pass + + def test_harm_categories_values(self): + pass + + +class TestMetadataUtilities: + def test_population_works(self, tmp_path): pass diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index d61e2291a2..0dbf0e13b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -78,13 +78,15 @@ async def test_fetch_datasets_async(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -97,12 +99,14 @@ async def test_fetch_datasets_async_with_filter(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) + mock_provider2.return_value.fetch_dataset = AsyncMock( + side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -115,13 +119,15 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset( + seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -236,3 +242,60 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): assert call_kwargs["dataset_name"] == "custom/darkbench" assert call_kwargs["config"] == "custom_config" assert call_kwargs["split"] == "test" + + +class TestMetadataParsingRemote: + def test_all_tag(self): + pass + + def test_tags(self): + pass + + def test_sizes(self): + pass + + def test_modalities(self): + pass + + def test_sources(self): + pass + + def test_ranks(self): + pass + + def test_harm_categories(self): + pass + + def test_empty_fitler(self): + pass + + def test_no_metadata(self): + pass + +class TestMetadataParsingLocal: + def test_all_tag(self): + pass + + def test_tags(self): + pass + + def test_sizes(self): + pass + + def test_modalities(self): + pass + + def test_sources(self): + pass + + def test_ranks(self): + pass + + def test_harm_categories(self): + pass + + def test_empty_fitler(self): + pass + + def test_no_metadata(self): + pass From 8dcbd5ff44541bf8ff0276f3df79886c1a21405c Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 20:02:20 +0000 Subject: [PATCH 6/9] review --- .../local/local_dataset_loader.py | 93 +++-- .../seed_datasets/remote/harmbench_dataset.py | 14 +- .../remote/remote_dataset_loader.py | 48 ++- .../seed_datasets/seed_dataset_provider.py | 57 +-- pyrit/datasets/seed_datasets/seed_metadata.py | 63 +-- .../test_seed_dataset_provider_integration.py | 283 +++++++++++++- .../datasets/test_seed_dataset_metadata.py | 126 ++++-- .../datasets/test_seed_dataset_provider.py | 360 +++++++++++++++--- 8 files changed, 830 insertions(+), 214 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index a54062c779..2c6dd4b778 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -2,14 +2,22 @@ # Licensed under the MIT license. import logging -import yaml from collections.abc import Callable +from dataclasses import fields from pathlib import Path -from typing import Any +from typing import Any, Optional + +import yaml from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset -from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata logger = logging.getLogger(__name__) @@ -38,8 +46,7 @@ def __init__(self, *, file_path: Path): dataset = SeedDataset.from_yaml_file(file_path) # Use the dataset_name from the YAML if available, otherwise use filename self._dataset_name = ( - getattr(dataset, "dataset_name", None) or getattr( - dataset, "name", None) or file_path.stem + getattr(dataset, "dataset_name", None) or getattr(dataset, "name", None) or file_path.stem ) except Exception as e: logger.warning(f"Could not pre-load dataset from {file_path}: {e}") @@ -70,31 +77,66 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: dataset.dataset_name = self.dataset_name return dataset except Exception as e: - logger.error( - f"Failed to load local dataset from {self.file_path}: {e}") + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - def _parse_metadata(self) -> SeedDatasetMetadata | None: + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ - Extract metadata from class attributes and format into SeedDatasetMetadata schema. - + Extract metadata from a local YAML file and coerce raw values into typed schema fields. + + YAML produces raw Python primitives (str, list) that must be converted to the + enum and set types expected by SeedDatasetMetadata before _match_filter can work. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. + Raises: - Exception: If the dataset cannot be loaded. + Exception: If the dataset file cannot be read. """ valid_fields = [f.name for f in fields(SeedDatasetMetadata)] try: - with open(self.file_path, 'r') as f: + with open(self.file_path, encoding="utf-8") as f: dataset = yaml.safe_load(f) except Exception as e: - logger.error( - f"Failed to load local datset from {self.file_path}: {e}" - ) + logger.error(f"Failed to load local dataset from {self.file_path}: {e}") raise - self_metadata = {k: v for k, v in dataset if k in valid_fields} - if not self_metadata: + + if not isinstance(dataset, dict): return None - return SeedDatasetMetadata(**self_metadata) + raw = {k: v for k, v in dataset.items() if k in valid_fields} + if not raw: + return None + + coerced = self._coerce_metadata_values(raw_metadata=raw) + return SeedDatasetMetadata(**coerced) + + @staticmethod + def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: + """ + Convert YAML primitive values into the enum/set types expected by SeedDatasetMetadata. + + Args: + raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values. + + Returns: + dict[str, Any]: Dictionary with values coerced to the correct types. + """ + coerced: dict[str, Any] = {} + for key, value in raw_metadata.items(): + if key == "tags" and isinstance(value, list): + coerced[key] = set(value) + elif key == "size" and isinstance(value, str): + coerced[key] = SeedDatasetSize(value) + elif key == "source" and isinstance(value, str): + coerced[key] = SeedDatasetSourceType(value) + elif key == "rank" and isinstance(value, str): + coerced[key] = SeedDatasetLoadingRank(value) + elif key == "modalities" and isinstance(value, list): + coerced[key] = [SeedDatasetModality(v) for v in value] + else: + coerced[key] = value + return coerced def _register_local_datasets() -> None: @@ -119,26 +161,21 @@ def _register_local_datasets() -> None: def make_init(path: Path) -> Callable[[Any], None]: def __init__(self: Any) -> None: # noqa: N807 - super(self.__class__, self).__init__( - file_path=path) + super(self.__class__, self).__init__(file_path=path) return __init__ type( class_name, (_LocalDatasetLoader,), - {"__init__": make_init( - yaml_file), "should_register": True, "__module__": __name__}, + {"__init__": make_init(yaml_file), "should_register": True, "__module__": __name__}, ) - logger.debug( - f"Registered local dataset loader: {class_name} for {yaml_file.name}") + logger.debug(f"Registered local dataset loader: {class_name} for {yaml_file.name}") except Exception as e: - logger.warning( - f"Failed to register local dataset {yaml_file}: {e}") + logger.warning(f"Failed to register local dataset {yaml_file}: {e}") else: - logger.warning( - f"Seed datasets directory not found: {seed_datasets_path}") + logger.warning(f"Seed datasets directory not found: {seed_datasets_path}") # Execute registration diff --git a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py index 4759d4b7ee..5d30fac2c4 100644 --- a/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/harmbench_dataset.py @@ -7,11 +7,8 @@ _RemoteDatasetLoader, ) from pyrit.datasets.seed_datasets.seed_metadata import ( - SeedDatasetLoadingRank, - SeedDatasetMetadata, SeedDatasetModality, SeedDatasetSize, - SeedDatasetSourceType, ) from pyrit.models import SeedDataset, SeedObjective @@ -27,10 +24,9 @@ class _HarmBenchDataset(_RemoteDatasetLoader): """ # Metadata - harm_categories: list[str] = ["cybercrime", "illegal", - "harmful", "chemical_biological", "harassment"] + harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT] - size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds + size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds # "default" means included in curated set tags: set[str] = {"default", "safety"} @@ -87,8 +83,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: # Check for missing keys in the example missing_keys = required_keys - example.keys() if missing_keys: - raise ValueError( - f"Missing keys in example: {', '.join(missing_keys)}") + raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}") # Extract data category = example["SemanticCategory"] @@ -104,8 +99,7 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: "biological, illegal activities, etc." ), source="https://github.com/centerforaisafety/HarmBench", - authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", - "Andy Zou", "Zifan Wang", "Norman Mu"], + authors=["Mantas Mazeika", "Long Phan", "Xuwang Yin", "Andy Zou", "Zifan Wang", "Norman Mu"], ) seeds.append(seed_prompt) diff --git a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py index 2f234f451b..9587a743f0 100644 --- a/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py @@ -76,8 +76,7 @@ def _validate_file_type(self, file_type: str) -> None: """ if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") def _read_cache(self, *, cache_file: Path, file_type: str) -> list[dict[str, str]]: """ @@ -134,19 +133,15 @@ def _fetch_from_public_url(self, *, source: str, file_type: str) -> list[dict[st if file_type in FILE_TYPE_HANDLERS: if file_type == "json": return cast( - "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"]( - io.StringIO(response.text)) + "list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO(response.text)) ) return cast( "list[dict[str, str]]", - FILE_TYPE_HANDLERS[file_type]["read"]( - io.StringIO("\n".join(response.text.splitlines()))), + FILE_TYPE_HANDLERS[file_type]["read"](io.StringIO("\n".join(response.text.splitlines()))), ) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") - raise Exception( - f"Failed to fetch examples from public URL. Status code: {response.status_code}") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") + raise Exception(f"Failed to fetch examples from public URL. Status code: {response.status_code}") def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str]]: """ @@ -166,8 +161,7 @@ def _fetch_from_file(self, *, source: str, file_type: str) -> list[dict[str, str if file_type in FILE_TYPE_HANDLERS: return cast("list[dict[str, str]]", FILE_TYPE_HANDLERS[file_type]["read"](file)) valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") def _fetch_from_url( self, @@ -199,26 +193,21 @@ def _fetch_from_url( file_type = source.split(".")[-1] if file_type not in FILE_TYPE_HANDLERS: valid_types = ", ".join(FILE_TYPE_HANDLERS.keys()) - raise ValueError( - f"Invalid file_type. Expected one of: {valid_types}.") + raise ValueError(f"Invalid file_type. Expected one of: {valid_types}.") data_home = DB_DATA_PATH / "seed-prompt-entries" - cache_file = data_home / \ - self._get_cache_file_name(source=source, file_type=file_type) + cache_file = data_home / self._get_cache_file_name(source=source, file_type=file_type) if cache and cache_file.exists(): return self._read_cache(cache_file=cache_file, file_type=file_type) if source_type == "public_url": - examples = self._fetch_from_public_url( - source=source, file_type=file_type) + examples = self._fetch_from_public_url(source=source, file_type=file_type) elif source_type == "file": - examples = self._fetch_from_file( - source=source, file_type=file_type) + examples = self._fetch_from_file(source=source, file_type=file_type) if cache: - self._write_cache(cache_file=cache_file, - examples=examples, file_type=file_type) + self._write_cache(cache_file=cache_file, examples=examples, file_type=file_type) else: with tempfile.NamedTemporaryFile( delete=False, mode="w", suffix=f".{file_type}", encoding="utf-8" @@ -296,16 +285,23 @@ def _load_dataset_sync() -> Any: # Run the synchronous load_dataset in a thread pool to avoid blocking the event loop return await asyncio.to_thread(_load_dataset_sync) except Exception as e: - logger.error( - f"Failed to load HuggingFace dataset {dataset_name}: {e}") + logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}") raise - def _parse_metadata(self) -> SeedDatasetMetadata | None: + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: """ Extract metadata from class attributes and format into SeedDatasetMetadata schema. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None. """ valid_fields = [f.name for f in fields(SeedDatasetMetadata)] - self_metadata = {k: v for k, v in self.__dict__.items() if k in valid_fields} + + provider_class = type(self) + self_metadata = { + key: getattr(provider_class, key) for key in valid_fields if getattr(provider_class, key, None) is not None + } + if not self_metadata: return None return SeedDatasetMetadata(**self_metadata) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index ae4e33eb31..ae45fcb200 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -72,6 +72,19 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: Exception: If the dataset cannot be fetched or processed. """ + def _parse_metadata(self) -> Optional[SeedDatasetMetadata]: + """ + Parse provider-specific metadata into the shared schema. + + Subclasses can override this to source metadata from class attributes, + prompt files, or any other backing format. The default implementation + returns None, which means metadata is not available for this provider. + + Returns: + Optional[SeedDatasetMetadata]: Parsed metadata for this provider, or None. + """ + return None + @classmethod def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]: """ @@ -107,19 +120,24 @@ def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> l provider = provider_class() # Parser ensures a standard metadata format - metadata: SeedDatasetMetadata = cls._parse_metadata() - if filters and not metadata and "all" not in filters.tags: + metadata = provider._parse_metadata() + + # "all" bypasses metadata filtering and returns every dataset. + if filters and filters.tags and "all" in filters.tags: + dataset_names.add(provider.dataset_name) + continue + + if filters and not metadata: # Datasets without metadata are skipped unless we want "all" continue # Filters detected but no match -> don't add this dataset - if filters and not cls._match_filter(metadata=metadata, filters=filters): + if filters and metadata and not cls._match_filter(metadata=metadata, filters=filters): continue dataset_names.add(provider.dataset_name) except Exception as e: - raise ValueError( - f"Could not get dataset name from {provider_class.__name__}: {e}") from e + raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e return sorted(dataset_names) @classmethod @@ -140,9 +158,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter bool: Whether or not the filters match or not. """ # Tags - if metadata.tags and "all" in metadata.tags: - # This is the only condition that returns true, because we want the "all" - # tag to override everything else in the filter. + if filters.tags and "all" in filters.tags: return True # These lines all disable SIM103 because metadata and filters tags can be optional, so @@ -156,8 +172,11 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Harm Categories - if metadata.harm_categories and filters.harm_categories and \ - not set(metadata.harm_categories) & set(filters.harm_categories): # noqa: SIM103 + if ( + metadata.harm_categories + and filters.harm_categories + and not set(metadata.harm_categories) & set(filters.harm_categories) + ): # noqa: SIM103 return False # Source Type @@ -165,8 +184,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Modalities - if metadata.modalities and filters.modalities and \ - not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 + if metadata.modalities and filters.modalities and not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103 return False # Rank @@ -215,11 +233,9 @@ async def fetch_datasets_async( # Validate dataset names if specified if dataset_names is not None: available_names = cls.get_all_dataset_names() - invalid_names = [ - name for name in dataset_names if name not in available_names] + invalid_names = [name for name in dataset_names if name not in available_names] if invalid_names: - raise ValueError( - f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") + raise ValueError(f"Dataset(s) not found: {invalid_names}. Available datasets: {available_names}") async def fetch_single_dataset( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -245,8 +261,7 @@ async def fetch_single_dataset( # Progress tracking total_count = len(cls._registry) - pbar = tqdm(total=total_count, - desc="Loading datasets - this can take a few minutes", unit="dataset") + pbar = tqdm(total=total_count, desc="Loading datasets - this can take a few minutes", unit="dataset") async def fetch_with_semaphore( provider_name: str, provider_class: type["SeedDatasetProvider"] @@ -284,12 +299,10 @@ async def fetch_with_semaphore( logger.info(f"Merging multiple sources for {dataset_name}.") existing_dataset = datasets[dataset_name] - combined_seeds = list( - existing_dataset.seeds) + list(dataset.seeds) + combined_seeds = list(existing_dataset.seeds) + list(dataset.seeds) existing_dataset.seeds = combined_seeds else: datasets[dataset_name] = dataset - logger.info( - f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") + logger.info(f"Successfully fetched {len(datasets)} unique datasets from {len(cls._registry)} providers") return list(datasets.values()) diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 6037d811b0..f5926eacb0 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -5,12 +5,16 @@ from enum import Enum from typing import Optional -from pyrit.common.path import DATASETS_PATH - """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). -The ground truth is SeedDatasetMetadata. This is +SeedDatasetMetadata is the internal schema used to normalize metadata fields +from different sources: +- Remote providers that declare metadata as class attributes +- Local prompt files that store metadata at the top level + +SeedDatasetFilter is the user-facing filter schema consumed by +SeedDatasetProvider.get_all_dataset_names(). """ @@ -61,12 +65,12 @@ class SeedDatasetFilter: SeedDatasetProvider. """ - tags: Optional[set[str]] - sizes: Optional[list[SeedDatasetSize]] - modalities: Optional[list[SeedDatasetModality]] - sources: Optional[list[SeedDatasetSourceType]] - ranks: Optional[list[SeedDatasetLoadingRank]] - harm_categories: Optional[list[str]] + tags: Optional[set[str]] = None + sizes: Optional[list[SeedDatasetSize]] = None + modalities: Optional[list[SeedDatasetModality]] = None + sources: Optional[list[SeedDatasetSourceType]] = None + ranks: Optional[list[SeedDatasetLoadingRank]] = None + harm_categories: Optional[list[str]] = None @dataclass(frozen=True) @@ -76,38 +80,9 @@ class SeedDatasetMetadata: object. """ - tags: Optional[set[str]] - size: Optional[SeedDatasetSize] - modalities: Optional[list[SeedDatasetModality]] - source: Optional[SeedDatasetSourceType] - rank: Optional[SeedDatasetLoadingRank] - harm_categories: Optional[list[str]] - - -class SeedDatasetMetadataUtilities: - """ - Collected utilities for managing and updating metadata. - """ - - @staticmethod - def populate_metadata() -> None: - """ - WARNING: Because this function updates the metadata for each SeedDatasetProvider, - it changes the provider's corresopnding source file. Run with caution! - - Updates the metadata per SeedDatasetProvider. - """ - - # 1 Gather all dataset files - - # 2 For each file, download and store in the database (in-memory) - - # 3 Count the number of entries exactly and identify its threshold - - # 4 If harm categories are found in source, add them - - # 5 Inspect type of prompts to identify modalities present - - # 6 Inspect source file to find where it pulled from - - # 7 Leave rank optional for now + tags: Optional[set[str]] = None + size: Optional[SeedDatasetSize] = None + modalities: Optional[list[SeedDatasetModality]] = None + source: Optional[SeedDatasetSourceType] = None + rank: Optional[SeedDatasetLoadingRank] = None + harm_categories: Optional[list[str]] = None diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 491f97e92a..22d4261ae1 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -2,12 +2,21 @@ # Licensed under the MIT license. import logging +import textwrap +from pathlib import Path +from unittest.mock import patch import pytest from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote import _VLSUMultimodalDataset -from pyrit.models import SeedDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetModality, + SeedDatasetSize, +) +from pyrit.models import SeedDataset, SeedPrompt logger = logging.getLogger(__name__) @@ -37,12 +46,10 @@ async def test_fetch_dataset_integration(self, name, provider_cls): try: # Use max_examples for slow providers that fetch many remote images - provider = provider_cls( - max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() + provider = provider_cls(max_examples=6) if provider_cls == _VLSUMultimodalDataset else provider_cls() dataset = await provider.fetch_dataset(cache=False) - assert isinstance( - dataset, SeedDataset), f"{name} did not return a SeedDataset" + assert isinstance(dataset, SeedDataset), f"{name} did not return a SeedDataset" assert len(dataset.seeds) > 0, f"{name} returned an empty dataset" assert dataset.dataset_name, f"{name} has no dataset_name" @@ -53,14 +60,266 @@ async def test_fetch_dataset_integration(self, name, provider_cls): f"Seed dataset_name mismatch in {name}: {seed.dataset_name} != {dataset.dataset_name}" ) - logger.info( - f"Successfully verified {name} with {len(dataset.seeds)} seeds") + logger.info(f"Successfully verified {name} with {len(dataset.seeds)} seeds") except Exception as e: pytest.fail(f"Failed to fetch dataset from {name}: {str(e)}") - @pytest.mark.asyncio - @pytest.mark.parameterize("name,provider_cls", get_dataset_providers()) - async def test_fetch_dataset_integration_with_filtering(self, name, provider_cls): - # TODO - pass + +class TestRemoteFilteringIntegration: + """ + Integration test for remote dataset filtering. + + Uses a mocked remote provider with class-level metadata attributes to + validate the full flow: metadata population, filter matching, and + get_all_dataset_names output. + """ + + def _make_remote_provider_cls( + self, + *, + name: str, + tags: set, + size: SeedDatasetSize, + modalities: list, + harm_categories: list, + ) -> type: + """Build a minimal concrete SeedDatasetProvider with class-level metadata.""" + from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import _RemoteDatasetLoader + + captured_name = name + + async def _fetch_dataset(self, *, cache=True): + return SeedDataset( + seeds=[SeedPrompt(value="x", data_type="text")], + dataset_name=captured_name, + ) + + attrs = { + "tags": tags, + "size": size, + "modalities": modalities, + "harm_categories": harm_categories, + "should_register": False, + "__module__": __name__, + # Concrete implementations satisfy ABC requirements + "dataset_name": property(lambda self: captured_name), + "fetch_dataset": _fetch_dataset, + "_fetch_from_url": lambda self, **kw: [], + } + + return type(f"_Mock_{name}", (_RemoteDatasetLoader,), attrs) + + def test_filter_matches_correct_remote_provider(self): + """Filter by size returns only providers that match.""" + large_cls = self._make_remote_provider_cls( + name="large_ds", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["violence"], + ) + small_cls = self._make_remote_provider_cls( + name="small_ds", + tags={"default"}, + size=SeedDatasetSize.SMALL, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["cybercrime"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + assert names == ["large_ds"] + + def test_filter_all_tag_returns_everything(self): + """tags={'all'} bypasses filtering and returns every provider.""" + cls1 = self._make_remote_provider_cls( + name="ds_a", + tags={"safety"}, + size=SeedDatasetSize.TINY, + modalities=[SeedDatasetModality.TEXT], + harm_categories=[], + ) + cls2 = self._make_remote_provider_cls( + name="ds_b", + tags={"custom"}, + size=SeedDatasetSize.HUGE, + modalities=[SeedDatasetModality.IMAGE], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"A": cls1, "B": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert sorted(names) == ["ds_a", "ds_b"] + + def test_multi_axis_filter(self): + """Multiple filter axes are ANDed together.""" + cls1 = self._make_remote_provider_cls( + name="text_large", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.TEXT], + harm_categories=["violence"], + ) + cls2 = self._make_remote_provider_cls( + name="image_large", + tags={"default"}, + size=SeedDatasetSize.LARGE, + modalities=[SeedDatasetModality.IMAGE], + harm_categories=["violence"], + ) + + with patch.dict( + SeedDatasetProvider._registry, + {"TL": cls1, "IL": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter( + sizes=[SeedDatasetSize.LARGE], + modalities=[SeedDatasetModality.TEXT], + ), + ) + assert names == ["text_large"] + + +class TestLocalFilteringIntegration: + """ + Integration test for local dataset filtering. + + Creates real YAML prompt files on disk, registers them as local providers, + and validates the full flow through get_all_dataset_names with filters. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_local_filter_by_size(self, tmp_path): + """Local YAML with size metadata is correctly coerced and filtered.""" + large_yaml = tmp_path / "large_ds.prompt" + large_yaml.write_text( + textwrap.dedent("""\ + dataset_name: large_local + size: large + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + small_yaml = tmp_path / "small_ds.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_local + size: small + harm_categories: + - cybercrime + seeds: + - value: test + data_type: text + """) + ) + + large_cls = self._make_local_cls(large_yaml) + small_cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Large": large_cls, "Small": small_cls}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + # dataset_name falls back to file stem when SeedDataset.from_yaml_file + # rejects extra keys like "size" during __init__ pre-loading + assert names == ["large_ds"] + + def test_local_filter_by_tags(self, tmp_path): + """Local YAML tags (list) are coerced to set for intersection.""" + yaml_path = tmp_path / "tagged.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: tagged_local + tags: + - safety + - default + harm_categories: + - violence + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Tagged": cls}, + clear=True, + ): + # dataset_name falls back to file stem ("tagged") when + # SeedDataset.from_yaml_file rejects extra keys like "tags" + matched = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert matched == ["tagged"] + + not_matched = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"unrelated"}), + ) + assert not_matched == [] + + def test_local_no_metadata_skipped(self, tmp_path): + """Local YAML without metadata fields is skipped when filters are provided.""" + yaml_path = tmp_path / "bare.prompt" + yaml_path.write_text( + textwrap.dedent("""\ + dataset_name: bare_local + seeds: + - value: test + data_type: text + """) + ) + cls = self._make_local_cls(yaml_path) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Without filters, the dataset is included + all_names = SeedDatasetProvider.get_all_dataset_names() + assert "bare_local" in all_names + + # With filters, it's skipped (no metadata to match against) + filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 5487a1c848..a99fe2450d 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -5,6 +5,15 @@ Tests for metadata components related to SeedDatasetProvider. """ +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) + class TestMetadataLifecycle: """ @@ -13,29 +22,76 @@ class TestMetadataLifecycle: """ def test_has_no_values(self): - pass + metadata = SeedDatasetMetadata() + assert metadata.tags is None + assert metadata.size is None + assert metadata.modalities is None + assert metadata.source is None + assert metadata.rank is None + assert metadata.harm_categories is None def test_has_some_values(self): - pass + metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + assert metadata.tags == {"safety"} + assert metadata.size == SeedDatasetSize.LARGE + assert metadata.modalities is None + assert metadata.source is None + assert metadata.rank is None + assert metadata.harm_categories is None def test_has_all_values(self): - pass + metadata = SeedDatasetMetadata( + tags={"default", "safety"}, + size=SeedDatasetSize.MEDIUM, + modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE], + source=SeedDatasetSourceType.REMOTE, + rank=SeedDatasetLoadingRank.DEFAULT, + harm_categories=["violence", "illegal"], + ) + assert metadata.tags == {"default", "safety"} + assert metadata.size == SeedDatasetSize.MEDIUM + assert len(metadata.modalities) == 2 + assert metadata.source == SeedDatasetSourceType.REMOTE + assert metadata.rank == SeedDatasetLoadingRank.DEFAULT + assert metadata.harm_categories == ["violence", "illegal"] class TestFilterLifecycle: """ - Test that the metadata object can be created with different + Test that the filter object can be created with different subsets of values. """ def test_has_no_values(self): - pass + f = SeedDatasetFilter() + assert f.tags is None + assert f.sizes is None + assert f.modalities is None + assert f.sources is None + assert f.ranks is None + assert f.harm_categories is None def test_has_some_values(self): - pass + f = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + assert f.sizes == [SeedDatasetSize.LARGE] + assert f.tags is None + assert f.modalities is None def test_has_all_values(self): - pass + f = SeedDatasetFilter( + tags={"default"}, + sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.MEDIUM], + modalities=[SeedDatasetModality.TEXT], + sources=[SeedDatasetSourceType.REMOTE], + ranks=[SeedDatasetLoadingRank.DEFAULT], + harm_categories=["violence"], + ) + assert f.tags == {"default"} + assert len(f.sizes) == 2 + assert f.modalities == [SeedDatasetModality.TEXT] + assert f.sources == [SeedDatasetSourceType.REMOTE] + assert f.ranks == [SeedDatasetLoadingRank.DEFAULT] + assert f.harm_categories == ["violence"] class TestMetadataProperties: @@ -44,48 +100,68 @@ class TestMetadataProperties: """ def test_size_value(self): - pass + for size in SeedDatasetSize: + metadata = SeedDatasetMetadata(size=size) + assert metadata.size == size def test_loading_rank_value(self): - pass + for rank in SeedDatasetLoadingRank: + metadata = SeedDatasetMetadata(rank=rank) + assert metadata.rank == rank def test_source_value(self): - pass + for source in SeedDatasetSourceType: + metadata = SeedDatasetMetadata(source=source) + assert metadata.source == source def test_modality_value(self): - pass + for modality in SeedDatasetModality: + metadata = SeedDatasetMetadata(modalities=[modality]) + assert modality in metadata.modalities def test_tags_value(self): - pass + metadata = SeedDatasetMetadata(tags={"safety", "default", "custom"}) + assert "safety" in metadata.tags + assert "default" in metadata.tags + assert "custom" in metadata.tags def test_harm_categories_value(self): - pass + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + assert "violence" in metadata.harm_categories + assert "cybercrime" in metadata.harm_categories class TestFilterProperties: """ - Test that the filter fields popualte correctly. + Test that the filter fields populate correctly. """ def test_sizes_values(self): - pass + f = SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.LARGE]) + assert SeedDatasetSize.SMALL in f.sizes + assert SeedDatasetSize.LARGE in f.sizes def test_loading_ranks_values(self): - pass + f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.SLOW]) + assert SeedDatasetLoadingRank.DEFAULT in f.ranks + assert SeedDatasetLoadingRank.SLOW in f.ranks def test_sources_values(self): - pass + f = SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) + assert SeedDatasetSourceType.LOCAL in f.sources + assert SeedDatasetSourceType.REMOTE in f.sources def test_modalities_values(self): - pass + f = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + assert SeedDatasetModality.TEXT in f.modalities + assert SeedDatasetModality.IMAGE in f.modalities def test_tags_values(self): - pass + f = SeedDatasetFilter(tags={"safety", "default"}) + assert "safety" in f.tags + assert "default" in f.tags def test_harm_categories_values(self): - pass - - -class TestMetadataUtilities: - def test_population_works(self, tmp_path): - pass + f = SeedDatasetFilter(harm_categories=["violence", "cybercrime"]) + assert "violence" in f.harm_categories + assert "cybercrime" in f.harm_categories diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 0dbf0e13b7..b24d8b56b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -1,13 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import textwrap from unittest.mock import AsyncMock, MagicMock, patch import pytest from pyrit.datasets import SeedDatasetProvider +from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset +from pyrit.datasets.seed_datasets.seed_metadata import ( + SeedDatasetFilter, + SeedDatasetLoadingRank, + SeedDatasetMetadata, + SeedDatasetModality, + SeedDatasetSize, + SeedDatasetSourceType, +) from pyrit.models import SeedDataset, SeedObjective, SeedPrompt @@ -78,15 +88,13 @@ async def test_fetch_datasets_async(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -99,14 +107,12 @@ async def test_fetch_datasets_async_with_filter(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" - mock_provider2.return_value.fetch_dataset = AsyncMock( - side_effect=Exception("Should not be called")) + mock_provider2.return_value.fetch_dataset = AsyncMock(side_effect=Exception("Should not be called")) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): datasets = await SeedDatasetProvider.fetch_datasets_async(dataset_names=["d1"]) @@ -119,15 +125,13 @@ async def test_fetch_datasets_async_invalid_dataset_name(self): mock_provider1 = MagicMock() mock_provider1.return_value.dataset_name = "d1" mock_provider1.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") + return_value=SeedDataset(seeds=[SeedPrompt(value="p1", data_type="text")], dataset_name="d1") ) mock_provider2 = MagicMock() mock_provider2.return_value.dataset_name = "d2" mock_provider2.return_value.fetch_dataset = AsyncMock( - return_value=SeedDataset( - seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") + return_value=SeedDataset(seeds=[SeedPrompt(value="p2", data_type="text")], dataset_name="d2") ) with patch.dict(SeedDatasetProvider._registry, {"P1": mock_provider1, "P2": mock_provider2}, clear=True): @@ -245,57 +249,319 @@ async def test_fetch_dataset_with_custom_config(self, mock_darkbench_data): class TestMetadataParsingRemote: + """Test metadata parsing and filter matching for remote providers.""" + + def test_parse_metadata_from_class_attrs(self): + """Test _parse_metadata correctly extracts class-level metadata attributes.""" + loader = _HarmBenchDataset() + metadata = loader._parse_metadata() + assert metadata is not None + assert metadata.tags == {"default", "safety"} + assert metadata.size == SeedDatasetSize.LARGE + assert metadata.modalities == [SeedDatasetModality.TEXT] + assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] + # source and rank are not declared as class attributes on HarmBench + assert metadata.source is None + assert metadata.rank is None + def test_all_tag(self): - pass + """Filter with tags={'all'} matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}) + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_tags(self): - pass + """Tag filter uses set intersection.""" + metadata = SeedDatasetMetadata(tags={"safety", "default"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"safety"})) + assert not SeedDatasetProvider._match_filter(metadata=metadata, filters=SeedDatasetFilter(tags={"unrelated"})) def test_sizes(self): - pass + """Size filter checks membership in the sizes list.""" + metadata = SeedDatasetMetadata(size=SeedDatasetSize.LARGE) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE, SeedDatasetSize.HUGE]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.SMALL]), + ) def test_modalities(self): - pass + """Modality filter uses set intersection.""" + metadata = SeedDatasetMetadata(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(modalities=[SeedDatasetModality.AUDIO]), + ) def test_sources(self): - pass + """Source filter checks membership.""" + metadata = SeedDatasetMetadata(source=SeedDatasetSourceType.REMOTE) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL]), + ) def test_ranks(self): - pass + """Rank filter checks membership.""" + metadata = SeedDatasetMetadata(rank=SeedDatasetLoadingRank.DEFAULT) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.SLOW]), + ) def test_harm_categories(self): - pass + """Harm category filter uses set intersection.""" + metadata = SeedDatasetMetadata(harm_categories=["violence", "cybercrime"]) + assert SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(harm_categories=["violence"]), + ) + assert not SeedDatasetProvider._match_filter( + metadata=metadata, + filters=SeedDatasetFilter(harm_categories=["unrelated"]), + ) - def test_empty_fitler(self): - pass + def test_empty_filter(self): + """Empty filter (all None) matches any metadata.""" + metadata = SeedDatasetMetadata(tags={"safety"}, size=SeedDatasetSize.LARGE) + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_no_metadata(self): - pass - -class TestMetadataParsingLocal: - def test_all_tag(self): - pass - - def test_tags(self): - pass - - def test_sizes(self): - pass - - def test_modalities(self): - pass - - def test_sources(self): - pass - - def test_ranks(self): - pass + """Provider without metadata is skipped when filters are applied.""" + mock_provider_cls = MagicMock() + mock_provider_instance = mock_provider_cls.return_value + mock_provider_instance.dataset_name = "no_metadata" + mock_provider_instance._parse_metadata.return_value = None - def test_harm_categories(self): - pass + with patch.dict(SeedDatasetProvider._registry, {"NoProv": mock_provider_cls}, clear=True): + names = SeedDatasetProvider.get_all_dataset_names(filters=SeedDatasetFilter(tags={"safety"})) + assert names == [] - def test_empty_fitler(self): - pass - def test_no_metadata(self): - pass +class TestMetadataParsingLocal: + """Test metadata parsing and filter matching for local YAML providers.""" + + def _make_loader(self, yaml_path): + """Create a _LocalDatasetLoader bypassing SeedDataset pre-loading.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = yaml_path + loader._dataset_name = yaml_path.stem + return loader + + def _write_yaml(self, tmp_path, name, content): + """Write a .prompt YAML file and return its path.""" + path = tmp_path / f"{name}.prompt" + path.write_text(content) + return path + + def test_parse_metadata_extracts_fields(self, tmp_path): + """Test _parse_metadata correctly extracts metadata fields from YAML.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["violence"] + + def test_all_tag(self, tmp_path): + """Filter with tags={'all'} matches regardless of metadata types.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"all"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_tags(self, tmp_path): + """YAML produces tags as list; set intersection in _match_filter expects a set.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + tags: + - safety + - default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(tags={"safety"}) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_sizes(self, tmp_path): + """YAML produces size as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + size: large + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_modalities(self, tmp_path): + """YAML produces modalities as list of strings; _match_filter uses enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + modalities: + - text + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_sources(self, tmp_path): + """YAML produces source as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + source: remote + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_ranks(self, tmp_path): + """YAML produces rank as string; _match_filter compares against enum values.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + rank: default + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_harm_categories(self, tmp_path): + """Both YAML and filter use list[str], so intersection works correctly.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + - cybercrime + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter(harm_categories=["violence"]) + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_empty_filter(self, tmp_path): + """Empty filter matches everything.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + harm_categories: + - violence + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is not None + filters = SeedDatasetFilter() + assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) + + def test_no_metadata(self, tmp_path): + """YAML without any metadata fields returns None from _parse_metadata.""" + yaml_path = self._write_yaml( + tmp_path, + "test", + textwrap.dedent("""\ + dataset_name: test + seeds: + - value: test prompt + data_type: text + """), + ) + loader = self._make_loader(yaml_path) + metadata = loader._parse_metadata() + assert metadata is None From 32b6752beeb71377b7f73c0a17cb741ea4e6e760 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 20:34:15 +0000 Subject: [PATCH 7/9] tests --- .../local/local_dataset_loader.py | 6 +- .../seed_datasets/seed_dataset_provider.py | 2 +- pyrit/datasets/seed_datasets/seed_metadata.py | 4 +- .../test_seed_dataset_provider_integration.py | 269 ++++++++++++++++++ .../datasets/test_seed_dataset_metadata.py | 26 +- .../datasets/test_seed_dataset_provider.py | 95 ++++++- 6 files changed, 378 insertions(+), 24 deletions(-) diff --git a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py index 2c6dd4b778..1ef745f628 100644 --- a/pyrit/datasets/seed_datasets/local/local_dataset_loader.py +++ b/pyrit/datasets/seed_datasets/local/local_dataset_loader.py @@ -128,12 +128,16 @@ def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]: coerced[key] = set(value) elif key == "size" and isinstance(value, str): coerced[key] = SeedDatasetSize(value) - elif key == "source" and isinstance(value, str): + elif key == "source_type" and isinstance(value, str): coerced[key] = SeedDatasetSourceType(value) elif key == "rank" and isinstance(value, str): coerced[key] = SeedDatasetLoadingRank(value) elif key == "modalities" and isinstance(value, list): coerced[key] = [SeedDatasetModality(v) for v in value] + elif key == "harm_categories" and isinstance(value, str): + coerced[key] = [value] + elif key == "tags" and isinstance(value, str): + coerced[key] = {value} else: coerced[key] = value return coerced diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index ae45fcb200..4d65b932c2 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -180,7 +180,7 @@ def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter return False # Source Type - if metadata.source and filters.sources and metadata.source not in filters.sources: # noqa: SIM103 + if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: # noqa: SIM103 return False # Modalities diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index f5926eacb0..01c97b8934 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -68,7 +68,7 @@ class SeedDatasetFilter: tags: Optional[set[str]] = None sizes: Optional[list[SeedDatasetSize]] = None modalities: Optional[list[SeedDatasetModality]] = None - sources: Optional[list[SeedDatasetSourceType]] = None + source_types: Optional[list[SeedDatasetSourceType]] = None ranks: Optional[list[SeedDatasetLoadingRank]] = None harm_categories: Optional[list[str]] = None @@ -83,6 +83,6 @@ class SeedDatasetMetadata: tags: Optional[set[str]] = None size: Optional[SeedDatasetSize] = None modalities: Optional[list[SeedDatasetModality]] = None - source: Optional[SeedDatasetSourceType] = None + source_type: Optional[SeedDatasetSourceType] = None rank: Optional[SeedDatasetLoadingRank] = None harm_categories: Optional[list[str]] = None diff --git a/tests/integration/datasets/test_seed_dataset_provider_integration.py b/tests/integration/datasets/test_seed_dataset_provider_integration.py index 22d4261ae1..f2da3a292a 100644 --- a/tests/integration/datasets/test_seed_dataset_provider_integration.py +++ b/tests/integration/datasets/test_seed_dataset_provider_integration.py @@ -323,3 +323,272 @@ def test_local_no_metadata_skipped(self, tmp_path): filters=SeedDatasetFilter(tags={"safety"}), ) assert filtered == [] + + +class TestEndToEndLocalDatasetWorkflow: + """ + End-to-end integration test that simulates a user workflow: + 1. Write a temporary local dataset to disk + 2. Register it into the provider registry + 3. Discover it via filtered get_all_dataset_names + 4. Fetch the actual dataset by name + 5. Verify metadata is accessible and correct + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + @pytest.mark.asyncio + async def test_user_discovers_and_fetches_filtered_dataset(self, tmp_path): + """ + Simulate a user who wants small text datasets about cybercrime: + 1. Filter returns the matching dataset name + 2. Fetch by that name succeeds and returns correct seeds + 3. Metadata on the provider is accessible and consistent + """ + # --- Step 0: Write two datasets to disk --- + matching_yaml = tmp_path / "cyber_small.prompt" + matching_yaml.write_text( + textwrap.dedent("""\ + dataset_name: cyber_prompts + harm_categories: + - cybercrime + seeds: + - value: Explain phishing techniques + data_type: text + - value: Describe SQL injection + data_type: text + """) + ) + unrelated_yaml = tmp_path / "violence_large.prompt" + unrelated_yaml.write_text( + textwrap.dedent("""\ + dataset_name: violence_prompts + harm_categories: + - violence + seeds: + - value: Describe a violent scenario + data_type: text + """) + ) + + matching_cls = self._make_local_cls(matching_yaml) + unrelated_cls = self._make_local_cls(unrelated_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Cyber": matching_cls, "Violence": unrelated_cls}, + clear=True, + ): + # --- Step 1: User filters by harm_categories --- + names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(harm_categories=["cybercrime"]), + ) + assert len(names) == 1 + dataset_name = names[0] + + # --- Step 2: User fetches the dataset by name --- + datasets = await SeedDatasetProvider.fetch_datasets_async( + dataset_names=[dataset_name], + ) + assert len(datasets) == 1 + dataset = datasets[0] + assert len(dataset.seeds) == 2 + assert dataset.seeds[0].value == "Explain phishing techniques" + assert dataset.seeds[1].value == "Describe SQL injection" + + # --- Step 3: User inspects metadata --- + provider = matching_cls() + metadata = provider._parse_metadata() + assert metadata is not None + assert metadata.harm_categories == ["cybercrime"] + + @pytest.mark.asyncio + async def test_user_fetches_unfiltered(self, tmp_path): + """ + Without filters, get_all_dataset_names returns everything, + and fetch_datasets_async retrieves all of them. + """ + ds1 = tmp_path / "ds_one.prompt" + ds1.write_text( + textwrap.dedent("""\ + dataset_name: dataset_one + seeds: + - value: prompt one + data_type: text + """) + ) + ds2 = tmp_path / "ds_two.prompt" + ds2.write_text( + textwrap.dedent("""\ + dataset_name: dataset_two + seeds: + - value: prompt two + data_type: text + """) + ) + + cls1 = self._make_local_cls(ds1) + cls2 = self._make_local_cls(ds2) + + with patch.dict( + SeedDatasetProvider._registry, + {"One": cls1, "Two": cls2}, + clear=True, + ): + names = SeedDatasetProvider.get_all_dataset_names() + assert len(names) == 2 + + datasets = await SeedDatasetProvider.fetch_datasets_async() + assert len(datasets) == 2 + fetched_names = sorted(d.dataset_name for d in datasets) + assert fetched_names == ["dataset_one", "dataset_two"] + + +class TestAllTagBypassIntegration: + """ + Integration tests for the tags={'all'} bypass pattern. + + The 'all' tag is a special escape hatch that returns every registered + dataset regardless of metadata presence or other filter axes. + """ + + @staticmethod + def _make_local_cls(yaml_path: Path) -> type: + """Build a dynamic local provider class for a YAML file.""" + + def make_init(path: Path): + def init_fn(self): + _LocalDatasetLoader.__init__(self, file_path=path) + + return init_fn + + return type( + f"LocalTest_{yaml_path.stem}", + (_LocalDatasetLoader,), + {"__init__": make_init(yaml_path), "should_register": False, "__module__": __name__}, + ) + + def test_all_tag_includes_datasets_without_metadata(self, tmp_path): + """ + A dataset whose YAML has no metadata fields at all is normally + skipped when filters are present. tags={'all'} overrides that. + """ + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Bare": cls}, + clear=True, + ): + # Normal filter skips it + filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"safety"}), + ) + assert filtered == [] + + # 'all' includes it + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert "bare_dataset" in all_names + + def test_all_tag_ignores_other_filter_axes(self, tmp_path): + """ + tags={'all'} returns everything even when other filter axes + would exclude datasets. + """ + small_yaml = tmp_path / "small.prompt" + small_yaml.write_text( + textwrap.dedent("""\ + dataset_name: small_dataset + size: small + harm_categories: + - cybercrime + seeds: + - value: small prompt + data_type: text + """) + ) + cls = self._make_local_cls(small_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Small": cls}, + clear=True, + ): + # Size filter alone would exclude it + size_filtered = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(sizes=[SeedDatasetSize.LARGE]), + ) + assert size_filtered == [] + + # 'all' tag overrides the size filter + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}, sizes=[SeedDatasetSize.LARGE]), + ) + assert "small" in all_names + + def test_all_tag_with_mixed_metadata_and_bare_datasets(self, tmp_path): + """ + With a mix of metadata-rich and metadata-bare datasets, + tags={'all'} returns all of them. + """ + rich_yaml = tmp_path / "rich.prompt" + rich_yaml.write_text( + textwrap.dedent("""\ + dataset_name: rich_dataset + harm_categories: + - violence + tags: + - safety + seeds: + - value: rich prompt + data_type: text + """) + ) + bare_yaml = tmp_path / "bare.prompt" + bare_yaml.write_text( + textwrap.dedent("""\ + dataset_name: bare_dataset + seeds: + - value: bare prompt + data_type: text + """) + ) + + rich_cls = self._make_local_cls(rich_yaml) + bare_cls = self._make_local_cls(bare_yaml) + + with patch.dict( + SeedDatasetProvider._registry, + {"Rich": rich_cls, "Bare": bare_cls}, + clear=True, + ): + all_names = SeedDatasetProvider.get_all_dataset_names( + filters=SeedDatasetFilter(tags={"all"}), + ) + assert len(all_names) == 2 + assert "bare_dataset" in all_names diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index a99fe2450d..4aaaed1fd4 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -26,7 +26,7 @@ def test_has_no_values(self): assert metadata.tags is None assert metadata.size is None assert metadata.modalities is None - assert metadata.source is None + assert metadata.source_type is None assert metadata.rank is None assert metadata.harm_categories is None @@ -35,7 +35,7 @@ def test_has_some_values(self): assert metadata.tags == {"safety"} assert metadata.size == SeedDatasetSize.LARGE assert metadata.modalities is None - assert metadata.source is None + assert metadata.source_type is None assert metadata.rank is None assert metadata.harm_categories is None @@ -44,14 +44,14 @@ def test_has_all_values(self): tags={"default", "safety"}, size=SeedDatasetSize.MEDIUM, modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE], - source=SeedDatasetSourceType.REMOTE, + source_type=SeedDatasetSourceType.REMOTE, rank=SeedDatasetLoadingRank.DEFAULT, harm_categories=["violence", "illegal"], ) assert metadata.tags == {"default", "safety"} assert metadata.size == SeedDatasetSize.MEDIUM assert len(metadata.modalities) == 2 - assert metadata.source == SeedDatasetSourceType.REMOTE + assert metadata.source_type == SeedDatasetSourceType.REMOTE assert metadata.rank == SeedDatasetLoadingRank.DEFAULT assert metadata.harm_categories == ["violence", "illegal"] @@ -67,7 +67,7 @@ def test_has_no_values(self): assert f.tags is None assert f.sizes is None assert f.modalities is None - assert f.sources is None + assert f.source_types is None assert f.ranks is None assert f.harm_categories is None @@ -82,14 +82,14 @@ def test_has_all_values(self): tags={"default"}, sizes=[SeedDatasetSize.SMALL, SeedDatasetSize.MEDIUM], modalities=[SeedDatasetModality.TEXT], - sources=[SeedDatasetSourceType.REMOTE], + source_types=[SeedDatasetSourceType.REMOTE], ranks=[SeedDatasetLoadingRank.DEFAULT], harm_categories=["violence"], ) assert f.tags == {"default"} assert len(f.sizes) == 2 assert f.modalities == [SeedDatasetModality.TEXT] - assert f.sources == [SeedDatasetSourceType.REMOTE] + assert f.source_types == [SeedDatasetSourceType.REMOTE] assert f.ranks == [SeedDatasetLoadingRank.DEFAULT] assert f.harm_categories == ["violence"] @@ -110,9 +110,9 @@ def test_loading_rank_value(self): assert metadata.rank == rank def test_source_value(self): - for source in SeedDatasetSourceType: - metadata = SeedDatasetMetadata(source=source) - assert metadata.source == source + for source_type in SeedDatasetSourceType: + metadata = SeedDatasetMetadata(source_type=source_type) + assert metadata.source_type == source_type def test_modality_value(self): for modality in SeedDatasetModality: @@ -147,9 +147,9 @@ def test_loading_ranks_values(self): assert SeedDatasetLoadingRank.SLOW in f.ranks def test_sources_values(self): - f = SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) - assert SeedDatasetSourceType.LOCAL in f.sources - assert SeedDatasetSourceType.REMOTE in f.sources + f = SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) + assert SeedDatasetSourceType.LOCAL in f.source_types + assert SeedDatasetSourceType.REMOTE in f.source_types def test_modalities_values(self): f = SeedDatasetFilter(modalities=[SeedDatasetModality.TEXT, SeedDatasetModality.IMAGE]) diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index b24d8b56b7..7095ed57cd 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -2,9 +2,12 @@ # Licensed under the MIT license. import textwrap +from dataclasses import fields as dc_fields +from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import pytest +import yaml from pyrit.datasets import SeedDatasetProvider from pyrit.datasets.seed_datasets.local.local_dataset_loader import _LocalDatasetLoader @@ -261,7 +264,7 @@ def test_parse_metadata_from_class_attrs(self): assert metadata.modalities == [SeedDatasetModality.TEXT] assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] # source and rank are not declared as class attributes on HarmBench - assert metadata.source is None + assert metadata.source_type is None assert metadata.rank is None def test_all_tag(self): @@ -302,14 +305,14 @@ def test_modalities(self): def test_sources(self): """Source filter checks membership.""" - metadata = SeedDatasetMetadata(source=SeedDatasetSourceType.REMOTE) + metadata = SeedDatasetMetadata(source_type=SeedDatasetSourceType.REMOTE) assert SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]), + filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]), ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(sources=[SeedDatasetSourceType.LOCAL]), + filters=SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL]), ) def test_ranks(self): @@ -472,13 +475,13 @@ def test_modalities(self, tmp_path): assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_sources(self, tmp_path): - """YAML produces source as string; _match_filter compares against enum values.""" + """YAML produces source_type as string; _match_filter compares against enum values.""" yaml_path = self._write_yaml( tmp_path, "test", textwrap.dedent("""\ dataset_name: test - source: remote + source_type: remote seeds: - value: test prompt data_type: text @@ -487,7 +490,7 @@ def test_sources(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is not None - filters = SeedDatasetFilter(sources=[SeedDatasetSourceType.REMOTE]) + filters = SeedDatasetFilter(source_types=[SeedDatasetSourceType.REMOTE]) assert SeedDatasetProvider._match_filter(metadata=metadata, filters=filters) def test_ranks(self, tmp_path): @@ -565,3 +568,81 @@ def test_no_metadata(self, tmp_path): loader = self._make_loader(yaml_path) metadata = loader._parse_metadata() assert metadata is None + + +class TestLocalDatasetMetadataCollisions: + """ + Regression tests that scan every real .prompt file under seed_datasets/local + to verify _parse_metadata does not crash from field-name collisions between + the YAML schema and SeedDatasetMetadata. + + The previous `source` field collision (URLs parsed as SeedDatasetSourceType) + is the motivating example. + """ + + @staticmethod + def _get_local_prompt_files() -> list: + """Collect all .prompt and .yaml files under the local datasets directory.""" + local_dir = Path(__file__).resolve().parents[3] / "pyrit" / "datasets" / "seed_datasets" / "local" + return sorted(local_dir.glob("**/*.prompt")) + sorted(local_dir.glob("**/*.yaml")) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_parse_metadata_does_not_crash(self, prompt_file): + """_parse_metadata must not raise on any real local dataset file.""" + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + # This must not raise — if a YAML key collides with a metadata field + # name but holds an incompatible value, the coercion layer should + # either handle it or skip it gracefully. + metadata = loader._parse_metadata() + # metadata can be None (no matching fields) or a valid SeedDatasetMetadata + if metadata is not None: + assert isinstance(metadata, SeedDatasetMetadata) + + @pytest.mark.parametrize("prompt_file", _get_local_prompt_files.__func__(), ids=lambda p: p.stem) + def test_no_yaml_key_shadows_metadata_field_with_wrong_type(self, prompt_file): + """ + If a YAML top-level key matches a SeedDatasetMetadata field name, the + coerced value must be the correct type (enum, set, list) — not a raw + string or other primitive that would silently break filtering. + """ + with open(prompt_file, encoding="utf-8") as f: + data = yaml.safe_load(f) + + if not isinstance(data, dict): + return + + metadata_field_names = {fld.name for fld in dc_fields(SeedDatasetMetadata)} + overlapping_keys = metadata_field_names & data.keys() + + if not overlapping_keys: + return + + # Coerce and construct — must not raise + loader = _LocalDatasetLoader.__new__(_LocalDatasetLoader) + loader.file_path = prompt_file + loader._dataset_name = prompt_file.stem + + raw = {k: data[k] for k in overlapping_keys} + coerced = _LocalDatasetLoader._coerce_metadata_values(raw_metadata=raw) + metadata = SeedDatasetMetadata(**coerced) + + # Verify coerced types match expectations + expected_types = { + "tags": (set, type(None)), + "size": (SeedDatasetSize, type(None)), + "modalities": (list, type(None)), + "source_type": (SeedDatasetSourceType, type(None)), + "rank": (SeedDatasetLoadingRank, type(None)), + "harm_categories": (list, type(None)), + } + for key in overlapping_keys: + value = getattr(metadata, key) + valid_types = expected_types.get(key) + if valid_types: + assert isinstance(value, valid_types), ( + f"Field '{key}' in {prompt_file.name} has type {type(value).__name__}, " + f"expected one of {valid_types}" + ) From c94a6da2056ec221d5a2dc69d6e50e94f0fd1917 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 13 Mar 2026 20:43:10 +0000 Subject: [PATCH 8/9] precommit --- pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py index b2b45c2a33..a622a4a018 100644 --- a/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py +++ b/pyrit/datasets/seed_datasets/remote/jbb_behaviors_dataset.py @@ -23,7 +23,7 @@ class _JBBBehaviorsDataset(_RemoteDatasetLoader): and may contain offensive content. Users should check with their legal department before using these prompts against production LLMs. """ - + def __init__( self, *, From 2e7e9375735f702761e38d739766faddaefba00a Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Sat, 14 Mar 2026 00:03:13 +0000 Subject: [PATCH 9/9] utilities scaffolding --- .../seed_datasets/seed_dataset_provider.py | 3 +- pyrit/datasets/seed_datasets/seed_metadata.py | 119 +++++++++++++++++- .../datasets/test_seed_dataset_metadata.py | 8 +- .../datasets/test_seed_dataset_provider.py | 7 +- 4 files changed, 125 insertions(+), 12 deletions(-) diff --git a/pyrit/datasets/seed_datasets/seed_dataset_provider.py b/pyrit/datasets/seed_datasets/seed_dataset_provider.py index 4d65b932c2..f5ef0d2736 100644 --- a/pyrit/datasets/seed_datasets/seed_dataset_provider.py +++ b/pyrit/datasets/seed_datasets/seed_dataset_provider.py @@ -9,7 +9,7 @@ from tqdm import tqdm -from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetMetadata +from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadingRank, SeedDatasetMetadata from pyrit.models.seeds import SeedDataset logger = logging.getLogger(__name__) @@ -33,6 +33,7 @@ class SeedDatasetProvider(ABC): """ _registry: dict[str, type["SeedDatasetProvider"]] = {} + rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN def __init_subclass__(cls, **kwargs: Any) -> None: """ diff --git a/pyrit/datasets/seed_datasets/seed_metadata.py b/pyrit/datasets/seed_datasets/seed_metadata.py index 01c97b8934..d203d7e1cf 100644 --- a/pyrit/datasets/seed_datasets/seed_metadata.py +++ b/pyrit/datasets/seed_datasets/seed_metadata.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import Optional, TypedDict """ Contains metadata objects for datasets (i.e. subclasses of SeedDatasetProvider). @@ -33,9 +33,18 @@ class SeedDatasetLoadingRank(Enum): Represents the general difficulty of loading in a dataset. """ + # Default is equivalent to "fastest" in the sense that datasets marked + # with a default rank will always get loaded. DEFAULT = "default" - EXTENDED = "extended" - SLOW = "slow" + + # These represent actual ranks. + PRIMARY = "primary" + SECONDARY = "secondary" + TERTIARY = "tertiary" + + # Unknown corresponds to an untested dataset that won't be loaded. It is the + # default provided in SeedDatasetProvider. + UNKNOWN = "unknown" class SeedDatasetModality(Enum): @@ -84,5 +93,107 @@ class SeedDatasetMetadata: size: Optional[SeedDatasetSize] = None modalities: Optional[list[SeedDatasetModality]] = None source_type: Optional[SeedDatasetSourceType] = None - rank: Optional[SeedDatasetLoadingRank] = None + rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN harm_categories: Optional[list[str]] = None + + +class SeedDatasetMetadataUtilities: + """ + Utilities for deriving metadata for datasets. Currently, only static attributes + are supported. + + The default working location for datasets is the in-memory database. + """ + + class Metrics(TypedDict): + """ + Typed dictionary for easier retrieval and calculation of dataset metrics. + """ + + exact_size: int + loading_time_ms: float + modalities_found: set[str] + source_type: str + harm_categories_found: set[str] + tags: set[str] + + # Stores working dataset calculations. + # Maps name to metrics, which are later converted into SeedDatasetMetadata. + _cache: dict[str, Metrics] = {} + + @classmethod + def populate_datasets(cls) -> None: + """ + Populate metadata for all registered datasets. + + WARNING: Because metadata is stored as class attributes, this method can directly + change source files. Be extra careful when running it. + """ + # Get all dataset names + # Calling SeedDatasetProvider would create a circular import, so we do this explicitly + datasets: list[str] = [] + + # Populate cache with empty (name, metrics) pairs + for dataset in datasets: + metrics: SeedDatasetMetadataUtilities.Metrics = { + "exact_size": -1, + "loading_time_ms": -1.0, + "modalities_found": {"None"}, + "source_type": "None", + "harm_categories_found": {"None"}, + "tags": {"None"}, + } + cls._cache[dataset] = metrics + + # Using a list, for each dataset name, load it in depending on class type + # Invoke the appropriate helper to parse it + + # If local, local_helper + + # If remote, remote_helper + + # Get contents from the memory database + # Note that we have to load it into the memory_database to get timing + # We also want the helper to do no initialization, just extract the relevant + # types and get ready to call a timing library + + # Calculate metrics one by one + + # Once out of the loop, calculate metadata fields + + # Loading rank by comparing relative speeds + + # Size by comparing buckets + + # Convert all others to types + + # Update (if update = True) the datasets + + # If remote, write to the file using regex + # E.g. harm_categories: ... should appear in source + + # If local, make sure the .prompt is formatted nicely + + @classmethod + def _local_helper(cls) -> None: + """ + Load local datasets into the working cache. + """ + + @classmethod + def _remote_helper(cls) -> None: + """ + Load remote datasets into the working cache. + """ + + @classmethod + def _remote_writer(cls) -> None: + """ + Write updated metadata to a remote dataset source file. + """ + + @classmethod + def _local_writer(cls) -> None: + """ + Write updated metadata to a local .prompt file. + """ diff --git a/tests/unit/datasets/test_seed_dataset_metadata.py b/tests/unit/datasets/test_seed_dataset_metadata.py index 4aaaed1fd4..73fa3dad3e 100644 --- a/tests/unit/datasets/test_seed_dataset_metadata.py +++ b/tests/unit/datasets/test_seed_dataset_metadata.py @@ -27,7 +27,7 @@ def test_has_no_values(self): assert metadata.size is None assert metadata.modalities is None assert metadata.source_type is None - assert metadata.rank is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN assert metadata.harm_categories is None def test_has_some_values(self): @@ -36,7 +36,7 @@ def test_has_some_values(self): assert metadata.size == SeedDatasetSize.LARGE assert metadata.modalities is None assert metadata.source_type is None - assert metadata.rank is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN assert metadata.harm_categories is None def test_has_all_values(self): @@ -142,9 +142,9 @@ def test_sizes_values(self): assert SeedDatasetSize.LARGE in f.sizes def test_loading_ranks_values(self): - f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.SLOW]) + f = SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.DEFAULT, SeedDatasetLoadingRank.TERTIARY]) assert SeedDatasetLoadingRank.DEFAULT in f.ranks - assert SeedDatasetLoadingRank.SLOW in f.ranks + assert SeedDatasetLoadingRank.TERTIARY in f.ranks def test_sources_values(self): f = SeedDatasetFilter(source_types=[SeedDatasetSourceType.LOCAL, SeedDatasetSourceType.REMOTE]) diff --git a/tests/unit/datasets/test_seed_dataset_provider.py b/tests/unit/datasets/test_seed_dataset_provider.py index 7095ed57cd..52029850b7 100644 --- a/tests/unit/datasets/test_seed_dataset_provider.py +++ b/tests/unit/datasets/test_seed_dataset_provider.py @@ -263,9 +263,10 @@ def test_parse_metadata_from_class_attrs(self): assert metadata.size == SeedDatasetSize.LARGE assert metadata.modalities == [SeedDatasetModality.TEXT] assert metadata.harm_categories == ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"] - # source and rank are not declared as class attributes on HarmBench + # source_type is not declared as a class attribute on HarmBench; + # rank inherits the UNKNOWN default from SeedDatasetProvider base class assert metadata.source_type is None - assert metadata.rank is None + assert metadata.rank == SeedDatasetLoadingRank.UNKNOWN def test_all_tag(self): """Filter with tags={'all'} matches any metadata.""" @@ -324,7 +325,7 @@ def test_ranks(self): ) assert not SeedDatasetProvider._match_filter( metadata=metadata, - filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.SLOW]), + filters=SeedDatasetFilter(ranks=[SeedDatasetLoadingRank.TERTIARY]), ) def test_harm_categories(self):