From 723704ea5f7f76a46c656dacc751025eb60fe417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 22 Jan 2026 11:17:57 +0100 Subject: [PATCH 1/5] rename api.py into typing.py --- _unittests/ut_xrun_doc/test_unit_test.py | 2 +- onnx_diagnostic/api.py | 15 -- onnx_diagnostic/helpers/onnx_helper.py | 176 +++++++----------- .../reference/torch_ops/_op_run.py | 2 +- onnx_diagnostic/torch_onnx/runtime_info.py | 2 +- onnx_diagnostic/typing.py | 15 ++ pyproject.toml | 2 + 7 files changed, 90 insertions(+), 124 deletions(-) delete mode 100644 onnx_diagnostic/api.py create mode 100644 onnx_diagnostic/typing.py diff --git a/_unittests/ut_xrun_doc/test_unit_test.py b/_unittests/ut_xrun_doc/test_unit_test.py index 61faeaf8..e225c775 100644 --- a/_unittests/ut_xrun_doc/test_unit_test.py +++ b/_unittests/ut_xrun_doc/test_unit_test.py @@ -16,7 +16,7 @@ has_cuda, has_onnxscript, ) -from onnx_diagnostic.api import TensorLike +from onnx_diagnostic.typing import TensorLike class TestUnitTest(ExtTestCase): diff --git a/onnx_diagnostic/api.py b/onnx_diagnostic/api.py deleted file mode 100644 index 1bbfb083..00000000 --- a/onnx_diagnostic/api.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Any - - -class TensorLike: - """Mocks a tensor.""" - - @property - def dtype(self) -> Any: - "Must be overwritten." - raise NotImplementedError("dtype must be overwritten.") - - @property - def shape(self) -> Any: - "Must be overwritten." - raise NotImplementedError("shape must be overwritten.") diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 6e7eede9..4fce4261 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -32,11 +32,10 @@ ValueInfoProto, load as onnx_load, ) +from ..typing import InferenceSessionLike, TensorLike -TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821 - -def _make_stat(init: TensorProto) -> Dict[str, float]: +def _make_stat(init: TensorProto) -> Dict[str, Any]: """ Produces statistics. @@ -160,11 +159,11 @@ def _validate_graph( verbose: int = 0, watch: Optional[Set[str]] = None, path: Optional[Sequence[str]] = None, -): - found = [] +) -> List[Union[NodeProto, TensorProto, ValueInfoProto]]: + found: List[Union[NodeProto, TensorProto, ValueInfoProto]] = [] path = path or ["root"] - set_init = set(i.name for i in g.initializer) - set_input = set(i.name for i in g.input) + set_init = {i.name for i in g.initializer} + set_input = {i.name for i in g.input} existing |= set_init | set_input if watch and set_init & watch: if verbose: @@ -215,18 +214,15 @@ def _validate_graph( f"in {'/'.join(path)}/{node.op_type}[{node.name}]" ) found.append(node) - out = set(o.name for o in g.output) + out = {o.name for o in g.output} ins = out & existing - if ins != out: - raise AssertionError( - f"One output is missing, out={node.input}, existing={ins}, path={path}" - ) + assert ins == out, f"One output is missing, out={node.input}, existing={ins}, path={path}" return found def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[str]] = None): - existing = set(g.input) - found = [] + existing: Set[str] = set(g.input) + found: List[Union[NodeProto, TensorProto, ValueInfoProto]] = [] for node in g.node: ins = set(node.input) & existing if ins != set(node.input): @@ -240,7 +236,7 @@ def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[s for att in node.attribute: if att.type == AttributeProto.GRAPH: found.extend( - _validate_graph(g, existing.copy(), path=[g.name], verbose=verbose) + _validate_graph(att.g, existing.copy(), path=[g.name], verbose=verbose) ) existing |= set(node.output) if watch and set(node.output) & watch: @@ -285,7 +281,7 @@ def check_model_ort( onx: ModelProto, providers: Optional[Union[str, List[Any]]] = None, dump_file: Optional[str] = None, -) -> "onnxruntime.InferenceSession": # noqa: F821 +) -> InferenceSessionLike: """ Loads a model with onnxruntime. @@ -308,10 +304,9 @@ def check_model_ort( if isinstance(onx, str): try: + # pyrefly: ignore[bad-return] return InferenceSession(onx, providers=providers) except Exception as e: - import onnx - if dump_file: onnx.save(onx, dump_file) @@ -319,8 +314,8 @@ def check_model_ort( f"onnxruntime cannot load the model " f"due to {e}\n{pretty_onnx(onnx.load(onx))}" ) - return try: + # pyrefly: ignore[bad-return] return InferenceSession(onx.SerializeToString(), providers=providers) except Exception as e: if dump_file: @@ -358,7 +353,17 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str: def pretty_onnx( - onx: Union[FunctionProto, GraphProto, ModelProto, ValueInfoProto, str], + onx: Union[ + AttributeProto, + FunctionProto, + GraphProto, + ModelProto, + NodeProto, + onnx.SparseTensorProto, + TensorProto, + ValueInfoProto, + str, + ], with_attributes: bool = False, highlight: Optional[Set[str]] = None, shape_inference: bool = False, @@ -377,6 +382,9 @@ def pretty_onnx( assert onx is not None, "onx cannot be None" if shape_inference: + assert isinstance( + onx, ModelProto + ), f"shape inference only works for ModelProto, not {type(onx)}" onx = onnx.shape_inference.infer_shapes(onx) if isinstance(onx, ValueInfoProto): @@ -447,6 +455,8 @@ def _high(n): shape = "x".join(map(str, onx.dims)) return f"TensorProto:{onx.data_type}:{shape}:{onx.name}" + assert not isinstance(onx, onnx.SparseTensorProto), "SparseTensorProto is not handled yet." + try: from onnx_array_api.plotting.text_plot import onnx_simple_text_plot @@ -538,12 +548,6 @@ def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorP return tensor -_STORAGE_TYPE = { - TensorProto.FLOAT16: np.int16, - TensorProto.BFLOAT16: np.int16, -} - - def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto: """ Converts an array into a :class:`onnx.TensorProto`. @@ -561,54 +565,9 @@ def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> Tenso ), f"Unable to convert type {type(tensor)} into TensorProto." return proto_from_tensor(tensor, name=name) - try: - from onnx.reference.ops.op_cast import ( - bfloat16, - float8e4m3fn, - float8e4m3fnuz, - float8e5m2, - float8e5m2fnuz, - ) - except ImportError: - bfloat16 = None - - if bfloat16 is None: - return onh.from_array(tensor, name) - - dt = tensor.dtype - if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn": - to = TensorProto.FLOAT8E4M3FN - dt_to = np.uint8 - elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz": - to = TensorProto.FLOAT8E4M3FNUZ - dt_to = np.uint8 - elif dt == float8e5m2 and dt.descr[0][0] == "e5m2": - to = TensorProto.FLOAT8E5M2 - dt_to = np.uint8 - elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz": - to = TensorProto.FLOAT8E5M2FNUZ - dt_to = np.uint8 - elif dt == bfloat16 and dt.descr[0][0] == "bfloat16": - to = TensorProto.BFLOAT16 - dt_to = np.uint16 - else: - try: - import ml_dtypes - except ImportError: - ml_dtypes = None - if ml_dtypes is not None and ( - tensor.dtype == ml_dtypes.bfloat16 - or tensor.dtype == ml_dtypes.float8_e4m3fn - or tensor.dtype == ml_dtypes.float8_e4m3fnuz - or tensor.dtype == ml_dtypes.float8_e5m2 - or tensor.dtype == ml_dtypes.float8_e5m2fnuz - ): - return from_array_ml_dtypes(tensor, name) - return onh.from_array(tensor, name) - - t = onh.from_array(tensor.astype(dt_to), name) - t.data_type = to - return t + assert isinstance(tensor, np.ndarray) # type checking + # pyrefly: ignore[bad-argument-type] + return onh.from_array(tensor, name) def to_array_extended(proto: TensorProto) -> TensorLike: @@ -666,6 +625,7 @@ def onnx_dtype_to_np_dtype(itype: int) -> Any: ) +# pyrefly: ignore[unknown-name] def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821 """ Converts a torch dtype or numpy dtype into a onnx element type. @@ -679,6 +639,7 @@ def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F pass from .torch_helper import torch_dtype_to_onnx_dtype + # pyrefly: ignore[bad-argument-type] return torch_dtype_to_onnx_dtype(dt) @@ -779,6 +740,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype: f"ml_dtypes can be used." ) from e + # pyrefly: ignore[bad-assignment] mapping: Dict[int, np.dtype] = { TensorProto.BFLOAT16: ml_dtypes.bfloat16, TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn, @@ -798,7 +760,7 @@ def iterator_initializer_constant( model: Union[FunctionProto, GraphProto, ModelProto], use_numpy: bool = True, prefix: str = "", -) -> Iterator[Tuple[str, Union["torch.Tensor", np.ndarray]]]: # noqa: F821 +) -> Iterator[Tuple[str, TensorLike]]: # noqa: F821 """ Iterates on iniatialiers and constant in an onnx model. @@ -814,9 +776,11 @@ def iterator_initializer_constant( if prefix: prefix += "." for init in graph.initializer: - yield f"{prefix}{init.name}", ( - to_array_extended(init) if use_numpy else to_tensor(init) - ) + s = f"{prefix}{init.name}" + if use_numpy: + yield s, to_array_extended(init) + # pyrefly: ignore[unbound-name] + yield s, to_tensor(init) nodes = graph.node name = graph.name if isinstance(model, ModelProto): @@ -831,13 +795,14 @@ def iterator_initializer_constant( if node.op_type == "Constant" and node.domain == "": from ..reference import ExtendedReferenceEvaluator as Inference - if not use_numpy: - import torch sess = Inference(node) value = sess.run(None, {})[0] - yield f"{prefix}{node.output[0]}", ( - value if use_numpy else torch.from_numpy(value) - ) + + if not use_numpy: + import torch + + yield f"{prefix}{node.output[0]}", (torch.from_numpy(value)) + yield f"{prefix}{node.output[0]}", (value) if node.op_type in {"Loop", "Body", "Scan"}: for att in node.attribute: @@ -870,7 +835,9 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union from .helper import size_type if isinstance(tensor, TensorProto): + # pyrefly: ignore[bad-assignment] tensor = to_array_extended(tensor) + assert isinstance(tensor, np.ndarray) # type checking itype = np_dtype_to_tensor_dtype(tensor.dtype) stat = dict( mean=float(tensor.mean()), @@ -948,7 +915,7 @@ class NodeCoordinates: def __init__( self, - node: Union[onnx.TensorProto, NodeProto, str], + node: Union[TensorProto, NodeProto, onnx.SparseTensorProto, ValueInfoProto, str], path: Tuple[Tuple[int, str, str], ...], ): assert isinstance(path, tuple), f"Unexpected type {type(path)} for path" @@ -968,9 +935,7 @@ def path_to_str(self) -> str: class ResultFound: - """ - Class returned by :func:`enumerate_results`. - """ + """Class returned by :func:`enumerate_results`.""" __slots__ = ("consumer", "name", "producer") @@ -1060,9 +1025,9 @@ def enumerate_results( print(f"[enumerate_results] {indent}-- {r}") yield r for i in proto.sparse_initializer: - if i.name in name: + if i.values.name in name: r = ResultFound( - i.name, + i.values.name, NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409 None, ) @@ -1165,9 +1130,9 @@ def shadowing_names( return shadowing_names( proto.node, verbose=verbose, - existing=set(i.name for i in proto.initializer) - | set(i.name for i in proto.sparse_initializer) - | set(i.name for i in proto.input if i.name), + existing={i.name for i in proto.initializer} + | {i.values.name for i in proto.sparse_initializer} + | {i.name for i in proto.input if i.name}, shadow_context=set(), post_shadow_context=set(), ) @@ -1201,9 +1166,9 @@ def shadowing_names( for att in node.attribute: if att.type == AttributeProto.GRAPH: g = att.g - shadow |= set(i.name for i in g.input) & shadow_context - shadow |= set(i.name for i in g.initializer) & shadow_context - shadow |= set(i.name for i in g.sparse_initializer) & shadow_context + shadow |= {i.name for i in g.input} & shadow_context + shadow |= {i.name for i in g.initializer} & shadow_context + shadow |= {i.values.name for i in g.sparse_initializer} & shadow_context s, _ps, c = shadowing_names( g.node, verbose=verbose, existing=existing, shadow_context=existing ) @@ -1225,9 +1190,9 @@ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]: """ hidden = set() memo = ( - set(i.name for i in graph.initializer) - | set(i.name for i in graph.sparse_initializer) - | set(i.name for i in graph.input) + {i.name for i in graph.initializer} + | {i.values.name for i in graph.sparse_initializer} + | {i.name for i in graph.input} ) for node in graph.node: for i in node.input: @@ -1392,9 +1357,7 @@ def _mkv_(name, itype, irank): def get_tensor_shape( obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto], ) -> Optional[List[Optional[Union[int, str]]]]: - """ - Returns the shape if that makes sense for this object. - """ + """Returns the shape if that makes sense for this object.""" if isinstance(obj, ValueInfoProto): return get_tensor_shape(obj.type) elif not isinstance(obj, onnx.TypeProto): @@ -1512,9 +1475,6 @@ def onnx_remove_node_unused( if not ({o for o in node.output if o} & marked_set): removed.add(ind) - if not is_function: - initializers = [i for i in graph.initializer if i.name in marked] - sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked] new_nodes = [node for i, node in enumerate(nodes) if i not in removed] # Finally create the new graph. @@ -1529,13 +1489,16 @@ def onnx_remove_node_unused( attributes=graph.attribute, doc_string=graph.doc_string, ) + + initializers = [i for i in graph.initializer if i.name in marked] + sparse_initializers = [i for i in graph.sparse_initializer if i.values.name in marked] new_graph = oh.make_graph( new_nodes, graph.name, graph.input, graph.output, initializers, - sparse_initializers, + sparse_initializer=sparse_initializers, ) new_graph.value_info.extend(graph.value_info) return new_graph @@ -1549,7 +1512,7 @@ def select_model_inputs_outputs( overwrite: Optional[Dict[str, Any]] = None, remove_unused: bool = True, verbose: int = 0, -): +) -> ModelProto: """ Takes a model and changes its outputs. @@ -1709,6 +1672,7 @@ def select_model_inputs_outputs( ) if remove_unused: graph = onnx_remove_node_unused(graph, recursive=False) + assert isinstance(graph, GraphProto) # type checking onnx_model = oh.make_model(graph, functions=model.functions) onnx_model.ir_version = model.ir_version onnx_model.producer_name = model.producer_name diff --git a/onnx_diagnostic/reference/torch_ops/_op_run.py b/onnx_diagnostic/reference/torch_ops/_op_run.py index 829c1c0c..e614328d 100644 --- a/onnx_diagnostic/reference/torch_ops/_op_run.py +++ b/onnx_diagnostic/reference/torch_ops/_op_run.py @@ -1,7 +1,7 @@ from typing import Any, Dict, List, Optional, Union, Tuple import onnx import torch -from ...api import TensorLike +from ...typing import TensorLike from ...helpers import string_type from ...helpers.torch_helper import to_tensor diff --git a/onnx_diagnostic/torch_onnx/runtime_info.py b/onnx_diagnostic/torch_onnx/runtime_info.py index 5d3ed0a4..2aa30f65 100644 --- a/onnx_diagnostic/torch_onnx/runtime_info.py +++ b/onnx_diagnostic/torch_onnx/runtime_info.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union import onnx import torch -from ..api import TensorLike +from ..typing import TensorLike from ..helpers import string_type from ..helpers.onnx_helper import get_hidden_inputs diff --git a/onnx_diagnostic/typing.py b/onnx_diagnostic/typing.py new file mode 100644 index 00000000..78fb7ea6 --- /dev/null +++ b/onnx_diagnostic/typing.py @@ -0,0 +1,15 @@ +from typing import Any, Dict, List, Protocol, Tuple, runtime_checkable + + +@runtime_checkable +class TensorLike(Protocol): + @property + def shape(self) -> Tuple[int, ...]: ... + @property + def dtype(self) -> object: ... + + +@runtime_checkable +class InferenceSessionLike(Protocol): + def __init__(self, model: Any, **kwargs): ... + def run(self, feeds: Dict[str, TensorLike]) -> List[TensorLike]: ... diff --git a/pyproject.toml b/pyproject.toml index 2c9d1e5d..e23613ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -133,6 +133,7 @@ disable_error_code = ["arg-type", "assignment", "operator", "var-annotated", "un [tool.pyrefly] project-includes = [ + "onnx_diagnostic/typing.py", "onnx_diagnostic/export/validate.py", "onnx_diagnostic/investigate/**", "onnx_diagnostic/helpers/args_helper.py", @@ -143,6 +144,7 @@ project-includes = [ "onnx_diagnostic/helpers/log_helper.py", "onnx_diagnostic/helpers/memory_peak.py", "onnx_diagnostic/helpers/mini_onnx_builder.py", + "onnx_diagnostic/helpers/onnx_helper.py", "onnx_diagnostic/reference/evaluator.py", "onnx_diagnostic/reference/quantized_evaluator.py", "onnx_diagnostic/reference/report_results_comparison.py", From 372b96932afefbebe1534ae61258073e7e6235e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 22 Jan 2026 11:38:06 +0100 Subject: [PATCH 2/5] fix --- onnx_diagnostic/helpers/torch_helper.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/onnx_diagnostic/helpers/torch_helper.py b/onnx_diagnostic/helpers/torch_helper.py index 01b447e3..1a8f4d0f 100644 --- a/onnx_diagnostic/helpers/torch_helper.py +++ b/onnx_diagnostic/helpers/torch_helper.py @@ -19,12 +19,7 @@ CacheKeyValue, ) from .mini_onnx_builder import create_onnx_model_from_input_tensors -from .onnx_helper import ( - to_array_extended, - tensor_dtype_to_np_dtype, - _STORAGE_TYPE, - onnx_dtype_name, -) +from .onnx_helper import to_array_extended, tensor_dtype_to_np_dtype, onnx_dtype_name def proto_from_tensor( @@ -84,7 +79,11 @@ def proto_from_tensor( byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr()) tensor.raw_data = bytes(byte_data) if sys.byteorder == "big": - np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore + storage_type = { + onnx.TensorProto.FLOAT16: np.int16, + onnx.TensorProto.BFLOAT16: np.int16, + } + np_dtype = storage_type[tensor.data_type] # type: ignore np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True) # type: ignore else: tensor.raw_data = np_arr.tobytes() From 378fbf2a2926c3568d8f1753d38dae0759c8533e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 22 Jan 2026 11:57:57 +0100 Subject: [PATCH 3/5] fix --- .github/workflows/pyrefly.yaml | 2 +- _unittests/ut_helpers/test_onnx_helper.py | 40 +++++++++---------- onnx_diagnostic/helpers/onnx_helper.py | 8 ++-- .../reference/report_results_comparison.py | 3 +- .../reference/torch_ops/_op_run.py | 2 +- .../reference/torch_ops/sequence_ops.py | 2 +- 6 files changed, 28 insertions(+), 29 deletions(-) diff --git a/.github/workflows/pyrefly.yaml b/.github/workflows/pyrefly.yaml index a48566b0..06ab7e24 100644 --- a/.github/workflows/pyrefly.yaml +++ b/.github/workflows/pyrefly.yaml @@ -21,7 +21,7 @@ jobs: run: | pip install pyrefly pip install -r requirements.txt - pip install transformers pandas matplotlib openpyxl + pip install transformers pandas matplotlib openpyxl onnx-array-api - name: Run pyrefly run: pyrefly check diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index 1ba5c9fd..e9f0a1fb 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -180,9 +180,7 @@ def _get_cdist_implementation( opsets: Dict[str, int], **kwargs: Any, ) -> FunctionProto: - """ - Returns the CDist implementation as a function. - """ + """Returns the CDist implementation as a function.""" assert len(node_inputs) == 2 assert len(node_outputs) == 1 assert opsets @@ -191,12 +189,6 @@ def _get_cdist_implementation( metric = kwargs["metric"] assert metric in ("euclidean", "sqeuclidean") # subgraph - nodes = [ - oh.make_node("Sub", ["next", "next_in"], ["diff"]), - oh.make_node("Constant", [], ["axis"], value_ints=[1]), - oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0), - oh.make_node("Identity", ["next_in"], ["next_out"]), - ] def make_value(name): value = ValueInfoProto() @@ -204,26 +196,32 @@ def make_value(name): return value graph = oh.make_graph( - nodes, + [ + oh.make_node("Sub", ["next", "next_in"], ["diff"]), + oh.make_node("Constant", [], ["axis"], value_ints=[1]), + oh.make_node("ReduceSumSquare", ["diff", "axis"], ["scan_out"], keepdims=0), + oh.make_node("Identity", ["next_in"], ["next_out"]), + ], "loop", [make_value("next_in"), make_value("next")], [make_value("next_out"), make_value("scan_out")], ) - scan = oh.make_node( - "Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph - ) - final = ( - oh.make_node("Sqrt", ["zout"], ["z"]) - if metric == "euclidean" - else oh.make_node("Identity", ["zout"], ["z"]) - ) return oh.make_function( "npx", f"CDist_{metric}", ["xa", "xb"], ["z"], - [scan, final], + [ + oh.make_node( + "Scan", ["xb", "xa"], ["next_out", "zout"], num_scan_inputs=1, body=graph + ), + ( + oh.make_node("Sqrt", ["zout"], ["z"]) + if metric == "euclidean" + else oh.make_node("Identity", ["zout"], ["z"]) + ), + ], [oh.make_opsetid("", opsets[""])], ) @@ -234,9 +232,7 @@ def test_iterate_function(self): ) model = oh.make_model( oh.make_graph( - [ - oh.make_node(proto.name, ["X", "Y"], ["Z"]), - ], + [oh.make_node(proto.name, ["X", "Y"], ["Z"])], "dummy", [ oh.make_tensor_value_info("X", itype, [None, None]), diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 4fce4261..83920ddc 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -779,8 +779,9 @@ def iterator_initializer_constant( s = f"{prefix}{init.name}" if use_numpy: yield s, to_array_extended(init) - # pyrefly: ignore[unbound-name] - yield s, to_tensor(init) + else: + # pyrefly: ignore[unbound-name] + yield s, to_tensor(init) nodes = graph.node name = graph.name if isinstance(model, ModelProto): @@ -802,7 +803,8 @@ def iterator_initializer_constant( import torch yield f"{prefix}{node.output[0]}", (torch.from_numpy(value)) - yield f"{prefix}{node.output[0]}", (value) + else: + yield f"{prefix}{node.output[0]}", (value) if node.op_type in {"Loop", "Body", "Scan"}: for att in node.attribute: diff --git a/onnx_diagnostic/reference/report_results_comparison.py b/onnx_diagnostic/reference/report_results_comparison.py index 72fba178..c13a3f4b 100644 --- a/onnx_diagnostic/reference/report_results_comparison.py +++ b/onnx_diagnostic/reference/report_results_comparison.py @@ -25,6 +25,7 @@ def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F self.max_diff = max_diff self.tensors = tensors self._build_mapping() + self.unique_run_names: Set[str] = set() # pyrefly: ignore[unknown-name] def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821 @@ -44,7 +45,7 @@ def _build_mapping(self): def clear(self): """Clears the last report.""" self.report_cmp = {} - self.unique_run_names: Set[str] = set() + self.unique_run_names = set() @property def value( diff --git a/onnx_diagnostic/reference/torch_ops/_op_run.py b/onnx_diagnostic/reference/torch_ops/_op_run.py index e614328d..e5818b4c 100644 --- a/onnx_diagnostic/reference/torch_ops/_op_run.py +++ b/onnx_diagnostic/reference/torch_ops/_op_run.py @@ -149,7 +149,7 @@ def insert_at( ) -> "OpRunSequence": "Inserts a value at a given position." assert isinstance(tensor, OpRunTensor), f"Unexpected type {type(tensor)} for tensor" - new_seq = OpRunSequence() + new_seq = OpRunSequence() # type: ignore[abstract] seq = self.sequence.copy() new_seq.sequence = seq if position is None: diff --git a/onnx_diagnostic/reference/torch_ops/sequence_ops.py b/onnx_diagnostic/reference/torch_ops/sequence_ops.py index 08820728..6196c319 100644 --- a/onnx_diagnostic/reference/torch_ops/sequence_ops.py +++ b/onnx_diagnostic/reference/torch_ops/sequence_ops.py @@ -46,7 +46,7 @@ def __init__(self, node: onnx.NodeProto, version: Optional[int] = None, verbose: ) def run(self) -> OpRunSequence: - return OpRunSequence(dtype=self.dtype) + return OpRunSequence(dtype=self.dtype) # type: ignore[abstract] class SequenceInsert_11(OpRunOpSequence): From ff295c0e94546b07053e37c332bc233884551344 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 22 Jan 2026 12:11:16 +0100 Subject: [PATCH 4/5] fix --- _unittests/ut_xrun_doc/test_unit_test.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/_unittests/ut_xrun_doc/test_unit_test.py b/_unittests/ut_xrun_doc/test_unit_test.py index e225c775..ca2dfa08 100644 --- a/_unittests/ut_xrun_doc/test_unit_test.py +++ b/_unittests/ut_xrun_doc/test_unit_test.py @@ -16,7 +16,6 @@ has_cuda, has_onnxscript, ) -from onnx_diagnostic.typing import TensorLike class TestUnitTest(ExtTestCase): @@ -111,10 +110,6 @@ def test_measure_time_max(self): }, ) - def test_exc(self): - self.assertRaise(lambda: TensorLike().dtype, NotImplementedError) - self.assertRaise(lambda: TensorLike().shape, NotImplementedError) - if __name__ == "__main__": unittest.main(verbosity=2) From d72db4f051c703c4f580d88906ea91c78ac17537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xavier=20Dupr=C3=A9?= Date: Thu, 22 Jan 2026 12:12:18 +0100 Subject: [PATCH 5/5] fix --- onnx_diagnostic/investigate/input_observer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnx_diagnostic/investigate/input_observer.py b/onnx_diagnostic/investigate/input_observer.py index 8fbf4df8..f3d7d9ee 100644 --- a/onnx_diagnostic/investigate/input_observer.py +++ b/onnx_diagnostic/investigate/input_observer.py @@ -115,7 +115,10 @@ def add_inputs(self, args: tuple[Any, ...], kwargs: dict[str, Any]): } flat_args, spec = torch.utils._pytree.tree_flatten((args, kwargs)) self.inputs_specs.append(spec) - cloned = [(None if t is None else t.clone().detach()) for t in flat_args] + cloned = [ + (None if not isinstance(t, torch.Tensor) else t.clone().detach()) + for t in flat_args + ] self.flat_inputs.append(cloned) cloned_args, cloned_kwargs = torch.utils._pytree.tree_unflatten(cloned, spec)