@@ -82,6 +82,10 @@ class BaseModelTest(ABC):
8282 Subclasses must implement all abstract methods to provide model-specific
8383 configuration, data preparation, and conversion functions.
8484
85+ Set ``is_autoregressive = True`` in subclasses for causal LM models to
86+ enable generation / KV-cache smoke tests. Non-autoregressive models
87+ (e.g. ESM2) leave the default ``False`` and those tests are skipped.
88+
8589 Example:
8690 ```python
8791 class ESM2ModelTester(BioNeMoModelTester):
@@ -98,6 +102,8 @@ def get_upstream_model_id(self):
98102 ```
99103 """
100104
105+ is_autoregressive : bool = False
106+
101107 @abstractmethod
102108 def get_model_class (self ) -> Type [PreTrainedModel ]:
103109 """Return the TransformerEngine model class to test.
@@ -980,4 +986,123 @@ def test_meta_fp8_init(self, fp8_recipe):
980986 model .init_empty_weights ()
981987 self .verify_model_parameters_initialized_correctly (model , should_be_fp8 = True )
982988
989+ # ==================== Generation Tests (Autoregressive Models Only) ====================
990+
991+ def _create_inference_params (self , config , batch_size = 1 , max_seq_len = 256 , num_beams = 1 ):
992+ """Create inference params for KV-cache generation tests.
993+
994+ Autoregressive model tests must override this method to provide
995+ model-specific ``HFInferenceParams`` with allocated KV-cache memory.
996+
997+ Args:
998+ config: Model configuration.
999+ batch_size: Batch size.
1000+ max_seq_len: Maximum sequence length.
1001+ num_beams: Number of beams for beam search.
1002+
1003+ Returns:
1004+ HFInferenceParams instance with allocated memory.
1005+ """
1006+ raise NotImplementedError (
1007+ "Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008+ )
1009+
1010+ def test_generate_without_cache (self ):
1011+ """Test basic generation without KV-cache (BSHD, use_cache=False)."""
1012+ if not self .is_autoregressive :
1013+ pytest .skip ("Not an autoregressive model" )
1014+
1015+ config = self .create_test_config (attn_input_format = "bshd" , self_attn_mask_type = "causal" )
1016+ model = self .get_model_class ()(config ).to ("cuda" ).to (torch .bfloat16 )
1017+ model .eval ()
1018+
1019+ tokenizer = self .get_tokenizer ()
1020+ prompt = "The quick brown fox jumps over"
1021+ inputs = tokenizer (prompt , return_tensors = "pt" )
1022+ inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
1023+
1024+ with torch .no_grad ():
1025+ output_ids = model .generate (** inputs , max_new_tokens = 16 , use_cache = False )
1026+
1027+ assert output_ids .shape [1 ] > inputs ["input_ids" ].shape [1 ]
1028+
1029+ def test_generate_with_cache (self ):
1030+ """Test single-prompt generation with KV-cache (THD format)."""
1031+ if not self .is_autoregressive :
1032+ pytest .skip ("Not an autoregressive model" )
1033+
1034+ config = self .create_test_config (attn_input_format = "thd" , self_attn_mask_type = "padding_causal" )
1035+ model = self .get_model_class ()(config ).to ("cuda" ).to (torch .bfloat16 )
1036+ model .eval ()
1037+
1038+ tokenizer = self .get_tokenizer ()
1039+ prompt = "The quick brown fox jumps over"
1040+ inputs = tokenizer (prompt , return_tensors = "pt" )
1041+ inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
1042+
1043+ past_key_values = self ._create_inference_params (config , batch_size = 1 )
1044+
1045+ with torch .no_grad ():
1046+ output_ids = model .generate (** inputs , max_new_tokens = 16 , use_cache = True , past_key_values = past_key_values )
1047+
1048+ assert output_ids .shape [1 ] > inputs ["input_ids" ].shape [1 ]
1049+
1050+ def test_generate_with_cache_batched (self ):
1051+ """Test batched generation with KV-cache (left-padded BSHD converted to THD)."""
1052+ if not self .is_autoregressive :
1053+ pytest .skip ("Not an autoregressive model" )
1054+
1055+ config = self .create_test_config (attn_input_format = "thd" , self_attn_mask_type = "padding_causal" )
1056+ model = self .get_model_class ()(config ).to ("cuda" ).to (torch .bfloat16 )
1057+ model .eval ()
1058+
1059+ tokenizer = self .get_tokenizer ()
1060+ prompts = (
1061+ "The quick brown fox jumps over the lazy dog." ,
1062+ "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
1063+ )
1064+ inputs = tokenizer (prompts , return_tensors = "pt" , padding = True , padding_side = "left" )
1065+ inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
1066+
1067+ past_key_values = self ._create_inference_params (config , batch_size = 2 )
1068+
1069+ with torch .no_grad ():
1070+ output_ids = model .generate (** inputs , max_new_tokens = 16 , use_cache = True , past_key_values = past_key_values )
1071+
1072+ assert output_ids .shape [0 ] == 2
1073+ assert output_ids .shape [1 ] > inputs ["input_ids" ].shape [1 ]
1074+
1075+ def test_generate_with_cache_beam_search (self ):
1076+ """Test batched generation with KV-cache and beam search."""
1077+ if not self .is_autoregressive :
1078+ pytest .skip ("Not an autoregressive model" )
1079+
1080+ config = self .create_test_config (attn_input_format = "thd" , self_attn_mask_type = "padding_causal" )
1081+ model = self .get_model_class ()(config ).to ("cuda" ).to (torch .bfloat16 )
1082+ model .eval ()
1083+
1084+ tokenizer = self .get_tokenizer ()
1085+ prompts = (
1086+ "The quick brown fox jumps over the lazy dog." ,
1087+ "Lorem ipsum dolor sit amet, consectetur adipiscing elit." ,
1088+ )
1089+ inputs = tokenizer (prompts , return_tensors = "pt" , padding = True , padding_side = "left" )
1090+ inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
1091+
1092+ num_beams = 2
1093+ past_key_values = self ._create_inference_params (config , batch_size = 2 , num_beams = num_beams )
1094+
1095+ with torch .no_grad ():
1096+ output_ids = model .generate (
1097+ ** inputs ,
1098+ max_new_tokens = 16 ,
1099+ use_cache = True ,
1100+ past_key_values = past_key_values ,
1101+ num_beams = num_beams ,
1102+ do_sample = True ,
1103+ )
1104+
1105+ assert output_ids .shape [0 ] == 2
1106+ assert output_ids .shape [1 ] > inputs ["input_ids" ].shape [1 ]
1107+
9831108 # TODO: add multi-GPU tests, e.g., meta-device init after fully_shard, cp tests, etc.
0 commit comments