Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ jobs:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.56.2', '4.57.6', 'main']
torch: ['2.9', 'main']
torch: ['2.10', 'main']
exclude:
- python: '3.10' # 3.10
torch: 'main'
- python: '3.10'
torch: '2.9'
torch: '2.10'
- python: '3.10'
transformers: '4.55.4'
- python: '3.10'
Expand All @@ -43,7 +43,7 @@ jobs:
- python: '3.11'
transformers: 'main'
- python: '3.13' # 3.11
torch: '2.9'
torch: '2.10'
- python: '3.13'
transformers: '4.48.3'
- python: '3.13'
Expand Down
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.8.11
++++++

* :pr:`396`: fix serialization for DynamicCache with different layer classes
* :pr:`394`: add function make_model_with_local_functions to partition a model into local functions

0.8.10
Expand Down
19 changes: 19 additions & 0 deletions _unittests/ut_helpers/test_torch_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,25 @@ def test_torch_deepcopy_sliding_windon_cache(self):
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cache), 1)

@unittest.skipIf(make_sliding_window_cache is not None, "transformers<5")
def test_torch_deepcopy_sliding_windon_cache5(self):
cache = make_dynamic_cache(
[
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
(torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
],
cls_layers="DynamicSlidingWindowLayer",
)
at = torch_deepcopy(cache)
self.assertEqual(type(cache), type(at))
self.assertEqual(max_diff(cache, at)["abs"], 0)
hash1 = string_type(at, with_shape=True, with_min_max=True)
CacheKeyValue(cache).key_cache[0] += 1000
hash2 = string_type(at, with_shape=True, with_min_max=True)
self.assertEqual(hash1, hash2)
self.assertGreater(torch_tensor_size(cache), 1)

def test_torch_deepcopy_none(self):
self.assertEmpty(torch_deepcopy(None))
self.assertEqual(torch_tensor_size(None), 0)
Expand Down
2 changes: 1 addition & 1 deletion _unittests/ut_tasks/test_tasks_image_text_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class TestTasksImageTextToText(ExtTestCase):
@hide_stdout()
@requires_transformers("4.56")
@requires_transformers("5.0.99")
@requires_torch("2.7.99")
def test_image_text_to_text_idefics(self):
mid = "HuggingFaceM4/tiny-random-idefics"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,20 @@ def test_base_sliding_window_cache_unflatten_flatten(self):
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)

@ignore_warnings(UserWarning)
@unittest.skipIf(make_sliding_window_cache, "transformers<5")
def test_base_sliding_window_cache_unflatten_flatten5(self):
cache = make_dynamic_cache(
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))],
cls_layers="DynamicSlidingWindowLayer",
)
with torch_export_patches(patch_transformers=True):
cache2 = torch_deepcopy([cache])
self.assertEqualAny([cache], cache2)
self.assertEqual(
[type(lay) for lay in cache.layers], [type(lay) for lay in cache2[0].layers]
)

@ignore_warnings(UserWarning)
@requires_torch("2.7.99")
@unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed")
Expand All @@ -215,6 +229,30 @@ def forward(self, cache):
with torch_export_patches(patch_transformers=True):
torch.export.export(model, (cache,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
@requires_torch("2.7.99")
@unittest.skipIf(make_sliding_window_cache, "transformers<5")
def test_sliding_window_cache_export5(self):
class Model(torch.nn.Module):
def forward(self, cache):
dc = CacheKeyValue(cache)
return dc.key_cache[0]

cache = make_dynamic_cache(
[
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4))),
],
cls_layers="DynamicSlidingWindowLayer",
)
model = Model()
model(cache)
DYN = torch.export.Dim.DYNAMIC
ds = make_dynamic_shapes_kv_cache(cache, {0: DYN})

with torch_export_patches(patch_transformers=True):
torch.export.export(model, (cache,), dynamic_shapes=(ds,))

@ignore_warnings(UserWarning)
@unittest.skipIf(not make_sliding_window_cache, "SlidingWindowCache was removed")
def test_sliding_window_cache_flatten(self):
Expand All @@ -233,6 +271,28 @@ def test_sliding_window_cache_flatten(self):
self.string_type(cache2, with_shape=True, with_min_max=True),
)

@ignore_warnings(UserWarning)
@unittest.skipIf(make_sliding_window_cache, "transformers<5")
def test_sliding_window_cache_flatten5(self):
cache = make_dynamic_cache(
[(torch.rand((4, 4, 4, 4)), torch.rand((4, 4, 4, 4)))],
cls_layers="DynamicSlidingWindowLayer",
)
with torch_export_patches(patch_transformers=True):
flat, _spec = torch.utils._pytree.tree_flatten(cache)
self.assertEqual(
"#2[T1s4x4x4x4,T1s4x4x4x4]",
self.string_type(flat, with_shape=True),
)
cache2 = torch.utils._pytree.tree_unflatten(flat, _spec)
self.assertEqual(
self.string_type(cache, with_shape=True, with_min_max=True),
self.string_type(cache2, with_shape=True, with_min_max=True),
)
self.assertEqual(
[type(lay) for lay in cache.layers], [type(lay) for lay in cache2.layers]
)

@ignore_warnings(UserWarning)
@requires_torch("2.7.99")
def test_static_cache(self):
Expand Down
9 changes: 7 additions & 2 deletions _unittests/ut_torch_export_patches/test_patch_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,12 @@ def test_patched_qwen2_5_vl_get_window_index(self):
self.assertEqualArray(torch.tensor(cu_window_seqlens1), cu_window_seqlens2)

@requires_transformers("4.55")
@unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
# @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers")
# see https://github.com/huggingface/transformers/pull/42564/files#diff-09bc594f9680f1d042fd485106c68022d77b59831697a00b3b38f12a3e40f395
@unittest.skip(
"vision_outputs = self.visual(pixel_values, "
"grid_thw=image_grid_thw, return_dict=True, **kwargs)"
)
def test_patched_qwen2_5_vl_forward(self):
from onnx_diagnostic.torch_export_patches.patches.patch_transformers import (
patched_Qwen2_5_VisionTransformerPretrainedModel,
Expand All @@ -422,7 +427,7 @@ def test_patched_qwen2_5_vl_forward(self):
instance, hidden_states, grid_thw
)
patched_class.get_window_index = f_get_window_index
self.assertEqualArray(expected, got)
self.assertEqualAny(expected, got)

@classmethod
def _get_cu_seqlens(cls):
Expand Down
45 changes: 31 additions & 14 deletions onnx_diagnostic/ext_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,6 +1028,19 @@ def assertEqualAny(
rtol=rtol,
msg=msg,
)
elif expected.__class__.__name__ == "BaseModelOutputWithPooling":
if expected.__class__.__name__ == value.__class__.__name__:
self.assertEqual(len(expected), len(value), msg=msg)
self.assertEqual(list(expected), list(value), msg=msg) # checks the order
self.assertEqualAny(
{k: v for k, v in expected.items()}, # noqa: C416
{k: v for k, v in value.items()}, # noqa: C416
atol=atol,
rtol=rtol,
msg=msg,
)
else:
self.assertEqualArray(expected.last_hidden_state, value)
elif isinstance(expected, (tuple, list, dict)):
self.assertIsInstance(value, type(expected), msg=msg)
self.assertEqual(len(expected), len(value), msg=msg)
Expand All @@ -1043,24 +1056,28 @@ def assertEqualAny(
"SlidingWindowCache",
"HybridCache",
):
from .helpers.cache_helper import CacheKeyValue

self.assertEqual(type(expected), type(value), msg=msg)
atts = ["key_cache", "value_cache"]
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
elif expected.__class__.__name__ == "StaticCache":
from .helpers.cache_helper import CacheKeyValue

self.assertEqual(type(expected), type(value), msg=msg)
self.assertEqual(expected.max_cache_len, value.max_cache_len)
atts = ["key_cache", "value_cache"]
self.assertEqualAny(
{k: expected.__dict__.get(k, None) for k in atts},
{k: value.__dict__.get(k, None) for k in atts},
atol=atol,
rtol=rtol,
)
self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
elif expected.__class__.__name__ == "CacheKeyValue":
self.assertEqual(type(expected), type(value), msg=msg)
if expected.cls_layers is None:
self.assertEqual(expected.cls_layers, value.cls_layers)
else:
self.assertEqualAny(
[cls.__name__ for cls in expected.cls_layers],
[cls.__name__ for cls in value.cls_layers],
msg=msg,
)
self.assertEqualAny(expected.key_cache, value.key_cache, msg=msg)
self.assertEqualAny(expected.value_cache, value.value_cache, msg=msg)
elif expected.__class__.__name__ == "EncoderDecoderCache":
self.assertEqual(type(expected), type(value), msg=msg)
atts = ["self_attention_cache", "cross_attention_cache"]
Expand Down
8 changes: 4 additions & 4 deletions onnx_diagnostic/helpers/cache_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ def make_dynamic_cache(
def make_static_cache(
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
max_cache_len: Optional[int] = None,
cls_layers: Optional[Union[str, List[type]]] = None,
) -> transformers.cache_utils.DynamicCache:
"""
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
Expand Down Expand Up @@ -379,6 +380,9 @@ def make_static_cache(
)
print(string_type(past_key_values, with_shape=True))
"""
assert not cls_layers or set(cls_layers) == {
transformers.cache_utils.StaticLayer
}, f"Not implemented when cls_layers={cls_layers!r}"
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)

