Skip to content

Commit beec19f

Browse files
authored
Refactor autoregressive model tests (#1486)
Moves shared auto-regressive model tests to modeling_common, making it easier to add new auto-regressive models <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Added generation tests with KV-cache support across model architectures (ESM2, LLaMA3, Mixtral), including batched and beam search variants. * Expanded test infrastructure to support flexible configuration overrides in quantized model tests. * **Bug Fixes** * Improved numerical stability in Mixtral by adjusting tensor dtype handling for expert routing operations. * **Refactor** * Consolidated generation test implementations across models into shared base test class to reduce duplication. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Peter St. John <pstjohn@nvidia.com>
1 parent 8d92081 commit beec19f

7 files changed

Lines changed: 422 additions & 389 deletions

File tree

bionemo-recipes/models/esm2/tests/common/test_modeling_common.py

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass
2222
from pathlib import Path
23-
from typing import Callable, Dict, List, Literal, Type
23+
from typing import Any, Callable, Dict, List, Literal, Type
2424

2525
import pytest
2626
import torch
@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282
Subclasses must implement all abstract methods to provide model-specific
8383
configuration, data preparation, and conversion functions.
8484
85+
Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+
enable generation / KV-cache smoke tests. Non-autoregressive models
87+
(e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589
Example:
8690
```python
8791
class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102
```
99103
"""
100104

105+
is_autoregressive: bool = False
106+
101107
@abstractmethod
102108
def get_model_class(self) -> Type[PreTrainedModel]:
103109
"""Return the TransformerEngine model class to test.
@@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format):
885891
msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}",
886892
)
887893

888-
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format):
894+
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs):
889895
"""Test that model initialized with FP8 works correctly."""
890896
if input_format == "thd" and not HAS_DATA_CENTER_GPU:
891897
pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.")
892898

893899
model_class = self.get_model_class()
894-
config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal")
900+
config = self.create_test_config(
901+
attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs
902+
)
895903

896904
# Initialize with FP8
897905
with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe):
@@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma
906914
input_data["labels"] = input_data["input_ids"].clone()
907915

908916
# Forward and backward pass with FP8
909-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
910-
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
911-
outputs = model(**input_data)
917+
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
918+
outputs = model(**input_data)
912919

913920
loss = outputs.loss
914921
assert torch.isfinite(loss)
@@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe):
979986
model.init_empty_weights()
980987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
981988

989+
# ==================== Generation Tests (Autoregressive Models Only) ====================
990+
@abstractmethod
991+
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any:
992+
"""Create inference params for KV-cache generation tests.
993+
994+
Autoregressive model tests must override this method to provide
995+
model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+
Args:
998+
config: Model configuration.
999+
batch_size: Batch size.
1000+
max_seq_len: Maximum sequence length.
1001+
num_beams: Number of beams for beam search.
1002+
1003+
Returns:
1004+
HFInferenceParams instance with allocated memory.
1005+
"""
1006+
pass
1007+
1008+
def test_generate_without_cache(self):
1009+
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
1010+
if not self.is_autoregressive:
1011+
pytest.skip("Not an autoregressive model")
1012+
1013+
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1014+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1015+
model.eval()
1016+
1017+
tokenizer = self.get_tokenizer()
1018+
prompt = "The quick brown fox jumps over"
1019+
inputs = tokenizer(prompt, return_tensors="pt")
1020+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1021+
1022+
with torch.no_grad():
1023+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False)
1024+
1025+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1026+
1027+
def test_generate_with_cache(self):
1028+
"""Test single-prompt generation with KV-cache (THD format)."""
1029+
if not self.is_autoregressive:
1030+
pytest.skip("Not an autoregressive model")
1031+
1032+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1033+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1034+
model.eval()
1035+
1036+
tokenizer = self.get_tokenizer()
1037+
prompt = "The quick brown fox jumps over"
1038+
inputs = tokenizer(prompt, return_tensors="pt")
1039+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1040+
1041+
past_key_values = self.create_inference_params(config, batch_size=1)
1042+
1043+
with torch.no_grad():
1044+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1045+
1046+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1047+
1048+
def test_generate_with_cache_batched(self):
1049+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1050+
if not self.is_autoregressive:
1051+
pytest.skip("Not an autoregressive model")
1052+
1053+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1054+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1055+
model.eval()
1056+
1057+
tokenizer = self.get_tokenizer()
1058+
prompts = (
1059+
"The quick brown fox jumps over the lazy dog.",
1060+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1061+
)
1062+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1063+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1064+
1065+
past_key_values = self.create_inference_params(config, batch_size=2)
1066+
1067+
with torch.no_grad():
1068+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1069+
1070+
assert output_ids.shape[0] == 2
1071+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1072+
1073+
def test_generate_with_cache_beam_search(self):
1074+
"""Test batched generation with KV-cache and beam search."""
1075+
if not self.is_autoregressive:
1076+
pytest.skip("Not an autoregressive model")
1077+
1078+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1079+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1080+
model.eval()
1081+
1082+
tokenizer = self.get_tokenizer()
1083+
prompts = (
1084+
"The quick brown fox jumps over the lazy dog.",
1085+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1086+
)
1087+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1088+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1089+
1090+
num_beams = 2
1091+
past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams)
1092+
1093+
with torch.no_grad():
1094+
output_ids = model.generate(
1095+
**inputs,
1096+
max_new_tokens=16,
1097+
use_cache=True,
1098+
past_key_values=past_key_values,
1099+
num_beams=num_beams,
1100+
do_sample=True,
1101+
)
1102+
1103+
assert output_ids.shape[0] == 2
1104+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1105+
9821106
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.

