Skip to content
14 changes: 14 additions & 0 deletions pyrit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,22 @@
from pyrit.datasets.jailbreak.text_jailbreak import TextJailBreak
from pyrit.datasets.seed_datasets import local, remote # noqa: F401
from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider
from pyrit.datasets.seed_datasets.seed_metadata import (
SeedDatasetFilter,
SeedDatasetLoadingRank,
SeedDatasetMetadata,
SeedDatasetModality,
SeedDatasetSize,
SeedDatasetSourceType,
)

__all__ = [
"SeedDatasetFilter",
"SeedDatasetMetadata",
"SeedDatasetLoadingRank",
"SeedDatasetModality",
"SeedDatasetSize",
"SeedDatasetSourceType",
"SeedDatasetProvider",
"TextJailBreak",
]
74 changes: 73 additions & 1 deletion pyrit/datasets/seed_datasets/local/local_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,20 @@

import logging
from collections.abc import Callable
from dataclasses import fields
from pathlib import Path
from typing import Any
from typing import Any, Optional

import yaml

from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider
from pyrit.datasets.seed_datasets.seed_metadata import (
SeedDatasetLoadingRank,
SeedDatasetMetadata,
SeedDatasetModality,
SeedDatasetSize,
SeedDatasetSourceType,
)
from pyrit.models import SeedDataset

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -70,6 +80,68 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:
logger.error(f"Failed to load local dataset from {self.file_path}: {e}")
raise

def _parse_metadata(self) -> Optional[SeedDatasetMetadata]:
"""
Extract metadata from a local YAML file and coerce raw values into typed schema fields.

YAML produces raw Python primitives (str, list) that must be converted to the
enum and set types expected by SeedDatasetMetadata before _match_filter can work.

Returns:
Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None.

Raises:
Exception: If the dataset file cannot be read.
"""
valid_fields = [f.name for f in fields(SeedDatasetMetadata)]
try:
with open(self.file_path, encoding="utf-8") as f:
dataset = yaml.safe_load(f)
except Exception as e:
logger.error(f"Failed to load local dataset from {self.file_path}: {e}")
raise

if not isinstance(dataset, dict):
return None

raw = {k: v for k, v in dataset.items() if k in valid_fields}
if not raw:
return None

coerced = self._coerce_metadata_values(raw_metadata=raw)
return SeedDatasetMetadata(**coerced)

@staticmethod
def _coerce_metadata_values(*, raw_metadata: dict[str, Any]) -> dict[str, Any]:
"""
Convert YAML primitive values into the enum/set types expected by SeedDatasetMetadata.

Args:
raw_metadata (dict[str, Any]): Dictionary of field names to raw YAML-parsed values.

Returns:
dict[str, Any]: Dictionary with values coerced to the correct types.
"""
coerced: dict[str, Any] = {}
for key, value in raw_metadata.items():
if key == "tags" and isinstance(value, list):
coerced[key] = set(value)
elif key == "size" and isinstance(value, str):
coerced[key] = SeedDatasetSize(value)
elif key == "source_type" and isinstance(value, str):
coerced[key] = SeedDatasetSourceType(value)
elif key == "rank" and isinstance(value, str):
coerced[key] = SeedDatasetLoadingRank(value)
elif key == "modalities" and isinstance(value, list):
coerced[key] = [SeedDatasetModality(v) for v in value]
elif key == "harm_categories" and isinstance(value, str):
coerced[key] = [value]
elif key == "tags" and isinstance(value, str):
coerced[key] = {value}
else:
coerced[key] = value
return coerced


def _register_local_datasets() -> None:
"""
Expand Down
11 changes: 11 additions & 0 deletions pyrit/datasets/seed_datasets/remote/harmbench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
_RemoteDatasetLoader,
)
from pyrit.datasets.seed_datasets.seed_metadata import (
SeedDatasetModality,
SeedDatasetSize,
)
from pyrit.models import SeedDataset, SeedObjective


Expand All @@ -19,6 +23,13 @@ class _HarmBenchDataset(_RemoteDatasetLoader):
Reference: https://github.com/centerforaisafety/HarmBench
"""

# Metadata
harm_categories: list[str] = ["cybercrime", "illegal", "harmful", "chemical_biological", "harassment"]
modalities: list[SeedDatasetModality] = [SeedDatasetModality.TEXT]
size: SeedDatasetSize = SeedDatasetSize.LARGE # 504 seeds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mention this in another comment, but harmbench actually loads super fast, which makes me thing we may want a better measure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the role I imagined for SeedDatasetLoadingRank, since that's more of a performance measurement than a literal description of the dataset. I'll change it a bit to make that more obvious. What do you think of using SeedDatasetLoadingRank?

