Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
310 changes: 310 additions & 0 deletions onnxscript/_internal/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,310 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

from __future__ import annotations

Check warning

Code scanning / lintrunner

RUFF/I001 Warning

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports

from typing import Any, Callable, Mapping, Sequence

import onnx
import onnxscript.values
import onnx_ir as ir

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "onnx_ir" should be placed before first party import "onnxscript.values" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order


def _make_node(
op_type: str,
inputs: Sequence[ir.Value | None],
attributes: Mapping[str, ir._convenience.SupportedAttrTypes] | None = None,
*,
num_outputs: int | None = None,
outputs: Sequence[ir.Value] | None = None,
domain: str = "",
overload: str = "",
version: int | None = None,
graph: ir.Graph | None = None,
name: str | None = None,
doc_string: str | None = None,
metadata_props: dict[str, str] | None = None,
) -> ir.Node:
if num_outputs is None and outputs is None:
raise ValueError("Either num_outputs or outputs must be provided.")
if num_outputs is not None and outputs is not None:
raise ValueError("Both num_outputs and outputs cannot be provided simultaneously.")
output_kwargs: dict[str, Any]
if outputs is None:
output_kwargs = dict(num_outputs=num_outputs)
else:
output_kwargs = dict(outputs=outputs)
if attributes is None:
attrs: Sequence[ir.Attr] = ()
else:
attrs = ir._convenience.convert_attributes(attributes)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _convenience of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
node = ir.Node(
domain,
op_type,
inputs,
attributes=attrs,
**output_kwargs,
overload=overload,
version=version,
graph=graph,
name=name,
doc_string=doc_string,
metadata_props=metadata_props,
)
return node


Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

class GraphBuilder:
def __init__(self, graph: ir.Graph, is_function: bool) -> None:
Comment thread Fixed

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'is_function' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
self._graph = graph
self._op_builder = self.opset("", None)

# Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done.
# The current context is used as a prefix for naming values and nodes.
# This allows us to generate names like "layer1.attention.query"
self._context_stack: list[str] = [""]


def opset(self, domain: str = "", version: int | None = None) -> OpBuilder:
return OpBuilder(self, domain, version)

@property
def op(self) -> OpBuilder:
return self._op_builder

@property
def graph(self) -> ir.Graph:
return self._graph

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
if name is None:
name = tensor.name
prefix = self.context_name()
if prefix:
name = f"{prefix}.{name}"
# TODO: set tensor name as well
shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims)
value = ir.Value(
name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor
)
self._graph.register_initializer(value)
return value

def _input_to_ir_value(self, value: ir.Value | ir.TensorProtocol, like_type: ir.Value | None = None) -> ir.Value:
if isinstance(value, ir.Value):
return value
elif isinstance(value, (int, float, bool, str)):
# Scalar constant
import numpy as np

Check notice

Code scanning / lintrunner

PYLINT/C0415 Note

Import outside toplevel (numpy) (import-outside-toplevel)
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel

if like_type is not None and like_type.type is not None:
dtype = like_type.type.dtype
else:
# Infer type from Python type
if isinstance(value, bool):
dtype = ir.DataType.BOOL
elif isinstance(value, int):
dtype = ir.DataType.INT64
elif isinstance(value, float):
dtype = ir.DataType.FLOAT32
elif isinstance(value, str):
dtype = ir.DataType.STRING
else:
raise TypeError(f"Unsupported scalar type: {type(value)}")
tensor = ir.Tensor(

Check failure

Code scanning / lintrunner

PYLINT/E1123 Error

Unexpected keyword argument 'data' in constructor call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg

Check failure

Code scanning / lintrunner

PYLINT/E1120 Error

No value for argument 'value' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter
data=np.array(value, dtype=dtype.numpy()),
name="const_scalar",
)
return self.initializer(tensor)
else:
# assert isinstance(value, ir.TensorProtocol):
# TODO: We could using caching to avoid duplicate initializers. However, it seems unlikely
# to be useful in practice, as shared use of a stateful module is rare.
return self.initializer(value)


def _adapt_outputs(self, outputs: int | Sequence[str | ir.Value]) -> Sequence[ir.Value]:
prefix = self.context_name()
if isinstance(outputs, int):
count = self.graph.num_nodes()
name = f"{prefix}.val_{count}" if prefix else f"val_{count}"
if outputs == 1:
return [ir.Value(name=name)]
else:
return [ir.Value(name=f"{name}.{i}") for i in range(outputs)]
adapted_outputs = []
for output in outputs:
if isinstance(output, ir.Value):
adapted_outputs.append(output)
elif isinstance(output, str):
adapted_outputs.append(ir.Value(name=output))
else:
raise TypeError(f"Output type not supported.")

