From 159402c771ec81be2a47256d516178644ca3ef08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 27 Jan 2026 18:28:22 +0100 Subject: [PATCH 1/4] investigation around input observer --- .../ut_investigate/test_input_observer.py | 86 ++++++++- onnx_diagnostic/investigate/input_observer.py | 177 ++++++++++++++---- 2 files changed, 220 insertions(+), 43 deletions(-) diff --git a/_unittests/ut_investigate/test_input_observer.py b/_unittests/ut_investigate/test_input_observer.py index 37f6688a..b1b8ca2d 100644 --- a/_unittests/ut_investigate/test_input_observer.py +++ b/_unittests/ut_investigate/test_input_observer.py @@ -4,14 +4,14 @@ from onnx_diagnostic.ext_test_case import ExtTestCase from onnx_diagnostic.investigate.input_observer import ( InputObserver, - infer_dynamic_dimensions, + _infer_dynamic_dimensions, ) class TestInputObserver(ExtTestCase): def test_infer_dynamic_dimensions(self): - self.assertEqual([2], infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)])) - self.assertEqual([0, 2], infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)])) + self.assertEqual([2], _infer_dynamic_dimensions([(1, 2, 3), (1, 2, 4)])) + self.assertEqual([0, 2], _infer_dynamic_dimensions([(1, 2, 3), (2, 2, 4)])) def test_io_captured_args(self): class Model(torch.nn.Module): @@ -42,6 +42,36 @@ def forward(self, x, y): self.assertIsInstance(args, tuple) self.assertEqual(2, len(args)) + def test_io_captured_not_forward(self): + class Model(torch.nn.Module): + def notforward(self, w): + return w.abs() + + def forward(self, x, y): + return x + self.notforward(y) + + inputs = [ + (torch.randn((5, 6)), torch.randn((1, 6))), + (torch.randn((7, 7)), torch.randn((1, 7))), + (torch.randn((7, 8)), torch.randn((1, 8))), + (torch.randn((7, 9)), torch.randn((1, 9))), + ] + + model = Model() + observer = InputObserver() + with observer(model, method_name="notforward"): + for args in inputs: + model(*args) + self.assertEqual(len(observer.info), 3) + for i in range(3): + self.assertEqual(len(observer.info.flat_outputs[i]), 1) + + cst = torch.export.Dim.DYNAMIC + self.assertEqual(({1: cst},), observer.infer_dynamic_shapes()) + args = observer.infer_arguments() + self.assertIsInstance(args, tuple) + self.assertEqual(1, len(args)) + def test_io_captured_kwargs(self): class Model(torch.nn.Module): def forward(self, x, y): @@ -516,6 +546,56 @@ def forward(self, x, custom=None): model(*args) self.assertEqual(expected, observer.infer_dynamic_shapes()) + def test_io_captured_args_kwargs_dynamic_batch(self): + class Model(torch.nn.Module): + def forward(self, x, y, z=None, w=None): + r = x + y + if z is not None: + r += z + if w is not None: + r += w + return r + + inputs = [ + ( + (torch.randn((5, 6)), torch.randn((1, 6))), + dict(z=torch.randn((5, 6)), w=torch.randn((1, 6))), + ), + ( + (torch.randn((5, 7)), torch.randn((1, 7))), + dict(z=torch.randn((5, 7)), w=torch.randn((1, 7))), + ), + ( + (torch.randn((5, 8)), torch.randn((1, 8))), + dict(z=torch.randn((5, 8)), w=torch.randn((1, 8))), + ), + ( + (torch.randn((5, 9)), torch.randn((1, 9))), + dict(z=torch.randn((5, 9)), w=torch.randn((1, 9))), + ), + ] + + model = Model() + expected = [model(*args, **kwargs) for args, kwargs in inputs] + observer = InputObserver() + with observer(model): + for args, kwargs in inputs: + model(*args, **kwargs) + self.assertEqual(len(observer.info), 3) + for i in range(3): + self.assertEqual(len(observer.info.flat_outputs[i]), 1) + torch.testing.assert_close(expected[i], observer.info.flat_outputs[i][0]) + + cst = torch.export.Dim.DYNAMIC + self.assertEqual( + dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}), + observer.infer_dynamic_shapes(add_batch_dimension_for={0, "z"}), + ) + self.assertEqual( + dict(x={0: cst, 1: cst}, y={1: cst}, z={0: cst, 1: cst}, w={1: cst}), + observer.infer_dynamic_shapes(add_batch_dimension_for={"x", "z"}), + ) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index f3d7d9ee..e9b492df 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -4,7 +4,7 @@ import torch -def flatten_unflatten_for_dynamic_shapes( +def _flatten_unflatten_for_dynamic_shapes( obj: Any, use_dict: bool = True, change_function: Callable[[torch.Tensor], Any] | None = None, @@ -38,7 +38,7 @@ def flatten_unflatten_for_dynamic_shapes( for subspec in (spec.children() if hasattr(spec, "children") else spec.children_specs): end += subspec.num_leaves value = subspec.unflatten(flat[start:end]) - value = flatten_unflatten_for_dynamic_shapes( + value = _flatten_unflatten_for_dynamic_shapes( value, use_dict=use_dict, change_function=change_function ) subtrees.append(value) @@ -66,7 +66,9 @@ def flatten_unflatten_for_dynamic_shapes( return subtrees -def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int]: +def _infer_dynamic_dimensions( + shape_list: Sequence[tuple[int, ...]], add_batch_dimension: bool = False +) -> list[int]: """ Returns the list of dynamic dimensions given a list of shapes corresponding to the same tensor. @@ -74,6 +76,8 @@ def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int] Args: shape_list: list of shapes, they must all have the same length + add_batch_dimension: + make the first dimension dynamic if it is not Returns: list of dynamic dimensions @@ -86,28 +90,44 @@ def infer_dynamic_dimensions(shape_list: Sequence[tuple[int, ...]]) -> list[int] dynamic = [] for i in range(rank): dims = [shape[i] for shape in shape_list] - if len(set(dims)) > 1: + if len(set(dims)) > 1 or (i == 0 and add_batch_dimension): dynamic.append(i) return dynamic class InputObserverInfo: - def __init__(self, signature: inspect.Signature): + """Contains all the necessary information to infer dynamic shapes + and the arguments to send to :func:`torch.export.export`. + + Args: + signature_names: Names of the arguments of the method + the collector tensors come from. They are used if it becomes + necessary to move positional arguments to named ones. + """ + + def __init__(self, signature_names: list[str]): # pyrefly: ignore self.inputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] self.flat_inputs: list[list[torch.Tensor | None]] = [] # pyrefly: ignore self.outputs_specs: list[torch.utils._pytree.PyTreeSpec] = [] - self.flat_outputs: list[torch.Tensor | list[torch.Tensor]] = [] - self.signature = signature + self.flat_outputs: list[list[torch.Tensor]] = [] + self.signature_names = signature_names self._max_args: tuple[Any, torch.Tensor] | None = None self._max_kwargs: dict[str, torch.Tensor] | None = None def __len__(self) -> int: + """Returns the number of collected set of inputs/outputs.""" return len(self.flat_inputs) def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): + """Stores one set of inputs. They are deepcopied. + + Args: + args: Positional arguments. + kwargs: Named arguments. + """ kwargs = { k: v for k, v in kwargs.items() @@ -128,18 +148,29 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): self._max_kwargs = cloned_kwargs def add_outputs(self, res: torch.Tensor | tuple[torch.Tensor, ...]): + """Stores outputs. They are deepcopied.""" flat_res, spec = torch.utils._pytree.tree_flatten(res) self.outputs_specs.append(spec) self.flat_outputs.append([t.clone().detach() for t in flat_res]) - def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]: + def _build_inputs_completed_with_none_values( + self, + ) -> tuple[list[int | str], list[list[torch.Tensor]]]: # Let's compute the sizes of each independently. if not self.flat_inputs or self._max_args is None or self._max_kwargs is None: raise RuntimeError("No inputs were captured.") - arg_sizes = [len(torch.utils._pytree.tree_flatten(a)[0]) for a in self._max_args] - kwarg_sizes = { - k: len(torch.utils._pytree.tree_flatten(v)[0]) for k, v in self._max_kwargs.items() - } + + flat_index_to_args = [] + arg_sizes = [] + for index_args, a in enumerate(self._max_args): + size = len(torch.utils._pytree.tree_flatten(a)[0]) + arg_sizes.append(size) + flat_index_to_args.extend([index_args] * size) + kwarg_sizes = {} + for k, v in self._max_kwargs.items(): + size = len(torch.utils._pytree.tree_flatten(v)[0]) + kwarg_sizes[k] = size + flat_index_to_args.extend([k] * size) # Let's reprocess everything. captured_inputs: dict[int | str, int] = {} @@ -179,10 +210,36 @@ def build_inputs_completed_with_none_values(self) -> list[list[torch.Tensor]]: else: flat.extend([None for _ in range(kwarg_sizes[k])]) new_flat_inputs.append(flat) - return new_flat_inputs + return flat_index_to_args, new_flat_inputs + + def infer_dynamic_shapes( + self, add_batch_dimension_for: set[int | str] | None = None + ) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: + """ + Infers dynamic shapes. Most of the time, models do support a batch dimension + but this batch dimension has the same value for every input sample. + Instead of running inference on new samples, argument `add_batch_dimension_for` + can be used to tell the first dimension is a dynamic dimension for a particular + set of inputs referenced by their name (str) or their position (int). + """ + + def _add_batch_dimension(name_or_position): + if not add_batch_dimension_for: + return False + if name_or_position in add_batch_dimension_for: + return True + if ( + isinstance(name_or_position, int) + and self.signature_names[name_or_position] in add_batch_dimension_for + ): + return True + return False + + flat_index_to_args, flat_inputs = self._build_inputs_completed_with_none_values() + + def _add_batch_dimension_for_flat_index(index): + return _add_batch_dimension(flat_index_to_args[index]) - def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: - flat_inputs = self.build_inputs_completed_with_none_values() # This is already checked by build_inputs_completed_with_none_values # but this is not always well captured by tools checking types. assert self._max_args is not None and self._max_kwargs is not None @@ -196,8 +253,9 @@ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[in ] n_tensors = len(shape_lists[0]) dynamic_shapes = [ - infer_dynamic_dimensions( - [s for s in [shapes[index] for shapes in shape_lists] if s is not None] + _infer_dynamic_dimensions( + [s for s in [shapes[index] for shapes in shape_lists] if s is not None], + add_batch_dimension=_add_batch_dimension_for_flat_index(index), ) for index in range(n_tensors) ] @@ -213,7 +271,7 @@ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[in return dict(zip(list(self._max_kwargs), flat_dynamic_shapes)) # positional arguments needs to be moved to the named arguments n_args = len(self._max_args) - pos_names = list(self.signature.parameters)[:n_args] + pos_names = self.signature_names[:n_args] return { **dict(zip(pos_names, flat_dynamic_shapes[:n_args])), **dict(zip(list(self._max_kwargs), flat_dynamic_shapes[n_args:])), @@ -235,20 +293,21 @@ def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[in ), ) mapping = {id(t): shape for t, shape in zip(flat_inputs, flat_dynamic_shapes)} - ds_args, ds_kwargs = flatten_unflatten_for_dynamic_shapes( + ds_args, ds_kwargs = _flatten_unflatten_for_dynamic_shapes( (self._max_args, self._max_kwargs), change_function=lambda t: mapping[id(t)] ) if not ds_kwargs: return tuple(ds_args) if not ds_args: return tuple(ds_kwargs) - pos_names = list(self.signature.parameters)[: len(ds_args)] + pos_names = self.signature_names[: len(ds_args)] return {**dict(zip(pos_names, ds_args)), **ds_kwargs} def infer_arguments( self, index: int | None = None ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]: - # This is already checked by build_inputs_completed_with_none_values + """Infers arguments based on the collected tensors.""" + # This is already checked by _build_inputs_completed_with_none_values # but this is not always well captured by tools checking types. assert self._max_args is not None and self._max_kwargs is not None candidate = None @@ -269,61 +328,99 @@ def infer_arguments( if not args: return kwargs # We need to move args to kwargs - pos_names = list(self.signature.parameters)[: len(args)] + pos_names = self.signature_names[: len(args)] return {**dict(zip(pos_names, args)), **kwargs} raise NotImplementedError( "We could not find a good set of inputs/outputs. " - "We need to replace none by empty tensors." + "We need to replace none by empty tensors. " + "This will be soon implemented." ) class InputObserver: - def __init__(self, store_n_calls: int = 3): - self.store_n_calls = store_n_calls + def __init__(self): self.info: InputObserverInfo | None = None - def _forward_captured(self, *args, _captured_forward=None, **kwargs): - assert _captured_forward is not None, "_captured_forward cannot be None" + def _replaced_method( + self, + *args, + _captured_method: Callable | None = None, + _store_n_calls: int = 3, + **kwargs, + ): + assert _captured_method is not None, "_captured_forward cannot be None" assert self.info is not None, "info cannot be None" n_stored = len(self.info) - if n_stored < self.store_n_calls: + if n_stored < _store_n_calls: self.info.add_inputs(args, kwargs) - res = _captured_forward(*args, **kwargs) - if n_stored < self.store_n_calls: + res = _captured_method(*args, **kwargs) + if n_stored < _store_n_calls: self.info.add_outputs(res) return res @contextlib.contextmanager - def __call__(self, model: torch.nn.Module): + def __call__( + self, model: torch.nn.Module, store_n_calls: int = 3, method_name: str = "forward" + ): + """Starts collecting inputs and outputs of a specific method. + The model method is replaced by a new one collecting tensors + before and after the inner one is called. + The original method is restored after the collection. + + Args: + model: Model + store_n_calls: The collection stops after this many calls + to avoid taking too much memory. + method_name: Method name to spy on. + """ if self.info is not None: raise RuntimeError( "This class was already used to capture a model. Please create a new one." ) - self.info = InputObserverInfo(signature=inspect.signature(model.forward)) - forward_method = model.forward - model.forward = ( - lambda *args, _captured_forward=forward_method, **kwargs: self._forward_captured( - *args, _captured_forward=_captured_forward, **kwargs - ) + if not hasattr(model, method_name): + raise ValueError(f"Model type {model} does not have a method {method_name!r}") + captured_method = getattr(model, method_name) + self.info = InputObserverInfo( + signature_names=list(inspect.signature(captured_method).parameters) + ) + setattr( + model, + method_name, + lambda *args, _cm=captured_method, _snc=store_n_calls, **kwargs: self._replaced_method( # noqa: E501 + *args, + _captured_method=_cm, + _store_n_calls=_snc, + **kwargs, + ), ) try: yield self finally: - model.forward = forward_method + setattr(model, method_name, captured_method) def _check_captured(self): if self.info is None: raise RuntimeError("No inputs were captured.") - def infer_dynamic_shapes(self) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: + def infer_dynamic_shapes( + self, add_batch_dimension_for: set[int | str] | None = None + ) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: + """ + Infers dynamic shapes. Most of the time, models do support a batch dimension + but this batch dimension has the same value for every input sample. + Instead of running inference on new samples, argument `add_batch_dimension_for` + can be used to tell the first dimension is a dynamic dimension for a particular + set of inputs referenced by their name (str) or their position (int). + """ self._check_captured() assert self.info is not None # missed by type checking - return self.info.infer_dynamic_shapes() + return self.info.infer_dynamic_shapes(add_batch_dimension_for=add_batch_dimension_for) def infer_arguments( self, index: int | None = None ) -> tuple[torch.Tensor, ...] | dict[str, torch.Tensor]: + """Infers arguments based on the collected tensors.""" self._check_captured() assert self.info is not None # missed by type checking return self.info.infer_arguments(index=index) From 6f9d0e03a329e2fa25b36e242a9e7e2c1655e28c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 27 Jan 2026 18:31:02 +0100 Subject: [PATCH 2/4] fix --- onnx_diagnostic/investigate/input_observer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index e9b492df..be5bf7bb 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -160,7 +160,7 @@ def _build_inputs_completed_with_none_values( if not self.flat_inputs or self._max_args is None or self._max_kwargs is None: raise RuntimeError("No inputs were captured.") - flat_index_to_args = [] + flat_index_to_args: list[int | str] = [] arg_sizes = [] for index_args, a in enumerate(self._max_args): size = len(torch.utils._pytree.tree_flatten(a)[0]) From d71c5f00dc59c74cccc6eb12c1864351119432f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Tue, 27 Jan 2026 18:37:45 +0100 Subject: [PATCH 3/4] more --- onnx_diagnostic/investigate/input_observer.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index be5bf7bb..7fe57cc5 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -69,8 +69,7 @@ def _flatten_unflatten_for_dynamic_shapes( def _infer_dynamic_dimensions( shape_list: Sequence[tuple[int, ...]], add_batch_dimension: bool = False ) -> list[int]: - """ - Returns the list of dynamic dimensions given a list of shapes + """Returns the list of dynamic dimensions given a list of shapes corresponding to the same tensor. Args: @@ -215,8 +214,8 @@ def _build_inputs_completed_with_none_values( def infer_dynamic_shapes( self, add_batch_dimension_for: set[int | str] | None = None ) -> tuple[dict[int, Any], ...] | dict[str, dict[int, Any]]: - """ - Infers dynamic shapes. Most of the time, models do support a batch dimension + """Infers dynamic shapes. based on the collected tensors. + Most of the time, models do support a batch dimension but this batch dimension has the same value for every input sample. Instead of running inference on new samples, argument `add_batch_dimension_for` can be used to tell the first dimension is a dynamic dimension for a particular From 012bd4fdec65ea2caa4c88d76c257ad1a0b55241 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Wed, 28 Jan 2026 11:21:49 +0100 Subject: [PATCH 4/4] disable two tests' --- _unittests/ut_ci_models/test_ci_export.py | 1 + _unittests/ut_xrun_doc/test_documentation_examples.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/_unittests/ut_ci_models/test_ci_export.py b/_unittests/ut_ci_models/test_ci_export.py index d25e4962..0491adac 100644 --- a/_unittests/ut_ci_models/test_ci_export.py +++ b/_unittests/ut_ci_models/test_ci_export.py @@ -10,6 +10,7 @@ class TestCiExport(ExtTestCase): @hide_stdout() + @requires_transformers("4.55") def test_main_qwen25_tiny_llm(self): main_qwen25( model_id="arnir0/Tiny-LLM", diff --git a/_unittests/ut_xrun_doc/test_documentation_examples.py b/_unittests/ut_xrun_doc/test_documentation_examples.py index 39b8b2ff..1859d71d 100644 --- a/_unittests/ut_xrun_doc/test_documentation_examples.py +++ b/_unittests/ut_xrun_doc/test_documentation_examples.py @@ -87,10 +87,10 @@ def add_test_methods(cls): if ( not reason - and name in {"plot_export_tiny_llm.py"} - and not has_transformers("4.51") + and name in {"plot_export_tiny_llm.py", "plot_export_tiny_llm_patched.py"} + and not has_transformers("4.55") ): - reason = "transformers<4.51" + reason = "transformers<4.55" if ( not reason