Skip to content

Commit 9dddcb3

Browse files
authored
Fix compatible issue with transformers 5.0+ (#2328)
## Describe your changes This pull request introduces compatibility updates for Hugging Face Transformers 5.0 and improves handling of dynamic cache and input formats in Olive's ONNX conversion and training utilities. It also updates tests and requirements to reflect these changes and ensure robust model export and training workflows. ### Transformers 5.0 Compatibility * Added patching and conversion utilities for `DynamicLayer.lazy_initialization`, `past_key_values`, and dynamic shapes to support the new DynamicCache format in Transformers >= 5.0. This ensures models using dynamic cache export correctly with `torch.export`. * Updated `_export_pytorch_model` logic to apply the new patches and conversions only for Transformers >= 5.0, while maintaining legacy support for older versions. ### Training Argument Handling * Improved filtering of training arguments in `create_training_args` to remove fields not valid for Transformers 5.0 and exclude `None` values, allowing Transformers to use its own defaults. ### Test Suite Updates * Modified model loading and metadata tests to remove `trust_remote_code` parameter and update expected file counts and tokenizer types for Transformers 5.0. [[1]](diffhunk://#diff-af681b2feed22286034d304b653185d2a4dc5d680e7d715a6ad41a1c731ff0fcL30-L45) [[2]](diffhunk://#diff-af681b2feed22286034d304b653185d2a4dc5d680e7d715a6ad41a1c731ff0fcL76-R80) [[3]](diffhunk://#diff-af681b2feed22286034d304b653185d2a4dc5d680e7d715a6ad41a1c731ff0fcL97-R90) [[4]](diffhunk://#diff-af681b2feed22286034d304b653185d2a4dc5d680e7d715a6ad41a1c731ff0fcL129-R121) * Updated model output comparison in rotation tests to cast logits to `float` before comparison, ensuring consistency across dtypes. ### Requirements Adjustments * Restricted `onnxscript` version to `<0.6.1` and removed the Transformers version pin, reflecting confidence in test suite compatibility with Transformers 5.0. [[1]](diffhunk://#diff-1ce09e5a57d7791711f12f84ecb7e089e925a2929b719d587561e8e58c7e4b90L24-R24) [[2]](diffhunk://#diff-1ce09e5a57d7791711f12f84ecb7e089e925a2929b719d587561e8e58c7e4b90L40-L41) ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent 86f9469 commit 9dddcb3

7 files changed

Lines changed: 167 additions & 40 deletions

File tree

olive/passes/onnx/conversion.py

Lines changed: 146 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
from olive.passes.onnx.common import get_external_data_config, ir_model_to_olive_model
3939
from olive.passes.pass_config import BasePassConfig, PassConfigParam, get_user_script_data_config
4040

41+
# pylint: disable=W0212
42+
4143
logger = logging.getLogger(__name__)
4244

4345

@@ -57,6 +59,128 @@ def forward(self, *input_data, **input_dict):
5759
return self.model(*input_data, **input_dict)
5860

5961

62+
def _register_dynamic_cache_export_support():
63+
"""Utilities for `DynamicCache` <> torch.export support."""
64+
from transformers.cache_utils import DynamicCache, DynamicLayer, DynamicSlidingWindowLayer
65+
66+
def _get_cache_dict(cache: DynamicCache):
67+
"""Convert cache to dictionary format for pytree operations."""
68+
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
69+
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
70+
71+
return {
72+
"cache": [(layer.keys, layer.values) for layer in cache.layers if layer.keys is not None],
73+
}
74+
75+
try:
76+
torch.utils._pytree.register_pytree_node(
77+
DynamicCache,
78+
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
79+
_unflatten_dynamic_cache,
80+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
81+
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
82+
_get_cache_dict(dynamic_cache)
83+
),
84+
)
85+
# TODO (team): This won't be needed in torch 2.7.
86+
torch.fx._pytree.register_pytree_flatten_spec(
87+
DynamicCache,
88+
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
89+
)
90+
# Catching this in case there are multiple runs for some test runs
91+
except ValueError as e:
92+
if "already registered as pytree node" not in str(e):
93+
raise
94+
95+
96+
def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
97+
from transformers.cache_utils import DynamicCache
98+
99+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
100+
cache = DynamicCache()
101+
# Reconstruct layers from keys and values lists
102+
cache_list = dictionary.get("cache", [])
103+
for i, (key, value) in enumerate(cache_list):
104+
cache.update(key, value, i)
105+
return cache
106+
107+
108+
def _patch_dynamic_layer_for_export():
109+
"""Patch DynamicLayer.lazy_initialization for torch.export compatibility (transformers >= 5.0).
110+
111+
The original uses torch.tensor([]) which creates a 1D empty tensor (shape [0]).
112+
torch.export needs consistent tensor ranks, so we use torch.narrow + torch.empty_like
113+
to preserve the full shape (e.g. [batch, heads, 0, head_dim]).
114+
"""
115+
from transformers.cache_utils import DynamicLayer
116+
117+
if not hasattr(DynamicLayer, "lazy_initialization"):
118+
return
119+
120+
def patched_lazy_initialization(self, key_states: torch.Tensor, value_states: torch.Tensor = None):
121+
self.dtype, self.device = key_states.dtype, key_states.device
122+
like = torch.narrow(key_states, dim=-2, start=0, length=0)
123+
if hasattr(key_states, "fake_mode"):
124+
with key_states.fake_mode:
125+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
126+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
127+
else:
128+
self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
129+
self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
130+
self.is_initialized = True
131+
132+
DynamicLayer.lazy_initialization = patched_lazy_initialization
133+
logger.debug("Patched DynamicLayer.lazy_initialization for torch.export compatibility.")
134+
135+
136+
def _convert_past_key_values_to_dynamic_cache(dummy_kwargs: dict, config=None) -> dict:
137+
"""Convert legacy list-format past_key_values to DynamicCache (transformers >= 5.0).
138+
139+
Transformers 5.0 models expect DynamicCache objects, not lists of (key, value) tensors.
140+
When config is provided, the DynamicCache will create correct layer types (e.g.
141+
DynamicSlidingWindowLayer for models using sliding window attention).
142+
"""
143+
pkv = dummy_kwargs.get("past_key_values")
144+
if pkv is None or not isinstance(pkv, (list, tuple)):
145+
return dummy_kwargs
146+
147+
# Check if it's legacy format: list of [key, value] pairs (each with exactly 2 elements)
148+
if not pkv or not isinstance(pkv[0], (list, tuple)) or len(pkv[0]) != 2:
149+
return dummy_kwargs
150+
151+
from transformers.cache_utils import DynamicCache
152+
153+
dc = DynamicCache(config=config)
154+
for layer_idx, kv in enumerate(pkv):
155+
dc.update(kv[0], kv[1], layer_idx=layer_idx)
156+
dummy_kwargs["past_key_values"] = dc
157+
logger.debug("Converted past_key_values from legacy list format to DynamicCache.")
158+
return dummy_kwargs
159+
160+
161+
def _convert_dynamic_shapes_for_dynamic_cache(dynamic_shapes: dict) -> dict:
162+
"""Convert dynamic_shapes for past_key_values from nested list to DynamicCache pytree format.
163+
164+
The old format is: [[key_shape, val_shape], ...] (one pair per layer)
165+
The DynamicCache pytree is: {"cache": [(key0, val0), (key1, val1), ...]}
166+
matching the structure from _register_dynamic_cache_export_support().
167+
"""
168+
pkv_shapes = dynamic_shapes.get("past_key_values")
169+
if pkv_shapes is None or not isinstance(pkv_shapes, (list, tuple)):
170+
return dynamic_shapes
171+
172+
if not pkv_shapes or not isinstance(pkv_shapes[0], (list, tuple)) or len(pkv_shapes[0]) != 2:
173+
return dynamic_shapes
174+
175+
# Convert [[key0, val0], [key1, val1], ...] -> {"cache": [(key0, val0), (key1, val1), ...]}
176+
# matching DynamicCache pytree: _dict_flatten({"cache": [(keys, values), ...]})
177+
dynamic_shapes["past_key_values"] = {
178+
"cache": [tuple(layer) for layer in pkv_shapes],
179+
}
180+
logger.debug("Converted dynamic_shapes for past_key_values to DynamicCache pytree format.")
181+
return dynamic_shapes
182+
183+
60184
def _patch_model_if_necessary(pytorch_model: torch.nn.Module):
61185
if not isinstance(pytorch_model, PreTrainedModel):
62186
return
@@ -179,9 +303,6 @@ def _export_pytorch_model(
179303
if torch_dtype:
180304
pytorch_model = pytorch_model.to(torch_dtype)
181305

182-
# Apply any necessary patches
183-
_patch_model_if_necessary(pytorch_model)
184-
185306
# get input and output names, and dynamic axes
186307
assert io_config is not None, "Cannot get io_config for the model."
187308
io_config = validate_config(io_config, IoConfig)
@@ -194,8 +315,6 @@ def _export_pytorch_model(
194315
# is taken, the old export always writes a model to the disk. When that happens we need to
195316
# load the model back into IR and load all the external tensor to memory
196317
with tempfile.TemporaryDirectory(prefix="olive_tmp") as tmp_dir:
197-
tmp_model_path = resolve_onnx_path(tmp_dir)
198-
199318
if dynamo:
200319
# Take the "release" version so that dev builds like 2.5.0dev1234 are treated as 2.5.0
201320
if _torch_is_older_than("2.7.0") and (
@@ -212,24 +331,39 @@ def _export_pytorch_model(
212331
"Please upgrade PyTorch to 2.6.0 or above."
213332
)
214333

215-
# Register DynamicCache export support
216-
from transformers.integrations.executorch import register_dynamic_cache_export_support
217-
218-
register_dynamic_cache_export_support()
219-
220334
if isinstance(dummy_inputs, dict):
221335
dummy_kwargs = dummy_inputs
222336
dummy_inputs = ()
223337
else:
224338
dummy_kwargs = {}
225339
dummy_inputs = tuple(dummy_inputs)
226340

341+
# Apply patches for DynamicCache / past_key_values compatibility
342+
if version.parse(transformers.__version__) >= version.parse("5.0"):
343+
# transformers >= 5.0: DynamicCache refactored to use DynamicLayer
344+
345+
_register_dynamic_cache_export_support()
346+
_patch_dynamic_layer_for_export()
347+
model_config = getattr(pytorch_model, "config", None)
348+
dummy_kwargs = _convert_past_key_values_to_dynamic_cache(dummy_kwargs, config=model_config)
349+
if io_config.dynamic_shapes:
350+
io_config.dynamic_shapes = _convert_dynamic_shapes_for_dynamic_cache(io_config.dynamic_shapes)
351+
else:
352+
# transformers < 5.0: patch forward to convert list <-> DynamicCache
353+
_patch_model_if_necessary(pytorch_model)
354+
227355
# NOTE: Usually validation is done in io_config.py, but because
228356
# dynamic_shapes has nested complexity, and it can't be validated multiple
229357
# times like others, we validate it here.
230358
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs = _validate_dynamic_shapes(
231359
io_config.dynamic_shapes, dummy_inputs, dummy_kwargs, pytorch_model
232360
)
361+
# torch.export requires strict type match between inputs and dynamic_shapes;
362+
# _validate_dynamic_shapes may return OrderedDict, so convert back to plain dict
363+
if isinstance(io_config.dynamic_shapes, collections.OrderedDict):
364+
io_config.dynamic_shapes = dict(io_config.dynamic_shapes)
365+
if isinstance(dummy_kwargs, collections.OrderedDict):
366+
dummy_kwargs = dict(dummy_kwargs)
233367

234368
# When dynamo=True, PyTorch prefers dynamic_shapes over dynamic_axes.
235369
# If dynamic_shapes is None and fallback is enabled, don't pass dynamic_axes
@@ -239,15 +373,13 @@ def _export_pytorch_model(
239373
onnx_program = torch.onnx.export( # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
240374
pytorch_model,
241375
dummy_inputs,
242-
tmp_model_path, # needed for fallback=True
243376
kwargs=dummy_kwargs,
244377
opset_version=config.target_opset,
245378
input_names=io_config.input_names,
246379
output_names=io_config.output_names,
247380
dynamic_axes=dynamic_axes_for_export,
248381
dynamic_shapes=io_config.dynamic_shapes,
249382
dynamo=True,
250-
fallback=False,
251383
optimize=config.optimize,
252384
report=logger.isEnabledFor(logging.DEBUG),
253385
)
@@ -264,6 +396,8 @@ def _export_pytorch_model(
264396
# default is True in 2.9.0 and later
265397
dynamo_args["dynamo"] = False
266398

399+
tmp_model_path = resolve_onnx_path(tmp_dir)
400+
267401
torch.onnx.export(
268402
pytorch_model,
269403
dummy_inputs,

olive/passes/pytorch/train_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ def create_training_args(self) -> transformers.TrainingArguments:
8383
if version.parse(transformers_version) < version.parse("4.41") and "eval_strategy" in args:
8484
args["evaluation_strategy"] = args.pop("eval_strategy")
8585
extra_args = args.pop("extra_args")
86+
# Filter out fields that are not valid TrainingArguments parameters (e.g. overwrite_output_dir
87+
# was removed in transformers 5.0 but is still used by Olive's own logic) and None values
88+
# so that transformers uses its own defaults
89+
training_args_fields = {f.name for f in dataclasses.fields(transformers.TrainingArguments) if f.init}
90+
args = {k: v for k, v in args.items() if k in training_args_fields and v is not None}
8691
return transformers.TrainingArguments(**args, **extra_args)
8792

8893

test/model/test_hf_model.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,16 @@ def setup(self):
2727
self.local_path = huggingface_hub.snapshot_download(self.model_name, revision=self.revision)
2828

2929
@pytest.mark.parametrize("local", [True, False])
30-
@pytest.mark.parametrize("trust_remote_code", [True, False])
31-
def test_load_model(self, local, trust_remote_code):
30+
def test_load_model(self, local):
3231
olive_model = HfModelHandler(
3332
model_path=self.local_path if local else self.model_name,
3433
task=self.task,
35-
load_kwargs={"trust_remote_code": trust_remote_code, "revision": self.revision},
34+
load_kwargs={"revision": self.revision},
3635
)
3736

3837
pytorch_model = olive_model.load_model()
3938
actual_class_path = f"{pytorch_model.__module__}.{pytorch_model.__class__.__name__}"
40-
if trust_remote_code:
41-
# When using remote code, the model is loaded from transformers_modules
42-
assert actual_class_path.startswith("transformers_modules.")
43-
assert actual_class_path.endswith(".modeling_phi3.Phi3ForCausalLM")
44-
else:
45-
# When not using remote code, the model is loaded from transformers
46-
assert actual_class_path == "transformers.models.phi3.modeling_phi3.Phi3ForCausalLM"
39+
assert actual_class_path == "transformers.models.phi3.modeling_phi3.Phi3ForCausalLM"
4740

4841
@pytest.mark.parametrize("local", [True, False])
4942
def test_load_model_with_kwargs(self, local):
@@ -73,19 +66,18 @@ def test_save_metadata(self, local, trust_remote_code, tokenizer_exists, tmp_pat
7366
if tokenizer_exists:
7467
olive_model.get_hf_tokenizer().save_pretrained(tmp_path)
7568
saved_filepaths = olive_model.save_metadata(tmp_path)
76-
# transformers>=4.53.x
77-
assert len(saved_filepaths) == (4 if tokenizer_exists else 10)
69+
# transformers>=5.0.0
70+
assert len(saved_filepaths) == (4 if tokenizer_exists else 7)
7871
assert all(Path(fp).exists() for fp in saved_filepaths)
7972
assert isinstance(transformers.AutoConfig.from_pretrained(tmp_path), transformers.Phi3Config)
80-
assert isinstance(transformers.AutoTokenizer.from_pretrained(tmp_path), transformers.LlamaTokenizerFast)
73+
assert isinstance(transformers.AutoTokenizer.from_pretrained(tmp_path), transformers.PreTrainedTokenizerBase)
8174

8275
@pytest.mark.parametrize("local", [True, False])
83-
@pytest.mark.parametrize("trust_remote_code", [True, False])
84-
def test_save_pretrained_metadata(self, local, trust_remote_code, tmp_path):
76+
def test_save_pretrained_metadata(self, local, tmp_path):
8577
olive_model = HfModelHandler(
8678
model_path=self.local_path if local else self.model_name,
8779
task=self.task,
88-
load_kwargs={"trust_remote_code": trust_remote_code, "revision": self.revision},
80+
load_kwargs={"revision": self.revision},
8981
)
9082

9183
# modify the config and save the model
@@ -94,8 +86,8 @@ def test_save_pretrained_metadata(self, local, trust_remote_code, tmp_path):
9486
loaded_model.save_pretrained(tmp_path)
9587

9688
saved_filepaths = olive_model.save_metadata(tmp_path)
97-
# generation config is also saved, transformers>=4.53.x
98-
assert len(saved_filepaths) == 9
89+
# generation config is also saved, transformers>=5.0.0
90+
assert len(saved_filepaths) == 6
9991

10092
with open(tmp_path / "config.json") as f:
10193
config = json.load(f)
@@ -126,7 +118,7 @@ def test_save_metadata_with_module_files(trust_remote_code, tmp_path):
126118
assert f"{config.__module__}.{config.__class__.__name__}" == expected_class_name
127119
assert isinstance(
128120
transformers.AutoTokenizer.from_pretrained(tmp_path, **load_kwargs),
129-
transformers.LlamaTokenizerFast,
121+
transformers.PreTrainedTokenizerBase,
130122
)
131123

132124

test/passes/onnx/test_conversion.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
55
import platform
6-
import shutil
76
from itertools import chain
87
from pathlib import Path
98
from unittest.mock import patch
@@ -211,9 +210,7 @@ def mock_onnx_export_func(*args, **kwargs):
211210
nonlocal dummy_kwargs
212211
# For dynamo export, inputs are passed via kwargs parameter
213212
dummy_kwargs = kwargs.get("kwargs", {})
214-
_, _, output_path = args
215-
shutil.copyfile(ONNX_MODEL_PATH, output_path)
216-
return MockOnnxProgram(output_path)
213+
return MockOnnxProgram(ONNX_MODEL_PATH)
217214

218215
output_folder = tmp_path / "onnx"
219216
output_folder.mkdir(parents=True, exist_ok=True)

test/passes/pytorch/test_rotate.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ def common_test_rotate(rotate_pass, tmp_path, model_path, rotate_mode, atol, **c
3535
with torch.no_grad():
3636
original_output = original_model(i)
3737
rotated_output = rotated_model(i)
38-
assert torch.allclose(original_output.logits, rotated_output.logits, atol=atol)
38+
# Cast to same dtype before comparison since rotated model may be saved/loaded in a different dtype
39+
assert torch.allclose(original_output.logits.float(), rotated_output.logits.float(), atol=atol)
3940

4041

4142
@pytest.mark.parametrize("model_path", ["tiny-phi3", "tiny-llama"])

test/requirements-test.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,5 +37,3 @@ sentencepiece
3737
soundfile
3838
tabulate
3939
torchvision
40-
# Remove version pin when the tests are fixed
41-
transformers<5.0.0

test/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def get_pytorch_model(batch_size=1):
7777
)
7878

7979

80-
def get_hf_model(model_path="hf-internal-testing/tiny-random-gptj"):
81-
return HfModelHandler(model_path=model_path)
80+
def get_hf_model(model_path="hf-internal-testing/tiny-random-LlamaForCausalLM"):
81+
return HfModelHandler(model_path=model_path, task="text-generation")
8282

8383

8484
def get_hf_model_config():

0 commit comments

Comments
 (0)