Check warning

Code scanning / lintrunner

RUFF/F541 Warning

Check warning

Code scanning / lintrunner

PYLINT/W1309 Warning

Using an f-string that does not have any interpolated variables (f-string-without-interpolation)
See f-string-without-interpolation. To disable, use # pylint: disable=f-string-without-interpolation
return adapted_outputs

def _get_schema(self, op_type: str, domain: str, version: int | None) -> onnx.defs.OpSchema | None:
if version is not None:
try:
return onnx.defs.get_schema(op_type, version, domain)
except onnx.defs.SchemaError:
pass
return None

def _partition_inputs_attributes(
self,
schema: onnx.defs.OpSchema | None,

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'schema' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
inputs: Sequence[ir.Value | ir.TensorProtocol],
kwargs: dict[str, Any],
) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]:
# Not implemented yet
return inputs, kwargs

def _cast_inputs(
self,
schema: onnx.defs.OpSchema | None,
inputs: Sequence[ir.Value | ir.TensorProtocol],

Check warning

Code scanning / lintrunner

RUFF/W291 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
) -> Sequence[ir.Value]:
"""Uses schema specification to support a limited form of auto-casting.

* Scalars are promoted to tensors.
* Further. they are cast to the required type when used in ops with other
tensor inputs that are required to be of same type.
Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same
type as A.

This is used by the converter in a static-mode, as well as by the eager-mode
execution in a dynamic-mode.
"""
if schema is None:
return [self._input_to_ir_value(i) for i in inputs]

expected_inputs = schema.inputs
# We make two passes. In the first pass, we identify known type-bindings for
# type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}.
# In the second pass, we use these bindings to cast scalar-values to
# tensors of appropriate types. The two passes are needed to handle cases
# like "Add(1, X)" where 1 must be cast to the same type as X.
type_bindings: dict[str, ir.Value] = {}
args_typevars: list[tuple[ir.Value | None, str | None]] = []
for i, x in enumerate(inputs):
if i < len(expected_inputs):
expected = expected_inputs[i]
elif expected_inputs[-1].option == onnx.defs.OpSchema.FormalParameterOption.Variadic:
expected = expected_inputs[-1]
if not expected.is_homogeneous:
args_typevars.append((x, None))
continue
else:
raise ValueError(
f"Number of actual parameters {len(inputs)} "
f"exceeds number of formal parameters {len(expected_inputs)}."
)
typevar = expected.type_str
if ("(" not in typevar) and (typevar not in type_bindings):
# typevar is an identifier, like "T"
if isinstance(x, ir.Value):
type_bindings[typevar] = x
args_typevars.append((x, typevar))
def adapt (x, typevar: str | None) -> ir.Value | None:
if x is None: return None

Check warning

Code scanning / lintrunner

RUFF/E701 Warning

if typevar is None:
return self._input_to_ir_value(x)
type_like = type_bindings.get(typevar) if typevar is not None else None
return self._input_to_ir_value(x, type_like)
return [adapt(x, typevar) for x, typevar in args_typevars]

def _cast_attributes(
self,
schema: onnx.defs.OpSchema | None,

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'schema' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument

Check warning

Code scanning / lintrunner

PYLINT/W0613 Warning

Unused argument 'schema' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument
attributes: dict[str, Any],
) -> dict[str, Any]:
return attributes

def add_node(self, node: ir.Node) -> None:
self.graph.append(node)
onnxscript.optimizer.basic_constant_propagation([node])
# TODO: inference.infer_outputs(node, 23)

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

def call_op(self, op_type: str, inputs: Sequence[ir.Value | ir.TensorProtocol], kwargs: dict[str, Any]):
domain = kwargs.pop("_domain", "")
version = kwargs.pop("_version", None)
outputs = kwargs.pop("_outputs", 1)

count = self.graph.num_nodes
node_name = f"node_{count}"

output_values = self._adapt_outputs(outputs)

schema = self._get_schema(op_type, domain, version)
inputs, attributes = self._partition_inputs_attributes(schema, inputs, kwargs)
inputs = self._cast_inputs(schema, inputs)
attributes = self._cast_attributes(schema, attributes)

node = _make_node(
op_type,
inputs=inputs,
attributes=attributes,
domain=domain,
version=version,
outputs=output_values,
graph=self.graph,
name=node_name
)
self.add_node(node)

return node.outputs if len(node.outputs) > 1 else node.outputs[0]

