From 1d0fc2ffb79c5fd7c6c0cbc4ba77090f82977a02 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 14:21:35 +0300 Subject: [PATCH 01/14] decompose `_stratification.py` onto two files --- .../context/data_handler/__init__.py | 8 +- .../context/data_handler/_readiness_util.py | 121 ++++++++++++++++++ .../context/data_handler/_stratification.py | 94 +------------- 3 files changed, 125 insertions(+), 98 deletions(-) create mode 100644 src/autointent/context/data_handler/_readiness_util.py diff --git a/src/autointent/context/data_handler/__init__.py b/src/autointent/context/data_handler/__init__.py index 24a3d3b5..fa6f89e2 100644 --- a/src/autointent/context/data_handler/__init__.py +++ b/src/autointent/context/data_handler/__init__.py @@ -1,10 +1,6 @@ from ._data_handler import DataHandler -from ._stratification import ( - SplitReadinessResult, - StratifiedSplitter, - check_split_readiness, - split_dataset, -) +from ._readiness_util import SplitReadinessResult, check_split_readiness +from ._stratification import StratifiedSplitter, split_dataset __all__ = [ "DataHandler", diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py new file mode 100644 index 00000000..3a919c2f --- /dev/null +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from collections import Counter +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from datasets import Dataset as HFDataset + + from autointent import Dataset + +from ._safe_multilabel_stratification import _validate_multilabel_matrix +from ._stratification import StratifiedSplitter + + +@dataclass(frozen=True) +class SplitReadinessResult: + """Result of checking whether a dataset can be stratified split. + + Attributes: + ready: True if stratification can be performed (enough samples per class). + underpopulated_classes: List of (label, count) for classes below the minimum. + min_samples_per_class_required: Minimum samples per class used for the check. + reason: Human-readable reason when not ready (e.g. OOS not configured). + """ + + ready: bool + underpopulated_classes: list[tuple[int, int]] + min_samples_per_class_required: int + reason: str | None + + +def check_split_readiness( + dataset: Dataset, + split: str, + test_size: float, + min_samples_per_class: int = 2, + allow_oos_in_train: bool | None = None, +) -> SplitReadinessResult: + """Check whether the dataset has enough samples per class for stratified splitting. + + Uses the same OOS and stratification logic as :func:`split_dataset`, so downstream + code can call this before creating a :class:`DataHandler` or calling :func:`split_dataset` + and handle underpopulated classes (e.g. skip phase, log, or fail with a clear message). + + Args: + dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). + split: The split name to check (e.g. ``Split.TRAIN``). + test_size: Proportion used for the test split (must match the value used when splitting). + min_samples_per_class: Minimum number of samples per class required for stratification. + Default 2 matches sklearn's requirement for a 2-way stratified split. + allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples + and this is not set, the function returns ``ready=False`` with a reason. + + Returns: + SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. + """ + if split not in dataset: + return SplitReadinessResult( + ready=False, + underpopulated_classes=[], + min_samples_per_class_required=min_samples_per_class, + reason=f"Dataset has no split '{split}'.", + ) + hf_split = dataset[split] + splitter = StratifiedSplitter( + test_size=test_size, + label_feature=dataset.label_feature, + random_seed=None, + ) + inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) + if inputs.multilabel: + underpopulated = _check_multilabel_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) + ready = len(underpopulated) == 0 + reason = None + if not ready: + parts = [f"label {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] + reason = "Multilabel stratification requires at least {} positives per label. Underpopulated: {}.".format( + min_samples_per_class, "; ".join(parts) + ) + return SplitReadinessResult( + ready=ready, + underpopulated_classes=underpopulated, + min_samples_per_class_required=min_samples_per_class, + reason=reason, + ) + underpopulated = _check_multiclass_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) + ready = len(underpopulated) == 0 + reason = None + if not ready: + parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] + reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( + min_samples_per_class, "; ".join(parts) + ) + return SplitReadinessResult( + ready=ready, + underpopulated_classes=underpopulated, + min_samples_per_class_required=min_samples_per_class, + reason=reason, + ) + + +def _check_multiclass_counts( + dataset: HFDataset, label_feature: str, min_samples_per_class: int +) -> list[tuple[int, int]]: + """Return (label, count) for each class with fewer than min_samples_per_class samples.""" + labels: list[int] = dataset[label_feature] + counts = Counter(labels) + return [(label, count) for label, count in counts.items() if count < min_samples_per_class] + + +def _check_multilabel_counts( + dataset: HFDataset, label_feature: str, min_samples_per_class: int +) -> list[tuple[int, int]]: + """Return (label_idx, positive_count) for each label with fewer than min_samples_per_class positives.""" + y = np.asarray(dataset[label_feature]) + _validate_multilabel_matrix(y) + counts = y.sum(axis=0).astype(int) + return [(int(idx), int(count)) for idx, count in enumerate(counts) if count < min_samples_per_class] diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 2a6d48b4..9b070947 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -7,15 +7,12 @@ from __future__ import annotations import logging -from collections import Counter from dataclasses import dataclass from typing import TYPE_CHECKING import numpy as np from datasets import concatenate_datasets from sklearn.model_selection import train_test_split -from skmultilearn.model_selection import IterativeStratification -from transformers import set_seed if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -26,6 +23,8 @@ from autointent import Dataset from autointent.custom_types import LabelType +from ._safe_multilabel_stratification import safe_multilabel_split_indices + logger = logging.getLogger(__name__) @@ -43,23 +42,6 @@ class StratifyInputs: post_split_fn: Callable[[HFDataset, HFDataset], tuple[HFDataset, HFDataset]] -@dataclass(frozen=True) -class SplitReadinessResult: - """Result of checking whether a dataset can be stratified split. - - Attributes: - ready: True if stratification can be performed (enough samples per class). - underpopulated_classes: List of (label, count) for classes below the minimum. - min_samples_per_class_required: Minimum samples per class used for the check. - reason: Human-readable reason when not ready (e.g. OOS not configured). - """ - - ready: bool - underpopulated_classes: list[tuple[int, int]] - min_samples_per_class_required: int - reason: str | None - - class StratifiedSplitter: """A class for stratified splitting of datasets. @@ -376,78 +358,6 @@ def _get_adjusted_test_size(self, n: int, k: int) -> float: return res -def _check_multiclass_counts( - dataset: HFDataset, label_feature: str, min_samples_per_class: int -) -> list[tuple[int, int]]: - """Return (label, count) for each class with fewer than min_samples_per_class samples.""" - labels: list[int] = dataset[label_feature] - counts = Counter(labels) - return [(label, count) for label, count in counts.items() if count < min_samples_per_class] - - -def check_split_readiness( - dataset: Dataset, - split: str, - test_size: float, - min_samples_per_class: int = 2, - allow_oos_in_train: bool | None = None, -) -> SplitReadinessResult: - """Check whether the dataset has enough samples per class for stratified splitting. - - Uses the same OOS and stratification logic as :func:`split_dataset`, so downstream - code can call this before creating a :class:`DataHandler` or calling :func:`split_dataset` - and handle underpopulated classes (e.g. skip phase, log, or fail with a clear message). - - Args: - dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). - split: The split name to check (e.g. ``Split.TRAIN``). - test_size: Proportion used for the test split (must match the value used when splitting). - min_samples_per_class: Minimum number of samples per class required for stratification. - Default 2 matches sklearn's requirement for a 2-way stratified split. - allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples - and this is not set, the function returns ``ready=False`` with a reason. - - Returns: - SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. - """ - if split not in dataset: - return SplitReadinessResult( - ready=False, - underpopulated_classes=[], - min_samples_per_class_required=min_samples_per_class, - reason=f"Dataset has no split '{split}'.", - ) - hf_split = dataset[split] - splitter = StratifiedSplitter( - test_size=test_size, - label_feature=dataset.label_feature, - random_seed=None, - ) - inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) - if inputs.multilabel: - # Multilabel stratification uses IterativeStratification; we do not validate it here. - return SplitReadinessResult( - ready=True, - underpopulated_classes=[], - min_samples_per_class_required=min_samples_per_class, - reason=None, - ) - underpopulated = _check_multiclass_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) - ready = len(underpopulated) == 0 - reason = None - if not ready: - parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] - reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( - min_samples_per_class, "; ".join(parts) - ) - return SplitReadinessResult( - ready=ready, - underpopulated_classes=underpopulated, - min_samples_per_class_required=min_samples_per_class, - reason=reason, - ) - - def split_dataset( dataset: Dataset, split: str, From a10868c3d72e62ff34e290e1a34ddc2cf01ef71e Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 14:22:16 +0300 Subject: [PATCH 02/14] implement safe stratification of multilabel data --- .../_safe_multilabel_stratification.py | 132 ++++++++++++++++++ .../context/data_handler/_stratification.py | 11 +- tests/data/test_check_split_readiness.py | 38 ++++- tests/data/test_stratificaiton.py | 29 ++++ 4 files changed, 197 insertions(+), 13 deletions(-) create mode 100644 src/autointent/context/data_handler/_safe_multilabel_stratification.py diff --git a/src/autointent/context/data_handler/_safe_multilabel_stratification.py b/src/autointent/context/data_handler/_safe_multilabel_stratification.py new file mode 100644 index 00000000..2d703bde --- /dev/null +++ b/src/autointent/context/data_handler/_safe_multilabel_stratification.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import numpy as np +from skmultilearn.model_selection import IterativeStratification +from transformers import set_seed + +_MULTILABEL_NDIMS = 2 +_RARE_LABEL_COUNT_SINGLETON = 1 +_RARE_LABEL_COUNT_PAIR = 2 +_COIN_FLIP_P = 0.5 + + +def safe_multilabel_split_indices( + y: np.ndarray, test_size: float, random_seed: int | None +) -> tuple[np.ndarray, np.ndarray]: + """Split multilabel data with coverage guarantees for rare labels.""" + _validate_multilabel_matrix(y) + n_samples = int(y.shape[0]) + rng = np.random.default_rng(random_seed) + + train_idx: set[int] = set() + test_idx: set[int] = set() + label_counts = y.sum(axis=0).astype(int) + + _force_singleton_labels(y=y, label_counts=label_counts, train_idx=train_idx) + _force_pair_labels(y=y, label_counts=label_counts, train_idx=train_idx, test_idx=test_idx, rng=rng) + + forced = train_idx | test_idx + remaining = np.array(sorted(set(range(n_samples)) - forced), dtype=int) + _iterative_stratify_remaining( + y=y, + remaining=remaining, + test_size=test_size, + random_seed=random_seed, + train_idx=train_idx, + test_idx=test_idx, + ) + return _finalize_partition(n_samples=n_samples, train_idx=train_idx, test_idx=test_idx) + + +def _validate_multilabel_matrix(y: np.ndarray) -> None: + if y.ndim != _MULTILABEL_NDIMS: + msg = ( + "Expected multilabel data to be a 2D matrix-like structure " + f"(n_samples, n_labels), got shape={getattr(y, 'shape', None)!r}." + ) + raise ValueError(msg) + + +def _assigned_split(sample_idx: int, train_idx: set[int], test_idx: set[int]) -> str | None: + if sample_idx in train_idx: + return "train" + if sample_idx in test_idx: + return "test" + return None + + +def _force_singleton_labels(y: np.ndarray, label_counts: np.ndarray, train_idx: set[int]) -> None: + for label, count in enumerate(label_counts): + if int(count) != _RARE_LABEL_COUNT_SINGLETON: + continue + sample = int(np.flatnonzero(y[:, label])[0]) + train_idx.add(sample) + + +def _force_pair_samples(a: int, b: int, train_idx: set[int], test_idx: set[int], rng: np.random.Generator) -> None: + a_split = _assigned_split(a, train_idx, test_idx) + b_split = _assigned_split(b, train_idx, test_idx) + + if a_split is not None and b_split is None: + (test_idx if a_split == "train" else train_idx).add(b) + return + if b_split is not None and a_split is None: + (test_idx if b_split == "train" else train_idx).add(a) + return + if a_split is None and b_split is None: + if rng.random() < _COIN_FLIP_P: + train_idx.add(a) + test_idx.add(b) + else: + train_idx.add(b) + test_idx.add(a) + + +def _force_pair_labels( + y: np.ndarray, label_counts: np.ndarray, train_idx: set[int], test_idx: set[int], rng: np.random.Generator +) -> None: + for label, count in enumerate(label_counts): + if int(count) != _RARE_LABEL_COUNT_PAIR: + continue + samples = np.flatnonzero(y[:, label]).astype(int) + a, b = sorted(samples.tolist(), key=lambda i: int(y[i].sum())) + _force_pair_samples(a=a, b=b, train_idx=train_idx, test_idx=test_idx, rng=rng) + + +def _iterative_stratify_remaining( + y: np.ndarray, + remaining: np.ndarray, + test_size: float, + random_seed: int | None, + train_idx: set[int], + test_idx: set[int], +) -> None: + if len(remaining) == 0: + return + if random_seed is not None: + # Workaround for buggy nature of IterativeStratification from skmultilearn + set_seed(random_seed) + splitter = IterativeStratification( + n_splits=2, + order=2, + sample_distribution_per_fold=[1.0 - test_size, test_size], + ) + train_r, test_r = next(splitter.split(np.arange(len(remaining)), y[remaining])) + train_idx |= set(remaining[train_r].tolist()) + test_idx |= set(remaining[test_r].tolist()) + + +def _finalize_partition(n_samples: int, train_idx: set[int], test_idx: set[int]) -> tuple[np.ndarray, np.ndarray]: + train_arr = np.array(sorted(train_idx), dtype=int) + test_arr = np.array(sorted(test_idx), dtype=int) + + if len(train_arr) + len(test_arr) != n_samples: + msg = ( + "Multilabel split did not partition all samples: " + f"n_samples={n_samples}, train={len(train_arr)}, test={len(test_arr)}." + ) + raise RuntimeError(msg) + if set(train_arr.tolist()) & set(test_arr.tolist()): + msg = "Multilabel split produced overlapping train/test indices." + raise RuntimeError(msg) + return train_arr, test_arr diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 9b070947..f2a4ac91 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -259,14 +259,9 @@ def _split_multilabel(self, dataset: HFDataset, test_size: float) -> Sequence[np Returns: A sequence containing indices for train and test splits. """ - if self.random_seed is not None: - set_seed(self.random_seed) # workaround for buggy nature of IterativeStratification from skmultilearn - splitter = IterativeStratification( - n_splits=2, - order=2, - sample_distribution_per_fold=[test_size, 1.0 - test_size], - ) - return next(splitter.split(np.arange(len(dataset)), np.array(dataset[self.label_feature]))) + y = np.asarray(dataset[self.label_feature]) + train_arr, test_arr = safe_multilabel_split_indices(y=y, test_size=test_size, random_seed=self.random_seed) + return (train_arr, test_arr) def _map_label( self, sample: dict[str, str | LabelType], old: LabelType, new: LabelType diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py index 962074fc..71500924 100644 --- a/tests/data/test_check_split_readiness.py +++ b/tests/data/test_check_split_readiness.py @@ -175,17 +175,45 @@ def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_b assert result_strict.min_samples_per_class_required == 3 -def test_check_split_readiness_multilabel_returns_ready(dataset_unsplitted): - """Multilabel datasets return ready=True (multilabel stratification is not validated).""" - dataset = dataset_unsplitted.to_multilabel() +def test_check_split_readiness_multilabel_returns_ready(): + """Multilabel datasets are checked by per-label positive counts.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0, 1]}, + {"utterance": "x2", "label": [1, 0, 0]}, + {"utterance": "x3", "label": [0, 1, 0]}, + {"utterance": "x4", "label": [0, 0, 1]}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + # label 1 appears only once -> not ready for min_samples_per_class=2 result = check_split_readiness( dataset, split=Split.TRAIN, test_size=0.5, allow_oos_in_train=False, ) - assert result.ready is True - assert result.underpopulated_classes == [] + assert result.ready is False + assert result.underpopulated_classes == [(1, 1)] + assert result.reason is not None + + # With min_samples_per_class=1 it becomes ready. + result_relaxed = check_split_readiness( + dataset, + split=Split.TRAIN, + test_size=0.5, + min_samples_per_class=1, + allow_oos_in_train=False, + ) + assert result_relaxed.ready is True + assert result_relaxed.underpopulated_classes == [] def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samples): diff --git a/tests/data/test_stratificaiton.py b/tests/data/test_stratificaiton.py index e964c965..3a728c01 100644 --- a/tests/data/test_stratificaiton.py +++ b/tests/data/test_stratificaiton.py @@ -1,5 +1,6 @@ import pytest +from autointent import Dataset from autointent.context.data_handler._stratification import split_dataset from autointent.custom_types import Split @@ -43,6 +44,34 @@ def test_multilabel_train_test_split(dataset_unsplitted): assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST) +def test_multilabel_train_test_split_multi_hot_preserves_label_coverage(): + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "u0", "label": [1, 0, 0]}, + {"utterance": "u1", "label": [1, 1, 0]}, + {"utterance": "u2", "label": [0, 1, 0]}, + {"utterance": "u3", "label": [0, 1, 1]}, + {"utterance": "u4", "label": [0, 0, 1]}, + {"utterance": "u5", "label": [1, 0, 1]}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + dataset[Split.TRAIN], dataset[Split.TEST] = split_dataset( + dataset, + split=Split.TRAIN, + test_size=0.5, + random_seed=42, + allow_oos_in_train=False, + ) + assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST) == dataset.n_classes + + def test_multilabel_train_test_split_few_shot(dataset_unsplitted): dataset = dataset_unsplitted dataset = dataset.to_multilabel() From d6c371b4760ed5751434d6d7e5ce1f576974abb5 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 14:27:32 +0300 Subject: [PATCH 03/14] fix typing --- .../_safe_multilabel_stratification.py | 27 +++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/autointent/context/data_handler/_safe_multilabel_stratification.py b/src/autointent/context/data_handler/_safe_multilabel_stratification.py index 2d703bde..8fa6b500 100644 --- a/src/autointent/context/data_handler/_safe_multilabel_stratification.py +++ b/src/autointent/context/data_handler/_safe_multilabel_stratification.py @@ -1,9 +1,14 @@ from __future__ import annotations +from typing import TYPE_CHECKING, Any + import numpy as np from skmultilearn.model_selection import IterativeStratification from transformers import set_seed +if TYPE_CHECKING: + import numpy.typing as npt + _MULTILABEL_NDIMS = 2 _RARE_LABEL_COUNT_SINGLETON = 1 _RARE_LABEL_COUNT_PAIR = 2 @@ -11,8 +16,8 @@ def safe_multilabel_split_indices( - y: np.ndarray, test_size: float, random_seed: int | None -) -> tuple[np.ndarray, np.ndarray]: + y: npt.NDArray[Any], test_size: float, random_seed: int | None +) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: """Split multilabel data with coverage guarantees for rare labels.""" _validate_multilabel_matrix(y) n_samples = int(y.shape[0]) @@ -38,7 +43,7 @@ def safe_multilabel_split_indices( return _finalize_partition(n_samples=n_samples, train_idx=train_idx, test_idx=test_idx) -def _validate_multilabel_matrix(y: np.ndarray) -> None: +def _validate_multilabel_matrix(y: npt.NDArray[Any]) -> None: if y.ndim != _MULTILABEL_NDIMS: msg = ( "Expected multilabel data to be a 2D matrix-like structure " @@ -55,7 +60,7 @@ def _assigned_split(sample_idx: int, train_idx: set[int], test_idx: set[int]) -> return None -def _force_singleton_labels(y: np.ndarray, label_counts: np.ndarray, train_idx: set[int]) -> None: +def _force_singleton_labels(y: npt.NDArray[Any], label_counts: npt.NDArray[Any], train_idx: set[int]) -> None: for label, count in enumerate(label_counts): if int(count) != _RARE_LABEL_COUNT_SINGLETON: continue @@ -83,7 +88,11 @@ def _force_pair_samples(a: int, b: int, train_idx: set[int], test_idx: set[int], def _force_pair_labels( - y: np.ndarray, label_counts: np.ndarray, train_idx: set[int], test_idx: set[int], rng: np.random.Generator + y: npt.NDArray[Any], + label_counts: npt.NDArray[Any], + train_idx: set[int], + test_idx: set[int], + rng: np.random.Generator, ) -> None: for label, count in enumerate(label_counts): if int(count) != _RARE_LABEL_COUNT_PAIR: @@ -94,8 +103,8 @@ def _force_pair_labels( def _iterative_stratify_remaining( - y: np.ndarray, - remaining: np.ndarray, + y: npt.NDArray[Any], + remaining: npt.NDArray[Any], test_size: float, random_seed: int | None, train_idx: set[int], @@ -116,7 +125,9 @@ def _iterative_stratify_remaining( test_idx |= set(remaining[test_r].tolist()) -def _finalize_partition(n_samples: int, train_idx: set[int], test_idx: set[int]) -> tuple[np.ndarray, np.ndarray]: +def _finalize_partition( + n_samples: int, train_idx: set[int], test_idx: set[int] +) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]: train_arr = np.array(sorted(train_idx), dtype=int) test_arr = np.array(sorted(test_idx), dtype=int) From b26f40a6e582352c2093afe24af547d5f9820f64 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 14:44:14 +0300 Subject: [PATCH 04/14] minor bug fix --- .../context/data_handler/_safe_multilabel_stratification.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/autointent/context/data_handler/_safe_multilabel_stratification.py b/src/autointent/context/data_handler/_safe_multilabel_stratification.py index 8fa6b500..633ffc29 100644 --- a/src/autointent/context/data_handler/_safe_multilabel_stratification.py +++ b/src/autointent/context/data_handler/_safe_multilabel_stratification.py @@ -118,7 +118,10 @@ def _iterative_stratify_remaining( splitter = IterativeStratification( n_splits=2, order=2, - sample_distribution_per_fold=[1.0 - test_size, test_size], + # NOTE: IterativeStratification expects fold distribution in (test, train) order, + # but returns indices as (train, test). This matches the library's behavior and + # keeps backward-compatible train/test sizes with prior implementation. + sample_distribution_per_fold=[test_size, 1.0 - test_size], ) train_r, test_r = next(splitter.split(np.arange(len(remaining)), y[remaining])) train_idx |= set(remaining[train_r].tolist()) From 719ac6738618075dd59b9f8f4ff5ff4aebb91d94 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 15:54:50 +0300 Subject: [PATCH 05/14] update readiness util --- .../context/data_handler/_readiness_util.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py index 3a919c2f..272eda85 100644 --- a/src/autointent/context/data_handler/_readiness_util.py +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -10,6 +10,7 @@ from datasets import Dataset as HFDataset from autointent import Dataset + from autointent.configs import DataConfig from ._safe_multilabel_stratification import _validate_multilabel_matrix from ._stratification import StratifiedSplitter @@ -36,7 +37,7 @@ def check_split_readiness( dataset: Dataset, split: str, test_size: float, - min_samples_per_class: int = 2, + config: DataConfig, allow_oos_in_train: bool | None = None, ) -> SplitReadinessResult: """Check whether the dataset has enough samples per class for stratified splitting. @@ -49,14 +50,14 @@ def check_split_readiness( dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). split: The split name to check (e.g. ``Split.TRAIN``). test_size: Proportion used for the test split (must match the value used when splitting). - min_samples_per_class: Minimum number of samples per class required for stratification. - Default 2 matches sklearn's requirement for a 2-way stratified split. + config: data config allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples and this is not set, the function returns ``ready=False`` with a reason. Returns: SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. """ + min_samples_per_class = _min_samples_per_class_for_config(config=config) if split not in dataset: return SplitReadinessResult( ready=False, @@ -102,6 +103,18 @@ def check_split_readiness( ) +def _min_samples_per_class_for_config(config: DataConfig) -> int: + """Return a recommended minimum samples-per-class for a given data config.""" + # Base requirement for a single stratified split. + # For CV, the canonical lower bound is one example per fold. + base = 2 if config.scheme == "ho" else int(config.n_folds) + + # separation_ratio triggers an extra stratified split of the effective train + # pool (e.g. decision vs scoring), so we double the requirement. + factor = 1 if config.separation_ratio is None else 2 + return base * factor + + def _check_multiclass_counts( dataset: HFDataset, label_feature: str, min_samples_per_class: int ) -> list[tuple[int, int]]: From 309a6cfacb505dc17dbdac63f11c15af259263f2 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 15:57:21 +0300 Subject: [PATCH 06/14] upd again --- src/autointent/context/data_handler/_readiness_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py index 272eda85..cf215767 100644 --- a/src/autointent/context/data_handler/_readiness_util.py +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -36,7 +36,6 @@ class SplitReadinessResult: def check_split_readiness( dataset: Dataset, split: str, - test_size: float, config: DataConfig, allow_oos_in_train: bool | None = None, ) -> SplitReadinessResult: @@ -67,7 +66,7 @@ def check_split_readiness( ) hf_split = dataset[split] splitter = StratifiedSplitter( - test_size=test_size, + test_size=config.validation_size, label_feature=dataset.label_feature, random_seed=None, ) From c9fb793292afb7314db43a7ae9938f49e65270e8 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sat, 14 Mar 2026 16:19:20 +0300 Subject: [PATCH 07/14] update tests --- tests/data/test_check_split_readiness.py | 34 ++++++++---------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py index 71500924..4bf64609 100644 --- a/tests/data/test_check_split_readiness.py +++ b/tests/data/test_check_split_readiness.py @@ -3,6 +3,7 @@ import pytest from autointent import Dataset +from autointent.configs import DataConfig from autointent.context.data_handler import ( SplitReadinessResult, check_split_readiness, @@ -87,7 +88,7 @@ def test_check_split_readiness_ready_when_enough_samples(dataset_enough_samples) result = check_split_readiness( dataset_enough_samples, split=Split.TRAIN, - test_size=0.3, + config=DataConfig(validation_size=0.3, separation_ratio=None), allow_oos_in_train=False, ) assert isinstance(result, SplitReadinessResult) @@ -102,7 +103,7 @@ def test_check_split_readiness_not_ready_underpopulated(dataset_underpopulated): result = check_split_readiness( dataset_underpopulated, split=Split.TRAIN, - test_size=0.3, + config=DataConfig(validation_size=0.3, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is False @@ -121,7 +122,7 @@ def test_check_split_readiness_missing_split(dataset_enough_samples): result = check_split_readiness( dataset_enough_samples, split="nonexistent_split", - test_size=0.3, + config=DataConfig(validation_size=0.3, separation_ratio=None), ) assert result.ready is False assert result.underpopulated_classes == [] @@ -134,7 +135,7 @@ def test_check_split_readiness_oos_allow_none(dataset_unsplitted): check_split_readiness( dataset_unsplitted, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=None, ) @@ -144,7 +145,7 @@ def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitt result = check_split_readiness( dataset_unsplitted, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is True @@ -157,8 +158,7 @@ def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_b result = check_split_readiness( dataset_two_classes_barely_enough, split=Split.TRAIN, - test_size=0.3, - min_samples_per_class=2, + config=DataConfig(validation_size=0.3, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is True @@ -166,8 +166,7 @@ def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_b result_strict = check_split_readiness( dataset_two_classes_barely_enough, split=Split.TRAIN, - test_size=0.3, - min_samples_per_class=3, + config=DataConfig(scheme="cv", n_folds=3, separation_ratio=None), allow_oos_in_train=False, ) assert result_strict.ready is False @@ -197,31 +196,20 @@ def test_check_split_readiness_multilabel_returns_ready(): result = check_split_readiness( dataset, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is False assert result.underpopulated_classes == [(1, 1)] assert result.reason is not None - # With min_samples_per_class=1 it becomes ready. - result_relaxed = check_split_readiness( - dataset, - split=Split.TRAIN, - test_size=0.5, - min_samples_per_class=1, - allow_oos_in_train=False, - ) - assert result_relaxed.ready is True - assert result_relaxed.underpopulated_classes == [] - def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samples): """When check_split_readiness says ready, split_dataset does not raise.""" result = check_split_readiness( dataset_enough_samples, split=Split.TRAIN, - test_size=0.5, + config=DataConfig(validation_size=0.5, separation_ratio=None), allow_oos_in_train=False, ) assert result.ready is True @@ -243,7 +231,7 @@ def test_check_split_readiness_underpopulated_implies_split_raises(dataset_under result = check_split_readiness( dataset_underpopulated, split=Split.TRAIN, - test_size=0.3, + config=DataConfig(validation_size=0.3), allow_oos_in_train=False, ) assert result.ready is False From 51ea3d1740344d2a12d250cf130b34bcd9800745 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 13:16:56 +0300 Subject: [PATCH 08/14] refactor readiness util a little bit --- .../context/data_handler/_readiness_util.py | 55 ++++++++----------- 1 file changed, 23 insertions(+), 32 deletions(-) diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py index cf215767..daaa2d29 100644 --- a/src/autointent/context/data_handler/_readiness_util.py +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -2,7 +2,7 @@ from collections import Counter from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, NamedTuple import numpy as np @@ -16,9 +16,17 @@ from ._stratification import StratifiedSplitter +class ClassCount(NamedTuple): + id: int + """Class (intent) index.""" + + count: int + """Number of samples from the class (intent).""" + + @dataclass(frozen=True) class SplitReadinessResult: - """Result of checking whether a dataset can be stratified split. + """Result of checking whether a dataset can be fed to autointent pipeline. Attributes: ready: True if stratification can be performed (enough samples per class). @@ -28,7 +36,7 @@ class SplitReadinessResult: """ ready: bool - underpopulated_classes: list[tuple[int, int]] + underpopulated_classes: list[ClassCount] min_samples_per_class_required: int reason: str | None @@ -39,11 +47,7 @@ def check_split_readiness( config: DataConfig, allow_oos_in_train: bool | None = None, ) -> SplitReadinessResult: - """Check whether the dataset has enough samples per class for stratified splitting. - - Uses the same OOS and stratification logic as :func:`split_dataset`, so downstream - code can call this before creating a :class:`DataHandler` or calling :func:`split_dataset` - and handle underpopulated classes (e.g. skip phase, log, or fail with a clear message). + """Check whether the dataset has enough samples per class for autointent pipeline. Args: dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). @@ -52,9 +56,6 @@ def check_split_readiness( config: data config allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples and this is not set, the function returns ``ready=False`` with a reason. - - Returns: - SplitReadinessResult with ``ready``, ``underpopulated_classes``, and optional ``reason``. """ min_samples_per_class = _min_samples_per_class_for_config(config=config) if split not in dataset: @@ -72,21 +73,9 @@ def check_split_readiness( ) inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) if inputs.multilabel: - underpopulated = _check_multilabel_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) - ready = len(underpopulated) == 0 - reason = None - if not ready: - parts = [f"label {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] - reason = "Multilabel stratification requires at least {} positives per label. Underpopulated: {}.".format( - min_samples_per_class, "; ".join(parts) - ) - return SplitReadinessResult( - ready=ready, - underpopulated_classes=underpopulated, - min_samples_per_class_required=min_samples_per_class, - reason=reason, - ) - underpopulated = _check_multiclass_counts(inputs.dataset, splitter.label_feature, min_samples_per_class) + underpopulated = _find_underpopulated_multilabel(inputs.dataset, splitter.label_feature, min_samples_per_class) + else: + underpopulated = _find_underpopulated_multiclass(inputs.dataset, splitter.label_feature, min_samples_per_class) ready = len(underpopulated) == 0 reason = None if not ready: @@ -114,20 +103,22 @@ def _min_samples_per_class_for_config(config: DataConfig) -> int: return base * factor -def _check_multiclass_counts( +def _find_underpopulated_multiclass( dataset: HFDataset, label_feature: str, min_samples_per_class: int -) -> list[tuple[int, int]]: +) -> list[ClassCount]: """Return (label, count) for each class with fewer than min_samples_per_class samples.""" labels: list[int] = dataset[label_feature] counts = Counter(labels) - return [(label, count) for label, count in counts.items() if count < min_samples_per_class] + return [ClassCount(id=label, count=count) for label, count in counts.items() if count < min_samples_per_class] -def _check_multilabel_counts( +def _find_underpopulated_multilabel( dataset: HFDataset, label_feature: str, min_samples_per_class: int -) -> list[tuple[int, int]]: +) -> list[ClassCount]: """Return (label_idx, positive_count) for each label with fewer than min_samples_per_class positives.""" y = np.asarray(dataset[label_feature]) _validate_multilabel_matrix(y) counts = y.sum(axis=0).astype(int) - return [(int(idx), int(count)) for idx, count in enumerate(counts) if count < min_samples_per_class] + return [ + ClassCount(id=int(idx), count=int(count)) for idx, count in enumerate(counts) if count < min_samples_per_class + ] From d925fba105dd8e94935a24c67697706468db8ee6 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 13:18:03 +0300 Subject: [PATCH 09/14] bug fix multilabel stratification --- src/autointent/context/data_handler/_stratification.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index f2a4ac91..6f786ede 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -172,7 +172,7 @@ def unmap_oos_multilabel(train_ds: HFDataset, test_ds: HFDataset) -> tuple[HFDat return StratifyInputs( dataset=mapped_dataset, - multilabel=False, + multilabel=True, test_size=self.test_size, post_split_fn=unmap_oos_multilabel, ) @@ -292,7 +292,9 @@ def _add_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) -> """ if sample[self.label_feature] is None: sample[self.label_feature] = [0] * n_classes - sample[self.label_feature] += [1] # type: ignore[operator] + sample[self.label_feature] += [1] # type: ignore[operator] + else: + sample[self.label_feature] += [0] # type: ignore[operator] return sample def _remove_oos_label(self, sample: dict[str, str | LabelType], n_classes: int) -> dict[str, str | LabelType]: From fcbd171dc65c6df63ca780ec9b40b500ae417854 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 13:29:15 +0300 Subject: [PATCH 10/14] widen utility's coverage --- .../context/data_handler/_readiness_util.py | 56 +++++++++++++++++-- 1 file changed, 51 insertions(+), 5 deletions(-) diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py index daaa2d29..defd59d5 100644 --- a/src/autointent/context/data_handler/_readiness_util.py +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -52,10 +52,9 @@ def check_split_readiness( Args: dataset: The dataset to check (e.g. the same passed to :func:`split_dataset`). split: The split name to check (e.g. ``Split.TRAIN``). - test_size: Proportion used for the test split (must match the value used when splitting). config: data config - allow_oos_in_train: Same as in :func:`split_dataset`. If the dataset has OOS samples - and this is not set, the function returns ``ready=False`` with a reason. + allow_oos_in_train: Same as in :func:`split_dataset`. If the split contains OOS samples + and this is ``None``, this function raises ``ValueError`` (mirrors splitting behavior). """ min_samples_per_class = _min_samples_per_class_for_config(config=config) if split not in dataset: @@ -77,8 +76,19 @@ def check_split_readiness( else: underpopulated = _find_underpopulated_multiclass(inputs.dataset, splitter.label_feature, min_samples_per_class) ready = len(underpopulated) == 0 - reason = None - if not ready: + reason: str | None = None + + if ready and (not inputs.multilabel): + split_ok, split_reason = _check_multiclass_split_size_feasibility( + dataset=inputs.dataset, + label_feature=splitter.label_feature, + test_size=inputs.test_size, + ) + if not split_ok: + ready = False + reason = split_reason + + if not ready and reason is None: parts = [f"class {label!r}: {count} (need {min_samples_per_class})" for label, count in underpopulated] reason = "Stratification requires at least {} samples per class. Underpopulated: {}.".format( min_samples_per_class, "; ".join(parts) @@ -122,3 +132,39 @@ def _find_underpopulated_multilabel( return [ ClassCount(id=int(idx), count=int(count)) for idx, count in enumerate(counts) if count < min_samples_per_class ] + + +def _check_multiclass_split_size_feasibility( + dataset: HFDataset, label_feature: str, test_size: float +) -> tuple[bool, str | None]: + """Return whether stratified train/test sizes are feasible for multiclass splits. + + Even if each class has >=2 samples, sklearn stratified splitting can fail when + the requested train/test sizes are too small to include all classes. + """ + labels = dataset[label_feature] + n_classes = len(set(labels)) + n_samples = len(labels) + + # Mirror sklearn's float test_size -> n_test calculation (ceil). + n_test = int(np.ceil(float(test_size) * n_samples)) + n_train = n_samples - n_test + + if n_test <= 0 or n_train <= 0: + return ( + False, + f"Requested split sizes are invalid (n_samples={n_samples}, test_size={test_size}).", + ) + if n_test < n_classes: + return ( + False, + f"Stratified split would allocate too few test samples (n_test={n_test}) " + f"for the number of classes (n_classes={n_classes}).", + ) + if n_train < n_classes: + return ( + False, + f"Stratified split would allocate too few train samples (n_train={n_train}) " + f"for the number of classes (n_classes={n_classes}).", + ) + return True, None From 2a0202b43a20f5c8cc4a0fb50b07d405d4c4b88c Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 13:29:42 +0300 Subject: [PATCH 11/14] annotate stop iteration error when all samples are oos --- src/autointent/context/data_handler/_stratification.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/autointent/context/data_handler/_stratification.py b/src/autointent/context/data_handler/_stratification.py index 6f786ede..81fb376a 100644 --- a/src/autointent/context/data_handler/_stratification.py +++ b/src/autointent/context/data_handler/_stratification.py @@ -160,7 +160,14 @@ def _stratify_inputs_allow_oos(self, dataset: HFDataset, multilabel: bool) -> St OOS is mapped to a class so it is stratified; post_split_fn unmaps it. """ if multilabel: - in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) + try: + in_domain_sample = next(sample for sample in dataset if sample[self.label_feature] is not None) + except StopIteration as e: + msg = ( + "Cannot infer multilabel dimensionality: dataset contains only OOS samples " + f"({self.label_feature}=None for all rows)." + ) + raise ValueError(msg) from e n_classes = len(in_domain_sample[self.label_feature]) mapped_dataset = dataset.map(self._add_oos_label, fn_kwargs={"n_classes": n_classes}) From 4989d020d3330146dd90c1741b38650379c758b9 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 13:29:48 +0300 Subject: [PATCH 12/14] add more tests --- tests/data/test_check_split_readiness.py | 161 +++++++++++++++++++++++ 1 file changed, 161 insertions(+) diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py index 4bf64609..235fbd14 100644 --- a/tests/data/test_check_split_readiness.py +++ b/tests/data/test_check_split_readiness.py @@ -19,10 +19,13 @@ def dataset_enough_samples(): "train": [ {"utterance": "a1", "label": 0}, {"utterance": "a2", "label": 0}, + {"utterance": "a3", "label": 0}, {"utterance": "b1", "label": 1}, {"utterance": "b2", "label": 1}, + {"utterance": "b3", "label": 1}, {"utterance": "c1", "label": 2}, {"utterance": "c2", "label": 2}, + {"utterance": "c3", "label": 2}, ], "test": [ {"utterance": "t1", "label": 0}, @@ -38,6 +41,28 @@ def dataset_enough_samples(): ) +@pytest.fixture +def dataset_three_classes_two_each(): + """3 classes, 2 samples each (no OOS). Useful for split-size feasibility tests.""" + return Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + {"utterance": "c1", "label": 2}, + {"utterance": "c2", "label": 2}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + @pytest.fixture def dataset_underpopulated(): """Multiclass dataset with one class having only 1 sample. Not ready for stratification.""" @@ -153,6 +178,34 @@ def test_check_split_readiness_oos_allow_false_enough_in_domain(dataset_unsplitt assert result.reason is None +def test_check_split_readiness_multiclass_too_small_test_split(dataset_three_classes_two_each): + """Even with >=2/class, stratification can fail if test split can't include all classes.""" + result = check_split_readiness( + dataset_three_classes_two_each, + split=Split.TRAIN, + config=DataConfig(validation_size=0.1, separation_ratio=None), + allow_oos_in_train=False, + ) + assert result.ready is False + assert result.underpopulated_classes == [] + assert result.reason is not None + assert "too few test samples" in result.reason + + +def test_check_split_readiness_multiclass_too_small_train_split(dataset_three_classes_two_each): + """Even with >=2/class, stratification can fail if train split can't include all classes.""" + result = check_split_readiness( + dataset_three_classes_two_each, + split=Split.TRAIN, + config=DataConfig(validation_size=0.8, separation_ratio=None), + allow_oos_in_train=False, + ) + assert result.ready is False + assert result.underpopulated_classes == [] + assert result.reason is not None + assert "too few train samples" in result.reason + + def test_check_split_readiness_min_samples_per_class_param(dataset_two_classes_barely_enough): """Custom min_samples_per_class is respected.""" result = check_split_readiness( @@ -204,6 +257,97 @@ def test_check_split_readiness_multilabel_returns_ready(): assert result.reason is not None +def test_check_split_readiness_multilabel_oos_allow_true_checks_oos_label(): + """Multilabel + OOS + allow_oos_in_train=True should not crash and should include OOS label.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0]}, + {"utterance": "x2", "label": [1, 0]}, + {"utterance": "x3", "label": [0, 1]}, + {"utterance": "x4", "label": [0, 1]}, + {"utterance": "oos1", "label": None}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + result = check_split_readiness( + dataset, + split=Split.TRAIN, + config=DataConfig(validation_size=0.5, separation_ratio=None), + allow_oos_in_train=True, + ) + assert result.ready is False + # OOS indicator label is appended -> index == n_classes == 2 + assert (2, 1) in result.underpopulated_classes + assert result.reason is not None + + +def test_check_split_readiness_multilabel_oos_allow_true_ready_when_oos_sufficient(): + """When OOS count meets minimum, multilabel readiness can be true.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0]}, + {"utterance": "x2", "label": [1, 0]}, + {"utterance": "x3", "label": [0, 1]}, + {"utterance": "x4", "label": [0, 1]}, + {"utterance": "oos1", "label": None}, + {"utterance": "oos2", "label": None}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + result = check_split_readiness( + dataset, + split=Split.TRAIN, + config=DataConfig(validation_size=0.5, separation_ratio=None), + allow_oos_in_train=True, + ) + assert result.ready is True + assert result.underpopulated_classes == [] + assert result.reason is None + + +def test_split_dataset_multilabel_oos_allow_true_does_not_raise(): + """Sanity-check: split_dataset supports multilabel+OOS when allow_oos_in_train=True.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "x1", "label": [1, 0]}, + {"utterance": "x2", "label": [1, 0]}, + {"utterance": "x3", "label": [0, 1]}, + {"utterance": "x4", "label": [0, 1]}, + {"utterance": "oos1", "label": None}, + {"utterance": "oos2", "label": None}, + ], + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + from autointent.context.data_handler import split_dataset + + train, test = split_dataset( + dataset, + split=Split.TRAIN, + test_size=0.5, + random_seed=42, + allow_oos_in_train=True, + ) + assert len(train) > 0 + assert len(test) > 0 + + def test_check_split_readiness_consistent_with_split_dataset(dataset_enough_samples): """When check_split_readiness says ready, split_dataset does not raise.""" result = check_split_readiness( @@ -245,3 +389,20 @@ def test_check_split_readiness_underpopulated_implies_split_raises(dataset_under random_seed=42, allow_oos_in_train=False, ) + + +def test_stratified_splitter_multilabel_allow_oos_all_oos_raises_value_error(): + """Multilabel OOS mapping needs an in-domain row to infer label dimensionality.""" + from datasets import Dataset as HFDataset + + from autointent.context.data_handler._stratification import StratifiedSplitter + + hf_ds = HFDataset.from_list( + [ + {"utterance": "oos1", "label": None}, + {"utterance": "oos2", "label": None}, + ] + ) + splitter = StratifiedSplitter(test_size=0.5, label_feature="label", random_seed=0) + with pytest.raises(ValueError, match=r"only OOS|infer multilabel dimensionality"): + splitter.get_stratify_inputs(hf_ds, multilabel=True, allow_oos_in_train=True) From 1d9c1bac6b84be81c1d15a0d76a0641fdf41c2fa Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 13:33:44 +0300 Subject: [PATCH 13/14] fix typing --- .../context/data_handler/_readiness_util.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py index defd59d5..24eaf67b 100644 --- a/src/autointent/context/data_handler/_readiness_util.py +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -20,7 +20,7 @@ class ClassCount(NamedTuple): id: int """Class (intent) index.""" - count: int + n_samples: int """Number of samples from the class (intent).""" @@ -30,7 +30,7 @@ class SplitReadinessResult: Attributes: ready: True if stratification can be performed (enough samples per class). - underpopulated_classes: List of (label, count) for classes below the minimum. + underpopulated_classes: List of (label, n_samples) for classes below the minimum. min_samples_per_class_required: Minimum samples per class used for the check. reason: Human-readable reason when not ready (e.g. OOS not configured). """ @@ -119,7 +119,11 @@ def _find_underpopulated_multiclass( """Return (label, count) for each class with fewer than min_samples_per_class samples.""" labels: list[int] = dataset[label_feature] counts = Counter(labels) - return [ClassCount(id=label, count=count) for label, count in counts.items() if count < min_samples_per_class] + return [ + ClassCount(id=label, n_samples=n_samples) + for label, n_samples in counts.items() + if n_samples < min_samples_per_class + ] def _find_underpopulated_multilabel( @@ -130,7 +134,9 @@ def _find_underpopulated_multilabel( _validate_multilabel_matrix(y) counts = y.sum(axis=0).astype(int) return [ - ClassCount(id=int(idx), count=int(count)) for idx, count in enumerate(counts) if count < min_samples_per_class + ClassCount(id=int(idx), n_samples=int(n_samples)) + for idx, n_samples in enumerate(counts) + if n_samples < min_samples_per_class ] From 32eb96c15553283b8c32a83253d7e0cf0167bfc7 Mon Sep 17 00:00:00 2001 From: voorhs Date: Sun, 15 Mar 2026 14:33:44 +0300 Subject: [PATCH 14/14] detect 0-samples classes too --- .../context/data_handler/_readiness_util.py | 37 ++++++++++++++----- tests/data/test_check_split_readiness.py | 30 +++++++++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/src/autointent/context/data_handler/_readiness_util.py b/src/autointent/context/data_handler/_readiness_util.py index 24eaf67b..f0e70521 100644 --- a/src/autointent/context/data_handler/_readiness_util.py +++ b/src/autointent/context/data_handler/_readiness_util.py @@ -71,10 +71,17 @@ def check_split_readiness( random_seed=None, ) inputs = splitter.get_stratify_inputs(hf_split, dataset.multilabel, allow_oos_in_train) + expected_n_classes = _expected_n_classes(dataset, inputs.dataset, splitter.label_feature) + if inputs.multilabel: underpopulated = _find_underpopulated_multilabel(inputs.dataset, splitter.label_feature, min_samples_per_class) else: - underpopulated = _find_underpopulated_multiclass(inputs.dataset, splitter.label_feature, min_samples_per_class) + underpopulated = _find_underpopulated_multiclass( + inputs.dataset, + splitter.label_feature, + min_samples_per_class, + expected_n_classes=expected_n_classes, + ) ready = len(underpopulated) == 0 reason: str | None = None @@ -83,6 +90,7 @@ def check_split_readiness( dataset=inputs.dataset, label_feature=splitter.label_feature, test_size=inputs.test_size, + expected_n_classes=expected_n_classes, ) if not split_ok: ready = False @@ -114,16 +122,19 @@ def _min_samples_per_class_for_config(config: DataConfig) -> int: def _find_underpopulated_multiclass( - dataset: HFDataset, label_feature: str, min_samples_per_class: int + dataset: HFDataset, label_feature: str, min_samples_per_class: int, expected_n_classes: int ) -> list[ClassCount]: """Return (label, count) for each class with fewer than min_samples_per_class samples.""" labels: list[int] = dataset[label_feature] counts = Counter(labels) - return [ - ClassCount(id=label, n_samples=n_samples) - for label, n_samples in counts.items() - if n_samples < min_samples_per_class - ] + + # Ensure "missing" classes are treated as 0-count (underpopulated) + result: list[ClassCount] = [] + for label in range(int(expected_n_classes)): + n_samples = int(counts.get(label, 0)) + if n_samples < min_samples_per_class: + result.append(ClassCount(id=int(label), n_samples=n_samples)) + return result def _find_underpopulated_multilabel( @@ -141,7 +152,7 @@ def _find_underpopulated_multilabel( def _check_multiclass_split_size_feasibility( - dataset: HFDataset, label_feature: str, test_size: float + dataset: HFDataset, label_feature: str, test_size: float, expected_n_classes: int ) -> tuple[bool, str | None]: """Return whether stratified train/test sizes are feasible for multiclass splits. @@ -149,7 +160,7 @@ def _check_multiclass_split_size_feasibility( the requested train/test sizes are too small to include all classes. """ labels = dataset[label_feature] - n_classes = len(set(labels)) + n_classes = expected_n_classes n_samples = len(labels) # Mirror sklearn's float test_size -> n_test calculation (ceil). @@ -174,3 +185,11 @@ def _check_multiclass_split_size_feasibility( f"for the number of classes (n_classes={n_classes}).", ) return True, None + + +def _expected_n_classes(dataset: Dataset, prepared: HFDataset, label_feature: str) -> int: + if dataset.multilabel: + return len(prepared[label_feature][0]) + labels: list[int] = prepared[label_feature] + max_seen = max(labels) if labels else -1 + return max(dataset.n_classes, int(max_seen) + 1) diff --git a/tests/data/test_check_split_readiness.py b/tests/data/test_check_split_readiness.py index 235fbd14..36877225 100644 --- a/tests/data/test_check_split_readiness.py +++ b/tests/data/test_check_split_readiness.py @@ -257,6 +257,36 @@ def test_check_split_readiness_multilabel_returns_ready(): assert result.reason is not None +def test_check_split_readiness_marks_declared_but_unseen_intent_as_underpopulated(): + """Intents with 0 samples should be flagged so callers can filter them out.""" + dataset = Dataset.from_dict( + { + "train": [ + {"utterance": "a1", "label": 0}, + {"utterance": "a2", "label": 0}, + {"utterance": "b1", "label": 1}, + {"utterance": "b2", "label": 1}, + ], + # Declare 3 intents, but only provide samples for ids 0 and 1. + "intents": [ + {"id": 0, "regex_full_match": [], "regex_partial_match": []}, + {"id": 1, "regex_full_match": [], "regex_partial_match": []}, + {"id": 2, "regex_full_match": [], "regex_partial_match": []}, + ], + } + ) + + result = check_split_readiness( + dataset, + split=Split.TRAIN, + config=DataConfig(validation_size=0.5, separation_ratio=None), + allow_oos_in_train=False, + ) + assert result.ready is False + assert (2, 0) in result.underpopulated_classes + assert result.reason is not None + + def test_check_split_readiness_multilabel_oos_allow_true_checks_oos_label(): """Multilabel + OOS + allow_oos_in_train=True should not crash and should include OOS label.""" dataset = Dataset.from_dict(