bionemo-recipes/models/esm2/tests/test_modeling_esm_te.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,7 @@ def test_convert_state_dict_explicit_check(self):
270270
model_te.state_dict()["esm.embeddings.word_embeddings.weight"].data_ptr()
271271
== model_te.state_dict()["lm_head.decoder.weight"].data_ptr()
272272
)
273+
274+
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1):
275+
"""These are unused for non-autoregressive models."""
276+
pass

bionemo-recipes/models/llama3/tests/common/test_modeling_common.py

Lines changed: 130 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from abc import ABC, abstractmethod
2121
from dataclasses import dataclass
2222
from pathlib import Path
23-
from typing import Callable, Dict, List, Literal, Type
23+
from typing import Any, Callable, Dict, List, Literal, Type
2424

2525
import pytest
2626
import torch
@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282
Subclasses must implement all abstract methods to provide model-specific
8383
configuration, data preparation, and conversion functions.
8484
85+
Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+
enable generation / KV-cache smoke tests. Non-autoregressive models
87+
(e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589
Example:
8690
```python
8791
class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102
```
99103
"""
100104

105+
is_autoregressive: bool = False
106+
101107
@abstractmethod
102108
def get_model_class(self) -> Type[PreTrainedModel]:
103109
"""Return the TransformerEngine model class to test.
@@ -885,13 +891,15 @@ def test_fp8_forward_and_backward_pass(self, fp8_recipe, input_format):
885891
msg=lambda x: f"FP8 loss differs too much from BF16 loss: {x}",
886892
)
887893

888-
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format):
894+
def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_format, **config_kwargs):
889895
"""Test that model initialized with FP8 works correctly."""
890896
if input_format == "thd" and not HAS_DATA_CENTER_GPU:
891897
pytest.xfail("Padded sequences are not supported on non-datacenter hardware for THD.")
892898

893899
model_class = self.get_model_class()
894-
config = self.create_test_config(attn_input_format=input_format, self_attn_mask_type="padding_causal")
900+
config = self.create_test_config(
901+
attn_input_format=input_format, self_attn_mask_type="padding_causal", **config_kwargs
902+
)
895903

