|
17 | 17 |
|
18 | 18 | import datasets |
19 | 19 | import datasets.distributed |
| 20 | +import torch |
20 | 21 | from torch.utils.data import DataLoader, DistributedSampler |
21 | 22 | from torchdata.stateful_dataloader import StatefulDataLoader |
22 | 23 | from transformers import AutoTokenizer |
@@ -306,3 +307,75 @@ def create_thd_dataloader( |
306 | 307 | ) |
307 | 308 |
|
308 | 309 | return train_dataloader, tokenized_dataset |
| 310 | + |
| 311 | + |
| 312 | +class MockTokenDataset(torch.utils.data.Dataset): |
| 313 | + """Dataset that generates random token sequences for benchmarking. |
| 314 | +
|
| 315 | + All sequences have the same fixed length, so no padding is needed. |
| 316 | +
|
| 317 | + Args: |
| 318 | + vocab_size: Vocabulary size for random token generation. |
| 319 | + seq_length: Length of each generated sequence. |
| 320 | + num_samples: Total number of samples in the dataset. |
| 321 | + """ |
| 322 | + |
| 323 | + def __init__(self, vocab_size: int, seq_length: int, num_samples: int): |
| 324 | + """Initialize the mock dataset.""" |
| 325 | + self.vocab_size = vocab_size |
| 326 | + self.seq_length = seq_length |
| 327 | + self.num_samples = num_samples |
| 328 | + |
| 329 | + def __len__(self): |
| 330 | + """Return the number of samples.""" |
| 331 | + return self.num_samples |
| 332 | + |
| 333 | + def __getitem__(self, idx): |
| 334 | + """Return a random token sequence.""" |
| 335 | + input_ids = torch.randint(0, self.vocab_size, (self.seq_length,)) |
| 336 | + return {"input_ids": input_ids} |
| 337 | + |
| 338 | + |
| 339 | +def _mock_collator(features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]: |
| 340 | + """Collator for MockTokenDataset that stacks fixed-length sequences into a batch.""" |
| 341 | + input_ids = torch.stack([f["input_ids"] for f in features]) |
| 342 | + return {"input_ids": input_ids, "labels": input_ids.clone(), "attention_mask": torch.ones_like(input_ids)} |
| 343 | + |
| 344 | + |
| 345 | +def create_mock_dataloader( |
| 346 | + distributed_config: DistributedConfig, |
| 347 | + micro_batch_size: int, |
| 348 | + max_seq_length: int, |
| 349 | + vocab_size: int = 128256, |
| 350 | + num_samples: int = 100_000, |
| 351 | + **kwargs, |
| 352 | +): |
| 353 | + """Create a mock dataloader with random tokens for benchmarking. |
| 354 | +
|
| 355 | + Args: |
| 356 | + distributed_config: The distributed configuration. |
| 357 | + micro_batch_size: The batch size per device. |
| 358 | + max_seq_length: The sequence length of each generated sample. |
| 359 | + vocab_size: Vocabulary size for random token generation. Defaults to Llama 3 vocab size. |
| 360 | + num_samples: Total number of samples in the dataset. |
| 361 | + **kwargs: Ignored extra arguments for compatibility with other dataloader configs. |
| 362 | +
|
| 363 | + Returns: |
| 364 | + A tuple of (dataloader, sampler). |
| 365 | + """ |
| 366 | + dataset = MockTokenDataset(vocab_size, max_seq_length, num_samples) |
| 367 | + sampler = DistributedSampler( |
| 368 | + dataset, |
| 369 | + rank=distributed_config.rank, |
| 370 | + num_replicas=distributed_config.world_size, |
| 371 | + seed=42, |
| 372 | + ) |
| 373 | + train_dataloader = DataLoader( |
| 374 | + dataset, |
| 375 | + batch_size=micro_batch_size, |
| 376 | + sampler=sampler, |
| 377 | + collate_fn=_mock_collator, |
| 378 | + num_workers=0, |
| 379 | + pin_memory=True, |
| 380 | + ) |
| 381 | + return train_dataloader, sampler |
0 commit comments