-
Notifications
You must be signed in to change notification settings - Fork 102
Initial builder #2808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Initial builder #2808
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
|
||
| # Licensed under the MIT License. | ||
|
|
||
| from __future__ import annotations | ||
|
||
|
|
||
| import numpy as np | ||
| import onnx | ||
| import onnx_ir as ir | ||
|
|
||
|
|
||
| def _get_numpy_value( | ||
| val: ir.Value | None, dtype: ir.DataType | None = None, size_limit: int | None = None | ||
| ) -> np.ndarray | None: | ||
| """Returns the numpy value of a constant value, if available. | ||
|
|
||
| It returns None if the value is not a constant value, or if the value is not of | ||
| the specified element dtype, or if the size of the value exceeds the specified | ||
| size_limit. | ||
| """ | ||
| if val is None: | ||
| return None | ||
| const_value = val.const_value | ||
| if const_value is not None: | ||
| if dtype is not None and const_value.dtype != dtype: | ||
| return None | ||
| if size_limit is not None and const_value.size > size_limit: | ||
| return None | ||
| try: | ||
| # Turn the constant value into a numpy array representation with the | ||
| # specifics of this conversion handled by the tensor type | ||
| array = const_value.numpy() | ||
| # Can/should not reinterpret strings via .view, resulting in | ||
| # "TypeError: Cannot change data-type for array of references." | ||
| # There is also no reason to reinterpret strings, this is only | ||
| # relevant for some arithmetic types | ||
| if const_value.dtype != ir.DataType.STRING: | ||
| # Reinterpret the array with `.view()` because some | ||
| # implementations of ir.TensorProtocol (e.g. PyTorch<=2.7) do | ||
| # not use ml_dtypes for bfloat16 etc. | ||
| array = array.view(const_value.dtype.numpy()) | ||
| except FileNotFoundError: | ||
| # External data is not available. | ||
| # logger.warning( | ||
| # "External data for value '%s' is not available. " | ||
| # "This may lead to incorrect constant folding.", | ||
| # val.name, | ||
| # ) | ||
| return None | ||
| assert isinstance(array, np.ndarray) | ||
| return array | ||
| return None | ||
|
|
||
|
|
||
| def _do_onnx_inference(node: ir.Node) -> None: | ||
| output_types = {} | ||
|
|
||
| def get_constant_value(x: ir.Value) -> onnx.TensorProto | None: | ||
| value = _get_numpy_value(x, size_limit=20) | ||
| if value is not None: | ||
| assert x.const_value is not None | ||
| return ir.serde.serialize_tensor(x.const_value) | ||
| return None | ||
|
|
||
| def get_type(index: int, value: ir.Value) -> onnx.TypeProto: | ||
| if value.type is None: | ||
| raise ValueError( | ||
| f"Type of input {index} value {value.name} of node {node.name} not known" | ||
| ) | ||
| type_proto = ir.serde.serialize_type(value.type) | ||
| if value.shape is not None: | ||
| ir.serde.serialize_shape_into(type_proto, value.shape) | ||
| return type_proto | ||
|
|
||
| input_types = {x.name: get_type(i, x) for i, x in enumerate(node.inputs) if x is not None} | ||
| input_data = {x.name: get_constant_value(x) for x in node.inputs if x is not None} | ||
| input_data = {k: v for k, v in input_data.items() if v is not None} | ||
|
|
||
| # TODO: pass in constant values, ir_version | ||
| schema = onnx.defs.get_schema(node.op_type, node.version, node.domain) | ||
| output_types = onnx.shape_inference.infer_node_outputs( | ||
| schema, | ||
| ir.serde.serialize_node(node), | ||
| input_types, # type: ignore[arg-type] | ||
| input_data, # type: ignore[arg-type] | ||
| ) | ||
| for output in node.outputs: | ||
| if output.name in output_types: | ||
| inferred_type = output_types[output.name] | ||
| # TODO: merge types, check for conflicts | ||
| inferred_shape = ir.serde.deserialize_type_proto_for_shape(inferred_type) | ||
| # NOTE: forward shape inference | ||
| output.merge_shapes(inferred_shape) | ||
| output.type = ir.serde.deserialize_type_proto_for_type(inferred_type) | ||
|
|
||
|
|
||
| def infer_outputs(node: ir.Node) -> None: | ||
| try: | ||
| _do_onnx_inference(node) | ||
| except Exception as e: | ||
Check warningCode scanning / lintrunner PYLINT/W0718 Warning
Catching too general exception Exception (broad-exception-caught)
See broad-exception-caught. To disable, use # pylint: disable=broad-exception-caught |
||
| # TODO: compose with any existing error | ||
| node.metadata_props["inference_error"] = str(e) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,291 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
| # Licensed under the MIT License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Callable, Sequence | ||
|
|
||
| import onnx | ||
| import onnx_ir as ir | ||
|
|
||
| import onnxscript._internal._inference as inference | ||
| import onnxscript.optimizer | ||
|
|
||
|
|
||
| class GraphBuilder: | ||
| def __init__(self, graph: ir.Graph, is_function: bool) -> None: | ||
|
||
| self._graph = graph | ||
|
|
||
| # Get the opset version for "" (default domain) from the graph | ||
| if "" not in graph.opset_imports: | ||
| raise ValueError('Input graph does not have an import for domain ""') | ||
| opset_version = graph.opset_imports[""] | ||
|
|
||
| self._op_builder = self.opset("", opset_version) | ||
|
|
||
| # Context stack to manage hierarchical naming. Each module/layer can push a new context, and pop it when done. | ||
| # The current context is used as a prefix for naming values and nodes. | ||
| # This allows us to generate names like "layer1.attention.query" | ||
| self._context_stack: list[str] = [""] | ||
|
|
||
| def opset(self, domain: str = "", version: int = 1) -> OpBuilder: | ||
| return OpBuilder(self, domain, version) | ||
|
|
||
| @property | ||
| def op(self) -> OpBuilder: | ||
| return self._op_builder | ||
|
|
||
| @property | ||
| def graph(self) -> ir.Graph: | ||
| return self._graph | ||
|
|
||
| def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: | ||
| if name is None: | ||
| name = tensor.name | ||
| prefix = self.context_name() | ||
| if prefix: | ||
| name = f"{prefix}.{name}" | ||
| # TODO: set tensor name as well | ||
| shape = ir.Shape((d if isinstance(d, int) else d.value) for d in tensor.shape.dims) | ||
| value = ir.Value( | ||
| name=name, shape=shape, type=ir.TensorType(tensor.dtype), const_value=tensor | ||
| ) | ||
| self._graph.register_initializer(value) | ||
| return value | ||
|
|
||
| def _input_to_ir_value( | ||
| self, value: ir.Value | ir.TensorProtocol, like_type: ir.Value | None = None | ||
| ) -> ir.Value: | ||
| if isinstance(value, ir.Value): | ||
| return value | ||
| elif isinstance(value, (int, float, bool, str)): | ||
| # Scalar constant | ||
| import numpy as np | ||
Check noticeCode scanning / lintrunner PYLINT/C0415 Note
Import outside toplevel (numpy) (import-outside-toplevel)
See import-outside-toplevel. To disable, use # pylint: disable=import-outside-toplevel |
||
|
|
||
| if like_type is not None and like_type.type is not None: | ||
| dtype = like_type.type.dtype | ||
| else: | ||
| # Infer type from Python type | ||
| if isinstance(value, bool): | ||
| dtype = ir.DataType.BOOL | ||
| elif isinstance(value, int): | ||
| dtype = ir.DataType.INT64 | ||
| elif isinstance(value, float): | ||
| dtype = ir.DataType.FLOAT32 | ||
| elif isinstance(value, str): | ||
| dtype = ir.DataType.STRING | ||
| else: | ||
| raise TypeError(f"Unsupported scalar type: {type(value)}") | ||
| tensor = ir.Tensor( | ||
Check failureCode scanning / lintrunner PYLINT/E1123 Error
Unexpected keyword argument 'data' in constructor call (unexpected-keyword-arg)
See unexpected-keyword-arg. To disable, use # pylint: disable=unexpected-keyword-arg Check failureCode scanning / lintrunner PYLINT/E1120 Error
No value for argument 'value' in constructor call (no-value-for-parameter)
See no-value-for-parameter. To disable, use # pylint: disable=no-value-for-parameter |
||
| data=np.array(value, dtype=dtype.numpy()), | ||
| name="const_scalar", | ||
| ) | ||
| return self.initializer(tensor) | ||
| else: | ||
| # assert isinstance(value, ir.TensorProtocol): | ||
| # TODO: We could using caching to avoid duplicate initializers. However, it seems unlikely | ||
| # to be useful in practice, as shared use of a stateful module is rare. | ||
| return self.initializer(value) | ||
|
|
||
| def _adapt_outputs( | ||
| self, outputs: int | Sequence[str | ir.Value], op_type: str = "" | ||
| ) -> Sequence[ir.Value]: | ||
| prefix = self.context_name() | ||
| if isinstance(outputs, int): | ||
| if outputs == 1: | ||
| name = f"{op_type}_output" if op_type else "output" | ||
| if prefix: | ||
| name = f"{prefix}.{name}" | ||
| return [ir.Value(name=name)] | ||
| else: | ||
| names = [ | ||
| f"{op_type}_output{i}" if op_type else f"output{i}" for i in range(outputs) | ||
| ] | ||
| if prefix: | ||
| names = [f"{prefix}.{n}" for n in names] | ||
| return [ir.Value(name=n) for n in names] | ||
| adapted_outputs = [] | ||
| for output in outputs: | ||
| if isinstance(output, ir.Value): | ||
| if prefix and output.name: | ||
| output.name = f"{prefix}.{output.name}" | ||
| adapted_outputs.append(output) | ||
| elif isinstance(output, str): | ||
| name = f"{prefix}.{output}" if prefix else output | ||
| adapted_outputs.append(ir.Value(name=name)) | ||
| else: | ||
| raise TypeError("Output type not supported.") | ||
| return adapted_outputs | ||
|
|
||
| def _get_schema( | ||
| self, op_type: str, domain: str, version: int | None | ||
| ) -> onnx.defs.OpSchema | None: | ||
| if version is not None: | ||
| try: | ||
| return onnx.defs.get_schema(op_type, version, domain) | ||
| except onnx.defs.SchemaError: | ||
| pass | ||
| return None | ||
|
|
||
| def _partition_inputs_attributes( | ||
| self, | ||
| schema: onnx.defs.OpSchema | None, | ||
Check warningCode scanning / lintrunner PYLINT/W0613 Warning
Unused argument 'schema' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument |
||
| inputs: Sequence[ir.Value | ir.TensorProtocol], | ||
| kwargs: dict[str, Any], | ||
| ) -> tuple[Sequence[ir.Value | ir.TensorProtocol], dict[str, Any]]: | ||
| # Not implemented yet | ||
| return inputs, kwargs | ||
|
|
||
| def _cast_inputs( | ||
| self, | ||
| schema: onnx.defs.OpSchema | None, | ||
| inputs: Sequence[ir.Value | ir.TensorProtocol], | ||
| ) -> Sequence[ir.Value]: | ||
| """Uses schema specification to support a limited form of auto-casting. | ||
|
|
||
| * Scalars are promoted to tensors. | ||
| * Further. they are cast to the required type when used in ops with other | ||
| tensor inputs that are required to be of same type. | ||
| Thus, in "A+1" or "Add(A, 1)", the value 1 will be converted to the same | ||
| type as A. | ||
|
|
||
| This is used by the converter in a static-mode, as well as by the eager-mode | ||
| execution in a dynamic-mode. | ||
| """ | ||
| if schema is None: | ||
| return [self._input_to_ir_value(i) for i in inputs] | ||
|
|
||
| expected_inputs = schema.inputs | ||
| # We make two passes. In the first pass, we identify known type-bindings for | ||
| # type-variables: eg., {'T1' : np.float32, 'T2' : np.int32}. | ||
| # In the second pass, we use these bindings to cast scalar-values to | ||
| # tensors of appropriate types. The two passes are needed to handle cases | ||
| # like "Add(1, X)" where 1 must be cast to the same type as X. | ||
| type_bindings: dict[str, ir.Value] = {} | ||
| args_typevars: list[tuple[ir.Value | None, str | None]] = [] | ||
| for i, x in enumerate(inputs): | ||
| if i < len(expected_inputs): | ||
| expected = expected_inputs[i] | ||
| elif ( | ||
| expected_inputs[-1].option == onnx.defs.OpSchema.FormalParameterOption.Variadic | ||
| ): | ||
| expected = expected_inputs[-1] | ||
| if not expected.is_homogeneous: | ||
| args_typevars.append((x, None)) | ||
| continue | ||
| else: | ||
| raise ValueError( | ||
| f"Number of actual parameters {len(inputs)} " | ||
| f"exceeds number of formal parameters {len(expected_inputs)}." | ||
| ) | ||
| typevar = expected.type_str | ||
| if ("(" not in typevar) and (typevar not in type_bindings): | ||
| # typevar is an identifier, like "T" | ||
| if isinstance(x, ir.Value): | ||
| type_bindings[typevar] = x | ||
| args_typevars.append((x, typevar)) | ||
|
|
||
| def adapt(x, typevar: str | None) -> ir.Value | None: | ||
| if x is None: | ||
| return None | ||
| if typevar is None: | ||
| return self._input_to_ir_value(x) | ||
| type_like = type_bindings.get(typevar) if typevar is not None else None | ||
| return self._input_to_ir_value(x, type_like) | ||
|
|
||
| return [adapt(x, typevar) for x, typevar in args_typevars] | ||
|
|
||
| def _cast_attributes( | ||
| self, | ||
| schema: onnx.defs.OpSchema | None, | ||
Check warningCode scanning / lintrunner PYLINT/W0613 Warning
Unused argument 'schema' (unused-argument)
See unused-argument. To disable, use # pylint: disable=unused-argument |
||
| attributes: dict[str, Any], | ||
| ) -> Sequence[ir.Attr]: | ||
| if attributes is None: | ||
| attrs: Sequence[ir.Attr] = () | ||
| else: | ||
| attrs = ir._convenience.convert_attributes(attributes) | ||
Check warningCode scanning / lintrunner PYLINT/W0212 Warning
Access to a protected member _convenience of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access |
||
| return attrs | ||
|
|
||
| def add_node(self, node: ir.Node) -> None: | ||
| self.graph.append(node) | ||
| onnxscript.optimizer.basic_constant_propagation([node]) | ||
| inference.infer_outputs(node) | ||
|
|
||
| def call_op( | ||
| self, | ||
| op_type: str, | ||
| inputs: Sequence[ir.Value | ir.TensorProtocol], | ||
| kwargs: dict[str, Any], | ||
| ): | ||
| domain = kwargs.pop("_domain", "") | ||
| version = kwargs.pop("_version", None) | ||
| outputs = kwargs.pop("_outputs", 1) | ||
|
|
||
| count = self.graph.num_nodes() | ||
| prefix = self.context_name() | ||
| node_name = f"{op_type}_node_{count}" | ||
| if prefix: | ||
| node_name = f"{prefix}.{node_name}" | ||
|
|
||
| output_values = self._adapt_outputs(outputs, op_type) | ||
|
|
||
| schema = self._get_schema(op_type, domain, version) | ||
| inputs, attributes = self._partition_inputs_attributes(schema, inputs, kwargs) | ||
| inputs = self._cast_inputs(schema, inputs) | ||
| attr_sequence = self._cast_attributes(schema, attributes) | ||
|
|
||
| node = ir.Node( | ||
| domain, | ||
| op_type, | ||
| inputs, | ||
| attr_sequence, | ||
| outputs=output_values, | ||
| version=version, | ||
| name=node_name, | ||
| ) | ||
| self.add_node(node) | ||
|
|
||
| return node.outputs if len(node.outputs) > 1 else node.outputs[0] | ||
|
|
||
| def push_module(self, module: str) -> None: | ||
| current = self.context_name() | ||
| if module: | ||
| new_context = f"{current}.{module}" if current else module | ||
| else: | ||
| new_context = current | ||
| self._context_stack.append(new_context) | ||
|
|
||
| def pop_module(self) -> None: | ||
| self._context_stack.pop() | ||
|
|
||
| def context_name(self) -> str: | ||
| return self._context_stack[-1] if self._context_stack else "" | ||
|
|
||
|
|
||
| class OpBuilder: | ||
| def __init__( | ||
| self, builder: GraphBuilder, domain: str = "", version: int | None = None | ||
| ) -> None: | ||
| self._builder = builder | ||
| self._domain = domain | ||
| self._version = version | ||
|
|
||
| @property | ||
| def builder(self) -> GraphBuilder: | ||
| return self._builder | ||
|
|
||
| def _call_op(self, op_type: str, inputs: Sequence[Any], kwargs: dict[str, Any]): | ||
| if "_domain" not in kwargs: | ||
| kwargs["_domain"] = self._domain | ||
| if self._version is not None and "_version" not in kwargs: | ||
| kwargs["_version"] = self._version | ||
| return self._builder.call_op(op_type, inputs, kwargs) | ||
|
|
||
| def __getattr__(self, op_type: str) -> Callable: | ||
| return lambda *args, **kwargs: self._call_op(op_type, args, kwargs) | ||
|
|
||
| def initializer(self, tensor: ir.TensorProtocol, name: str | None = None) -> ir.Value: | ||
| return self._builder.initializer(tensor, name) | ||
|
|
||
| def call(self, function, *args, **kwargs): | ||
| return self._builder.call(function, *args, **kwargs) | ||
Uh oh!
There was an error while loading. Please reload this page.