Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions onnxscript/version_converter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def __init__(self, target_version: int, fallback: bool = False) -> None:
super().__init__()
self.target_version = target_version
self.fallback = fallback
self._convert_pass = _ConvertVersionPass(
# NOTE: The current version converter only supports inlined models.
self._inline_pass = common_passes.InlinePass()
self._convert_pass = _ConvertVersionPassRequiresInline(
target_version=target_version,
fallback=fallback,
)
Expand All @@ -51,15 +53,16 @@ def call(self, model: ir.Model) -> ir.passes.PassResult:
# Run the conversion pass outside of Sequential so that errors
# (e.g. VersionConverterError) propagate directly without being
# wrapped in PassError.
result = self._convert_pass(model)
cleanup_result = self._cleanup_passes(result)
inline_result = self._inline_pass(model)
result = self._convert_pass(inline_result.model)
cleanup_result = self._cleanup_passes(result.model)
return ir.passes.PassResult(
cleanup_result.model,
result.modified or cleanup_result.modified,
result.modified or cleanup_result.modified or inline_result.modified,
)


class _ConvertVersionPass(ir.passes.InPlacePass):
class _ConvertVersionPassRequiresInline(ir.passes.InPlacePass):
"""Convert the model to the specified ONNX opset version.

This pass leverages the onnxscript version converter to convert the model. If
Expand Down
3 changes: 3 additions & 0 deletions onnxscript/version_converter/_version_converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ def test_version_convert_gridsample_cubic(self):
self.assertEqual(model.graph.node(4).version, 20)
self.assertEqual(model.graph.node(4).attributes["mode"].value, "cubic")

@pytest.mark.xfail(reason="Version converter does not currently support local-function.")
def test_version_convert_function_nodes(self):
"""Test that version converter processes nodes inside model functions."""
model = ir.from_onnx_text(
Expand Down Expand Up @@ -253,6 +254,7 @@ def test_version_convert_function_nodes(self):
self.assertEqual(func[3].version, 20)
self.assertEqual(len(func[3].inputs), 3) # DFT 19->20 adds dft_length input

@pytest.mark.xfail(reason="Version converter does not currently support local-function.")
def test_version_convert_function_with_control_flow_subgraph(self):
"""Test that version converter processes subgraphs inside control flow nodes in functions."""
model = ir.from_onnx_text(
Expand Down Expand Up @@ -455,6 +457,7 @@ def test_metadata_is_copied_to_multiple_replacement_nodes(self):
f"Node {i} ({node.op_type}) should have metadata copied",
)

@pytest.mark.xfail(reason="Version converter does not currently support local-function.")
def test_version_convert_raises_on_function_node_with_ref_attribute(self):
"""Test that version conversion raises when a function contains a node with a ref attribute."""
# Build a function with a LeakyRelu node that uses a RefAttr for 'alpha'
Expand Down
Loading