2020from abc import ABC , abstractmethod
2121from dataclasses import dataclass
2222from pathlib import Path
23- from typing import Callable , Dict , List , Literal , Type
23+ from typing import Any , Callable , Dict , List , Literal , Type
2424
2525import pytest
2626import torch
@@ -987,8 +987,8 @@ def test_meta_fp8_init(self, fp8_recipe):
987987 self .verify_model_parameters_initialized_correctly (model , should_be_fp8 = True )
988988
989989 # ==================== Generation Tests (Autoregressive Models Only) ====================
990-
991- def _create_inference_params (self , config , batch_size = 1 , max_seq_len = 256 , num_beams = 1 ):
990+ @ abstractmethod
991+ def create_inference_params (self , config , batch_size = 1 , max_seq_len = 256 , num_beams = 1 ) -> Any :
992992 """Create inference params for KV-cache generation tests.
993993
994994 Autoregressive model tests must override this method to provide
@@ -1003,9 +1003,7 @@ def _create_inference_params(self, config, batch_size=1, max_seq_len=256, num_be
10031003 Returns:
10041004 HFInferenceParams instance with allocated memory.
10051005 """
1006- raise NotImplementedError (
1007- "Autoregressive models must override _create_inference_params to provide model-specific HFInferenceParams."
1008- )
1006+ pass
10091007
10101008 def test_generate_without_cache (self ):
10111009 """Test basic generation without KV-cache (BSHD, use_cache=False)."""
@@ -1040,7 +1038,7 @@ def test_generate_with_cache(self):
10401038 inputs = tokenizer (prompt , return_tensors = "pt" )
10411039 inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
10421040
1043- past_key_values = self ._create_inference_params (config , batch_size = 1 )
1041+ past_key_values = self .create_inference_params (config , batch_size = 1 )
10441042
10451043 with torch .no_grad ():
10461044 output_ids = model .generate (** inputs , max_new_tokens = 16 , use_cache = True , past_key_values = past_key_values )
@@ -1064,7 +1062,7 @@ def test_generate_with_cache_batched(self):
10641062 inputs = tokenizer (prompts , return_tensors = "pt" , padding = True , padding_side = "left" )
10651063 inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
10661064
1067- past_key_values = self ._create_inference_params (config , batch_size = 2 )
1065+ past_key_values = self .create_inference_params (config , batch_size = 2 )
10681066
10691067 with torch .no_grad ():
10701068 output_ids = model .generate (** inputs , max_new_tokens = 16 , use_cache = True , past_key_values = past_key_values )
@@ -1090,7 +1088,7 @@ def test_generate_with_cache_beam_search(self):
10901088 inputs = {k : v .to ("cuda" ) for k , v in inputs .items ()}
10911089
10921090 num_beams = 2
1093- past_key_values = self ._create_inference_params (config , batch_size = 2 , num_beams = num_beams )
1091+ past_key_values = self .create_inference_params (config , batch_size = 2 , num_beams = num_beams )
10941092
10951093 with torch .no_grad ():
10961094 output_ids = model .generate (
0 commit comments