def call(self, function, *args, **kwargs):
if isinstance(function, ir.Function):
function_ir = function
elif isinstance(function, onnxscript.values.OnnxFunction):
function_proto = function.to_function_proto()
function_ir = ir.serde.deserialize_function(function_proto)
else:
raise TypeError("Function must be an ir.Function or onnxscript.ONNXFunction")
nodes, outputs = inliner.instantiate(function_ir, args, kwargs)

Check failure

Code scanning / lintrunner

RUFF/F821 Error

Check failure

Code scanning / lintrunner

PYLINT/E0602 Error

Undefined variable 'inliner' (undefined-variable)
See undefined-variable. To disable, use # pylint: disable=undefined-variable
for node in nodes:
self.add_node(node)
return outputs if len(outputs) > 1 else outputs[0]

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

def push_module(self, module: str) -> None:
current = self.context_name()
if module.name:
new_context = f"{current}.{module.name}" if current else module.name
else:
new_context = current
self._context_stack.append(new_context)

def pop_module(self) -> None:
self._context_stack.pop()

def context_name(self) -> str:
return self._context_stack[-1] if self._context_stack else ""

class OpBuilder:
def __init__(self, builder: GraphBuilder, domain: str = "", version: int | None = None) -> None:
self._builder = builder
self._domain = domain
self._version = version

@property
def builder(self) -> GraphBuilder:
return self._builder

def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]):
if "_domain" not in kwargs:
kwargs["_domain"] = self._domain
if self._version is not None and "_version" not in kwargs:
kwargs["_version"] = self._version
return self._builder.call_op(op_type, inputs, kwargs)

def __getattr__(self, op_type: str) -> Callable:
return lambda *args, **kwargs: self._call_op(op_type, args, kwargs)

def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value:
return self._builder.initializer(tensor, name)

def call(self, function, *args, **kwargs):
return self._builder.call(function, *args, **kwargs)

Check warning

Code scanning / lintrunner

RUFF/W391 Warning


74 changes: 74 additions & 0 deletions onnxscript/_internal/builder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) Microsoft Corporation.

Check warning

Code scanning / lintrunner

RUFF-FORMAT/format Warning

Run lintrunner -a to apply this patch.

Check warning

Code scanning / lintrunner

RUFF/format Warning

Run lintrunner -a to apply this patch.
# Licensed under the MIT License.

from __future__ import annotations

Check warning

Code scanning / lintrunner

RUFF/I001 Warning

Import block is un-sorted or un-formatted.
See https://docs.astral.sh/ruff/rules/unsorted-imports

import unittest
from typing import Sequence

import onnx_ir as ir
import onnxscript._internal.builder as builder


def _build(trace_function, input_types: Sequence[ir.TypeAndShape], output_types: Sequence[ir.TypeAndShape]) -> ir.Model:

graph = ir.Graph(
name="test_model",
inputs=[],
outputs=[],
nodes=[],
opset_imports={"": 23},
)

onnx_model = ir.Model(graph=graph, ir_version=10)

for i, input_type in enumerate(input_types):
input_name = f"input_{i}"
graph.inputs.append(ir.Value(name=input_name, type=input_type))

graph_builder = builder.GraphBuilder(graph, is_function=False)
outputs = trace_function(graph_builder.op, *graph.inputs)
if not isinstance(outputs, Sequence):
outputs = [outputs]
if len(outputs) != len(output_types):
raise ValueError(f"Expected {len(output_types)} outputs, but got {len(outputs)}.")
for output, output_type in zip(outputs, output_types):
output.type = output_type.type # TODO: need merge_type method in ir.Value
output.merge_shapes(output_type.shape)

graph.outputs.extend(outputs)

return onnx_model


class GraphBuilderTest(unittest.TestCase):
def test_builder_basic(self):
def _add_mul_add(op: builder.OpBuilder, x: ir.Value, y: ir.Value) -> ir.Value:
t1 = op.Add(x, y)
t2 = op.Mul(x, y)
z = op.Add(t1, t2)
return z

float_2d = ir.TypeAndShape(ir.TensorType(ir.DataType.FLOAT), ir.Shape([3, 4]))
model = _build(
_add_mul_add,
input_types=[float_2d, float_2d],
output_types=[float_2d],
)
graph = model.graph
# Expect exactly 3 nodes: Add, Mul, Add
op_types = [node.op_type for node in graph]
self.assertEqual(op_types, ["Add", "Mul", "Add"])

# Verify inputs and outputs
self.assertEqual(len(graph.inputs), 2)
self.assertEqual(len(graph.outputs), 1)

# Verify the connectivity: final Add takes outputs of the first Add and Mul
nodes = list(graph)
add1, mul, add2 = nodes
self.assertEqual(list(add2.inputs), [add1.outputs[0], mul.outputs[0]])


if __name__ == "__main__":
unittest.main()
Loading