diff --git a/src/autointent/context/data_handler/__init__.py b/src/autointent/context/data_handler/__init__.py index 24a3d3b5..fa6f89e2 100644 --- a/src/autointent/context/data_handler/__init__.py +++ b/src/autointent/context/data_handler/__init__.py @@ -1,10 +1,6 @@ from ._data_handler import DataHandler -from ._stratification import ( - SplitReadinessResult, - StratifiedSplitter, - check_split_readiness, - split_dataset, -) +from ._readiness_util import SplitReadinessResult, check_split_readiness +from ._stratification import StratifiedSplitter, split_dataset __all__ = [ "DataHandler", diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py new file mode 100644 index 00000000..f0e70521 --- /dev/null +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass +from typing import TYPE_CHECKING, NamedTuple + +import numpy as np + +if TYPE_CHECKING: + from datasets import Dataset as HFDataset + + from autointent import Dataset + from autointent.configs import DataConfig + +from ._safe_multilabel_stratification import _validate_multilabel_matrix +from ._stratification import StratifiedSplitter + + +class ClassCount(NamedTuple): + id: int + """Class (intent) index.""" + + n_samples: int + """Number of samples from the class (intent).""" + + +@dataclass(frozen=True) +class SplitReadinessResult: + """Result of checking whether a dataset can be fed to autointent pipeline. + + Attributes: + ready: True if stratification can be performed (enough samples per class). + underpopulated_classes: List of (label, n_samples) for classes below the minimum. + min_samples_per_class_required: Minimum samples per class used for the check. + reason: Human-readable reason when not ready (e.g. OOS not configured). + """ + + ready: bool + underpopulated_classes: list[ClassCount] + min_samples_per_class_required: int + reason: str | None + + +def check_split_readiness( + dataset: Dataset, + split: str, + config: DataConfig, + allow_oos_in_train: bool | None = None, +) -> SplitReadinessResult: + """Check whether the dataset has enough samples per class for autointent pipeline. + + Args: + dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). + split: The split name to check (e.g. ``Split.TRAIN``). + config: data config + allow_oos_in_train: Same as in :func:`split_dataset`. If the split contains OOS samples + and this is ``None``, this function raises ``ValueError`` (mirrors splitting behavior). + """ + min_samples_per_class = _min_samples_per_class_for_config(config=config) + if split not in dataset: + return SplitReadinessResult( + ready=False, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason=f"Dataset has no split '{split}'.", + ) + hf_split = dataset[split] + splitter = StratifiedSplitter( + test_size=config.validation_size, + label_feature=dataset.label_feature, + random_seed=None, + ) + inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) + expected_n_classes = _expected_n_classes(dataset, inputs.dataset, splitter.label_feature) + + if inputs.multilabel: + underpopulated = _find_underpopulated_multilabel(inputs.dataset, splitter.label_feature, min_samples_per_class) + else: + underpopulated = _find_underpopulated_multiclass( + inputs.dataset, + splitter.label_feature, + min_samples_per_class, + expected_n_classes=expected_n_classes, + ) + ready = len(underpopulated) == 0 + reason: str | None = None + + if ready and (not inputs.multilabel): + split_ok, split_reason = _check_multiclass_split_size_feasibility( + dataset=inputs.dataset, + label_feature=splitter.label_feature, + test_size=inputs.test_size, + expected_n_classes=expected_n_classes, + ) + if not split_ok: + ready = False + reason = split_reason + + if not ready and reason is None: + parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] + reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( + min_samples_per_class, "; ".join(parts) + ) + return SplitReadinessResult( + ready=ready, + underpopulated_classes=underpopulated, + min_samples_per_class_required=min_samples_per_class, + reason=reason, + ) + + +def _min_samples_per_class_for_config(config: DataConfig) -> int: + """Return a recommended minimum samples-per-class for a given data config.""" + # Base requirement for a single stratified split. + # For CV, the canonical lower bound is one example per fold. + base = 2 if config.scheme == "ho" else int(config.n_folds) + + # separation_ratio triggers an extra stratified split of the effective train + # pool (e.g. decision vs scoring), so we double the requirement. + factor = 1 if config.separation_ratio is None else 2 + return base * factor + + +def _find_underpopulated_multiclass( + dataset: HFDataset, label_feature: str, min_samples_per_class: int, expected_n_classes: int +) -> list[ClassCount]: + """Return (label, count) for each class with fewer than min_samples_per_class samples.""" + labels: list[int] = dataset[label_feature] + counts = Counter(labels) + + # Ensure "missing" classes are treated as 0-count (underpopulated) + result: list[ClassCount] = [] + for label in range(int(expected_n_classes)): + n_samples = int(counts.get(label, 0)) + if n_samples < min_samples_per_class: + result.append(ClassCount(id=int(label), n_samples=n_samples)) + return result + + +def _find_underpopulated_multilabel( + dataset: HFDataset, label_feature: str, min_samples_per_class: int +) -> list[ClassCount]: + """Return (label_idx, positive_count) for each label with fewer than min_samples_per_class positives.""" + y = np.asarray(dataset[label_feature]) + _validate_multilabel_matrix(y) + counts = y.sum(axis=0).astype(int) + return [ + ClassCount(id=int(idx), n_samples=int(n_samples)) + for idx, n_samples in enumerate(counts) + if n_samples < min_samples_per_class + ] + + +def _check_multiclass_split_size_feasibility( + dataset: HFDataset, label_feature: str, test_size: float, expected_n_classes: int +) -> tuple[bool, str | None]: + """Return whether stratified train/test sizes are feasible for multiclass splits. + + Even if each class has >=2 samples, sklearn stratified splitting can fail when + the requested train/test sizes are too small to include all classes. + """ + labels = dataset[label_feature] + n_classes = expected_n_classes + n_samples = len(labels) + + # Mirror sklearn's float test_size -> n_test calculation (ceil). + n_test = int(np.ceil(float(test_size) * n_samples)) + n_train = n_samples - n_test + + if n_test <= 0 or n_train <= 0: + return ( + False, + f"Requested split sizes are invalid (n_samples={n_samples}, test_size={test_size}).", + ) + if n_test < n_classes: + return ( + False, + f"Stratified split would allocate too few test samples (n_test={n_test}) " + f"for the number of classes (n_classes={n_classes}).", + ) + if n_train < n_classes: + return ( + False, + f"Stratified split would allocate too few train samples (n_train={n_train}) " + f"for the number of classes (n_classes={n_classes}).", + ) + return True, None + + +def _expected_n_classes(dataset: Dataset, prepared: HFDataset, label_feature: str) -> int: + if dataset.multilabel: + return len(prepared[label_feature][0]) + labels: list[int] = prepared[label_feature] + max_seen = max(labels) if labels else -1 + return max(dataset.n_classes, int(max_seen) + 1) diff --git a/src/autointent/context/data_handler/_safe_multilabel_stratification.py b/src/autointent/context/data_handler/_safe_multilabel_stratification.py new file mode 100644 index 00000000..633ffc29 --- /dev/null +++ b/src/autointent/context/data_handler/_safe_multilabel_stratification.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np +from skmultilearn.model_selection import IterativeStratification +from transformers import set_seed + +if TYPE_CHECKING: + import numpy.typing as npt + +_MULTILABEL_NDIMS = 2 +_RARE_LABEL_COUNT_SINGLETON = 1 +_RARE_LABEL_COUNT_PAIR = 2 +_COIN_FLIP_P = 0.5 + + +def safe_multilabel_split_indices( + y: npt.NDArray[Any], test_size: float, random_seed: int | None +) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: + """Split multilabel data with coverage guarantees for rare labels.""" + _validate_multilabel_matrix(y) + n_samples = int(y.shape[0]) + rng = np.random.default_rng(random_seed) + + train_idx: set[int] = set() + test_idx: set[int] = set() + label_counts = y.sum(axis=0).astype(int) + + _force_singleton_labels(y=y, label_counts=label_counts, train_idx=train_idx) + _force_pair_labels(y=y, label_counts=label_counts, train_idx=train_idx, test_idx=test_idx, rng=rng) + + forced = train_idx | test_idx + remaining = np.array(sorted(set(range(n_samples)) - forced), dtype=int) + _iterative_stratify_remaining( + y=y, + remaining=remaining, + test_size=test_size, + random_seed=random_seed, + train_idx=train_idx, + test_idx=test_idx, + ) + return _finalize_partition(n_samples=n_samples, train_idx=train_idx, test_idx=test_idx) + + +def _validate_multilabel_matrix(y: npt.NDArray[Any]) -> None: + if y.ndim != _MULTILABEL_NDIMS: + msg = ( + "Expected multilabel data to be a 2D matrix-like structure " + f"(n_samples, n_labels), got shape={getattr(y, 'shape', None)!r}." + ) + raise ValueError(msg) + + +def _assigned_split(sample_idx: int, train_idx: set[int], test_idx: set[int]) -> str | None: + if sample_idx in train_idx: + return "train" + if sample_idx in test_idx: + return "test" + return None + + +def _force_singleton_labels(y: npt.NDArray[Any], label_counts: npt.NDArray[Any], train_idx: set[int]) -> None: + for label, count in enumerate(label_counts): + if int(count) != _RARE_LABEL_COUNT_SINGLETON: + continue + sample = int(np.flatnonzero(y[:, label])[0]) + train_idx.add(sample) + + +def _force_pair_samples(a: int, b: int, train_idx: set[int], test_idx: set[int], rng: np.random.Generator) -> None: + a_split = _assigned_split(a, train_idx, test_idx) + b_split = _assigned_split(b, train_idx, test_idx) + + if a_split is not None and b_split is None: + (test_idx if a_split == "train" else train_idx).add(b) + return + if b_split is not None and a_split is None: + (test_idx if b_split == "train" else train_idx).add(a) + return + if a_split is None and b_split is None: + if rng.random() < _COIN_FLIP_P: + train_idx.add(a) + test_idx.add(b) + else: + train_idx.add(b) + test_idx.add(a) + + +def _force_pair_labels( + y: npt.NDArray[Any], + label_counts: npt.NDArray[Any], + train_idx: set[int], + test_idx: set[int], + rng: np.random.Generator, +) -> None: + for label, count in enumerate(label_counts): + if int(count) != _RARE_LABEL_COUNT_PAIR: + continue + samples = np.flatnonzero(y[:, label]).astype(int) + a, b = sorted(samples.tolist(), key=lambda i: int(y[i].sum())) + _force_pair_samples(a=a, b=b, train_idx=train_idx, test_idx=test_idx, rng=rng) + + +def _iterative_stratify_remaining( + y: npt.NDArray[Any], + remaining: npt.NDArray[Any], + test_size: float, + random_seed: int | None, + train_idx: set[int], + test_idx: set[int], +) -> None: + if len(remaining) == 0: + return + if random_seed is not None: + # Workaround for buggy nature of IterativeStratification from skmultilearn + set_seed(random_seed) + splitter = IterativeStratification( + n_splits=2, + order=2, + # NOTE: IterativeStratification expects fold distribution in (test, train) order, + # but returns indices as (train, test). This matches the library's behavior and + # keeps backward-compatible train/test sizes with prior implementation. + sample_distribution_per_fold=[test_size, 1.0 - test_size], + ) + train_r, test_r = next(splitter.split(np.arange(len(remaining)), y[remaining])) + train_idx |= set(remaining[train_r].tolist()) + test_idx |= set(remaining[test_r].tolist()) + + +def _finalize_partition( + n_samples: int, train_idx: set[int], test_idx: set[int] +) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: + train_arr = np.array(sorted(train_idx), dtype=int) + test_arr = np.array(sorted(test_idx), dtype=int) + + if len(train_arr) + len(test_arr) != n_samples: + msg = ( + "Multilabel split did not partition all samples: " + f"n_samples={n_samples}, train={len(train_arr)}, test={len(test_arr)}." + ) + raise RuntimeError(msg) + if set(train_arr.tolist()) & set(test_arr.tolist()): + msg = "Multilabel split produced overlapping train/test indices." + raise RuntimeError(msg) + return train_arr, test_arr diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 2a6d48b4..81fb376a 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -7,15 +7,12 @@ from __future__ import annotations import logging -from collections import Counter from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np from datasets import concatenate_datasets from sklearn.model_selection import train_test_split -from skmultilearn.model_selection import IterativeStratification -from transformers import set_seed if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -26,6 +23,8 @@ from autointent import Dataset from autointent.custom_types import LabelType +from ._safe_multilabel_stratification import safe_multilabel_split_indices + logger = logging.getLogger(__name__) @@ -43,23 +42,6 @@ class StratifyInputs: post_split_fn: Callable[[HFDataset, HFDataset], tuple[HFDataset, HFDataset]] -@dataclass(frozen=True) -class SplitReadinessResult: - """Result of checking whether a dataset can be stratified split. - - Attributes: - ready: True if stratification can be performed (enough samples per class). - underpopulated_classes: List of (label, count) for classes below the minimum. - min_samples_per_class_required: Minimum samples per class used for the check. - reason: Human-readable reason when not ready (e.g. OOS not configured). - """ - - ready: bool - underpopulated_classes: list[tuple[int, int]] - min_samples_per_class_required: int - reason: str | None - - class StratifiedSplitter: """A class for stratified splitting of datasets. @@ -178,7 +160,14 @@ def _stratify_inputs_allow_oos(self, dataset: HFDataset, multilabel: bool) -> St OOS is mapped to a class so it is stratified; post_split_fn unmaps it. """ if multilabel: - in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) + try: + in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) + except StopIteration as e: + msg = ( + "Cannot infer multilabel dimensionality: dataset contains only OOS samples " + f"({self.label_feature}=None for all rows)." + ) + raise ValueError(msg) from e n_classes = len(in_domain_sample[self.label_feature]) mapped_dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) @@ -190,7 +179,7 @@ def unmap_oos_multilabel(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDat return StratifyInputs( dataset=mapped_dataset, - multilabel=False, + multilabel=True, test_size=self.test_size, post_split_fn=unmap_oos_multilabel, ) @@ -277,14 +266,9 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np Returns: A sequence containing indices for train and test splits. """ - if self.random_seed is not None: - set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn - splitter = IterativeStratification( - n_splits=2, - order=2, - sample_distribution_per_fold=[test_size, 1.0 - test_size], - ) - return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature]))) + y = np.asarray(dataset[self.label_feature]) + train_arr, test_arr = safe_multilabel_split_indices(y=y, test_size=test_size, random_seed=self.random_seed) + return (train_arr, test_arr) def _map_label( self, sample: dict[str, str | LabelType], old: LabelType, new: LabelType @@ -315,7 +299,9 @@ def _add_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) -> """ if sample[self.label_feature] is None: sample[self.label_feature] = [0] * n_classes - sample[self.label_feature] += [1] # type: ignore[operator] + sample[self.label_feature] += [1] # type: ignore[operator] + else: + sample[self.label_feature] += [0] # type: ignore[operator] return sample def _remove_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) -> dict[str, str | LabelType]: @@ -376,78 +362,6 @@ def _get_adjusted_test_size(self, n: int, k: int) -> float: return res -def _check_multiclass_counts( - dataset: HFDataset, label_feature: str, min_samples_per_class: int -) -> list[tuple[int, int]]: - """Return (label, count) for each class with fewer than min_samples_per_class samples.""" - labels: list[int] = dataset[label_feature] - counts = Counter(labels) - return [(label, count) for label, count in counts.items() if count < min_samples_per_class] - - -def check_split_readiness( - dataset: Dataset, - split: str, - test_size: float, - min_samples_per_class: int = 2, - allow_oos_in_train: bool | None = None, -) -> SplitReadinessResult: - """Check whether the dataset has enough samples per class for stratified splitting. - - Uses the same OOS and stratification logic as :func:`split_dataset`, so downstream - code can call this before creating a :class:`DataHandler` or calling :func:`split_dataset` - and handle underpopulated classes (e.g. skip phase, log, or fail with a clear message). - - Args: - dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). - split: The split name to check (e.g. ``Split.TRAIN``). - test_size: Proportion used for the test split (must match the value used when splitting). - min_samples_per_class: Minimum number of samples per class required for stratification. - Default 2 matches sklearn's requirement for a 2-way stratified split. - allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples - and this is not set, the function returns ``ready=False`` with a reason. - - Returns: - SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. - """ - if split not in dataset: - return SplitReadinessResult( - ready=False, - underpopulated_classes=[], - min_samples_per_class_required=min_samples_per_class, - reason=f"Dataset has no split '{split}'.", - ) - hf_split = dataset[split] - splitter = StratifiedSplitter( - test_size=test_size, - label_feature=dataset.label_feature, - random_seed=None, - ) - inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) - if inputs.multilabel: - # Multilabel stratification uses IterativeStratification; we do not validate it here. - return SplitReadinessResult( - ready=True, - underpopulated_classes=[], - min_samples_per_class_required=min_samples_per_class, - reason=None, - ) - underpopulated = _check_multiclass_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) - ready = len(underpopulated) == 0 - reason = None - if not ready: - parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] - reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( - min_samples_per_class, "; ".join(parts) - ) - return SplitReadinessResult( - ready=ready, - underpopulated_classes=underpopulated, - min_samples_per_class_required=min_samples_per_class, - reason=reason, - ) - - def split_dataset( dataset: Dataset, split: str, diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py index 962074fc..36877225 100644 --- a/tests/data/test_check_split_readiness.py +++ b/tests/data/test_check_split_readiness.py @@ -3,6 +3,7 @@ import pytest from autointent import Dataset +from autointent.configs import DataConfig from autointent.context.data_handler import ( SplitReadinessResult, check_split_readiness, @@ -18,10 +19,13 @@ def dataset_enough_samples(): "train": [ {"utterance": "a1", "label": 0}, {"utterance": "a2", "label": 0}, + {"utterance": "a3", "label": 0}, {"utterance": "b1", "label": 1}, {"utterance": "b2", "label": 1}, + {"utterance": "b3", "label": 1}, {"utterance": "c1", "label": 2}, {"utterance": "c2", "label": 2}, + {"utterance": "c3", "label": 2}, ], "test": [ {"utterance": "t1", "label": 0}, @@ -37,6 +41,28 @@ def dataset_enough_samples(): ) +@pytest.fixture +def dataset_three_classes_two_each(): + """3 classes, 2 samples each (no OOS). Useful for split-size feasibility tests.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + {"utterance": "c1", "label": 2}, + {"utterance": "c2", "label": 2}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + @pytest.fixture def dataset_underpopulated(): """Multiclass dataset with one class having only 1 sample. Not ready for stratification.""" @@ -87,7 +113,7 @@ def test_check_split_readiness_ready_when_enough_samples(dataset_enough_samples) result = check_split_readiness( dataset_enough_samples, split=Split.TRAIN, - test_size=0.3, + config=DataConfig(validation_size=0.3, separation_ratio=None), allow_oos_in_train=False, ) assert isinstance(result, SplitReadinessResult) @@ -102,7 +128,7 @@ def test_check_split_readiness_not_ready_underpopulated(dataset_underpopulated): result = check_split_readiness( dataset_underpopulated, split=Split.TRAIN, - test_size=0.3, + config=DataConfig(validation_size=0.3, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is False @@ -121,7 +147,7 @@ def test_check_split_readiness_missing_split(dataset_enough_samples): result = check_split_readiness( dataset_enough_samples, split="nonexistent_split", - test_size=0.3, + config=DataConfig(validation_size=0.3, separation_ratio=None), ) assert result.ready is False assert result.underpopulated_classes == [] @@ -134,7 +160,7 @@ def test_check_split_readiness_oos_allow_none(dataset_unsplitted): check_split_readiness( dataset_unsplitted, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=None, ) @@ -144,7 +170,7 @@ def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitt result = check_split_readiness( dataset_unsplitted, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is True @@ -152,13 +178,40 @@ def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitt assert result.reason is None +def test_check_split_readiness_multiclass_too_small_test_split(dataset_three_classes_two_each): + """Even with >=2/class, stratification can fail if test split can't include all classes.""" + result = check_split_readiness( + dataset_three_classes_two_each, + split=Split.TRAIN, + config=DataConfig(validation_size=0.1, separation_ratio=None), + allow_oos_in_train=False, + ) + assert result.ready is False + assert result.underpopulated_classes == [] + assert result.reason is not None + assert "too few test samples" in result.reason + + +def test_check_split_readiness_multiclass_too_small_train_split(dataset_three_classes_two_each): + """Even with >=2/class, stratification can fail if train split can't include all classes.""" + result = check_split_readiness( + dataset_three_classes_two_each, + split=Split.TRAIN, + config=DataConfig(validation_size=0.8, separation_ratio=None), + allow_oos_in_train=False, + ) + assert result.ready is False + assert result.underpopulated_classes == [] + assert result.reason is not None + assert "too few train samples" in result.reason + + def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_barely_enough): """Custom min_samples_per_class is respected.""" result = check_split_readiness( dataset_two_classes_barely_enough, split=Split.TRAIN, - test_size=0.3, - min_samples_per_class=2, + config=DataConfig(validation_size=0.3, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is True @@ -166,8 +219,7 @@ def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_b result_strict = check_split_readiness( dataset_two_classes_barely_enough, split=Split.TRAIN, - test_size=0.3, - min_samples_per_class=3, + config=DataConfig(scheme="cv", n_folds=3, separation_ratio=None), allow_oos_in_train=False, ) assert result_strict.ready is False @@ -175,17 +227,155 @@ def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_b assert result_strict.min_samples_per_class_required == 3 -def test_check_split_readiness_multilabel_returns_ready(dataset_unsplitted): - """Multilabel datasets return ready=True (multilabel stratification is not validated).""" - dataset = dataset_unsplitted.to_multilabel() +def test_check_split_readiness_multilabel_returns_ready(): + """Multilabel datasets are checked by per-label positive counts.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0, 1]}, + {"utterance": "x2", "label": [1, 0, 0]}, + {"utterance": "x3", "label": [0, 1, 0]}, + {"utterance": "x4", "label": [0, 0, 1]}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + # label 1 appears only once -> not ready for min_samples_per_class=2 result = check_split_readiness( dataset, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=False, ) + assert result.ready is False + assert result.underpopulated_classes == [(1, 1)] + assert result.reason is not None + + +def test_check_split_readiness_marks_declared_but_unseen_intent_as_underpopulated(): + """Intents with 0 samples should be flagged so callers can filter them out.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + ], + # Declare 3 intents, but only provide samples for ids 0 and 1. + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + result = check_split_readiness( + dataset, + split=Split.TRAIN, + config=DataConfig(validation_size=0.5, separation_ratio=None), + allow_oos_in_train=False, + ) + assert result.ready is False + assert (2, 0) in result.underpopulated_classes + assert result.reason is not None + + +def test_check_split_readiness_multilabel_oos_allow_true_checks_oos_label(): + """Multilabel + OOS + allow_oos_in_train=True should not crash and should include OOS label.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0]}, + {"utterance": "x2", "label": [1, 0]}, + {"utterance": "x3", "label": [0, 1]}, + {"utterance": "x4", "label": [0, 1]}, + {"utterance": "oos1", "label": None}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + result = check_split_readiness( + dataset, + split=Split.TRAIN, + config=DataConfig(validation_size=0.5, separation_ratio=None), + allow_oos_in_train=True, + ) + assert result.ready is False + # OOS indicator label is appended -> index == n_classes == 2 + assert (2, 1) in result.underpopulated_classes + assert result.reason is not None + + +def test_check_split_readiness_multilabel_oos_allow_true_ready_when_oos_sufficient(): + """When OOS count meets minimum, multilabel readiness can be true.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0]}, + {"utterance": "x2", "label": [1, 0]}, + {"utterance": "x3", "label": [0, 1]}, + {"utterance": "x4", "label": [0, 1]}, + {"utterance": "oos1", "label": None}, + {"utterance": "oos2", "label": None}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + result = check_split_readiness( + dataset, + split=Split.TRAIN, + config=DataConfig(validation_size=0.5, separation_ratio=None), + allow_oos_in_train=True, + ) assert result.ready is True assert result.underpopulated_classes == [] + assert result.reason is None + + +def test_split_dataset_multilabel_oos_allow_true_does_not_raise(): + """Sanity-check: split_dataset supports multilabel+OOS when allow_oos_in_train=True.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0]}, + {"utterance": "x2", "label": [1, 0]}, + {"utterance": "x3", "label": [0, 1]}, + {"utterance": "x4", "label": [0, 1]}, + {"utterance": "oos1", "label": None}, + {"utterance": "oos2", "label": None}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + from autointent.context.data_handler import split_dataset + + train, test = split_dataset( + dataset, + split=Split.TRAIN, + test_size=0.5, + random_seed=42, + allow_oos_in_train=True, + ) + assert len(train) > 0 + assert len(test) > 0 def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samples): @@ -193,7 +383,7 @@ def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samp result = check_split_readiness( dataset_enough_samples, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is True @@ -215,7 +405,7 @@ def test_check_split_readiness_underpopulated_implies_split_raises(dataset_under result = check_split_readiness( dataset_underpopulated, split=Split.TRAIN, - test_size=0.3, + config=DataConfig(validation_size=0.3), allow_oos_in_train=False, ) assert result.ready is False @@ -229,3 +419,20 @@ def test_check_split_readiness_underpopulated_implies_split_raises(dataset_under random_seed=42, allow_oos_in_train=False, ) + + +def test_stratified_splitter_multilabel_allow_oos_all_oos_raises_value_error(): + """Multilabel OOS mapping needs an in-domain row to infer label dimensionality.""" + from datasets import Dataset as HFDataset + + from autointent.context.data_handler._stratification import StratifiedSplitter + + hf_ds = HFDataset.from_list( + [ + {"utterance": "oos1", "label": None}, + {"utterance": "oos2", "label": None}, + ] + ) + splitter = StratifiedSplitter(test_size=0.5, label_feature="label", random_seed=0) + with pytest.raises(ValueError, match=r"only OOS|infer multilabel dimensionality"): + splitter.get_stratify_inputs(hf_ds, multilabel=True, allow_oos_in_train=True) diff --git a/tests/data/test_stratificaiton.py b/tests/data/test_stratificaiton.py index e964c965..3a728c01 100644 --- a/tests/data/test_stratificaiton.py +++ b/tests/data/test_stratificaiton.py @@ -1,5 +1,6 @@ import pytest +from autointent import Dataset from autointent.context.data_handler._stratification import split_dataset from autointent.custom_types import Split @@ -43,6 +44,34 @@ def test_multilabel_train_test_split(dataset_unsplitted): assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST) +def test_multilabel_train_test_split_multi_hot_preserves_label_coverage(): + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "u0", "label": [1, 0, 0]}, + {"utterance": "u1", "label": [1, 1, 0]}, + {"utterance": "u2", "label": [0, 1, 0]}, + {"utterance": "u3", "label": [0, 1, 1]}, + {"utterance": "u4", "label": [0, 0, 1]}, + {"utterance": "u5", "label": [1, 0, 1]}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + dataset[Split.TRAIN], dataset[Split.TEST] = split_dataset( + dataset, + split=Split.TRAIN, + test_size=0.5, + random_seed=42, + allow_oos_in_train=False, + ) + assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST) == dataset.n_classes + + def test_multilabel_train_test_split_few_shot(dataset_unsplitted): dataset = dataset_unsplitted dataset = dataset.to_multilabel()