From 01e709867e2d0c3f35f63a77dfd0e97b20c88147 Mon Sep 17 00:00:00 2001 From: Samuel Stevens Date: Mon, 16 Mar 2026 13:55:15 -0400 Subject: [PATCH 1/4] Add special-token support to ordered and shuffled loaders --- .../specs/special-token-dataloaders.md | 107 ++++++++++++++++++ src/saev/data/ordered.py | 33 ++++-- src/saev/data/shards.py | 7 +- src/saev/data/shuffled.py | 64 +++++++++-- tests/test_ordered_dataloader.py | 67 +++++++++++ tests/test_shards_math.py | 10 ++ tests/test_shuffled_dataloader.py | 82 +++++++++++++- 7 files changed, 349 insertions(+), 21 deletions(-) create mode 100644 docs/internal/specs/special-token-dataloaders.md diff --git a/docs/internal/specs/special-token-dataloaders.md b/docs/internal/specs/special-token-dataloaders.md new file mode 100644 index 00000000..04960e93 --- /dev/null +++ b/docs/internal/specs/special-token-dataloaders.md @@ -0,0 +1,107 @@ +# Special Token Dataloaders + +## Problem + +`IndexedDataset` already supports `Config.tokens = "special"` and returns one activation per example for the CLS token. `OrderedDataLoader` and `ShuffledDataLoader` still reject that mode, even though the public docs and downstream plans already assume CLS-only loading is available. + +This blocks a simple workflow for training or analyzing models on only special tokens, for example training an SAE on CLS activations. + +## Goal + +Add `tokens = "special"` support to the ordered and shuffled activation loaders for a fixed transformer layer. + +## Non-goals + +- No `tokens = "all"` support in either loader. +- No support for `layer = "all"` in either loader. +- No patch-label filtering for special tokens. +- No change to on-disk shard layout or indexing semantics. + +## Requirements + +### Functional + +1. `OrderedConfig.tokens` must accept `"special"` as well as `"content"`. + +2. `OrderedDataLoader` and `ShuffledDataLoader` must both accept: + - `tokens = "special"` + - `layer = ` where the layer is present in metadata + +3. In special-token mode, each yielded sample corresponds to exactly one example: + - `example_idx` is the example index + - `token_idx` is `-1` + - `act` is the activation stored at token position `0` in the shard + +4. Epoch sizes in special-token mode: + - `n_samples == metadata.n_examples` + - `len(loader)` follows the existing `batch_size` and `drop_last` logic + +5. Ordered loader ordering in special-token mode: + - samples are yielded in increasing `example_idx` + - `token_idx` is always `-1` + +6. Shuffled loader semantics in special-token mode: + - each example appears once per epoch + - order remains deterministic for a fixed seed + - batches still expose the same keys: `act`, `example_idx`, `token_idx` + +7. Token labels: + - ordered loader must not attach `token_labels` for special tokens, even if `labels.bin` exists + - shuffled loader must reject `ignore_labels` when `tokens != "content"` + +### Non-functional + +1. Reuse the existing shard protocol. Special-token mode must read token position `0` from each example when `metadata.cls_token` is true. + +2. Keep the implementation small. The existing content-token path should remain unchanged except where a shared branch is cleaner than duplicate code. + +3. Preserve the current meaning of `token_idx = -1` for special tokens so loader outputs match `IndexedDataset` and `shards.IndexMap`. + +## Design + +### Ordered loader + +Continue to use `shards.IndexMap` to translate a global sample index into a shard location. This already knows that special tokens map to: + +- `content_token_idx = -1` +- `token_idx_in_shard = 0` + +The ordered manager only needs two behavior changes: + +1. permit `tokens = "special"` in the fixed-layer path +2. skip label lookup when `content_token_idx < 0` + +### Shuffled loader + +The shuffled loader currently iterates over every content token in a shard chunk. In special-token mode it should instead emit exactly one activation per example in the chunk: + +- activation source: `mmap[start:end, layer_i, 0]` +- metadata: + - column 0: global example indices + - column 1: `-1` + +`ignore_labels` remains content-token-only because `labels.bin` is defined over content tokens. + +## Test Plan + +Add red tests before implementation: + +1. Ordered loader special-token smoke test on fake shards: + - batch iterates successfully + - all `token_idx == -1` + - first batch has sequential `example_idx` + +2. Ordered loader matches `IndexedDataset` in special-token mode. + +3. Shuffled loader special-token epoch test on fake shards: + - all `token_idx == -1` + - every example appears exactly once in a full epoch + - activations match `IndexedDataset` for the sampled `example_idx` + +4. Shuffled loader rejects `ignore_labels` when `tokens = "special"`. + +## Acceptance Criteria + +- The new tests fail before the implementation change. +- The new tests pass after the implementation change. +- Existing content-token tests continue to pass. diff --git a/src/saev/data/ordered.py b/src/saev/data/ordered.py index 733002f6..db9e538f 100644 --- a/src/saev/data/ordered.py +++ b/src/saev/data/ordered.py @@ -59,7 +59,7 @@ class Config: """ shards: pathlib.Path = pathlib.Path("$SAEV_SCRATCH/saev/shards/abcdefg") - tokens: tp.Literal["content"] = "content" + tokens: tp.Literal["special", "content"] = "content" layer: int | tp.Literal["all"] = -2 batch_size: int = 1024 * 16 batch_timeout_s: float = 30.0 @@ -92,9 +92,9 @@ def _manager_main( ) # 0. PRE-CONDITIONS - if cfg.tokens != "content" or not isinstance(cfg.layer, int): + if cfg.tokens not in ("special", "content") or not isinstance(cfg.layer, int): raise NotImplementedError( - "High-throughput loader only supports `content` and fixed `layer` mode for now." + "High-throughput loader only supports `special` or `content` with fixed `layer` mode for now." ) assert cfg.layer in md.layers, f"Layer {cfg.layer} not in {md.layers}" @@ -107,7 +107,7 @@ def _manager_main( # Check if labels.bin exists labels_mmap = None labels_path = cfg.shards / "labels.bin" - if labels_path.exists(): + if labels_path.exists() and cfg.tokens == "content": labels_mmap = np.memmap( labels_path, mode="r", @@ -121,7 +121,7 @@ def _manager_main( assert shard.n_examples == shard_info[0].n_examples == md.examples_per_shard # Calculate total number of samples - n_samples = md.n_examples * md.content_tokens_per_example + n_samples = len(index_map) logger.debug("Found %d samples.", n_samples) @@ -162,7 +162,7 @@ def _manager_main( batch_token_i.append(idx.content_token_idx) # Add patch label if available - if labels_mmap is not None: + if labels_mmap is not None and idx.content_token_idx >= 0: batch_token_labels.append( labels_mmap[idx.example_idx, idx.content_token_idx] ) @@ -176,7 +176,7 @@ def _manager_main( } # Add labels if available - if labels_mmap is not None: + if labels_mmap is not None and batch_token_labels: batch["token_labels"] = torch.tensor( batch_token_labels, dtype=torch.long ) @@ -218,6 +218,10 @@ def __init__(self, cfg: Config): self.cfg = cfg if not os.path.isdir(self.cfg.shards): raise RuntimeError(f"Activations are not saved at '{self.cfg.shards}'.") + if self.cfg.layer == "all": + raise NotImplementedError( + "High-throughput loader only supports a fixed integer `layer`." + ) self.md = shards.Metadata.load(self.cfg.shards) @@ -279,9 +283,13 @@ def _start_manager(self): ) self.manager_proc.start() - def __iter__(self) -> collections.abc.Iterable[ExampleBatch]: + def __iter__(self) -> collections.abc.Iterator[ExampleBatch]: """Yields batches in order.""" self._start_manager() + msg = "Manager state did not initialize correctly." + assert self.batch_queue is not None, msg + assert self.err_queue is not None, msg + assert self.manager_proc is not None, msg n = 0 try: @@ -352,6 +360,10 @@ def __del__(self): def _calculate_n_samples(self) -> int: """Helper to calculate total number of examples based on config.""" + if self.cfg.tokens == "special": + msg = "tokens='special' requires shards with a CLS token." + assert self.md.cls_token, msg + match (self.cfg.tokens, self.cfg.layer): case ("special", "all"): return self.md.n_examples * len(self.md.layers) @@ -366,7 +378,10 @@ def _calculate_n_samples(self) -> int: * self.md.content_tokens_per_example ) case _: - tp.assert_never((self.cfg.tokens, self.cfg.layer)) + msg = ( + f"Unsupported loader config: {self.cfg.tokens=}, {self.cfg.layer=}." + ) + raise AssertionError(msg) def __len__(self) -> int: """Returns the number of batches in an epoch.""" diff --git a/src/saev/data/shards.py b/src/saev/data/shards.py index 96f58bf3..bcb697f6 100644 --- a/src/saev/data/shards.py +++ b/src/saev/data/shards.py @@ -995,7 +995,7 @@ class IndexMap: md: Metadata tokens: tp.Literal["special", "content", "all"] - layer: int + layer: int | tp.Literal["all"] layer_idx_lookup: dict[int, int] def __init__( @@ -1029,7 +1029,7 @@ def from_global(self, idx: int | np.int_) -> Index: # [CLS] tokens only right now example_idx = idx shard_idx = idx // self.md.examples_per_shard - example_idx_in_shard = idx // self.md.examples_per_shard + example_idx_in_shard = idx % self.md.examples_per_shard return Index( idx=idx, example_idx=example_idx, @@ -1101,4 +1101,5 @@ def __len__(self) -> int: * self.md.tokens_per_example ) case _: - tp.assert_never((self.cfg.tokens, self.cfg.layer)) + msg = f"Unsupported index map config: {self.tokens=}, {self.layer=}." + raise AssertionError(msg) diff --git a/src/saev/data/shuffled.py b/src/saev/data/shuffled.py index 43e88e38..49513e0a 100644 --- a/src/saev/data/shuffled.py +++ b/src/saev/data/shuffled.py @@ -161,7 +161,7 @@ def _io_worker( shard_info = shards.ShardInfo.load(shards_path) # Pre-conditions - assert cfg.tokens == "content" + assert cfg.tokens in ("special", "content") assert isinstance(cfg.layer, int) # If we need to filter by labels, ensure we have the labels @@ -200,6 +200,45 @@ def _io_worker( for start, end in helpers.batched_idx( shard_info[shard_i].n_examples, chunk_size ): + if cfg.tokens == "special": + t0 = time.perf_counter() + acts = torch.from_numpy(mmap[start:end, layer_i, 0]) + t1 = time.perf_counter() + + meta = torch.full((end - start, 2), -1, dtype=torch.int32) + meta[:, 0] = ex_i_offset + torch.arange(start, end) + + last_ex_i = int(meta[:, 0].max().item()) + if last_ex_i >= md.n_examples: + err = ExampleOutOfBoundsError(md, last_ex_i) + logger.warning(err.message) + raise err + + fill_before = reservoir.fill() + reservoir.put(acts, meta) + t2 = time.perf_counter() + fill_after = reservoir.fill() + + n_reads += 1 + bytes_sent += ( + acts.numel() * acts.element_size() + + meta.numel() * meta.element_size() + ) + + now = time.time() + if now - t_last_report >= cfg.log_every_s: + logger.debug( + "shard=%s mb_sent=%.1f read_ms=%.2f put_ms=%.2f fill-before=%.3f fill-after=%.3f", + shard_i, + bytes_sent / 1e6, + (t1 - t0) * 1e3, + (t2 - t1) * 1e3, + fill_before, + fill_after, + ) + t_last_report = now + continue + for t in range(md.content_tokens_per_example): token_idx = t + int(md.cls_token) @@ -240,7 +279,7 @@ def _io_worker( meta = torch.full((end - start, 2), t, dtype=torch.int32) meta[:, 0] = ex_i_offset + torch.arange(start, end) - last_ex_i = meta[:, 0].max().item() + last_ex_i = int(meta[:, 0].max().item()) if last_ex_i >= md.n_examples: err = ExampleOutOfBoundsError(md, last_ex_i) logger.warning(err.message) @@ -315,9 +354,9 @@ def _manager_main( ) # 0. PRE-CONDITIONS - if cfg.tokens != "content" or not isinstance(cfg.layer, int): + if cfg.tokens not in ("special", "content") or not isinstance(cfg.layer, int): raise NotImplementedError( - "High-throughput loader only supports `content` and fixed `layer` mode for now." + "High-throughput loader only supports `special` or `content` with fixed `layer` mode for now." ) assert cfg.layer in metadata.layers, f"Layer {cfg.layer} not in {metadata.layers}" @@ -506,6 +545,10 @@ def _start_manager(self): def __iter__(self) -> collections.abc.Iterator[ExampleBatch]: """Yields batches.""" self._start_manager() + msg = "Manager state did not initialize correctly." + assert self.reservoir is not None, msg + assert self.err_queue is not None, msg + assert self.manager_proc is not None, msg n, b = 0, 0 try: @@ -641,12 +684,14 @@ def _calculate_n_samples(self) -> int: When ignore_labels is specified, this counts the actual number of patches that remain after filtering out the ignored labels. """ + if self.cfg.tokens == "special": + msg = "tokens='special' requires shards with a CLS token." + assert self.metadata.cls_token, msg + # First calculate the maximum possible samples max_samples = 0 match (self.cfg.tokens, self.cfg.layer): - case ("cls", "all"): - max_samples = self.metadata.n_examples * len(self.metadata.layers) - case ("cls", int()): + case ("special", int()): max_samples = self.metadata.n_examples case ("content", int()): max_samples = ( @@ -659,7 +704,10 @@ def _calculate_n_samples(self) -> int: * self.metadata.content_tokens_per_example ) case _: - tp.assert_never((self.cfg.tokens, self.cfg.layer)) + msg = ( + f"Unsupported loader config: {self.cfg.tokens=}, {self.cfg.layer=}." + ) + raise AssertionError(msg) # If no filtering, return max samples if not self.cfg.ignore_labels: diff --git a/tests/test_ordered_dataloader.py b/tests/test_ordered_dataloader.py index 7da913ca..4919d177 100644 --- a/tests/test_ordered_dataloader.py +++ b/tests/test_ordered_dataloader.py @@ -247,6 +247,19 @@ def test_constructor_validation(ordered_cfg): OrderedDataLoader(cfg) +@pytest.mark.slow +@pytest.mark.parametrize("tokens", ["content", "special"]) +def test_constructor_rejects_layer_all(tokens): + with pytest.helpers.tmp_shards_root() as shards_root: + shards_dir = pytest.helpers.write_shards( + shards_root, data=datasets.FakeImg(n_examples=4) + ) + cfg = OrderedConfig(shards=shards_dir, tokens=tokens, layer="all") + + with pytest.raises(NotImplementedError, match="fixed integer `layer`"): + OrderedDataLoader(cfg) + + def test_properties(ordered_cfg): """Test OrderedDataLoader properties.""" dl = OrderedDataLoader(ordered_cfg) @@ -360,6 +373,60 @@ def test_timeout_handling(ordered_cfg): assert batch["act"].shape[0] > 0 +@pytest.mark.slow +def test_special_tokens_smoke(): + with pytest.helpers.tmp_shards_root() as shards_root: + shards_dir = pytest.helpers.write_shards( + shards_root, data=datasets.FakeImg(n_examples=5) + ) + cfg = OrderedConfig(shards=shards_dir, tokens="special", layer=0, batch_size=3) + dl = OrderedDataLoader(cfg) + + batch = next(iter(dl)) + + torch.testing.assert_close( + batch["example_idx"], torch.tensor([0, 1, 2], dtype=torch.long) + ) + torch.testing.assert_close( + batch["token_idx"], torch.full((3,), -1, dtype=torch.long) + ) + + +@pytest.mark.slow +def test_special_tokens_match_indexed_dataset(): + with pytest.helpers.tmp_shards_root() as shards_root: + shards_dir = pytest.helpers.write_shards( + shards_root, data=datasets.FakeImg(n_examples=6) + ) + ordered_cfg = OrderedConfig( + shards=shards_dir, tokens="special", layer=0, batch_size=4 + ) + dl = OrderedDataLoader(ordered_cfg) + indexed_cfg = IndexedConfig(shards=shards_dir, tokens="special", layer=0) + ds = IndexedDataset(indexed_cfg) + + seen = 0 + for batch in dl: + for i in range(batch["act"].shape[0]): + example_idx = batch["example_idx"][i].item() + token_idx = batch["token_idx"][i].item() + indexed_example = ds[seen] + + assert example_idx == seen + assert token_idx == -1 + assert indexed_example["example_idx"] == example_idx + assert indexed_example["token_idx"] == token_idx + torch.testing.assert_close( + indexed_example["act"], + batch["act"][i], + rtol=1e-5, + atol=1e-6, + ) + seen += 1 + + assert seen == len(ds) + + @pytest.mark.slow def test_ordered_dataloader_with_tiny_fake_dataset(): """Test OrderedDataLoader with a very small fake dataset to ensure end behavior works.""" diff --git a/tests/test_shards_math.py b/tests/test_shards_math.py index 6ad5228c..5ef1b307 100644 --- a/tests/test_shards_math.py +++ b/tests/test_shards_math.py @@ -93,3 +93,13 @@ def test_special_tokens_with_cls_token(custom_shards_dir): assert index.example_idx_in_shard == 0 assert index.layer_idx_in_shard == 0 assert index.token_idx_in_shard == 0 + + idx = 5 + index = index_map.from_global(idx) + assert index.idx == idx + assert index.example_idx == 5 + assert index.content_token_idx == -1 + assert index.shard_idx == 1 + assert index.example_idx_in_shard == 2 + assert index.layer_idx_in_shard == 0 + assert index.token_idx_in_shard == 0 diff --git a/tests/test_shuffled_dataloader.py b/tests/test_shuffled_dataloader.py index 3d2deb68..c57c02c2 100644 --- a/tests/test_shuffled_dataloader.py +++ b/tests/test_shuffled_dataloader.py @@ -15,7 +15,13 @@ import torch.multiprocessing as mp import saev.data -from saev.data import ShuffledConfig, ShuffledDataLoader, datasets +from saev.data import ( + IndexedConfig, + IndexedDataset, + ShuffledConfig, + ShuffledDataLoader, + datasets, +) mp.set_start_method("spawn", force=True) @@ -66,6 +72,80 @@ def test_iter_smoke(cfg): assert "token_idx" in batch +@pytest.mark.slow +def test_special_tokens_cover_each_example_once(): + with tmp_shards_root() as shards_root: + shards_dir = saev.data.shards.worker_fn( + family="fake-clip", + ckpt="hf-hub:hf-internal-testing/tiny-open-clip-model", + content_tokens_per_example=16, + cls_token=True, + d_model=128, + layers=[0], + data=datasets.FakeImg(n_examples=7), + batch_size=4, + n_workers=0, + max_tokens_per_shard=64, + shards_root=shards_root, + device="cpu", + ) + + cfg = ShuffledConfig( + shards=shards_dir, tokens="special", layer=0, batch_size=3, seed=17 + ) + dl = ShuffledDataLoader(cfg) + ds = IndexedDataset(IndexedConfig(shards=shards_dir, tokens="special", layer=0)) + + seen_example_i = set() + n_seen = 0 + for batch in dl: + assert torch.all(batch["token_idx"] == -1) + for i in range(batch["act"].shape[0]): + example_idx = int(batch["example_idx"][i].item()) + seen_example_i.add(example_idx) + indexed_example = ds[example_idx] + assert indexed_example["token_idx"] == -1 + torch.testing.assert_close( + indexed_example["act"], + batch["act"][i], + rtol=1e-5, + atol=1e-6, + ) + n_seen += 1 + + assert n_seen == len(ds) + assert seen_example_i == set(range(len(ds))) + + +@pytest.mark.slow +def test_special_tokens_reject_ignore_labels(): + with tmp_shards_root() as shards_root: + shards_dir = saev.data.shards.worker_fn( + family="fake-clip", + ckpt="hf-hub:hf-internal-testing/tiny-open-clip-model", + content_tokens_per_example=16, + cls_token=True, + d_model=128, + layers=[0], + data=datasets.FakeImg(n_examples=4), + batch_size=4, + n_workers=0, + max_tokens_per_shard=64, + shards_root=shards_root, + device="cpu", + ) + cfg = ShuffledConfig( + shards=shards_dir, + tokens="special", + layer=0, + batch_size=2, + ignore_labels=[0], + ) + + with pytest.raises(NotImplementedError, match="content"): + ShuffledDataLoader(cfg) + + def test_batches(cfg): dl = ShuffledDataLoader(cfg) it = iter(dl) From 0c3e2acbadb01715bb9d36c5232de33fd5abf905 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 12:03:10 -0700 Subject: [PATCH 2/4] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/saev/data/ordered.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/saev/data/ordered.py b/src/saev/data/ordered.py index db9e538f..f10678b8 100644 --- a/src/saev/data/ordered.py +++ b/src/saev/data/ordered.py @@ -381,7 +381,7 @@ def _calculate_n_samples(self) -> int: msg = ( f"Unsupported loader config: {self.cfg.tokens=}, {self.cfg.layer=}." ) - raise AssertionError(msg) + raise ValueError(msg) def __len__(self) -> int: """Returns the number of batches in an epoch.""" From fa20cf6730524c5a7d301c70723972f03082c910 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 16 Mar 2026 12:03:27 -0700 Subject: [PATCH 3/4] change error type Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- src/saev/data/shuffled.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/saev/data/shuffled.py b/src/saev/data/shuffled.py index 49513e0a..7f65e2cc 100644 --- a/src/saev/data/shuffled.py +++ b/src/saev/data/shuffled.py @@ -707,7 +707,7 @@ def _calculate_n_samples(self) -> int: msg = ( f"Unsupported loader config: {self.cfg.tokens=}, {self.cfg.layer=}." ) - raise AssertionError(msg) + raise ValueError(msg) # If no filtering, return max samples if not self.cfg.ignore_labels: From 39c455460c04a915e46e9c522fc6e1d0454d33ed Mon Sep 17 00:00:00 2001 From: Samuel Stevens Date: Mon, 16 Mar 2026 20:06:53 -0400 Subject: [PATCH 4/4] Refactor shuffled chunk writes and restore assert_never --- src/saev/data/ordered.py | 5 +- src/saev/data/shards.py | 3 +- src/saev/data/shuffled.py | 115 +++++++++++++++++--------------------- 3 files changed, 52 insertions(+), 71 deletions(-) diff --git a/src/saev/data/ordered.py b/src/saev/data/ordered.py index f10678b8..04869026 100644 --- a/src/saev/data/ordered.py +++ b/src/saev/data/ordered.py @@ -378,10 +378,7 @@ def _calculate_n_samples(self) -> int: * self.md.content_tokens_per_example ) case _: - msg = ( - f"Unsupported loader config: {self.cfg.tokens=}, {self.cfg.layer=}." - ) - raise ValueError(msg) + tp.assert_never((self.cfg.tokens, self.cfg.layer)) def __len__(self) -> int: """Returns the number of batches in an epoch.""" diff --git a/src/saev/data/shards.py b/src/saev/data/shards.py index bcb697f6..1daca331 100644 --- a/src/saev/data/shards.py +++ b/src/saev/data/shards.py @@ -1101,5 +1101,4 @@ def __len__(self) -> int: * self.md.tokens_per_example ) case _: - msg = f"Unsupported index map config: {self.tokens=}, {self.layer=}." - raise AssertionError(msg) + tp.assert_never((self.tokens, self.layer)) diff --git a/src/saev/data/shuffled.py b/src/saev/data/shuffled.py index 7f65e2cc..f21b6114 100644 --- a/src/saev/data/shuffled.py +++ b/src/saev/data/shuffled.py @@ -174,6 +174,53 @@ def _io_worker( chunk_size = min(1024, math.ceil(cfg.batch_size * cfg.buffer_size / cfg.n_threads)) + def put_chunk( + acts: Tensor, + meta: Int[Tensor, " n 2"], + *, + shard_i: int, + t0: float, + t1: float, + ) -> None: + nonlocal bytes_sent, n_reads, t_last_report + + n_examples = acts.shape[0] + msg = f"{n_examples} != {meta.shape[0]}" + assert n_examples == meta.shape[0], msg + msg = f"Expected metadata shape {(n_examples, 2)}, got {tuple(meta.shape)}" + assert tuple(meta.shape) == (n_examples, 2), msg + + last_ex_i = int(meta[:, 0].max().item()) + if last_ex_i >= md.n_examples: + err = ExampleOutOfBoundsError(md, last_ex_i) + logger.warning(err.message) + raise err + + fill_before = reservoir.fill() + reservoir.put(acts, meta) + t2 = time.perf_counter() + fill_after = reservoir.fill() + + n_reads += 1 + bytes_sent += ( + acts.numel() * acts.element_size() + meta.numel() * meta.element_size() + ) + + now = time.time() + if now - t_last_report < cfg.log_every_s: + return + + logger.debug( + "shard=%s mb_sent=%.1f read_ms=%.2f put_ms=%.2f fill-before=%.3f fill-after=%.3f", + shard_i, + bytes_sent / 1e6, + (t1 - t0) * 1e3, + (t2 - t1) * 1e3, + fill_before, + fill_after, + ) + t_last_report = now + reason = "" while not stop_event.is_set(): @@ -183,7 +230,6 @@ def _io_worker( logger.debug("Got 'None' from work_queue; exiting.") reason = "poison_pill" break - t1 = time.perf_counter() fname = f"acts{shard_i:06}.bin" logger.info("Opening %s.", fname) @@ -194,7 +240,6 @@ def _io_worker( mmap = np.memmap( acts_fpath, mode="r", dtype=np.float32, shape=md.shard_shape ) - t2 = time.perf_counter() # Only iterate over the actual number of examples in this shard for start, end in helpers.batched_idx( @@ -207,36 +252,7 @@ def _io_worker( meta = torch.full((end - start, 2), -1, dtype=torch.int32) meta[:, 0] = ex_i_offset + torch.arange(start, end) - - last_ex_i = int(meta[:, 0].max().item()) - if last_ex_i >= md.n_examples: - err = ExampleOutOfBoundsError(md, last_ex_i) - logger.warning(err.message) - raise err - - fill_before = reservoir.fill() - reservoir.put(acts, meta) - t2 = time.perf_counter() - fill_after = reservoir.fill() - - n_reads += 1 - bytes_sent += ( - acts.numel() * acts.element_size() - + meta.numel() * meta.element_size() - ) - - now = time.time() - if now - t_last_report >= cfg.log_every_s: - logger.debug( - "shard=%s mb_sent=%.1f read_ms=%.2f put_ms=%.2f fill-before=%.3f fill-after=%.3f", - shard_i, - bytes_sent / 1e6, - (t1 - t0) * 1e3, - (t2 - t1) * 1e3, - fill_before, - fill_after, - ) - t_last_report = now + put_chunk(acts, meta, shard_i=shard_i, t0=t0, t1=t1) continue for t in range(md.content_tokens_per_example): @@ -279,35 +295,7 @@ def _io_worker( meta = torch.full((end - start, 2), t, dtype=torch.int32) meta[:, 0] = ex_i_offset + torch.arange(start, end) - last_ex_i = int(meta[:, 0].max().item()) - if last_ex_i >= md.n_examples: - err = ExampleOutOfBoundsError(md, last_ex_i) - logger.warning(err.message) - raise err - - fill_before = reservoir.fill() - reservoir.put(acts, meta) - t2 = time.perf_counter() - fill_after = reservoir.fill() - - n_reads += 1 - bytes_sent += ( - acts.numel() * acts.element_size() - + meta.numel() * meta.element_size() - ) - - now = time.time() - if now - t_last_report >= cfg.log_every_s: - logger.debug( - "shard=%s mb_sent=%.1f read_ms=%.2f put_ms=%.2f fill-before=%.3f fill-after=%.3f", - shard_i, - bytes_sent / 1e6, - (t1 - t0) * 1e3, - (t2 - t1) * 1e3, - fill_before, - fill_after, - ) - t_last_report = now + put_chunk(acts, meta, shard_i=shard_i, t0=t0, t1=t1) except queue.Empty: # Wait 0.1 seconds for new data. time.sleep(0.1) @@ -704,10 +692,7 @@ def _calculate_n_samples(self) -> int: * self.metadata.content_tokens_per_example ) case _: - msg = ( - f"Unsupported loader config: {self.cfg.tokens=}, {self.cfg.layer=}." - ) - raise ValueError(msg) + tp.assert_never((self.cfg.tokens, self.cfg.layer)) # If no filtering, return max samples if not self.cfg.ignore_labels: