diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 0384e0423..0716111f7 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -37,6 +37,11 @@ _OP_NAMESPACE_PREFIXES = ("special", "linalg", "fft") +# Keep Python returning overloads explicitly opt-in while the API is being +# rolled out. Some C++ returning-capable ops have extra semantics that need +# separate Python review before exposure. +_PY_RETURNING_OPS = frozenset({"add"}) + def _op_namespace_parts(op_name): parts = op_name.split("_") @@ -454,9 +459,16 @@ def _find_tensor_params(op_name): return params +def _has_make_return_value(op_name): + source = _strip_cpp_comments(_find_base_header(op_name).read_text()) + + return re.search(r"\bMakeReturnValue\s*\(", source) is not None + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) optional_non_tensor_params = _find_optional_non_tensor_params(operator.name) + tensor_params = _find_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) vector_int64_params = _find_vector_int64_params(operator.name) @@ -483,63 +495,81 @@ def _is_vector_tensor(arg): def _is_vector_int64(arg): return arg.spelling in vector_int64_params - def _generate_params(node): - parts = [] + def _is_tensor(arg): + if arg.spelling in optional_non_tensor_params: + return False - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue + if arg.spelling in tensor_params: + return True + + return "Tensor" in arg.type.spelling + def _is_plain_tensor(arg): + return ( + not _is_optional_tensor(arg) + and not _is_vector_tensor(arg) + and _is_tensor(arg) + ) + + def _args_without_stream(node): + return [arg for arg in node.get_arguments() if arg.spelling != "stream"] + + def _generate_params_for_args(args): + parts = [] + + for arg in args: if _is_optional_tensor(arg): parts.append(f"std::optional {arg.spelling}") elif _is_vector_tensor(arg): parts.append(f"std::vector {arg.spelling}") elif _is_vector_int64(arg): parts.append(f"const std::vector {arg.spelling}") + elif _is_tensor(arg): + parts.append(f"py::object {arg.spelling}") else: - param = arg.type.spelling.replace("const Tensor", "py::object").replace( - "Tensor", "py::object" - ) - parts.append(f"{param} {arg.spelling}") + parts.append(f"{arg.type.spelling} {arg.spelling}") return ", ".join(parts) - def _generate_arguments(node): - args = [] + def _generate_params(node): + return _generate_params_for_args(_args_without_stream(node)) - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue + def _generate_arguments_for_args(args): + generated_args = [] + for arg in args: if _is_optional_tensor(arg): - args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + generated_args.append( + f"OptionalTensorFromPybind11Handle({arg.spelling})" + ) elif _is_vector_tensor(arg): - args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") - elif "Tensor" in arg.type.spelling: - args.append(f"TensorFromPybind11Handle({arg.spelling})") + generated_args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") + elif _is_tensor(arg): + generated_args.append(f"TensorFromPybind11Handle({arg.spelling})") else: - args.append(arg.spelling) + generated_args.append(arg.spelling) - return ", ".join(args) + return ", ".join(generated_args) + + def _generate_arguments(node): + return _generate_arguments_for_args(_args_without_stream(node)) op_name = operator.name op_type = _op_cpp_type(op_name) symbol_name = _op_symbol_name(op_name) - def _first_tensor_arg(node): - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue + def _first_tensor_arg_for_args(args): + for arg in args: if _is_optional_tensor(arg): continue if _is_vector_tensor(arg): return f"{arg.spelling}.at(0)" - if "Tensor" in arg.type.spelling: + if _is_tensor(arg): return arg.spelling return None - def _default_impl_index_expr(node): - first_tensor = _first_tensor_arg(node) + def _default_impl_index_expr_for_args(args): + first_tensor = _first_tensor_arg_for_args(args) if first_tensor is None: return "0" return ( @@ -547,6 +577,9 @@ def _default_impl_index_expr(node): f"DeviceFromPybind11Handle({first_tensor}).type())" ) + def _default_impl_index_expr(node): + return _default_impl_index_expr_for_args(_args_without_stream(node)) + def _generate_init(constructor): constructor_params = _generate_params(constructor) default_impl_index = _default_impl_index_expr(constructor) @@ -557,13 +590,10 @@ def _generate_init(constructor): return std::unique_ptr{{static_cast(generated_dispatch::Make{symbol_name}(config, {_generate_arguments(constructor)}).release())}}; }}))""" - def _generate_py_args(node): + def _generate_py_args_for_args(args): parts = [] - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue - + for arg in args: if _is_optional(arg): parts.append(f'py::arg("{arg.spelling}") = py::none()') else: @@ -571,6 +601,9 @@ def _generate_py_args(node): return ", ".join(parts) + def _generate_py_args(node): + return _generate_py_args_for_args(_args_without_stream(node)) + def _generate_call(op_name, call, method=True): call_params = _generate_params(call) call_args = _generate_arguments(call) @@ -609,6 +642,74 @@ def _generate_call(op_name, call, method=True): return generated_dispatch::Invoke{symbol_name}(op, {call_args}); }})""" + def _is_returning_candidate(call): + args = _args_without_stream(call) + + if len(args) < 2 or not _is_plain_tensor(args[-1]): + return False + + if not any(_is_plain_tensor(arg) for arg in args[:-1]): + return False + + return not any( + _is_optional_tensor(arg) or _is_vector_tensor(arg) for arg in args[:-1] + ) + + def _returning_local_name(arg): + return f"infini_{arg.spelling}" + + def _generate_returning_call(call): + args = _args_without_stream(call) + input_args = args[:-1] + call_params = _generate_params_for_args(input_args) + params = ( + f"{call_params}, std::uintptr_t stream, " + "std::optional implementation_index" + if call_params + else "std::uintptr_t stream, " + "std::optional implementation_index" + ) + py_args = _generate_py_args_for_args(input_args) + py_args_str = f"{py_args}, " if py_args else "" + default_impl_index = _default_impl_index_expr_for_args(input_args) + tensor_locals = "\n".join( + f" Pybind11Tensor {_returning_local_name(arg)}{{{arg.spelling}}};" + for arg in input_args + if _is_plain_tensor(arg) + ) + make_args = ", ".join( + _returning_local_name(arg) if _is_plain_tensor(arg) else arg.spelling + for arg in input_args + ) + dispatch_args = ", ".join( + f"Tensor{{{_returning_local_name(arg)}}}" + if _is_plain_tensor(arg) + else arg.spelling + for arg in input_args + ) + dispatch_args = ( + f"{dispatch_args}, Tensor{{out}}" if dispatch_args else "Tensor{out}" + ) + + if tensor_locals: + tensor_locals += "\n" + + return ( + f' m.def("{op_name}", []({params}) {{\n' + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" Config config;\n" + f" config.set_implementation_index(\n" + f" implementation_index.value_or({default_impl_index}));\n" + f"{tensor_locals}" + f" auto out = Self::MakeReturnValue({make_args});\n" + f" generated_dispatch::Call{symbol_name}(handle, config, {dispatch_args});\n" + f" return out.object();\n" + f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = py::none());' + ) + def _overload_order_key(node): """Sort key that places more-specific overloads first. @@ -630,11 +731,7 @@ def _overload_order_key(node): total += 1 - if ( - _is_optional_tensor(arg) - or _is_vector_tensor(arg) - or "Tensor" in arg.type.spelling - ): + if _is_optional_tensor(arg) or _is_vector_tensor(arg) or _is_tensor(arg): object_like += 1 return (object_like, -total) @@ -644,9 +741,19 @@ def _overload_order_key(node): inits = "\n".join(_generate_init(constructor) for constructor in constructors) calls = "\n".join(_generate_call(operator.name, call) for call in operator_calls) - callers = "\n".join( + returning_calls = [] + if operator.name in _PY_RETURNING_OPS and _has_make_return_value(operator.name): + returning_calls = sorted( + (call for call in operator_calls if _is_returning_candidate(call)), + key=lambda call: len(_args_without_stream(call)), + )[:1] + returning_callers = "\n".join( + _generate_returning_call(call) for call in returning_calls + ) + inplace_callers = "\n".join( _generate_call(operator.name, call, method=False) for call in operator_calls ) + callers = "\n".join(filter(None, (returning_callers, inplace_callers))) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 6610e4ae1..4f619c279 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -5,6 +5,11 @@ #include #include +#include +#include +#include +#include +#include #include "tensor.h" #include "torch/device_.h" @@ -29,6 +34,102 @@ inline DataType DataTypeFromString(const std::string& name) { return kStringToDataType.at(name); } +inline std::string DataTypeToTorchName(DataType dtype) { + switch (dtype) { + case DataType::kInt8: + return "int8"; + case DataType::kInt16: + return "int16"; + case DataType::kInt32: + return "int32"; + case DataType::kInt64: + return "int64"; + case DataType::kUInt8: + return "uint8"; + case DataType::kUInt16: + return "uint16"; + case DataType::kUInt32: + return "uint32"; + case DataType::kUInt64: + return "uint64"; + case DataType::kFloat16: + return "float16"; + case DataType::kBFloat16: + return "bfloat16"; + case DataType::kFloat32: + return "float32"; + case DataType::kFloat64: + return "float64"; + default: + throw py::value_error("unsupported `dtype` for PyTorch tensor creation"); + } +} + +inline py::object TorchDType(DataType dtype) { + auto torch = py::module_::import("torch"); + auto name = DataTypeToTorchName(dtype); + + if (!py::hasattr(torch, name.c_str())) { + throw py::value_error("current PyTorch build does not expose `dtype` `" + + name + "`"); + } + + return torch.attr(name.c_str()); +} + +inline std::string TorchDeviceTypeName(Device::Type type) { + switch (type) { + case Device::Type::kCpu: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kNvidia: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kCambricon: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kAscend: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kMetax: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kMoore: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kIluvatar: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kKunlun: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kHygon: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kQy: + return std::string{detail::TorchDeviceName::kValue}; + default: + throw py::value_error("unsupported `device` for PyTorch tensor creation"); + } +} + +inline py::object TorchDevice(Device device) { + auto torch = py::module_::import("torch"); + auto name = TorchDeviceTypeName(device.type()); + + if (device.type() != Device::Type::kCpu) { + name += ":" + std::to_string(device.index()); + } + + return torch.attr("device")(name); +} + +inline Tensor::Strides ContiguousStrides(const Tensor::Shape& shape) { + Tensor::Strides strides(shape.size(), 1); + + for (std::size_t i = shape.size(); i > 1; --i) { + strides[i - 2] = strides[i - 1] * shape[i - 1]; + } + + return strides; +} + template inline Device::Type DeviceTypeFromString(const std::string& name) { static const auto kTorchNameToTypes{ @@ -154,6 +255,53 @@ inline std::vector VectorTensorFromPybind11Handle( return result; } +class Pybind11Tensor { + public: + explicit Pybind11Tensor(py::handle obj) + : object_{py::reinterpret_borrow(obj)}, + shape_{object_.attr("shape").cast()}, + strides_{object_.attr("stride")().cast()}, + dtype_{TensorFromPybind11Handle(object_).dtype()}, + device_{DeviceFromPybind11Handle(object_)} {} + + static Pybind11Tensor Empty(const Tensor::Shape& shape, DataType dtype, + Device device) { + auto torch = py::module_::import("torch"); + auto strides = ContiguousStrides(shape); + auto object = torch.attr("empty_strided")( + shape, strides, py::arg("dtype") = TorchDType(dtype), + py::arg("device") = TorchDevice(device)); + + return Pybind11Tensor{object}; + } + + void* data() const { + return reinterpret_cast( + object_.attr("data_ptr")().cast()); + } + + const Tensor::Shape& shape() const { return shape_; } + + const Tensor::Strides& strides() const { return strides_; } + + DataType dtype() const { return dtype_; } + + Device device() const { return device_; } + + const py::object& object() const { return object_; } + + private: + py::object object_; + + Tensor::Shape shape_; + + Tensor::Strides strides_; + + DataType dtype_; + + Device device_; +}; + } // namespace infini::ops #endif diff --git a/tests/test_add.py b/tests/test_add.py index e2266c30d..cba60f899 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -75,9 +75,13 @@ def test_add( other = randn_strided(shape, other_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) + call = _add + + if input_strides is None and other_strides is None and out_strides is None: + call = _add_returning return Payload( - lambda *args: _add(*args, implementation_index=implementation_index), + lambda *args: call(*args, implementation_index=implementation_index), _torch_add, (input, other, out), {}, @@ -98,6 +102,15 @@ def _add(input, other, out, implementation_index=0): return out +def _add_returning(input, other, out, implementation_index=0): + return infini.ops.add( + input, + other, + stream=get_stream(input.device), + implementation_index=implementation_index, + ) + + def _torch_add(input, other, out): if input.dtype in _UINT_DTYPES: input = input.to(torch.int64) diff --git a/tests/test_generate_wrappers.py b/tests/test_generate_wrappers.py index 931677553..a8cf4997f 100644 --- a/tests/test_generate_wrappers.py +++ b/tests/test_generate_wrappers.py @@ -176,6 +176,127 @@ class Mul { assert 'py::arg("implementation_index") = py::none()' in text +def test_pybind_generates_returning_call_when_make_return_value_exists( + monkeypatch, tmp_path +): + module = _load_generator_module() + base_header = tmp_path / "add.h" + base_header.write_text( + """ +class Add { + public: + virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; + + template + static auto MakeReturnValue(const TensorLike& input, const TensorLike& other) { + return TensorLike::Empty(input.shape(), input.dtype(), input.device()); + } +}; +""" + ) + monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header) + + operator = module._Operator( + "add", + constructors=[], + calls=[ + module._ParsedFunction( + [ + module._ParsedArgument("const Tensor", "input"), + module._ParsedArgument("const Tensor", "other"), + module._ParsedArgument("Tensor", "out"), + ] + ) + ], + ) + + text = module._generate_pybind11(operator) + + assert ( + 'm.def("add", [](py::object input, py::object other, ' + "std::uintptr_t stream, std::optional implementation_index)" + ) in text + assert "Pybind11Tensor infini_input{input};" in text + assert "Pybind11Tensor infini_other{other};" in text + assert "auto out = Self::MakeReturnValue(infini_input, infini_other);" in text + assert ( + "generated_dispatch::CallAdd(handle, config, Tensor{infini_input}, " + "Tensor{infini_other}, Tensor{out});" + ) in text + assert "return out.object();" in text + + +def test_pybind_skips_returning_call_for_unenabled_op(monkeypatch, tmp_path): + module = _load_generator_module() + base_header = tmp_path / "gemm.h" + base_header.write_text( + """ +class Gemm { + public: + virtual void operator()(const Tensor a, const Tensor b, Tensor c) const = 0; + + template + static auto MakeReturnValue(const TensorLike& a, const TensorLike& b) { + return TensorLike::Empty(a.shape(), a.dtype(), a.device()); + } +}; +""" + ) + monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header) + + operator = module._Operator( + "gemm", + constructors=[], + calls=[ + module._ParsedFunction( + [ + module._ParsedArgument("const Tensor", "a"), + module._ParsedArgument("const Tensor", "b"), + module._ParsedArgument("Tensor", "c"), + ] + ) + ], + ) + + text = module._generate_pybind11(operator) + + assert "Self::MakeReturnValue" not in text + assert text.count('m.def("gemm"') == 1 + + +def test_pybind_skips_returning_call_without_make_return_value(monkeypatch, tmp_path): + module = _load_generator_module() + base_header = tmp_path / "mul.h" + base_header.write_text( + """ +class Mul { + public: + virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; +}; +""" + ) + monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header) + + operator = module._Operator( + "mul", + constructors=[], + calls=[ + module._ParsedFunction( + [ + module._ParsedArgument("const Tensor", "input"), + module._ParsedArgument("const Tensor", "other"), + module._ParsedArgument("Tensor", "out"), + ] + ) + ], + ) + + text = module._generate_pybind11(operator) + + assert "Self::MakeReturnValue" not in text + assert text.count('m.def("mul"') == 1 + + def test_normalize_op_allowlist_accepts_spaces_and_commas(): module = _load_generator_module()