# "default" means included in curated set
tags: set[str] = {"default", "safety"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like mixing the pieces with tags. Would "default" actually be part of "ranks"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, this needs a refactor.


def __init__(
self,
*,
Expand Down
20 changes: 20 additions & 0 deletions pyrit/datasets/seed_datasets/remote/remote_dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import tempfile
from abc import ABC
from collections.abc import Callable
from dataclasses import fields
from pathlib import Path
from typing import Any, Literal, Optional, TextIO, cast

Expand All @@ -19,6 +20,7 @@
from pyrit.common.path import DB_DATA_PATH
from pyrit.common.text_helper import read_txt, write_txt
from pyrit.datasets.seed_datasets.seed_dataset_provider import SeedDatasetProvider
from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetMetadata

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -285,3 +287,21 @@ def _load_dataset_sync() -> Any:
except Exception as e:
logger.error(f"Failed to load HuggingFace dataset {dataset_name}: {e}")
raise

def _parse_metadata(self) -> Optional[SeedDatasetMetadata]:
"""
Extract metadata from class attributes and format into SeedDatasetMetadata schema.

Returns:
Optional[SeedDatasetMetadata]: Parsed metadata if available, otherwise None.
"""
valid_fields = [f.name for f in fields(SeedDatasetMetadata)]

provider_class = type(self)
self_metadata = {
key: getattr(provider_class, key) for key in valid_fields if getattr(provider_class, key, None) is not None
}

if not self_metadata:
return None
return SeedDatasetMetadata(**self_metadata)
94 changes: 93 additions & 1 deletion pyrit/datasets/seed_datasets/seed_dataset_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from tqdm import tqdm

from pyrit.datasets.seed_datasets.seed_metadata import SeedDatasetFilter, SeedDatasetLoadingRank, SeedDatasetMetadata
from pyrit.models.seeds import SeedDataset

logger = logging.getLogger(__name__)
Expand All @@ -25,9 +26,14 @@ class SeedDatasetProvider(ABC):
Subclasses must implement:
- fetch_dataset(): Fetch and return the dataset as a SeedDataset
- dataset_name property: Human-readable name for the dataset

All subclasses also have a _metadata property that is optional to make
dataset addition easier, but failing to complete it makes downstream
analysis more difficult.
"""

_registry: dict[str, type["SeedDatasetProvider"]] = {}
rank: SeedDatasetLoadingRank = SeedDatasetLoadingRank.UNKNOWN

def __init_subclass__(cls, **kwargs: Any) -> None:
"""
Expand Down Expand Up @@ -67,6 +73,19 @@ async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:
Exception: If the dataset cannot be fetched or processed.
"""

def _parse_metadata(self) -> Optional[SeedDatasetMetadata]:
"""
Parse provider-specific metadata into the shared schema.

Subclasses can override this to source metadata from class attributes,
prompt files, or any other backing format. The default implementation
returns None, which means metadata is not available for this provider.

Returns:
Optional[SeedDatasetMetadata]: Parsed metadata for this provider, or None.
"""
return None

@classmethod
def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]:
"""
Expand All @@ -78,10 +97,13 @@ def get_all_providers(cls) -> dict[str, type["SeedDatasetProvider"]]:
return cls._registry.copy()

@classmethod
def get_all_dataset_names(cls) -> list[str]:
def get_all_dataset_names(cls, filters: Optional[SeedDatasetFilter] = None) -> list[str]:
"""
Get the names of all registered datasets.

Args:
filters (Optional[SeedDatasetFilter]): List of filters to apply.

Returns:
List[str]: List of dataset names from all registered providers.

Expand All @@ -97,11 +119,81 @@ def get_all_dataset_names(cls) -> list[str]:
try:
# Instantiate to get dataset name
provider = provider_class()

# Parser ensures a standard metadata format
metadata = provider._parse_metadata()

# "all" bypasses metadata filtering and returns every dataset.
if filters and filters.tags and "all" in filters.tags:
dataset_names.add(provider.dataset_name)
continue

if filters and not metadata:
# Datasets without metadata are skipped unless we want "all"
continue

# Filters detected but no match -> don't add this dataset
if filters and metadata and not cls._match_filter(metadata=metadata, filters=filters):
continue

dataset_names.add(provider.dataset_name)
except Exception as e:
raise ValueError(f"Could not get dataset name from {provider_class.__name__}: {e}") from e
return sorted(dataset_names)

@classmethod
def _match_filter(cls, metadata: SeedDatasetMetadata, filters: SeedDatasetFilter) -> bool:
"""

Match the filter(s) with the metadata provided by the SeedDatasetProvider subclass.
By default, filters across dimensions (e.g. size, harm categories) are treated as AND
requirements. Filters within a dimension (e.g. SeedDatasetSize.SMALL,
SeedDatasetSize.LARGE) are treated as OR requirements.

Args:
metadata (SeedDatasetMetadata): The metadata object extracted from the SeedDatasetProvider
subclass.
filters (SeedDatasetFilter): The filter object provided by the user to get_all_dataset_names.

Returns:
bool: Whether or not the filters match or not.
"""
# Tags
if filters.tags and "all" in filters.tags:
return True

# These lines all disable SIM103 because metadata and filters tags can be optional, so
# directly checking for membership breaks type checking.

if metadata.tags and filters.tags and not (filters.tags & metadata.tags): # noqa: SIM103
return False

# Size
if metadata.size and filters.sizes and metadata.size not in filters.sizes: # noqa: SIM103
return False

# Harm Categories
if (
metadata.harm_categories
and filters.harm_categories
and not set(metadata.harm_categories) & set(filters.harm_categories)
): # noqa: SIM103
return False

# Source Type
if metadata.source_type and filters.source_types and metadata.source_type not in filters.source_types: # noqa: SIM103
return False

# Modalities
if metadata.modalities and filters.modalities and not set(metadata.modalities) & set(filters.modalities): # noqa: SIM103
return False

# Rank
if metadata.rank and filters.ranks and metadata.rank not in filters.ranks: # noqa: SIM103
return False

return True

@classmethod
async def fetch_datasets_async(
cls,
Expand Down
Loading
Loading