diff --git a/onnxscript/_internal/_inference.py b/onnxscript/_internal/_inference.py new file mode 100644 index 0000000000..35b3ef3228 --- /dev/null +++ b/onnxscript/_internal/_inference.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +from __future__ import annotations + +import numpy as np +import onnx +import onnx_ir as ir + + +def _get_numpy_value( + val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None +) -> np.ndarray | None: + """Returns the numpy value of a constant value, if available. + + It returns None if the value is not a constant value, or if the value is not of + the specified element dtype, or if the size of the value exceeds the specified + size_limit. + """ + if val is None: + return None + const_value = val.const_value + if const_value is not None: + if dtype is not None and const_value.dtype != dtype: + return None + if size_limit is not None and const_value.size > size_limit: + return None + try: + # Turn the constant value into a numpy array representation with the + # specifics of this conversion handled by the tensor type + array = const_value.numpy() + # Can/should not reinterpret strings via .view, resulting in + # "TypeError: Cannot change data-type for array of references." + # There is also no reason to reinterpret strings, this is only + # relevant for some arithmetic types + if const_value.dtype != ir.DataType.STRING: + # Reinterpret the array with `.view()` because some + # implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do + # not use ml_dtypes for bfloat16 etc. + array = array.view(const_value.dtype.numpy()) + except FileNotFoundError: + # External data is not available. + # logger.warning( + # "External data for value '%s' is not available. " + # "This may lead to incorrect constant folding.", + # val.name, + # ) + return None + assert isinstance(array, np.ndarray) + return array + return None + + +def _do_onnx_inference(node: ir.Node) -> None: + output_types = {} + + def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: + value = _get_numpy_value(x, size_limit=20) + if value is not None: + assert x.const_value is not None + return ir.serde.serialize_tensor(x.const_value) + return None + + def get_type(index: int, value: ir.Value) -> onnx.TypeProto: + if value.type is None: + raise ValueError( + f"Type of input {index} value {value.name} of node {node.name} not known" + ) + type_proto = ir.serde.serialize_type(value.type) + if value.shape is not None: + ir.serde.serialize_shape_into(type_proto, value.shape) + return type_proto + + input_types = {x.name: get_type(i, x) for i, x in enumerate(node.inputs) if x is not None} + input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} + input_data = {k: v for k, v in input_data.items() if v is not None} + + # TODO: pass in constant values, ir_version + schema = onnx.defs.get_schema(node.op_type, node.version, node.domain) + output_types = onnx.shape_inference.infer_node_outputs( + schema, + ir.serde.serialize_node(node), + input_types, # type: ignore[arg-type] + input_data, # type: ignore[arg-type] + ) + for output in node.outputs: + if output.name in output_types: + inferred_type = output_types[output.name] + # TODO: merge types, check for conflicts + inferred_shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) + # NOTE: forward shape inference + output.merge_shapes(inferred_shape) + output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) + + +def infer_outputs(node: ir.Node) -> None: + try: + _do_onnx_inference(node) + except Exception as e: + # TODO: compose with any existing error + node.metadata_props["inference_error"] = str(e) diff --git a/onnxscript/_internal/builder.py b/onnxscript/_internal/builder.py new file mode 100644 index 0000000000..dd580555a3 --- /dev/null +++ b/onnxscript/_internal/builder.py @@ -0,0 +1,291 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +from typing import Any, Callable, Sequence + +import onnx +import onnx_ir as ir + +import onnxscript._internal._inference as inference +import onnxscript.optimizer + + +class GraphBuilder: + def __init__(self, graph: ir.Graph, is_function: bool) -> None: + self._graph = graph + + # Get the opset version for "" (default domain) from the graph + if "" not in graph.opset_imports: + raise ValueError('Input graph does not have an import for domain ""') + opset_version = graph.opset_imports[""] + + self._op_builder = self.opset("", opset_version) + + # Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done. + # The current context is used as a prefix for naming values and nodes. + # This allows us to generate names like "layer1.attention.query" + self._context_stack: list[str] = [""] + + def opset(self, domain: str = "", version: int = 1) -> OpBuilder: + return OpBuilder(self, domain, version) + + @property + def op(self) -> OpBuilder: + return self._op_builder + + @property + def graph(self) -> ir.Graph: + return self._graph + + def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: + if name is None: + name = tensor.name + prefix = self.context_name() + if prefix: + name = f"{prefix}.{name}" + # TODO: set tensor name as well + shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims) + value = ir.Value( + name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor + ) + self._graph.register_initializer(value) + return value + + def _input_to_ir_value( + self, value: ir.Value | ir.TensorProtocol, like_type: ir.Value | None = None + ) -> ir.Value: + if isinstance(value, ir.Value): + return value + elif isinstance(value, (int, float, bool, str)): + # Scalar constant + import numpy as np + + if like_type is not None and like_type.type is not None: + dtype = like_type.type.dtype + else: + # Infer type from Python type + if isinstance(value, bool): + dtype = ir.DataType.BOOL + elif isinstance(value, int): + dtype = ir.DataType.INT64 + elif isinstance(value, float): + dtype = ir.DataType.FLOAT32 + elif isinstance(value, str): + dtype = ir.DataType.STRING + else: + raise TypeError(f"Unsupported scalar type: {type(value)}") + tensor = ir.Tensor( + data=np.array(value, dtype=dtype.numpy()), + name="const_scalar", + ) + return self.initializer(tensor) + else: + # assert isinstance(value, ir.TensorProtocol): + # TODO: We could using caching to avoid duplicate initializers. However, it seems unlikely + # to be useful in practice, as shared use of a stateful module is rare. + return self.initializer(value) + + def _adapt_outputs( + self, outputs: int | Sequence[str | ir.Value], op_type: str = "" + ) -> Sequence[ir.Value]: + prefix = self.context_name() + if isinstance(outputs, int): + if outputs == 1: + name = f"{op_type}_output" if op_type else "output" + if prefix: + name = f"{prefix}.{name}" + return [ir.Value(name=name)] + else: + names = [ + f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) + ] + if prefix: + names = [f"{prefix}.{n}" for n in names] + return [ir.Value(name=n) for n in names] + adapted_outputs = [] + for output in outputs: + if isinstance(output, ir.Value): + if prefix and output.name: + output.name = f"{prefix}.{output.name}" + adapted_outputs.append(output) + elif isinstance(output, str): + name = f"{prefix}.{output}" if prefix else output + adapted_outputs.append(ir.Value(name=name)) + else: + raise TypeError("Output type not supported.") + return adapted_outputs + + def _get_schema( + self, op_type: str, domain: str, version: int | None + ) -> onnx.defs.OpSchema | None: + if version is not None: + try: + return onnx.defs.get_schema(op_type, version, domain) + except onnx.defs.SchemaError: + pass + return None + + def _partition_inputs_attributes( + self, + schema: onnx.defs.OpSchema | None, + inputs: Sequence[ir.Value | ir.TensorProtocol], + kwargs: dict[str, Any], + ) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]: + # Not implemented yet + return inputs, kwargs + + def _cast_inputs( + self, + schema: onnx.defs.OpSchema | None, + inputs: Sequence[ir.Value | ir.TensorProtocol], + ) -> Sequence[ir.Value]: + """Uses schema specification to support a limited form of auto-casting. + + * Scalars are promoted to tensors. + * Further. they are cast to the required type when used in ops with other + tensor inputs that are required to be of same type. + Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same + type as A. + + This is used by the converter in a static-mode, as well as by the eager-mode + execution in a dynamic-mode. + """ + if schema is None: + return [self._input_to_ir_value(i) for i in inputs] + + expected_inputs = schema.inputs + # We make two passes. In the first pass, we identify known type-bindings for + # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. + # In the second pass, we use these bindings to cast scalar-values to + # tensors of appropriate types. The two passes are needed to handle cases + # like "Add(1, X)" where 1 must be cast to the same type as X. + type_bindings: dict[str, ir.Value] = {} + args_typevars: list[tuple[ir.Value | None, str | None]] = [] + for i, x in enumerate(inputs): + if i < len(expected_inputs): + expected = expected_inputs[i] + elif ( + expected_inputs[-1].option == onnx.defs.OpSchema.FormalParameterOption.Variadic + ): + expected = expected_inputs[-1] + if not expected.is_homogeneous: + args_typevars.append((x, None)) + continue + else: + raise ValueError( + f"Number of actual parameters {len(inputs)} " + f"exceeds number of formal parameters {len(expected_inputs)}." + ) + typevar = expected.type_str + if ("(" not in typevar) and (typevar not in type_bindings): + # typevar is an identifier, like "T" + if isinstance(x, ir.Value): + type_bindings[typevar] = x + args_typevars.append((x, typevar)) + + def adapt(x, typevar: str | None) -> ir.Value | None: + if x is None: + return None + if typevar is None: + return self._input_to_ir_value(x) + type_like = type_bindings.get(typevar) if typevar is not None else None + return self._input_to_ir_value(x, type_like) + + return [adapt(x, typevar) for x, typevar in args_typevars] + + def _cast_attributes( + self, + schema: onnx.defs.OpSchema | None, + attributes: dict[str, Any], + ) -> Sequence[ir.Attr]: + if attributes is None: + attrs: Sequence[ir.Attr] = () + else: + attrs = ir._convenience.convert_attributes(attributes) + return attrs + + def add_node(self, node: ir.Node) -> None: + self.graph.append(node) + onnxscript.optimizer.basic_constant_propagation([node]) + inference.infer_outputs(node) + + def call_op( + self, + op_type: str, + inputs: Sequence[ir.Value | ir.TensorProtocol], + kwargs: dict[str, Any], + ): + domain = kwargs.pop("_domain", "") + version = kwargs.pop("_version", None) + outputs = kwargs.pop("_outputs", 1) + + count = self.graph.num_nodes() + prefix = self.context_name() + node_name = f"{op_type}_node_{count}" + if prefix: + node_name = f"{prefix}.{node_name}" + + output_values = self._adapt_outputs(outputs, op_type) + + schema = self._get_schema(op_type, domain, version) + inputs, attributes = self._partition_inputs_attributes(schema, inputs, kwargs) + inputs = self._cast_inputs(schema, inputs) + attr_sequence = self._cast_attributes(schema, attributes) + + node = ir.Node( + domain, + op_type, + inputs, + attr_sequence, + outputs=output_values, + version=version, + name=node_name, + ) + self.add_node(node) + + return node.outputs if len(node.outputs) > 1 else node.outputs[0] + + def push_module(self, module: str) -> None: + current = self.context_name() + if module: + new_context = f"{current}.{module}" if current else module + else: + new_context = current + self._context_stack.append(new_context) + + def pop_module(self) -> None: + self._context_stack.pop() + + def context_name(self) -> str: + return self._context_stack[-1] if self._context_stack else "" + + +class OpBuilder: + def __init__( + self, builder: GraphBuilder, domain: str = "", version: int | None = None + ) -> None: + self._builder = builder + self._domain = domain + self._version = version + + @property + def builder(self) -> GraphBuilder: + return self._builder + + def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): + if "_domain" not in kwargs: + kwargs["_domain"] = self._domain + if self._version is not None and "_version" not in kwargs: + kwargs["_version"] = self._version + return self._builder.call_op(op_type, inputs, kwargs) + + def __getattr__(self, op_type: str) -> Callable: + return lambda *args, **kwargs: self._call_op(op_type, args, kwargs) + + def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: + return self._builder.initializer(tensor, name) + + def call(self, function, *args, **kwargs): + return self._builder.call(function, *args, **kwargs) diff --git a/onnxscript/_internal/builder_test.py b/onnxscript/_internal/builder_test.py new file mode 100644 index 0000000000..5d13709f4f --- /dev/null +++ b/onnxscript/_internal/builder_test.py @@ -0,0 +1,397 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import unittest +from typing import Sequence + +import onnx_ir as ir + +import onnxscript._internal.builder as builder + +_default_opset_version = 23 + + +def _build( + trace_function, + input_types: Sequence[ir.TypeAndShape], + output_types: Sequence[ir.TypeAndShape], +) -> ir.Model: + graph = ir.Graph( + name="test_model", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": _default_opset_version}, + ) + + onnx_model = ir.Model(graph=graph, ir_version=10) + + for i, input_type in enumerate(input_types): + input_name = f"input_{i}" + graph.inputs.append(ir.Value(name=input_name, type=input_type)) + + graph_builder = builder.GraphBuilder(graph, is_function=False) + outputs = trace_function(graph_builder.op, *graph.inputs) + if not isinstance(outputs, Sequence): + outputs = [outputs] + if len(outputs) != len(output_types): + raise ValueError(f"Expected {len(output_types)} outputs, but got {len(outputs)}.") + for output, output_type in zip(outputs, output_types): + output.type = output_type.type # TODO: need merge_type method in ir.Value + output.merge_shapes(output_type.shape) + + graph.outputs.extend(outputs) + + return onnx_model + + +def _create_builder_with_inputs() -> tuple[builder.OpBuilder, ir.Value, ir.Value]: + """Create a graph builder with two float tensor inputs (shape [2, 3, 4]). + + Returns: + A tuple of (op_builder, input_x, input_y). + """ + graph = ir.Graph( + name="test_model", + inputs=[], + outputs=[], + nodes=[], + opset_imports={"": 23}, + ) + + for i in range(2): + input_name = f"input_{i}" + graph.inputs.append( + ir.Value( + name=input_name, + type=ir.TensorType(ir.DataType.FLOAT), + shape=ir.Shape([2, 3, 4]), + ) + ) + + graph_builder = builder.GraphBuilder(graph, is_function=False) + x, y = graph.inputs + return graph_builder.op, x, y + + +class GraphBuilderTest(unittest.TestCase): + def test_builder_basic(self): + def _add_mul_add(op: builder.OpBuilder, x: ir.Value, y: ir.Value) -> ir.Value: + t1 = op.Add(x, y) + t2 = op.Mul(x, y) + z = op.Add(t1, t2) + return z + + float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) + model = _build( + _add_mul_add, + input_types=[float_2d, float_2d], + output_types=[float_2d], + ) + graph = model.graph + # Expect exactly 3 nodes: Add, Mul, Add + op_types = [node.op_type for node in graph] + self.assertEqual(op_types, ["Add", "Mul", "Add"]) + + # Verify inputs and outputs + self.assertEqual(len(graph.inputs), 2) + self.assertEqual(len(graph.outputs), 1) + + # Verify the connectivity: final Add takes outputs of the first Add and Mul + nodes = list(graph) + add1, mul, add2 = nodes + self.assertEqual(list(add2.inputs), [add1.outputs[0], mul.outputs[0]]) + + def test_value_naming(self): + """Test that output names can be specified via the _outputs option.""" + + def _add_with_custom_names( + op: builder.OpBuilder, x: ir.Value, y: ir.Value + ) -> ir.Value: + # Specify custom names for output values + t1 = op.Add(x, y, _outputs=["add_result"]) + t2 = op.Mul(x, y, _outputs=["mul_result"]) + z = op.Add(t1, t2, _outputs=["final_add"]) + return z + + float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) + model = _build( + _add_with_custom_names, + input_types=[float_2d, float_2d], + output_types=[float_2d], + ) + graph = model.graph + + # Verify that the nodes have outputs with the specified names + nodes = list(graph) + self.assertEqual(len(nodes), 3) + + # Check output names + self.assertEqual(nodes[0].outputs[0].name, "add_result") + self.assertEqual(nodes[1].outputs[0].name, "mul_result") + self.assertEqual(nodes[2].outputs[0].name, "final_add") + + # Verify the final output has the correct name + self.assertEqual(len(graph.outputs), 1) + self.assertEqual(graph.outputs[0].name, "final_add") + + def test_value_naming_with_hierarchy(self): + """Test that hierarchical naming works with user-specified output names.""" + op, x, y = _create_builder_with_inputs() + + # Test custom names at root level + t1 = op.Add(x, y, _outputs=["my_add"]) + self.assertEqual(t1.name, "my_add") + + # Test custom names with hierarchical context + op.builder.push_module("layer1") + t2 = op.Mul(t1, y, _outputs=["my_mul"]) + self.assertEqual(t2.name, "layer1.my_mul") + + # Test nested hierarchical context with custom names + op.builder.push_module("attention") + t3 = op.Add(t2, x, _outputs=["my_nested_add"]) + self.assertEqual(t3.name, "layer1.attention.my_nested_add") + + # Pop back and verify prefix is applied correctly + op.builder.pop_module() + t4 = op.Mul(t3, y, _outputs=["another_mul"]) + self.assertEqual(t4.name, "layer1.another_mul") + + op.builder.pop_module() + t5 = op.Add(t4, x, _outputs=["final_result"]) + self.assertEqual(t5.name, "final_result") + + def test_value_naming_with_ir_value_objects(self): + """Test that hierarchical naming works when passing ir.Value objects as _outputs.""" + op, x, y = _create_builder_with_inputs() + + # Create pre-named ir.Value objects + out1 = ir.Value(name="my_output") + out2 = ir.Value(name="layer_output") + out3 = ir.Value(name="nested_output") + + # Test at root level + t1 = op.Add(x, y, _outputs=[out1]) + self.assertEqual(t1.name, "my_output") + self.assertIs(t1, out1) + + # Test with hierarchical context + op.builder.push_module("layer1") + t2 = op.Mul(t1, y, _outputs=[out2]) + self.assertEqual(t2.name, "layer1.layer_output") + self.assertIs(t2, out2) + + # Test nested hierarchical context + op.builder.push_module("attention") + t3 = op.Add(t2, x, _outputs=[out3]) + self.assertEqual(t3.name, "layer1.attention.nested_output") + self.assertIs(t3, out3) + + def test_default_output_naming_strategy(self): + """Test the default naming strategy for generated output values using op_type_output format.""" + + def _ops_with_default_names( + op: builder.OpBuilder, x: ir.Value, y: ir.Value + ) -> ir.Value: + # Single output operations should be named {op_type}_output + t1 = op.Add(x, y) + t2 = op.Mul(x, y) + z = op.Add(t1, t2) + return z + + float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4])) + model = _build( + _ops_with_default_names, + input_types=[float_2d, float_2d], + output_types=[float_2d], + ) + graph = model.graph + + # Verify the nodes use the new naming strategy + nodes = list(graph) + self.assertEqual(len(nodes), 3) + + # Check output names follow the {op_type}_output pattern for single outputs + self.assertEqual(nodes[0].outputs[0].name, "Add_output") + self.assertEqual(nodes[1].outputs[0].name, "Mul_output") + self.assertEqual(nodes[2].outputs[0].name, "Add_output") + + # Verify the final output has the correct name + self.assertEqual(len(graph.outputs), 1) + self.assertEqual(graph.outputs[0].name, "Add_output") + + def test_hierarchical_naming(self): + """Test the hierarchical naming strategy (for value and node names).""" + op, x, y = _create_builder_with_inputs() + + # Test node and value naming at root level + t1 = op.Add(x, y) + self.assertEqual(t1.name, "Add_output") + self.assertEqual(t1.producer().name, "Add_node_0") + + t2 = op.Mul(t1, y) + self.assertEqual(t2.name, "Mul_output") + self.assertEqual(t2.producer().name, "Mul_node_1") + + # Test node and value naming with hierarchical context prefix + op.builder.push_module("layer1") + t3 = op.Add(t2, x) + self.assertEqual(t3.name, "layer1.Add_output") + self.assertEqual(t3.producer().name, "layer1.Add_node_2") + + # Test nested hierarchical context + op.builder.push_module("attention") + t4 = op.Mul(t3, y) + self.assertEqual(t4.name, "layer1.attention.Mul_output") + self.assertEqual(t4.producer().name, "layer1.attention.Mul_node_3") + + # Pop back to layer1 and verify naming continues correctly + op.builder.pop_module() + t5 = op.Add(t4, x) + self.assertEqual(t5.name, "layer1.Add_output") + self.assertEqual(t5.producer().name, "layer1.Add_node_4") + + # Pop back to root context + op.builder.pop_module() + t6 = op.Mul(t5, y) + self.assertEqual(t6.name, "Mul_output") + self.assertEqual(t6.producer().name, "Mul_node_5") + + def test_shape_inference_add(self): + """Test that shape inference works correctly for Add operation.""" + op, x, y = _create_builder_with_inputs() + + # Create Add node without explicitly setting output type/shape + result = op.Add(x, y) + + # Verify output type is inferred correctly + self.assertIsNotNone(result.type) + self.assertEqual(result.type.dtype, ir.DataType.FLOAT) + + # Verify output shape is inferred correctly + self.assertIsNotNone(result.shape) + self.assertEqual(list(result.shape), [2, 3, 4]) + + def test_custom_domain_explicit(self): + """Test using operations from custom domains with explicit _domain parameter.""" + op, x, y = _create_builder_with_inputs() + + # Create a custom domain operation with explicit _domain parameter + # Using "com.microsoft" as an example domain + result = op.CustomOp(x, y, _domain="com.microsoft") + + # Verify the node was created with the correct domain + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + node = nodes[0] + self.assertEqual(node.domain, "com.microsoft") + self.assertEqual(node.op_type, "CustomOp") + + # Verify inputs and outputs are connected correctly + self.assertEqual(list(node.inputs), [x, y]) + self.assertEqual(node.outputs[0], result) + + def test_custom_domain_with_version(self): + """Test using operations from custom domains with explicit _domain and _version parameters.""" + op, x, y = _create_builder_with_inputs() + + # Create a custom domain operation with explicit _domain and _version parameters + result = op.MicrosoftOp(x, y, _domain="com.microsoft", _version=10) + + # Verify the node was created with the correct domain and version + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 1) + node = nodes[0] + self.assertEqual(node.domain, "com.microsoft") + self.assertEqual(node.op_type, "MicrosoftOp") + self.assertEqual(node.version, 10) + + # Verify output value is created + self.assertIsNotNone(result) + self.assertEqual(result.name, "MicrosoftOp_output") + + def test_multiple_custom_domain_operations(self): + """Test mixing operations from multiple domains.""" + op, x, y = _create_builder_with_inputs() + + # Create standard domain operation + t1 = op.Add(x, y) + + # Create custom domain operation + t2 = op.CustomOp(t1, y, _domain="com.microsoft") + + # Create another custom domain operation with different domain + t3 = op.AnotherOp(t2, x, _domain="com.custom") + + # Verify all nodes were created with correct domains + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 3) + + self.assertEqual(nodes[0].domain, "") + self.assertEqual(nodes[0].op_type, "Add") + + self.assertEqual(nodes[1].domain, "com.microsoft") + self.assertEqual(nodes[1].op_type, "CustomOp") + + self.assertEqual(nodes[2].domain, "com.custom") + self.assertEqual(nodes[2].op_type, "AnotherOp") + + def test_opset_builder_for_custom_domain(self): + """Test creating and using an opset builder for a custom domain.""" + op, x, y = _create_builder_with_inputs() + + # Create an OpBuilder for the "com.microsoft" domain with version 1 + ms_op = op.builder.opset("com.microsoft", 1) + + # Use operations through the custom domain opset builder + t1 = ms_op.CustomOp(x, y) + t2 = ms_op.AnotherOp(t1, x) + + # Verify all nodes were created with the correct domain + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 2) + + # Verify first operation + self.assertEqual(nodes[0].domain, "com.microsoft") + self.assertEqual(nodes[0].op_type, "CustomOp") + self.assertEqual(nodes[0].version, 1) + self.assertEqual(list(nodes[0].inputs), [x, y]) + + # Verify second operation + self.assertEqual(nodes[1].domain, "com.microsoft") + self.assertEqual(nodes[1].op_type, "AnotherOp") + self.assertEqual(nodes[1].version, 1) + self.assertEqual(list(nodes[1].inputs), [t1, x]) + + def test_mixed_domain_opsets(self): + """Test using both standard domain and custom domain opset builders together.""" + op, x, y = _create_builder_with_inputs() + + # Create custom domain opset builder + ms_op = op.builder.opset("com.microsoft", 2) + + # Mix operations from different domains + t1 = op.Add(x, y) # Standard domain operation + t2 = ms_op.MsAdd(t1, y) # Custom domain operation + t3 = op.Mul(t2, x) # Back to standard domain + + # Verify nodes were created with correct domains + nodes = list(op.builder.graph) + self.assertEqual(len(nodes), 3) + + self.assertEqual(nodes[0].domain, "") + self.assertEqual(nodes[0].op_type, "Add") + + self.assertEqual(nodes[1].domain, "com.microsoft") + self.assertEqual(nodes[1].op_type, "MsAdd") + self.assertEqual(nodes[1].version, 2) + + self.assertEqual(nodes[2].domain, "") + self.assertEqual(nodes[2].op_type, "Mul") + + +if __name__ == "__main__": + unittest.main()