Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,33 +17,31 @@ jobs:
matrix:
os: [ubuntu-latest]
python: ['3.10', '3.11', '3.12', '3.13']
transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57.6', 'main']
transformers: ['4.48.3', '4.51.3', '4.55.4', '4.56.2', '4.57.6', 'main']
torch: ['2.9', 'main']
exclude:
- python: '3.10' # 3.10
torch: 'main'
- python: '3.10'
torch: '2.9'
- python: '3.10'
transformers: 'main'
- python: '3.10'
transformers: '4.52.4'
- python: '3.10'
transformers: '4.55.4'
- python: '3.10'
transformers: '4.56.2'
- python: '3.10'
transformers: '4.57.6'
- python: '3.10'
transformers: 'main'
- python: '3.11' # 3.11
torch: 'main'
- python: '3.11'
transformers: 'main'
- python: '3.11'
transformers: '4.55.4'
- python: '3.11'
transformers: '4.56.2'
- python: '3.11'
transformers: '4.57.6'
- python: '3.11'
transformers: 'main'
- python: '3.13' # 3.11
torch: '2.9'
- python: '3.13'
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Change Logs
0.8.11
++++++

* :pr:`394`: add function make_model_with_local_functions to partition a model into local functions

0.8.10
++++++

Expand Down
7 changes: 0 additions & 7 deletions _doc/api/api.rst

This file was deleted.

2 changes: 1 addition & 1 deletion _doc/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ API of onnx_diagnostic
:maxdepth: 1
:caption: modules

api
typing
ext_test_case

.. automodule:: onnx_diagnostic
Expand Down
6 changes: 6 additions & 0 deletions _doc/api/typing.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

onnx_diagnostic.typing
======================

.. automodule:: onnx_diagnostic.typing
:members:
Binary file added _doc/cmds/_img_partition.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions _doc/cmds/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ Command Lines
compare
config
optimize
partition
sbs
validate
47 changes: 47 additions & 0 deletions _doc/cmds/partition.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
-m onnx_diagnostic partition ... move layer nodes in local functions
====================================================================

The command line leverages the metadata added by the exporter.
Every node is tagged with information indicating which part of the model
it comes from. In particular the eky `namespace`:

::

transformers.models.llama.modeling_llama.LlamaForCausalLM/model:
transformers.models.llama.modeling_llama.LlamaModel/model.layers.0:
transformers.models.llama.modeling_llama.LlamaDecoderLayer/model.layers.0.self_attn:
transformers.models.llama.modeling_llama.LlamaAttention/unsqueeze_15:
aten.unsqueeze.default

Description
+++++++++++

See :func:`onnx_diagnostic.helpers.onnx_helper.make_model_with_local_functions`.

.. runpython::

from onnx_diagnostic._command_lines_parser import get_parser_partition

get_parser_partition().print_help()

Example
+++++++

.. code-block:: bash

python -m onnx_diagnostic partition arnir0_Tiny-LLM-onnx-dynamo-ir-f16-cuda-op18.onnx partition.onnx -r ".*[.]layers[.][0-9]+$" -v 1

This produces the following output:

::

-- load 'arnir0_Tiny-LLM-onnx-dynamo-ir-f16-cuda-op18.onnx'
-- partition
[make_model_with_local_functions] matched 1 partitions
[make_model_with_local_functions] move 89 nodes in partition 'transformers_models_llama_modeling_llama_LlamaModel/model_layers_0'
-- save into 'partition.onnx'
-- done

The partitioned model includes the following node:

.. image:: _img_partition.png
190 changes: 190 additions & 0 deletions _unittests/ut_helpers/test_onnx_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
shadowing_names,
onnx_dtype_name,
extract_subset_of_nodes,
make_subfunction,
make_submodel,
make_model_with_local_functions,
select_model_inputs_outputs,
_enumerate_model_node_outputs,
pretty_onnx,
)

TFLOAT = TensorProto.FLOAT
Expand Down Expand Up @@ -537,6 +540,46 @@ def _type_rank_fn(name):
check_model(new_model)
self.check_ort(new_model)

