diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fd1479f3..68707605 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,33 +17,31 @@ jobs: matrix: os: [ubuntu-latest] python: ['3.10', '3.11', '3.12', '3.13'] - transformers: ['4.48.3', '4.51.3', '4.52.4', '4.55.4', '4.56.2', '4.57.6', 'main'] + transformers: ['4.48.3', '4.51.3', '4.55.4', '4.56.2', '4.57.6', 'main'] torch: ['2.9', 'main'] exclude: - python: '3.10' # 3.10 torch: 'main' - python: '3.10' torch: '2.9' - - python: '3.10' - transformers: 'main' - - python: '3.10' - transformers: '4.52.4' - python: '3.10' transformers: '4.55.4' - python: '3.10' transformers: '4.56.2' - python: '3.10' transformers: '4.57.6' + - python: '3.10' + transformers: 'main' - python: '3.11' # 3.11 torch: 'main' - - python: '3.11' - transformers: 'main' - python: '3.11' transformers: '4.55.4' - python: '3.11' transformers: '4.56.2' - python: '3.11' transformers: '4.57.6' + - python: '3.11' + transformers: 'main' - python: '3.13' # 3.11 torch: '2.9' - python: '3.13' diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index eee79a0d..d600abe3 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,8 @@ Change Logs 0.8.11 ++++++ +* :pr:`394`: add function make_model_with_local_functions to partition a model into local functions + 0.8.10 ++++++ diff --git a/_doc/api/api.rst b/_doc/api/api.rst deleted file mode 100644 index 47811c34..00000000 --- a/_doc/api/api.rst +++ /dev/null @@ -1,7 +0,0 @@ - -onnx_diagnostic.api -=================== - -.. automodule:: onnx_diagnostic.api - :members: - :no-undoc-members: diff --git a/_doc/api/index.rst b/_doc/api/index.rst index 52b01674..3b32e2ce 100644 --- a/_doc/api/index.rst +++ b/_doc/api/index.rst @@ -20,7 +20,7 @@ API of onnx_diagnostic :maxdepth: 1 :caption: modules - api + typing ext_test_case .. automodule:: onnx_diagnostic diff --git a/_doc/api/typing.rst b/_doc/api/typing.rst new file mode 100644 index 00000000..962cf1b1 --- /dev/null +++ b/_doc/api/typing.rst @@ -0,0 +1,6 @@ + +onnx_diagnostic.typing +====================== + +.. automodule:: onnx_diagnostic.typing + :members: diff --git a/_doc/cmds/_img_partition.png b/_doc/cmds/_img_partition.png new file mode 100644 index 00000000..6b0f8e7c Binary files /dev/null and b/_doc/cmds/_img_partition.png differ diff --git a/_doc/cmds/index.rst b/_doc/cmds/index.rst index a357777a..a211ba04 100644 --- a/_doc/cmds/index.rst +++ b/_doc/cmds/index.rst @@ -11,5 +11,6 @@ Command Lines compare config optimize + partition sbs validate diff --git a/_doc/cmds/partition.rst b/_doc/cmds/partition.rst new file mode 100644 index 00000000..dcde17be --- /dev/null +++ b/_doc/cmds/partition.rst @@ -0,0 +1,47 @@ +-m onnx_diagnostic partition ... move layer nodes in local functions +==================================================================== + +The command line leverages the metadata added by the exporter. +Every node is tagged with information indicating which part of the model +it comes from. In particular the eky `namespace`: + +:: + + transformers.models.llama.modeling_llama.LlamaForCausalLM/model: + transformers.models.llama.modeling_llama.LlamaModel/model.layers.0: + transformers.models.llama.modeling_llama.LlamaDecoderLayer/model.layers.0.self_attn: + transformers.models.llama.modeling_llama.LlamaAttention/unsqueeze_15: + aten.unsqueeze.default + +Description ++++++++++++ + +See :func:`onnx_diagnostic.helpers.onnx_helper.make_model_with_local_functions`. + +.. runpython:: + + from onnx_diagnostic._command_lines_parser import get_parser_partition + + get_parser_partition().print_help() + +Example ++++++++ + +.. code-block:: bash + + python -m onnx_diagnostic partition arnir0_Tiny-LLM-onnx-dynamo-ir-f16-cuda-op18.onnx partition.onnx -r ".*[.]layers[.][0-9]+$" -v 1 + +This produces the following output: + +:: + + -- load 'arnir0_Tiny-LLM-onnx-dynamo-ir-f16-cuda-op18.onnx' + -- partition + [make_model_with_local_functions] matched 1 partitions + [make_model_with_local_functions] move 89 nodes in partition 'transformers_models_llama_modeling_llama_LlamaModel/model_layers_0' + -- save into 'partition.onnx' + -- done + +The partitioned model includes the following node: + +.. image:: _img_partition.png diff --git a/_unittests/ut_helpers/test_onnx_helper.py b/_unittests/ut_helpers/test_onnx_helper.py index e9f0a1fb..5d83a711 100644 --- a/_unittests/ut_helpers/test_onnx_helper.py +++ b/_unittests/ut_helpers/test_onnx_helper.py @@ -28,9 +28,12 @@ shadowing_names, onnx_dtype_name, extract_subset_of_nodes, + make_subfunction, make_submodel, + make_model_with_local_functions, select_model_inputs_outputs, _enumerate_model_node_outputs, + pretty_onnx, ) TFLOAT = TensorProto.FLOAT @@ -537,6 +540,46 @@ def _type_rank_fn(name): check_model(new_model) self.check_ort(new_model) + def test_make_subfunction(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + new_function = make_subfunction( + "localf", + model.graph.node[:4], + opset_imports=model.opset_import, + output_names=["xm1", "xm2c"], + ) + self.assertIsInstance(new_function, FunctionProto) + self.assertEqual(len(new_function.node), 4) + self.assertEqual(new_function.output, ["xm1", "xm2c"]) + self.assertEqual(new_function.input, ["X", "Y", "shape1", "shape2", "un", "zero"]) + def test_extract_subset_of_nodes_bigger(self): model = onnx.load( os.path.join( @@ -670,6 +713,153 @@ def enumerate_model_tensors(model): got = sess.run(None, {"X": x})[0] self.assertEqual((x**2 + y).tolist(), got.tolist()) + def test_make_model_with_local_functions(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + for i_node in [0, 1, 2, 3]: + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = "namespace" + meta.value = "LLL" + new_model = make_model_with_local_functions(model, "^LLL$") + check_model(model) + self.assertEqual(len(new_model.functions), 1) + self.assertEqual( + ["X", "Y", "shape1", "shape2", "un", "zero"], new_model.functions[0].input + ) + self.assertEqual(["xm1", "xm2c"], new_model.functions[0].output) + self.assertEqual("LLL", new_model.functions[0].name) + self.assertEqual("local_function", new_model.functions[0].domain) + self.assertIn("LLL[local_function]", pretty_onnx(new_model)) + check_model(new_model) + + def test_make_model_with_local_functions_bug(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + for i_node in [0, 2, 3, 4]: + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = "namespace" + meta.value = "LLL" + self.assertRaise( + lambda: make_model_with_local_functions(model, "^LLL$"), + ValueError, + "Results {'xu1'} are needed for inputs ['X', 'Y', 'shape1', " + "'shape2', 'xu2', 'zero'] but also requires ['xm1', 'xm2', 'xu1'] " + "which is not allowed.", + ) + check_model(model) + + @hide_stdout() + def test_make_model_with_local_functions_2(self): + model = oh.make_model( + oh.make_graph( + [ + oh.make_node("Unsqueeze", ["X", "zero"], ["xu1"]), + oh.make_node("Unsqueeze", ["xu1", "un"], ["xu2"]), + oh.make_node("Reshape", ["xu2", "shape1"], ["xm1"]), + oh.make_node("Reshape", ["Y", "shape2"], ["xm2c"]), + oh.make_node("Cast", ["xm2c"], ["xm2"], to=1), + oh.make_node("MatMul", ["xm1", "xm2"], ["xm"]), + oh.make_node("Reshape", ["xm", "shape3"], ["Z"]), + ], + "dummy", + [oh.make_tensor_value_info("X", TFLOAT, [320, 1280])], + [oh.make_tensor_value_info("Z", TFLOAT, [3, 5, 320, 640])], + [ + onh.from_array( + np.random.rand(3, 5, 1280, 640).astype(np.float32), name="Y" + ), + onh.from_array(np.array([0], dtype=np.int64), name="zero"), + onh.from_array(np.array([1], dtype=np.int64), name="un"), + onh.from_array(np.array([1, 320, 1280], dtype=np.int64), name="shape1"), + onh.from_array(np.array([15, 1280, 640], dtype=np.int64), name="shape2"), + onh.from_array(np.array([3, 5, 320, 640], dtype=np.int64), name="shape3"), + ], + ), + opset_imports=[oh.make_opsetid("", 18)], + ir_version=9, + ) + for i_node in [0, 1, 2, 3]: + node = model.graph.node[i_node] + meta = node.metadata_props.add() + meta.key = f"source[{i_node}]" + meta.value = f"LLL{i_node//3}" + new_model = make_model_with_local_functions( + model, "^LLL[01]$", metadata_key_prefix="source[", verbose=1 + ) + check_model(model) + self.assertEqual(len(new_model.functions), 2) + p = pretty_onnx(new_model) + self.assertIn("LLL0[local_function]", p) + self.assertIn("LLL1[local_function]", p) + + self.assertEqual(["X", "shape1", "un", "zero"], new_model.functions[0].input) + self.assertEqual(["xm1"], new_model.functions[0].output) + self.assertEqual("LLL0", new_model.functions[0].name) + self.assertEqual("local_function", new_model.functions[0].domain) + self.assertEqual(len(new_model.functions[0].node), 3) + + self.assertEqual(["Y", "shape2"], new_model.functions[1].input) + self.assertEqual(["xm2c"], new_model.functions[1].output) + self.assertEqual("LLL1", new_model.functions[1].name) + self.assertEqual("local_function", new_model.functions[1].domain) + self.assertEqual(len(new_model.functions[1].node), 1) + + check_model(new_model) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines.py b/_unittests/ut_xrun_doc/test_command_lines.py index 60f37180..07644e21 100644 --- a/_unittests/ut_xrun_doc/test_command_lines.py +++ b/_unittests/ut_xrun_doc/test_command_lines.py @@ -11,6 +11,7 @@ get_parser_find, get_parser_lighten, get_parser_optimize, + get_parser_partition, get_parser_print, get_parser_sbs, get_parser_stats, @@ -186,6 +187,13 @@ def test_parser_optimize(self): text = st.getvalue() self.assertIn("default", text) + def test_parser_partition(self): + st = StringIO() + with redirect_stdout(st): + get_parser_partition().print_help() + text = st.getvalue() + self.assertIn("regex", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines_exe.py b/_unittests/ut_xrun_doc/test_command_lines_exe.py index e88ecb87..bd78aae7 100644 --- a/_unittests/ut_xrun_doc/test_command_lines_exe.py +++ b/_unittests/ut_xrun_doc/test_command_lines_exe.py @@ -219,6 +219,14 @@ def test_l_parser_optimize(self): self.assertIn("default", text) self.assertExists(output) + def test_m_parser_partition(self): + output = self.get_dump_file("test_parser_partition.onnx") + st = StringIO() + with redirect_stdout(st): + main(["partition", self.dummy_path, output, "-v", "1"]) + text = st.getvalue() + self.assertIn("-- done", text) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_diagnostic/_command_lines_parser.py b/onnx_diagnostic/_command_lines_parser.py index 7177ddb9..25897600 100644 --- a/onnx_diagnostic/_command_lines_parser.py +++ b/onnx_diagnostic/_command_lines_parser.py @@ -1572,6 +1572,76 @@ def _cmd_optimize(argv: List[Any]): ) +def get_parser_partition() -> ArgumentParser: + parser = ArgumentParser( + prog="partition", + formatter_class=RawTextHelpFormatter, + description=textwrap.dedent(""" + Partitions an onnx model by moving nodes into local functions. + Exporters may add metadata to the onnx nodes telling which part + of the model it comes from (namespace, source, ...). + This nodes are moved into local functions. + """), + epilog=textwrap.dedent(""" + The regular may match the following values, + 'model.layers.0.forward', 'model.layers.1.forward', ... + A local function will be created for each distinct layer. + """), + ) + parser.add_argument("input", help="input model") + parser.add_argument("output", help="output model") + parser.add_argument( + "-r", + "--regex", + default=".*[.]layers[.][0-9]+[.]forward$", + help=textwrap.dedent(""" + merges all nodes sharing the same value in node metadata, + these values must match the regular expression specified by + this parameter, the default value matches what transformers + usually to define a layer + """).strip("\n"), + ) + parser.add_argument( + "-p", + "--meta-prefix", + default="namespace,source[", + help="allowed prefixes for keys in the metadata", + ) + parser.add_argument( + "-v", + "--verbose", + default=0, + required=False, + type=int, + help="verbosity", + ) + return parser + + +def _cmd_partition(argv: List[Any]): + from .helpers.onnx_helper import make_model_with_local_functions + + parser = get_parser_partition() + args = parser.parse_args(argv[1:]) + + if args.verbose: + print(f"-- load {args.input!r}") + onx = onnx.load(args.input, load_external_data=False) + if args.verbose: + print("-- partition") + onx2 = make_model_with_local_functions( + onx, + regex=args.regex, + metadata_key_prefix=tuple(args.meta_prefix.split(",")), + verbose=args.verbose, + ) + if args.verbose: + print(f"-- save into {args.output!r}") + onnx.save(onx2, args.output) + if args.verbose: + print("-- done") + + ############# # main parser ############# @@ -1593,6 +1663,7 @@ def get_main_parser() -> ArgumentParser: find - find node consuming or producing a result lighten - makes an onnx model lighter by removing the weights optimize - optimizes an onnx model + partition - partition a model, each partition appears as local function print - prints the model on standard output sbs - compares an exported program and a onnx model stats - produces statistics on a model @@ -1610,6 +1681,7 @@ def get_main_parser() -> ArgumentParser: "find", "lighten", "optimize", + "partition", "print", "sbs", "stats", @@ -1631,6 +1703,7 @@ def main(argv: Optional[List[Any]] = None): find=_cmd_find, lighten=_cmd_lighten, optimize=_cmd_optimize, + partition=_cmd_partition, print=_cmd_print, sbs=_cmd_sbs, stats=_cmd_stats, @@ -1658,6 +1731,7 @@ def main(argv: Optional[List[Any]] = None): find=get_parser_find, lighten=get_parser_lighten, optimize=get_parser_optimize, + partition=get_parser_partition, print=get_parser_print, sbs=get_parser_sbs, stats=get_parser_stats, diff --git a/onnx_diagnostic/helpers/onnx_helper.py b/onnx_diagnostic/helpers/onnx_helper.py index 83920ddc..07a61f4d 100644 --- a/onnx_diagnostic/helpers/onnx_helper.py +++ b/onnx_diagnostic/helpers/onnx_helper.py @@ -1,6 +1,7 @@ import functools import json import os +import re import sys import warnings from typing import ( @@ -1320,7 +1321,6 @@ def make_submodel( Creates a model with the given list of nodes. It computes the minimum list of inputs needed for this model. The function assumes the nodes are sorted. - It does not handle yet subgraphs. :param nodes: list of nodes :param ir_version: ir version @@ -1343,17 +1343,55 @@ def _mkv_(name, itype, irank): if att.type == onnx.AttributeProto.GRAPH: not_known |= get_hidden_inputs(att.g) - model = oh.make_model( + return oh.make_model( oh.make_graph( nodes, "submodel", [_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n], - [_mkv_(n, *type_rank_fn(n)) for n in sorted(output_names) if n], + [_mkv_(n, *type_rank_fn(n)) for n in output_names if n], ), ir_version=ir_version, opset_imports=opset_imports, ) - return model + + +def make_subfunction( + name: str, + nodes: List[NodeProto], + opset_imports: Sequence[OperatorSetIdProto], + output_names: List[str], + domain: str = "local_function", +) -> FunctionProto: + """ + Creates a function with the given list of nodes. + It computes the minimum list of inputs needed for this model. + The function assumes the nodes are sorted. + + :param name: function name + :param nodes: list of nodes + :param opset_imports: opset import + :param output_names: desired outputs + :param domain: function domain + :return: model proto + """ + not_known: Set[str] = set() + for node in nodes[::-1]: + not_known -= {o for o in node.output if o} + not_known |= {i for i in node.input if i} + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + not_known |= get_hidden_inputs(att.g) + + return oh.make_function( + domain, + name, + nodes=nodes, + inputs=sorted(not_known), + outputs=output_names, + opset_imports=opset_imports, + ) def get_tensor_shape( @@ -1693,3 +1731,172 @@ def select_model_inputs_outputs( op_set.version = oimp.version return onnx_model + + +def _find_used_names(node_list, node_indices): + # find all the outputs the subset of nodes produces + possible_outputs = set() + for i_node in node_indices: + if not node_list[i_node]: + continue + possible_outputs |= {o for o in node_list[i_node].output if o} + # find all requires input from the other nodes + set_indices = set(node_indices) + not_known: Set[str] = set() + ranges = list(range(len(node_list))) + for i_node in ranges[::-1]: + if i_node in set_indices: + continue + node = node_list[i_node] + if not node: + continue + not_known -= {o for o in node.output if o} + not_known |= {i for i in node.input if i} + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + not_known |= get_hidden_inputs(att.g) + # output + selection = possible_outputs & not_known + assert selection, ( + f"No output is needed, possible_outputs={sorted(possible_outputs)}, " + f"not_known={sorted(not_known)}" + ) + return sorted(selection) + + +def check_for_non_recursivity( + node_list: List[Optional[NodeProto]], inputs: Sequence[str], outputs: Sequence[str] +): + """ + We finally need to check that any of this output is not required + by one input from the function itself, that would mean one node + needs an output of the function and is also required by the function: + it is probably missing from the initial set. + + + + :param node_list: list of nodes + :param inputs: input names to consider + :param outputs: output names which cannot be involved in input names + """ + set_inputs = set(inputs) + set_outputs = set(outputs) + for node in node_list[::-1]: + if not node: + continue + si = set(node.output) + if si & set_inputs: + set_inputs |= set(node.input) + if node.op_type in {"Scan", "If", "Loop"}: + # there are hidden inputs + for att in node.attribute: + if att.type == onnx.AttributeProto.GRAPH: + set_inputs |= get_hidden_inputs(att.g) + if set_outputs & set_inputs: + raise ValueError( + f"Results {set_outputs & set_inputs} are needed for inputs {inputs} " + f"but also requires {outputs} which is not allowed." + ) + + +def make_model_with_local_functions( + model: ModelProto, + regex: str = ".*[.]layers[.][0-9]+[.]forward$", + domain: str = "local_function", + metadata_key_prefix: Union[str, Tuple[str, ...]] = ("namespace", "source["), + verbose: int = 0, +) -> ModelProto: + """ + Selects nodes based on a regular expression, using metadata + ``'namespace'``. It is going to look into every value + matching the regular expression and partition the nodes based + on the unique values the regular expression finds. + Every set of nodes it replaced by a call to a local function. + + :param model: model proto + :param regex: regular expression + :param domain: function domain + :param metadata_keys: list of metadata keys to consider, + every value is split into multiple ones. + :param verbose: verbosity + :return: model proto + """ + prefix = ( + metadata_key_prefix + if isinstance(metadata_key_prefix, tuple) + else (metadata_key_prefix,) + ) + reg = re.compile(regex) + unique_values = set() + unique: Dict[str, List[int]] = {} + for i, node in enumerate(model.graph.node): + selected = False + for data in node.metadata_props: + if data.key.startswith(prefix): + values = re.split("[,:]", data.value) + for v in values: + if not v: + continue + if reg.match(v): + if v not in unique: + unique[v] = [] + unique[v].append(i) + selected = True + break + unique_values.add(v) + if selected: + break + # sets of nodes. + if not unique: + if verbose: + print(f"[make_model_with_local_functions] no match in {sorted(unique_values)}") + return model + + if verbose: + print(f"[make_model_with_local_functions] matched {len(unique)} partitions") + functions = [] + new_nodes: List[Optional[NodeProto]] = list(model.graph.node) + for key, node_indices in unique.items(): + function_name = key.strip().replace(".", "_") + if verbose: + print( + f"[make_model_with_local_functions] move {len(node_indices)} " + f"nodes in partition {function_name!r}" + ) + outputs = _find_used_names(new_nodes, node_indices) + function_nodes = [new_nodes[i] for i in node_indices] + lf = make_subfunction( + function_name, + [n for n in function_nodes if n], + model.opset_import, + outputs, + domain=domain, + ) + check_for_non_recursivity(new_nodes, lf.input, lf.output) + functions.append(lf) + maxi = max(node_indices) + for i in node_indices: + new_nodes[i] = None + new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain) + + return oh.make_model( + oh.make_graph( + [n for n in new_nodes if n], + model.graph.name, + model.graph.input, + model.graph.output, + model.graph.initializer, + doc_string=model.graph.doc_string, + value_info=model.graph.value_info, + sparse_initializer=model.graph.sparse_initializer, + ), + ir_version=model.ir_version, + opset_imports=( + model.opset_import + if domain in {d.domain for d in model.opset_import} + else [*model.opset_import, oh.make_opsetid(domain, 1)] + ), + functions=[*model.functions, *functions], + )