diff --git a/onnxscript/version_converter/__init__.py b/onnxscript/version_converter/__init__.py index c29fb4c989..deea931d16 100644 --- a/onnxscript/version_converter/__init__.py +++ b/onnxscript/version_converter/__init__.py @@ -37,7 +37,9 @@ def __init__(self, target_version: int, fallback: bool = False) -> None: super().__init__() self.target_version = target_version self.fallback = fallback - self._convert_pass = _ConvertVersionPass( + # NOTE: The current version converter only supports inlined models. + self._inline_pass = common_passes.InlinePass() + self._convert_pass = _ConvertVersionPassRequiresInline( target_version=target_version, fallback=fallback, ) @@ -51,15 +53,16 @@ def call(self, model: ir.Model) -> ir.passes.PassResult: # Run the conversion pass outside of Sequential so that errors # (e.g. VersionConverterError) propagate directly without being # wrapped in PassError. - result = self._convert_pass(model) - cleanup_result = self._cleanup_passes(result) + inline_result = self._inline_pass(model) + result = self._convert_pass(inline_result.model) + cleanup_result = self._cleanup_passes(result.model) return ir.passes.PassResult( cleanup_result.model, - result.modified or cleanup_result.modified, + result.modified or cleanup_result.modified or inline_result.modified, ) -class _ConvertVersionPass(ir.passes.InPlacePass): +class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass): """Convert the model to the specified ONNX opset version. This pass leverages the onnxscript version converter to convert the model. If diff --git a/onnxscript/version_converter/_version_converter_test.py b/onnxscript/version_converter/_version_converter_test.py index c920746d7b..a37e8e262f 100644 --- a/onnxscript/version_converter/_version_converter_test.py +++ b/onnxscript/version_converter/_version_converter_test.py @@ -208,6 +208,7 @@ def test_version_convert_gridsample_cubic(self): self.assertEqual(model.graph.node(4).version, 20) self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic") + @pytest.mark.xfail(reason="Version converter does not currently support local-function.") def test_version_convert_function_nodes(self): """Test that version converter processes nodes inside model functions.""" model = ir.from_onnx_text( @@ -253,6 +254,7 @@ def test_version_convert_function_nodes(self): self.assertEqual(func[3].version, 20) self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input + @pytest.mark.xfail(reason="Version converter does not currently support local-function.") def test_version_convert_function_with_control_flow_subgraph(self): """Test that version converter processes subgraphs inside control flow nodes in functions.""" model = ir.from_onnx_text( @@ -455,6 +457,7 @@ def test_metadata_is_copied_to_multiple_replacement_nodes(self): f"Node {i} ({node.op_type}) should have metadata copied", ) + @pytest.mark.xfail(reason="Version converter does not currently support local-function.") def test_version_convert_raises_on_function_node_with_ref_attribute(self): """Test that version conversion raises when a function contains a node with a ref attribute.""" # Build a function with a LeakyRelu node that uses a RefAttr for 'alpha'