Skip to content

Commit 18b7f9e

Browse files
committed
starting cp benchmarking
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 1ded8ea commit 18b7f9e

4 files changed

Lines changed: 121 additions & 2 deletions

File tree

bionemo-recipes/recipes/llama3_native_te/dataset.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import datasets
1919
import datasets.distributed
20+
import torch
2021
from torch.utils.data import DataLoader, DistributedSampler
2122
from torchdata.stateful_dataloader import StatefulDataLoader
2223
from transformers import AutoTokenizer
@@ -306,3 +307,75 @@ def create_thd_dataloader(
306307
)
307308

308309
return train_dataloader, tokenized_dataset
310+
311+
312+
class MockTokenDataset(torch.utils.data.Dataset):
313+
"""Dataset that generates random token sequences for benchmarking.
314+
315+
All sequences have the same fixed length, so no padding is needed.
316+
317+
Args:
318+
vocab_size: Vocabulary size for random token generation.
319+
seq_length: Length of each generated sequence.
320+
num_samples: Total number of samples in the dataset.
321+
"""
322+
323+
def __init__(self, vocab_size: int, seq_length: int, num_samples: int):
324+
"""Initialize the mock dataset."""
325+
self.vocab_size = vocab_size
326+
self.seq_length = seq_length
327+
self.num_samples = num_samples
328+
329+
def __len__(self):
330+
"""Return the number of samples."""
331+
return self.num_samples
332+
333+
def __getitem__(self, idx):
334+
"""Return a random token sequence."""
335+
input_ids = torch.randint(0, self.vocab_size, (self.seq_length,))
336+
return {"input_ids": input_ids}
337+
338+
339+
def _mock_collator(features: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
340+
"""Collator for MockTokenDataset that stacks fixed-length sequences into a batch."""
341+
input_ids = torch.stack([f["input_ids"] for f in features])
342+
return {"input_ids": input_ids, "labels": input_ids.clone(), "attention_mask": torch.ones_like(input_ids)}
343+
344+
345+
def create_mock_dataloader(
346+
distributed_config: DistributedConfig,
347+
micro_batch_size: int,
348+
max_seq_length: int,
349+
vocab_size: int = 128256,
350+
num_samples: int = 100_000,
351+
**kwargs,
352+
):
353+
"""Create a mock dataloader with random tokens for benchmarking.
354+
355+
Args:
356+
distributed_config: The distributed configuration.
357+
micro_batch_size: The batch size per device.
358+
max_seq_length: The sequence length of each generated sample.
359+
vocab_size: Vocabulary size for random token generation. Defaults to Llama 3 vocab size.
360+
num_samples: Total number of samples in the dataset.
361+
**kwargs: Ignored extra arguments for compatibility with other dataloader configs.
362+
363+
Returns:
364+
A tuple of (dataloader, sampler).
365+
"""
366+
dataset = MockTokenDataset(vocab_size, max_seq_length, num_samples)
367+
sampler = DistributedSampler(
368+
dataset,
369+
rank=distributed_config.rank,
370+
num_replicas=distributed_config.world_size,
371+
seed=42,
372+
)
373+
train_dataloader = DataLoader(
374+
dataset,
375+
batch_size=micro_batch_size,
376+
sampler=sampler,
377+
collate_fn=_mock_collator,
378+
num_workers=0,
379+
pin_memory=True,
380+
)
381+
return train_dataloader, sampler
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
defaults:
2+
- defaults
3+
- _self_
4+
5+
config_name_or_path: ./model_configs/meta-llama/Llama-3.2-1B
6+
7+
config_kwargs:
8+
attn_input_format: "bshd"
9+
self_attn_mask_type: "causal"
10+
11+
cp_size: 1
12+
13+
use_mock_dataset: true
14+
use_sequence_packing: false
15+
use_meta_device: true
16+
use_torch_compile: false
17+
18+
num_train_steps: 100
19+
20+
dataset:
21+
tokenizer_name_or_path: null # Not needed for mock dataset
22+
micro_batch_size: 1
23+
max_seq_length: 8192
24+
num_samples: 100_000
25+
load_dataset_kwargs: null # Not needed for mock dataset
26+
27+
wandb:
28+
name: "llama3-cp-benchmark"
29+
mode: "offline"
30+
31+
lr_scheduler_kwargs:
32+
num_warmup_steps: 10
33+
num_decay_steps: 90
34+
35+
checkpoint:
36+
ckpt_dir: null
37+
save_final_model: false
38+
resume_from_checkpoint: false
39+
40+
logger:
41+
frequency: 1

bionemo-recipes/recipes/llama3_native_te/hydra_config/defaults.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ use_meta_device: true
1313
use_torch_compile: false
1414

1515
use_sequence_packing: false
16+
use_mock_dataset: false
1617

1718
dataset:
1819
tokenizer_name_or_path: ??? # Set to the path of your tokenizer (e.g., meta-llama/Llama-3.1-8B or ./tokenizers/nucleotide_fast_tokenizer)

bionemo-recipes/recipes/llama3_native_te/train_fsdp2_cp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
from checkpoint import load_checkpoint_fsdp2, save_checkpoint_fsdp2, save_final_model_fsdp2, should_save_checkpoint
3232
from collator import ContextParallelDataLoaderWrapper, DataCollatorForContextParallel
33-
from dataset import create_bshd_dataloader, create_thd_dataloader
33+
from dataset import create_bshd_dataloader, create_mock_dataloader, create_thd_dataloader
3434
from distributed_config import DistributedConfig
3535
from modeling_llama_te import NVLlamaConfig, NVLlamaForCausalLM
3636
from perf_logger import PerfLogger
@@ -119,7 +119,11 @@ def main(args: DictConfig) -> float | None:
119119
logger.info("pad_sequences_to_be_divisible_by is not provided, using cp_mesh.size() * 2")
120120
OmegaConf.update(args, "dataset.pad_sequences_to_be_divisible_by", device_mesh["cp"].size() * 2)
121121
if device_mesh["cp"].get_local_rank() == 0:
122-
if args.use_sequence_packing:
122+
if args.use_mock_dataset:
123+
train_dataloader, dataset_or_sampler = create_mock_dataloader(
124+
dist_config, vocab_size=config.vocab_size, **args.dataset
125+
)
126+
elif args.use_sequence_packing:
123127
train_dataloader, dataset_or_sampler = create_thd_dataloader(dist_config, **args.dataset)
124128
else:
125129
train_dataloader, dataset_or_sampler = create_bshd_dataloader(dist_config, **args.dataset)

0 commit comments

Comments
 (0)