Skip to content

Commit 9c7d191

Browse files
committed
refactoring generation tests
Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 7c33578 commit 9c7d191

9 files changed

Lines changed: 529 additions & 176 deletions

File tree

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282
Subclasses must implement all abstract methods to provide model-specific
8383
configuration, data preparation, and conversion functions.
8484
85+
Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+
enable generation / KV-cache smoke tests. Non-autoregressive models
87+
(e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589
Example:
8690
```python
8791
class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102
```
99103
"""
100104

105+
is_autoregressive: bool = False
106+
101107
@abstractmethod
102108
def get_model_class(self) -> Type[PreTrainedModel]:
103109
"""Return the TransformerEngine model class to test.
@@ -980,4 +986,123 @@ def test_meta_fp8_init(self, fp8_recipe):
980986
model.init_empty_weights()
981987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
982988

989+
# ==================== Generation Tests (Autoregressive Models Only) ====================
990+
991+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
992+
"""Create inference params for KV-cache generation tests.
993+
994+
Autoregressive model tests must override this method to provide
995+
model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+
Args:
998+
config: Model configuration.
999+
batch_size: Batch size.
1000+
max_seq_len: Maximum sequence length.
1001+
num_beams: Number of beams for beam search.
1002+
1003+
Returns:
1004+
HFInferenceParams instance with allocated memory.
1005+
"""
1006+
raise NotImplementedError(
1007+
"Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008+
)
1009+
1010+
def test_generate_without_cache(self):
1011+
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
1012+
if not self.is_autoregressive:
1013+
pytest.skip("Not an autoregressive model")
1014+
1015+
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1016+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1017+
model.eval()
1018+
1019+
tokenizer = self.get_tokenizer()
1020+
prompt = "The quick brown fox jumps over"
1021+
inputs = tokenizer(prompt, return_tensors="pt")
1022+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1023+
1024+
with torch.no_grad():
1025+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False)
1026+
1027+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1028+
1029+
def test_generate_with_cache(self):
1030+
"""Test single-prompt generation with KV-cache (THD format)."""
1031+
if not self.is_autoregressive:
1032+
pytest.skip("Not an autoregressive model")
1033+
1034+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1035+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1036+
model.eval()
1037+
1038+
tokenizer = self.get_tokenizer()
1039+
prompt = "The quick brown fox jumps over"
1040+
inputs = tokenizer(prompt, return_tensors="pt")
1041+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1042+
1043+
past_key_values = self._create_inference_params(config, batch_size=1)
1044+
1045+
with torch.no_grad():
1046+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1047+
1048+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1049+
1050+
def test_generate_with_cache_batched(self):
1051+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1052+
if not self.is_autoregressive:
1053+
pytest.skip("Not an autoregressive model")
1054+
1055+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1056+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1057+
model.eval()
1058+
1059+
tokenizer = self.get_tokenizer()
1060+
prompts = (
1061+
"The quick brown fox jumps over the lazy dog.",
1062+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1063+
)
1064+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1065+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1066+
1067+
past_key_values = self._create_inference_params(config, batch_size=2)
1068+
1069+
with torch.no_grad():
1070+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1071+
1072+
assert output_ids.shape[0] == 2
1073+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1074+
1075+
def test_generate_with_cache_beam_search(self):
1076+
"""Test batched generation with KV-cache and beam search."""
1077+
if not self.is_autoregressive:
1078+
pytest.skip("Not an autoregressive model")
1079+
1080+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1081+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1082+
model.eval()
1083+
1084+
tokenizer = self.get_tokenizer()
1085+
prompts = (
1086+
"The quick brown fox jumps over the lazy dog.",
1087+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1088+
)
1089+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1090+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1091+
1092+
num_beams = 2
1093+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
1094+
1095+
with torch.no_grad():
1096+
output_ids = model.generate(
1097+
**inputs,
1098+
max_new_tokens=16,
1099+
use_cache=True,
1100+
past_key_values=past_key_values,
1101+
num_beams=num_beams,
1102+
do_sample=True,
1103+
)
1104+
1105+
assert output_ids.shape[0] == 2
1106+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1107+
9831108
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.

