From eb5fc90aa4060b6a367041d45cf4073b24b017dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 5 Feb 2026 17:50:51 +0100 Subject: [PATCH 1/8] patch doc --- _doc/status/patches_diff.rst | 4 +++- onnx_diagnostic/torch_export_patches/patch_details.py | 9 +++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/_doc/status/patches_diff.rst b/_doc/status/patches_diff.rst index c2b84241..4045c25c 100644 --- a/_doc/status/patches_diff.rst +++ b/_doc/status/patches_diff.rst @@ -62,7 +62,9 @@ Those two versions leads to the following list of patches. ): pass for patch in details.patched: - print(f"* {patch.family} - {getattr(patch.function_to_patch, '__name__', patch.function_to_patch)}") + if patch.function_to_patch == patch.patch: + continue + print(f"* :ref:`{patch.refid}`") print() print() for patch in details.patched: diff --git a/onnx_diagnostic/torch_export_patches/patch_details.py b/onnx_diagnostic/torch_export_patches/patch_details.py index 3cb3a2ea..67dcad24 100644 --- a/onnx_diagnostic/torch_export_patches/patch_details.py +++ b/onnx_diagnostic/torch_export_patches/patch_details.py @@ -117,6 +117,12 @@ def make_diff(self) -> str: def function_name(cls, f: Callable) -> str: return f.__qualname__ + @property + def refid(self) -> str: + kind = self.family or "" + patch_name = self.function_name(self.patch) + return f"patch-{kind}-{patch_name}" + def format_diff(self, format: str = "raw") -> str: """ Format a diff between two function as a string. @@ -154,6 +160,9 @@ def format_diff(self, format: str = "raw") -> str: return f"{title}\n{diff}" rows = [ + "", + f".. _{self.refid}", + "", title, "=" * len(title), "", From 2698987193c0ee884ad7845c0d5e57ffefbae9c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 6 Feb 2026 10:27:08 +0100 Subject: [PATCH 2/8] update CI for 5.1.0 --- .github/workflows/check-release.yml | 2 +- .github/workflows/ci.yml | 6 +++--- _unittests/ut_tasks/test_tasks.py | 2 +- _unittests/ut_tasks/test_tasks_image_text_to_text.py | 6 +++--- .../torch_export_patches/onnx_export_errors.py | 4 ++-- onnx_diagnostic/torch_export_patches/patch_details.py | 11 +++++++++-- 6 files changed, 19 insertions(+), 12 deletions(-) diff --git a/.github/workflows/check-release.yml b/.github/workflows/check-release.yml index 7fbafa41..2d9be2fa 100644 --- a/.github/workflows/check-release.yml +++ b/.github/workflows/check-release.yml @@ -16,7 +16,7 @@ jobs: matrix: os: [ubuntu-latest, macOS-latest, windows-latest] python: ['3.13'] - transformers: ['5.0', 'main'] + transformers: ['5.1.0', 'main'] torch: ['2.10', 'main'] steps: diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b33d1031..ed4f416b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.11', '3.12', '3.13'] - transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.0', 'main'] + transformers: ['4.48.3', '4.51.3', '4.55.4', '4.57.6', '5.1.0', 'main'] torch: ['2.10', 'main'] exclude: # 3.10 - torch @@ -29,7 +29,7 @@ jobs: - python: '3.10' transformers: '4.57.6' - python: '3.10' - transformers: '5.0' + transformers: '5.1.0' - python: '3.10' transformers: 'main' # 3.11 - torch @@ -41,7 +41,7 @@ jobs: - python: '3.11' transformers: '4.57.6' - python: '3.11' - transformers: '5.0' + transformers: '5.1.0' - python: '3.11' transformers: 'main' # 3.13 - torch diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index 63a58358..f3e6a205 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -266,7 +266,7 @@ def test_falcon_mamba_dev(self): model(**inputs) model(**data["inputs2"]) self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)]) - if not has_transformers("5.0.99"): + if not has_transformers("5.1.99"): raise unittest.SkipTest("The model has control flow.") with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): torch.export.export( diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 891aa9d1..96f4f152 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -15,7 +15,7 @@ class TestTasksImageTextToText(ExtTestCase): @hide_stdout() - @requires_transformers("5.0.99") + @requires_transformers("5.1.99") @requires_torch("2.7.99") def test_image_text_to_text_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" @@ -32,7 +32,7 @@ def test_image_text_to_text_idefics(self): self.assertEqualAny(expected, ep.module()(**inputs), atol=1) @hide_stdout() - @requires_transformers("5.0.99") + @requires_transformers("5.1.99") @requires_torch("2.7.99") def test_image_text_to_text_tiny_gemma3(self): """ @@ -88,7 +88,7 @@ def test_image_text_to_text_gemma3_4b_it(self): self.assertEqualAny(expected, ep.module()(**inputs)) @hide_stdout() - @requires_transformers("5.0.99") + @requires_transformers("5.1.99") @requires_torch("2.7.99") def test_image_text_to_text_zai_glm(self): """ diff --git a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py index 79c4d9e7..30fa9727 100644 --- a/onnx_diagnostic/torch_export_patches/onnx_export_errors.py +++ b/onnx_diagnostic/torch_export_patches/onnx_export_errors.py @@ -71,10 +71,10 @@ def patch_module_or_classes( if isinstance(mod, list): to_patch = mod name = "list" - list_name = "auto/list" + list_name = "_PATCHED_list" else: name, to_patch = get_patches(mod, verbose) - list_name = f"auto/{mod.__name__.split('.')[-1]}" + list_name = f"_PATCHED_{mod.__name__.split('.')[-1]}" res = {} for cls in to_patch: diff --git a/onnx_diagnostic/torch_export_patches/patch_details.py b/onnx_diagnostic/torch_export_patches/patch_details.py index 67dcad24..77844356 100644 --- a/onnx_diagnostic/torch_export_patches/patch_details.py +++ b/onnx_diagnostic/torch_export_patches/patch_details.py @@ -120,7 +120,13 @@ def function_name(cls, f: Callable) -> str: @property def refid(self) -> str: kind = self.family or "" - patch_name = self.function_name(self.patch) + patch_name = ( + self.function_name(self.patch) + .replace(".", "-") + .replace("/", "-") + .replace(">", "") + .replace("<", "") + ) return f"patch-{kind}-{patch_name}" def format_diff(self, format: str = "raw") -> str: @@ -155,13 +161,14 @@ def format_diff(self, format: str = "raw") -> str: else self.function_name(self.function_to_patch) ) patch_name = self.function_name(self.patch) + kind = kind.replace("_PATCHED_", "") title = f"{kind}{function_to_pach_name} -> {patch_name}" if format == "raw": return f"{title}\n{diff}" rows = [ "", - f".. _{self.refid}", + f".. _{self.refid}:", "", title, "=" * len(title), From ccc85528acf0cae9f7aaeb55c4694a00d2b9d6de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 6 Feb 2026 10:37:01 +0100 Subject: [PATCH 3/8] fix version --- _scripts/test_backend_onnxruntime.py | 2 +- .../ut_reference/test_backend_onnxruntime_evaluator.py | 2 +- .../ut_torch_export_patches/test_patch_transformers.py | 8 ++++---- _unittests/ut_torch_onnx/test_discrepancies.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/_scripts/test_backend_onnxruntime.py b/_scripts/test_backend_onnxruntime.py index 222df32f..df5f3a71 100644 --- a/_scripts/test_backend_onnxruntime.py +++ b/_scripts/test_backend_onnxruntime.py @@ -141,7 +141,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): backend_test.exclude("(test_adagrad|test_adam|test_add_uint8)") -if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"): +if pv.Version(onnxruntime.__version__) <= pv.Version("1.25"): backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)") diff --git a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py index 1d5131c6..b588221a 100644 --- a/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py +++ b/_unittests/ut_reference/test_backend_onnxruntime_evaluator.py @@ -299,7 +299,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): ) -if pv.Version(onnxruntime.__version__) <= pv.Version("1.24"): +if pv.Version(onnxruntime.__version__) <= pv.Version("1.25"): backend_test.exclude("(test_attention_4d_with|test_attention_4d_gqa)") # import all test cases at global scope to make them visible to python.unittest diff --git a/_unittests/ut_torch_export_patches/test_patch_transformers.py b/_unittests/ut_torch_export_patches/test_patch_transformers.py index 4982d06c..0651090e 100644 --- a/_unittests/ut_torch_export_patches/test_patch_transformers.py +++ b/_unittests/ut_torch_export_patches/test_patch_transformers.py @@ -703,7 +703,7 @@ def test_plug_multi_head_attention_qwen25_packed_float16(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) self.assertLess(results.diffs[0]["abs"], 0.01) - @requires_onnxruntime("1.24") + @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopmha_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( @@ -738,7 +738,7 @@ def test_plug_multi_head_attention_qwen25_loopmha_float16(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.01) self.assertLess(results.diffs[0]["abs"], 0.01) - @requires_onnxruntime("1.24") + @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopmha_float32(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( @@ -773,7 +773,7 @@ def test_plug_multi_head_attention_qwen25_loopmha_float32(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=1e-5) self.assertLess(results.diffs[0]["abs"], 1e-5) - @requires_onnxruntime("1.24") + @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopa24_float16(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( @@ -801,7 +801,7 @@ def test_plug_multi_head_attention_qwen25_loopa24_float16(self): self.assertEqualArray(results.eager_outputs[0], results.onnx_outputs[0], atol=0.005) self.assertLess(results.diffs[0]["abs"], 0.005) - @requires_onnxruntime("1.24") + @requires_onnxruntime("1.25") @unittest.skipIf(not patch_qwen2_5, "Qwen25 not part of this transformers") def test_plug_multi_head_attention_qwen25_loopa24_float32(self): from onnx_diagnostic.torch_export_patches.patches._patch_transformers_qwen2_5 import ( diff --git a/_unittests/ut_torch_onnx/test_discrepancies.py b/_unittests/ut_torch_onnx/test_discrepancies.py index 9d344ff4..dd73ddfe 100644 --- a/_unittests/ut_torch_onnx/test_discrepancies.py +++ b/_unittests/ut_torch_onnx/test_discrepancies.py @@ -46,7 +46,7 @@ def qwen_sdpa_attention( return attn_output for model_name in ["attention_loopa24.onnx", "attention_loopmha.onnx"]: - if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.24"): + if model_name == "attention_loopa24.onnx" and not has_onnxruntime("1.25"): # not available continue with self.subTest(model=model_name): From 1748fbcd9785fb8be4bd06e7e291a935bf5016bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Fri, 6 Feb 2026 14:31:16 +0100 Subject: [PATCH 4/8] fix dynamiccache --- CHANGELOGS.rst | 2 ++ onnx_diagnostic/helpers/torch_helper.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 97a3309a..2273b14c 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.9.1 +++++ +* :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0 + 0.9.0 +++++ diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 8888dbef..1d8d7f40 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -850,6 +850,15 @@ def torch_deepcopy(value: Any) -> Any: if value.__class__.__name__ == "DynamicCache": from .cache_helper import CacheKeyValue + if ( + hasattr(value, "layers") + and len(value.layers) == 1 + and value.layers[0].keys is None + ): + import transformers + + return transformers.cache_utils.DynamicCache(None) + ca = CacheKeyValue(value) pairs = list(zip(ca.key_cache, ca.value_cache)) assert not hasattr(value, "layers") or len(value.layers) == len(pairs), ( From 5ea6d2d5a1cf6a56c22d75b8a77dc778186bcf31 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 7 Feb 2026 00:10:49 +0100 Subject: [PATCH 5/8] fix unt --- _unittests/ut_tasks/test_tasks_image_text_to_text.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 96f4f152..81615212 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -15,7 +15,7 @@ class TestTasksImageTextToText(ExtTestCase): @hide_stdout() - @requires_transformers("5.1.99") + @requires_transformers("5.1.999") @requires_torch("2.7.99") def test_image_text_to_text_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" From 953f4eb009bd3b032c2eae5732e72a798501b760 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 7 Feb 2026 01:14:00 +0100 Subject: [PATCH 6/8] fix for 5.2.0 --- CHANGELOGS.rst | 2 +- .../ut_tasks/test_tasks_image_text_to_text.py | 6 ++--- .../_patch_transformers_output_capturing.py | 27 +++++++++++++++++++ .../patches/patch_transformers.py | 6 +++++ 4 files changed, 37 insertions(+), 4 deletions(-) create mode 100644 onnx_diagnostic/torch_export_patches/patches/_patch_transformers_output_capturing.py diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index 2273b14c..09ecf385 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,7 @@ Change Logs 0.9.1 +++++ -* :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0 +* :pr:`408`: fix torch_deepcopy for empty DynamicCache and transformers==5.1.0, 5.2.0 (see https://github.com/huggingface/transformers/pull/43765/) 0.9.0 +++++ diff --git a/_unittests/ut_tasks/test_tasks_image_text_to_text.py b/_unittests/ut_tasks/test_tasks_image_text_to_text.py index 81615212..3d487180 100644 --- a/_unittests/ut_tasks/test_tasks_image_text_to_text.py +++ b/_unittests/ut_tasks/test_tasks_image_text_to_text.py @@ -15,7 +15,7 @@ class TestTasksImageTextToText(ExtTestCase): @hide_stdout() - @requires_transformers("5.1.999") + @requires_transformers("5.2.99") @requires_torch("2.7.99") def test_image_text_to_text_idefics(self): mid = "HuggingFaceM4/tiny-random-idefics" @@ -32,7 +32,7 @@ def test_image_text_to_text_idefics(self): self.assertEqualAny(expected, ep.module()(**inputs), atol=1) @hide_stdout() - @requires_transformers("5.1.99") + @requires_transformers("5.2.99") @requires_torch("2.7.99") def test_image_text_to_text_tiny_gemma3(self): """ @@ -88,7 +88,7 @@ def test_image_text_to_text_gemma3_4b_it(self): self.assertEqualAny(expected, ep.module()(**inputs)) @hide_stdout() - @requires_transformers("5.1.99") + @requires_transformers("5.2.99") @requires_torch("2.7.99") def test_image_text_to_text_zai_glm(self): """ diff --git a/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_output_capturing.py b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_output_capturing.py new file mode 100644 index 00000000..b35e27f3 --- /dev/null +++ b/onnx_diagnostic/torch_export_patches/patches/_patch_transformers_output_capturing.py @@ -0,0 +1,27 @@ +try: + import transformers.utils.output_capturing # noqa: F401 + + patch_output_capturing = True +except ImportError: + patch_output_capturing = False + + +if patch_output_capturing: + # Introduced in 5.2.0 + # https://github.com/huggingface/transformers/pull/43765/ + # changes#diff-b5f9fdbe43ffd89fbdf2b246dc78dd32aa4bdb587e7a53e4dad37b7efd79ab0a + import torch + import transformers + from transformers.utils.import_utils import is_torchdynamo_compiling + + class patched_CompileableContextVar: + _PATCHES_ = ["set"] + _PATCHED_CLASS_ = transformers.utils.output_capturing.CompileableContextVar + + def set(self, value): + if is_torchdynamo_compiling() and not torch.compiler.is_exporting(): + self.global_var = value + self.compiling = True + return None + else: + return self.context_var.set(value) diff --git a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py index 89b71250..20623543 100644 --- a/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +++ b/onnx_diagnostic/torch_export_patches/patches/patch_transformers.py @@ -42,6 +42,12 @@ patched_sdpa_mask_recent_torch, ) +from ._patch_transformers_output_capturing import patch_output_capturing + +if patch_output_capturing: + from ._patch_transformers_output_capturing import patched_CompileableContextVar + + # transformers models dependent patches if _has_transformers("4.51"): From 75eb58e6deb944c124a261eb7aa485645cafd9f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 7 Feb 2026 01:27:49 +0100 Subject: [PATCH 7/8] fix ut --- _unittests/ut_tasks/test_tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_unittests/ut_tasks/test_tasks.py b/_unittests/ut_tasks/test_tasks.py index f3e6a205..c0c666ae 100644 --- a/_unittests/ut_tasks/test_tasks.py +++ b/_unittests/ut_tasks/test_tasks.py @@ -266,7 +266,7 @@ def test_falcon_mamba_dev(self): model(**inputs) model(**data["inputs2"]) self.assertIn((data["size"], data["n_weights"]), [(274958336, 68739584)]) - if not has_transformers("5.1.99"): + if not has_transformers("5.2.99"): raise unittest.SkipTest("The model has control flow.") with torch_export_patches(patch_transformers=True, verbose=10, stop_if_static=1): torch.export.export( From f8fc83653c1e58364a76bcb56a7619f61f612d21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Sat, 7 Feb 2026 11:22:08 +0100 Subject: [PATCH 8/8] fix documentation --- _doc/status/patches_diff.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/_doc/status/patches_diff.rst b/_doc/status/patches_diff.rst index 4045c25c..3effb8ad 100644 --- a/_doc/status/patches_diff.rst +++ b/_doc/status/patches_diff.rst @@ -61,13 +61,21 @@ Those two versions leads to the following list of patches. patch_details=details, ): pass + done = set() for patch in details.patched: if patch.function_to_patch == patch.patch: continue + if patch.refid in done: + continue + done.add(patch.refid) print(f"* :ref:`{patch.refid}`") print() print() + done = set() for patch in details.patched: + if patch.refid in done: + continue + done.add(patch.refid) if patch.function_to_patch == patch.patch: continue rst = patch.format_diff(format="rst")