def test_make_subfunction(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
new_function = make_subfunction(
"localf",
model.graph.node[:4],
opset_imports=model.opset_import,
output_names=["xm1", "xm2c"],
)
self.assertIsInstance(new_function, FunctionProto)
self.assertEqual(len(new_function.node), 4)
self.assertEqual(new_function.output, ["xm1", "xm2c"])
self.assertEqual(new_function.input, ["X", "Y", "shape1", "shape2", "un", "zero"])

def test_extract_subset_of_nodes_bigger(self):
model = onnx.load(
os.path.join(
Expand Down Expand Up @@ -670,6 +713,153 @@ def enumerate_model_tensors(model):
got = sess.run(None, {"X": x})[0]
self.assertEqual((x**2 + y).tolist(), got.tolist())

def test_make_model_with_local_functions(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
for i_node in [0, 1, 2, 3]:
node = model.graph.node[i_node]
meta = node.metadata_props.add()
meta.key = "namespace"
meta.value = "LLL"
new_model = make_model_with_local_functions(model, "^LLL$")
check_model(model)
self.assertEqual(len(new_model.functions), 1)
self.assertEqual(
["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input
)
self.assertEqual(["xm1", "xm2c"], new_model.functions[0].output)
self.assertEqual("LLL", new_model.functions[0].name)
self.assertEqual("local_function", new_model.functions[0].domain)
self.assertIn("LLL[local_function]", pretty_onnx(new_model))
check_model(new_model)

def test_make_model_with_local_functions_bug(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
for i_node in [0, 2, 3, 4]:
node = model.graph.node[i_node]
meta = node.metadata_props.add()
meta.key = "namespace"
meta.value = "LLL"
self.assertRaise(
lambda: make_model_with_local_functions(model, "^LLL$"),
ValueError,
"Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', "
"'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] "
"which is not allowed.",
)
check_model(model)

@hide_stdout()
def test_make_model_with_local_functions_2(self):
model = oh.make_model(
oh.make_graph(
[
oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
oh.make_node("Cast", ["xm2c"], ["xm2"], to=1),
oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]),
oh.make_node("Reshape", ["xm", "shape3"], ["Z"]),
],
"dummy",
[oh.make_tensor_value_info("X", TFLOAT, [320, 1280])],
[oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])],
[
onh.from_array(
np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y"
),
onh.from_array(np.array([0], dtype=np.int64), name="zero"),
onh.from_array(np.array([1], dtype=np.int64), name="un"),
onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"),
onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"),
onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"),
],
),
opset_imports=[oh.make_opsetid("", 18)],
ir_version=9,
)
for i_node in [0, 1, 2, 3]:
node = model.graph.node[i_node]
meta = node.metadata_props.add()
meta.key = f"source[{i_node}]"
meta.value = f"LLL{i_node//3}"
new_model = make_model_with_local_functions(
model, "^LLL[01]$", metadata_key_prefix="source[", verbose=1
)
check_model(model)
self.assertEqual(len(new_model.functions), 2)
p = pretty_onnx(new_model)
self.assertIn("LLL0[local_function]", p)
self.assertIn("LLL1[local_function]", p)

self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input)
self.assertEqual(["xm1"], new_model.functions[0].output)
self.assertEqual("LLL0", new_model.functions[0].name)
self.assertEqual("local_function", new_model.functions[0].domain)
self.assertEqual(len(new_model.functions[0].node), 3)

self.assertEqual(["Y", "shape2"], new_model.functions[1].input)
self.assertEqual(["xm2c"], new_model.functions[1].output)
self.assertEqual("LLL1", new_model.functions[1].name)
self.assertEqual("local_function", new_model.functions[1].domain)
self.assertEqual(len(new_model.functions[1].node), 1)

check_model(new_model)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
get_parser_find,
get_parser_lighten,
get_parser_optimize,
get_parser_partition,
get_parser_print,
get_parser_sbs,
get_parser_stats,
Expand Down Expand Up @@ -186,6 +187,13 @@ def test_parser_optimize(self):
text = st.getvalue()
self.assertIn("default", text)

def test_parser_partition(self):
st = StringIO()
with redirect_stdout(st):
get_parser_partition().print_help()
text = st.getvalue()
self.assertIn("regex", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines_exe.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ def test_l_parser_optimize(self):
self.assertIn("default", text)
self.assertExists(output)

def test_m_parser_partition(self):
output = self.get_dump_file("test_parser_partition.onnx")
st = StringIO()
with redirect_stdout(st):
main(["partition", self.dummy_path, output, "-v", "1"])
text = st.getvalue()
self.assertIn("-- done", text)


if __name__ == "__main__":
unittest.main(verbosity=2)
Loading
Loading