bionemo-recipes/models/llama3/tests/common/test_modeling_common.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282
Subclasses must implement all abstract methods to provide model-specific
8383
configuration, data preparation, and conversion functions.
8484
85+
Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+
enable generation / KV-cache smoke tests. Non-autoregressive models
87+
(e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589
Example:
8690
```python
8791
class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102
```
99103
"""
100104

105+
is_autoregressive: bool = False
106+
101107
@abstractmethod
102108
def get_model_class(self) -> Type[PreTrainedModel]:
103109
"""Return the TransformerEngine model class to test.
@@ -980,4 +986,123 @@ def test_meta_fp8_init(self, fp8_recipe):
980986
model.init_empty_weights()
981987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
982988

989+
# ==================== Generation Tests (Autoregressive Models Only) ====================
990+
991+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
992+
"""Create inference params for KV-cache generation tests.
993+
994+
Autoregressive model tests must override this method to provide
995+
model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+
Args:
998+
config: Model configuration.
999+
batch_size: Batch size.
1000+
max_seq_len: Maximum sequence length.
1001+
num_beams: Number of beams for beam search.
1002+
1003+
Returns:
1004+
HFInferenceParams instance with allocated memory.
1005+
"""
1006+
raise NotImplementedError(
1007+
"Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008+
)
1009+
1010+
def test_generate_without_cache(self):
1011+
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
1012+
if not self.is_autoregressive:
1013+
pytest.skip("Not an autoregressive model")
1014+
1015+
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1016+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1017+
model.eval()
1018+
1019+
tokenizer = self.get_tokenizer()
1020+
prompt = "The quick brown fox jumps over"
1021+
inputs = tokenizer(prompt, return_tensors="pt")
1022+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1023+
1024+
with torch.no_grad():
1025+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False)
1026+
1027+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1028+
1029+
def test_generate_with_cache(self):
1030+
"""Test single-prompt generation with KV-cache (THD format)."""
1031+
if not self.is_autoregressive:
1032+
pytest.skip("Not an autoregressive model")
1033+
1034+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1035+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1036+
model.eval()
1037+
1038+
tokenizer = self.get_tokenizer()
1039+
prompt = "The quick brown fox jumps over"
1040+
inputs = tokenizer(prompt, return_tensors="pt")
1041+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1042+
1043+
past_key_values = self._create_inference_params(config, batch_size=1)
1044+
1045+
with torch.no_grad():
1046+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1047+
1048+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1049+
1050+
def test_generate_with_cache_batched(self):
1051+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1052+
if not self.is_autoregressive:
1053+
pytest.skip("Not an autoregressive model")
1054+
1055+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1056+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1057+
model.eval()
1058+
1059+
tokenizer = self.get_tokenizer()
1060+
prompts = (
1061+
"The quick brown fox jumps over the lazy dog.",
1062+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1063+
)
1064+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1065+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1066+
1067+
past_key_values = self._create_inference_params(config, batch_size=2)
1068+
1069+
with torch.no_grad():
1070+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1071+
1072+
assert output_ids.shape[0] == 2
1073+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1074+
1075+
def test_generate_with_cache_beam_search(self):
1076+
"""Test batched generation with KV-cache and beam search."""
1077+
if not self.is_autoregressive:
1078+
pytest.skip("Not an autoregressive model")
1079+
1080+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1081+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1082+
model.eval()
1083+
1084+
tokenizer = self.get_tokenizer()
1085+
prompts = (
1086+
"The quick brown fox jumps over the lazy dog.",
1087+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1088+
)
1089+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1090+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1091+
1092+
num_beams = 2
1093+
past_key_values = self._create_inference_params(config, batch_size=2, num_beams=num_beams)
1094+
1095+
with torch.no_grad():
1096+
output_ids = model.generate(
1097+
**inputs,
1098+
max_new_tokens=16,
1099+
use_cache=True,
1100+
past_key_values=past_key_values,
1101+
num_beams=num_beams,
1102+
do_sample=True,
1103+
)
1104+
1105+
assert output_ids.shape[0] == 2
1106+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1107+
9831108
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.

bionemo-recipes/models/llama3/tests/test_modeling_llama_te.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class TestLlama3Model(BaseModelTest):
4949
This class provides LLaMA3-specific configuration for the common test suite.
5050
"""
5151

52+
is_autoregressive = True
53+
5254
def get_model_class(self) -> Type[PreTrainedModel]:
5355
"""Return the LLaMA3 TE model class."""
5456
return NVLlamaForCausalLM
@@ -138,7 +140,23 @@ def get_tolerances(self) -> TestTolerances:
138140
cp_loss_rtol=0.25,
139141
)
140142

141-
# ==================== LLaMA3-Specific Tests ====================
143+
# ==================== LLaMA3-Specific Overrides ====================
144+
145+
def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
146+
"""Create HFInferenceParams for the given config."""
147+
past_key_values = HFInferenceParams(
148+
max_batch_size=batch_size * num_beams,
149+
max_sequence_length=max_seq_len,
150+
num_heads_kv=config.num_key_value_heads,
151+
head_dim_k=config.hidden_size // config.num_attention_heads,
152+
dtype=torch.bfloat16,
153+
qkv_format="thd",
154+
max_ctx_len=max_seq_len,
155+
)
156+
for layer_number in range(1, config.num_hidden_layers + 1):
157+
past_key_values.allocate_memory(layer_number)
158+
return past_key_values
159+
142160
def test_golden_values(self, input_format): # pyright: ignore[reportIncompatibleMethodOverride]
143161
"""For llama3, we can test both the dynamic sequence packing and native bshd attention formats."""
144162
model_hf = self.get_reference_model(dtype=torch.bfloat16)

bionemo-recipes/models/mixtral/modeling_mixtral_te.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
166166

167167
# Permute tokens by expert using TE moe_permute
168168
permuted_hidden, row_id_map = transformer_engine.pytorch.moe_permute(
169-
hidden_states, selected_experts, map_type="index"
169+
hidden_states, selected_experts.to(torch.int32), map_type="index"
170170
)
171171

172172
# Compute m_splits: number of tokens per expert
@@ -185,11 +185,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
185185
# Down projection
186186
expert_output = self.experts_down(intermediate, m_splits=m_splits) # [total_tokens, H]
187187

188-
# Unpermute and combine with routing weights
188+
# Unpermute and combine with routing weights (keep probs in float32 for numerical stability)
189189
output = transformer_engine.pytorch.moe_unpermute(
190190
expert_output,
191191
row_id_map,
192-
merging_probs=routing_weights.to(expert_output.dtype),
192+
merging_probs=routing_weights,
193193
map_type="index",
194194
)
195195

0 commit comments

Comments
 (0)