896904
# Initialize with FP8
897905
with transformer_engine.pytorch.quantized_model_init(recipe=fp8_recipe):
@@ -906,9 +914,8 @@ def test_quantized_model_init_forward_and_backward(self, fp8_recipe, input_forma
906914
input_data["labels"] = input_data["input_ids"].clone()
907915

908916
# Forward and backward pass with FP8
909-
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
910-
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
911-
outputs = model(**input_data)
917+
with transformer_engine.pytorch.autocast(recipe=fp8_recipe):
918+
outputs = model(**input_data)
912919

913920
loss = outputs.loss
914921
assert torch.isfinite(loss)
@@ -979,4 +986,121 @@ def test_meta_fp8_init(self, fp8_recipe):
979986
model.init_empty_weights()
980987
self.verify_model_parameters_initialized_correctly(model, should_be_fp8=True)
981988

989+
# ==================== Generation Tests (Autoregressive Models Only) ====================
990+
@abstractmethod
991+
def create_inference_params(self, config, batch_size=1, max_seq_len=256, num_beams=1) -> Any:
992+
"""Create inference params for KV-cache generation tests.
993+
994+
Autoregressive model tests must override this method to provide
995+
model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+
Args:
998+
config: Model configuration.
999+
batch_size: Batch size.
1000+
max_seq_len: Maximum sequence length.
1001+
num_beams: Number of beams for beam search.
1002+
1003+
Returns:
1004+
HFInferenceParams instance with allocated memory.
1005+
"""
1006+
pass
1007+
1008+
def test_generate_without_cache(self):
1009+
"""Test basic generation without KV-cache (BSHD, use_cache=False)."""
1010+
if not self.is_autoregressive:
1011+
pytest.skip("Not an autoregressive model")
1012+
1013+
config = self.create_test_config(attn_input_format="bshd", self_attn_mask_type="causal")
1014+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1015+
model.eval()
1016+
1017+
tokenizer = self.get_tokenizer()
1018+
prompt = "The quick brown fox jumps over"
1019+
inputs = tokenizer(prompt, return_tensors="pt")
1020+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1021+
1022+
with torch.no_grad():
1023+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=False)
1024+
1025+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1026+
1027+
def test_generate_with_cache(self):
1028+
"""Test single-prompt generation with KV-cache (THD format)."""
1029+
if not self.is_autoregressive:
1030+
pytest.skip("Not an autoregressive model")
1031+
1032+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1033+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1034+
model.eval()
1035+
1036+
tokenizer = self.get_tokenizer()
1037+
prompt = "The quick brown fox jumps over"
1038+
inputs = tokenizer(prompt, return_tensors="pt")
1039+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1040+
1041+
past_key_values = self.create_inference_params(config, batch_size=1)
1042+
1043+
with torch.no_grad():
1044+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1045+
1046+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1047+
1048+
def test_generate_with_cache_batched(self):
1049+
"""Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1050+
if not self.is_autoregressive:
1051+
pytest.skip("Not an autoregressive model")
1052+
1053+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1054+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1055+
model.eval()
1056+
1057+
tokenizer = self.get_tokenizer()
1058+
prompts = (
1059+
"The quick brown fox jumps over the lazy dog.",
1060+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1061+
)
1062+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1063+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1064+
1065+
past_key_values = self.create_inference_params(config, batch_size=2)
1066+
1067+
with torch.no_grad():
1068+
output_ids = model.generate(**inputs, max_new_tokens=16, use_cache=True, past_key_values=past_key_values)
1069+
1070+
assert output_ids.shape[0] == 2
1071+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1072+
1073+
def test_generate_with_cache_beam_search(self):
1074+
"""Test batched generation with KV-cache and beam search."""
1075+
if not self.is_autoregressive:
1076+
pytest.skip("Not an autoregressive model")
1077+
1078+
config = self.create_test_config(attn_input_format="thd", self_attn_mask_type="padding_causal")
1079+
model = self.get_model_class()(config).to("cuda").to(torch.bfloat16)
1080+
model.eval()
1081+
1082+
tokenizer = self.get_tokenizer()
1083+
prompts = (
1084+
"The quick brown fox jumps over the lazy dog.",
1085+
"Lorem ipsum dolor sit amet, consectetur adipiscing elit.",
1086+
)
1087+
inputs = tokenizer(prompts, return_tensors="pt", padding=True, padding_side="left")
1088+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
1089+
1090+
num_beams = 2
1091+
past_key_values = self.create_inference_params(config, batch_size=2, num_beams=num_beams)
1092+
1093+
with torch.no_grad():
1094+
output_ids = model.generate(
1095+
**inputs,
1096+
max_new_tokens=16,
1097+
use_cache=True,
1098+
past_key_values=past_key_values,
1099+
num_beams=num_beams,
1100+
do_sample=True,
1101+
)
1102+
1103+
assert output_ids.shape[0] == 2
1104+
assert output_ids.shape[1] > inputs["input_ids"].shape[1]
1105+
9821106
# TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.

0 commit comments

Comments
 (0)