class _config:
Expand Down Expand Up @@ -583,13 +587,9 @@ def get_text_config(self, *args, **kwargs):
)
return finalize_cache(cache)

def get_make_hybrid_cache():
return make_sliding_window_cache

else:
make_sliding_window_cache = None # type: ignore[assignment]


if hasattr(transformers.cache_utils, "HybridCache"):

def make_hybrid_cache(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import itertools
from typing import Any, Callable, List, Set, Tuple
import torch
import transformers.cache_utils
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache

try:
Expand All @@ -27,16 +28,38 @@
DynamicCache: "4.50",
BaseModelOutput: None,
}
SHORTEN_LAYER_NAMES = {
"DynamicLayer": "D",
"DynamicSlidingWindowLayer": "W",
"StaticLayer": "S",
"StaticSlidingWindowLayer": "X",
"D": "DynamicLayer",
"W": "DynamicSlidingWindowLayer",
"S": "StaticLayer",
"X": "StaticSlidingWindowLayer",
}


def _flatten_key_value_cache(cache: Cache) -> Tuple[List[Any], torch.utils._pytree.Context]:
ca = CacheKeyValue(cache)
flat = list(itertools.chain.from_iterable(zip(ca.key_cache, ca.value_cache)))
keys = list(
itertools.chain.from_iterable(
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
unique = set(ca.cls_layers) if ca.cls_layers else None
if (
cache.__class__.__name__ != "DynamicCache"
or unique is None
or (len(unique) == 1 and unique.pop().__name__ == "DynamicLayer")
):
keys = list(
itertools.chain.from_iterable(
(f"key_{i}", f"value_{i}") for i in range(len(ca.key_cache))
)
)
)
return flat, keys

keys = []
for i in range(len(ca.key_cache)):
letter = SHORTEN_LAYER_NAMES[ca.cls_layers[i].__name__]
keys.extend([f"key_{letter}{i}", f"value_{letter}{i}"])
return flat, keys


Expand All @@ -54,7 +77,20 @@ def _unflatten_cache(
output_type=None,
) -> DynamicCache:
"""Restores a :class:`transformers.cache_utils.DynamicCache` from python objects."""
res = make_cache(list(zip(values[::2], values[1::2])))
expected = list(
itertools.chain.from_iterable(
(f"key_{i}", f"value_{i}") for i in range(len(values) // 2)
)
)
if expected == context:
res = make_cache(list(zip(values[::2], values[1::2])))
else:
cls_layer_names = [SHORTEN_LAYER_NAMES[name.split("_")[1][0]] for name in context][::2]
cls_layers = [
getattr(transformers.cache_utils, cls_name) for cls_name in cls_layer_names
]
res = make_cache(list(zip(values[::2], values[1::2])), cls_layers=cls_layers)

assert output_type is None or isinstance(
res, output_type
), f"Type mismatch between {output_type} (expected) and {type(res)}"
Expand All @@ -70,29 +106,13 @@ def flatten_dynamic_cache(
dynamic_cache: DynamicCache,
) -> Tuple[List[Any], torch.utils._pytree.Context]:
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
assert (
not hasattr(dynamic_cache, "layers")
or not dynamic_cache.layers
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
), (
f"The serialization does not work yet on other layers "
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
)
return _flatten_key_value_cache(dynamic_cache)


def flatten_with_keys_dynamic_cache(
dynamic_cache: DynamicCache,
) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
"""Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
assert (
not hasattr(dynamic_cache, "layers")
or not dynamic_cache.layers
or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
), (
f"The serialization does not work yet on other layers "
f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
)
return _flatten_with_keys_cache(dynamic_cache)


Expand Down Expand Up @@ -160,7 +180,9 @@ def unflatten_static_cache(
) -> StaticCache:
"""Restores a :class:`transformers.cache_utils.StaticCache` from python objects."""
return _unflatten_cache(
lambda *args: make_static_cache(*args, max_cache_len=values[0].shape[2]),
lambda *args, **kwargs: make_static_cache(
*args, max_cache_len=values[0].shape[2], **kwargs
),
values,
context,
output_type=output_type,
Expand Down
Loading