From 7a009455f9c1eda9a10f62707999ca1038c2080d Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Mon, 29 Jun 2026 13:16:34 +0800 Subject: [PATCH 1/3] feat: expose returning calls to python --- scripts/generate_wrappers.py | 178 +++++++++++++++++++++++++------- src/base/gemm.h | 5 +- src/pybind11_utils.h | 148 ++++++++++++++++++++++++++ src/torch/ops/gemm/gemm.h | 7 ++ tests/test_add.py | 47 +++++++++ tests/test_gemm.py | 67 ++++++++++++ tests/test_generate_wrappers.py | 83 +++++++++++++++ 7 files changed, 494 insertions(+), 41 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 0384e0423..390919e71 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -454,9 +454,16 @@ def _find_tensor_params(op_name): return params +def _has_make_return_value(op_name): + source = _strip_cpp_comments(_find_base_header(op_name).read_text()) + + return re.search(r"\bMakeReturnValue\s*\(", source) is not None + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) optional_non_tensor_params = _find_optional_non_tensor_params(operator.name) + tensor_params = _find_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) vector_int64_params = _find_vector_int64_params(operator.name) @@ -483,63 +490,81 @@ def _is_vector_tensor(arg): def _is_vector_int64(arg): return arg.spelling in vector_int64_params - def _generate_params(node): - parts = [] + def _is_tensor(arg): + if arg.spelling in optional_non_tensor_params: + return False - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue + if arg.spelling in tensor_params: + return True + return "Tensor" in arg.type.spelling + + def _is_plain_tensor(arg): + return ( + not _is_optional_tensor(arg) + and not _is_vector_tensor(arg) + and _is_tensor(arg) + ) + + def _args_without_stream(node): + return [arg for arg in node.get_arguments() if arg.spelling != "stream"] + + def _generate_params_for_args(args): + parts = [] + + for arg in args: if _is_optional_tensor(arg): parts.append(f"std::optional {arg.spelling}") elif _is_vector_tensor(arg): parts.append(f"std::vector {arg.spelling}") elif _is_vector_int64(arg): parts.append(f"const std::vector {arg.spelling}") + elif _is_tensor(arg): + parts.append(f"py::object {arg.spelling}") else: - param = arg.type.spelling.replace("const Tensor", "py::object").replace( - "Tensor", "py::object" - ) - parts.append(f"{param} {arg.spelling}") + parts.append(f"{arg.type.spelling} {arg.spelling}") return ", ".join(parts) - def _generate_arguments(node): - args = [] + def _generate_params(node): + return _generate_params_for_args(_args_without_stream(node)) - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue + def _generate_arguments_for_args(args): + generated_args = [] + for arg in args: if _is_optional_tensor(arg): - args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})") + generated_args.append( + f"OptionalTensorFromPybind11Handle({arg.spelling})" + ) elif _is_vector_tensor(arg): - args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") - elif "Tensor" in arg.type.spelling: - args.append(f"TensorFromPybind11Handle({arg.spelling})") + generated_args.append(f"VectorTensorFromPybind11Handle({arg.spelling})") + elif _is_tensor(arg): + generated_args.append(f"TensorFromPybind11Handle({arg.spelling})") else: - args.append(arg.spelling) + generated_args.append(arg.spelling) + + return ", ".join(generated_args) - return ", ".join(args) + def _generate_arguments(node): + return _generate_arguments_for_args(_args_without_stream(node)) op_name = operator.name op_type = _op_cpp_type(op_name) symbol_name = _op_symbol_name(op_name) - def _first_tensor_arg(node): - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue + def _first_tensor_arg_for_args(args): + for arg in args: if _is_optional_tensor(arg): continue if _is_vector_tensor(arg): return f"{arg.spelling}.at(0)" - if "Tensor" in arg.type.spelling: + if _is_tensor(arg): return arg.spelling return None - def _default_impl_index_expr(node): - first_tensor = _first_tensor_arg(node) + def _default_impl_index_expr_for_args(args): + first_tensor = _first_tensor_arg_for_args(args) if first_tensor is None: return "0" return ( @@ -547,6 +572,9 @@ def _default_impl_index_expr(node): f"DeviceFromPybind11Handle({first_tensor}).type())" ) + def _default_impl_index_expr(node): + return _default_impl_index_expr_for_args(_args_without_stream(node)) + def _generate_init(constructor): constructor_params = _generate_params(constructor) default_impl_index = _default_impl_index_expr(constructor) @@ -557,13 +585,10 @@ def _generate_init(constructor): return std::unique_ptr{{static_cast(generated_dispatch::Make{symbol_name}(config, {_generate_arguments(constructor)}).release())}}; }}))""" - def _generate_py_args(node): + def _generate_py_args_for_args(args): parts = [] - for arg in node.get_arguments(): - if arg.spelling == "stream": - continue - + for arg in args: if _is_optional(arg): parts.append(f'py::arg("{arg.spelling}") = py::none()') else: @@ -571,6 +596,9 @@ def _generate_py_args(node): return ", ".join(parts) + def _generate_py_args(node): + return _generate_py_args_for_args(_args_without_stream(node)) + def _generate_call(op_name, call, method=True): call_params = _generate_params(call) call_args = _generate_arguments(call) @@ -609,6 +637,74 @@ def _generate_call(op_name, call, method=True): return generated_dispatch::Invoke{symbol_name}(op, {call_args}); }})""" + def _is_returning_candidate(call): + args = _args_without_stream(call) + + if len(args) < 2 or not _is_plain_tensor(args[-1]): + return False + + if not any(_is_plain_tensor(arg) for arg in args[:-1]): + return False + + return not any( + _is_optional_tensor(arg) or _is_vector_tensor(arg) for arg in args[:-1] + ) + + def _returning_local_name(arg): + return f"infini_{arg.spelling}" + + def _generate_returning_call(call): + args = _args_without_stream(call) + input_args = args[:-1] + call_params = _generate_params_for_args(input_args) + params = ( + f"{call_params}, std::uintptr_t stream, " + "std::optional implementation_index" + if call_params + else "std::uintptr_t stream, " + "std::optional implementation_index" + ) + py_args = _generate_py_args_for_args(input_args) + py_args_str = f"{py_args}, " if py_args else "" + default_impl_index = _default_impl_index_expr_for_args(input_args) + tensor_locals = "\n".join( + f" Pybind11Tensor {_returning_local_name(arg)}{{{arg.spelling}}};" + for arg in input_args + if _is_plain_tensor(arg) + ) + make_args = ", ".join( + _returning_local_name(arg) if _is_plain_tensor(arg) else arg.spelling + for arg in input_args + ) + dispatch_args = ", ".join( + f"Tensor{{{_returning_local_name(arg)}}}" + if _is_plain_tensor(arg) + else arg.spelling + for arg in input_args + ) + dispatch_args = ( + f"{dispatch_args}, Tensor{{out}}" if dispatch_args else "Tensor{out}" + ) + + if tensor_locals: + tensor_locals += "\n" + + return ( + f' m.def("{op_name}", []({params}) {{\n' + f" Handle handle;\n" + f" if (stream) {{\n" + f" handle.set_stream(reinterpret_cast(stream));\n" + f" }}\n" + f" Config config;\n" + f" config.set_implementation_index(\n" + f" implementation_index.value_or({default_impl_index}));\n" + f"{tensor_locals}" + f" auto out = Self::MakeReturnValue({make_args});\n" + f" generated_dispatch::Call{symbol_name}(handle, config, {dispatch_args});\n" + f" return out.object();\n" + f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = py::none());' + ) + def _overload_order_key(node): """Sort key that places more-specific overloads first. @@ -630,11 +726,7 @@ def _overload_order_key(node): total += 1 - if ( - _is_optional_tensor(arg) - or _is_vector_tensor(arg) - or "Tensor" in arg.type.spelling - ): + if _is_optional_tensor(arg) or _is_vector_tensor(arg) or _is_tensor(arg): object_like += 1 return (object_like, -total) @@ -644,9 +736,19 @@ def _overload_order_key(node): inits = "\n".join(_generate_init(constructor) for constructor in constructors) calls = "\n".join(_generate_call(operator.name, call) for call in operator_calls) - callers = "\n".join( + returning_calls = [] + if _has_make_return_value(operator.name): + returning_calls = sorted( + (call for call in operator_calls if _is_returning_candidate(call)), + key=lambda call: len(_args_without_stream(call)), + )[:1] + returning_callers = "\n".join( + _generate_returning_call(call) for call in returning_calls + ) + inplace_callers = "\n".join( _generate_call(operator.name, call, method=False) for call in operator_calls ) + callers = "\n".join(filter(None, (returning_callers, inplace_callers))) return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ diff --git a/src/base/gemm.h b/src/base/gemm.h index c0b0fdc35..787644482 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -37,7 +37,7 @@ class Gemm : public Operator { } Gemm(const Tensor a, const Tensor b, Tensor c) - : Gemm{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} + : Gemm{a, b, std::nullopt, 0.0F, std::nullopt, std::nullopt, c} {} virtual void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, @@ -45,8 +45,7 @@ class Gemm : public Operator { std::optional trans_b, Tensor c) const = 0; virtual void operator()(const Tensor a, const Tensor b, Tensor c) const { - return operator()(a, b, std::nullopt, std::nullopt, std::nullopt, - std::nullopt, c); + return operator()(a, b, std::nullopt, 0.0F, std::nullopt, std::nullopt, c); } virtual void operator()(const Tensor a, const Tensor b, diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index 6610e4ae1..d8e0503cb 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -5,6 +5,11 @@ #include #include +#include +#include +#include +#include +#include #include "tensor.h" #include "torch/device_.h" @@ -29,6 +34,102 @@ inline DataType DataTypeFromString(const std::string& name) { return kStringToDataType.at(name); } +inline std::string DataTypeToTorchName(DataType dtype) { + switch (dtype) { + case DataType::kInt8: + return "int8"; + case DataType::kInt16: + return "int16"; + case DataType::kInt32: + return "int32"; + case DataType::kInt64: + return "int64"; + case DataType::kUInt8: + return "uint8"; + case DataType::kUInt16: + return "uint16"; + case DataType::kUInt32: + return "uint32"; + case DataType::kUInt64: + return "uint64"; + case DataType::kFloat16: + return "float16"; + case DataType::kBFloat16: + return "bfloat16"; + case DataType::kFloat32: + return "float32"; + case DataType::kFloat64: + return "float64"; + default: + throw py::value_error("Unsupported dtype for PyTorch tensor creation"); + } +} + +inline py::object TorchDType(DataType dtype) { + auto torch = py::module_::import("torch"); + auto name = DataTypeToTorchName(dtype); + + if (!py::hasattr(torch, name.c_str())) { + throw py::value_error("Current PyTorch build does not expose dtype `" + + name + "`"); + } + + return torch.attr(name.c_str()); +} + +inline std::string TorchDeviceTypeName(Device::Type type) { + switch (type) { + case Device::Type::kCpu: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kNvidia: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kCambricon: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kAscend: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kMetax: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kMoore: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kIluvatar: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kKunlun: + return std::string{ + detail::TorchDeviceName::kValue}; + case Device::Type::kHygon: + return std::string{detail::TorchDeviceName::kValue}; + case Device::Type::kQy: + return std::string{detail::TorchDeviceName::kValue}; + default: + throw py::value_error("Unsupported device for PyTorch tensor creation"); + } +} + +inline py::object TorchDevice(Device device) { + auto torch = py::module_::import("torch"); + auto name = TorchDeviceTypeName(device.type()); + + if (device.type() != Device::Type::kCpu) { + name += ":" + std::to_string(device.index()); + } + + return torch.attr("device")(name); +} + +inline Tensor::Strides ContiguousStrides(const Tensor::Shape& shape) { + Tensor::Strides strides(shape.size(), 1); + + for (std::size_t i = shape.size(); i > 1; --i) { + strides[i - 2] = strides[i - 1] * shape[i - 1]; + } + + return strides; +} + template inline Device::Type DeviceTypeFromString(const std::string& name) { static const auto kTorchNameToTypes{ @@ -154,6 +255,53 @@ inline std::vector VectorTensorFromPybind11Handle( return result; } +class Pybind11Tensor { + public: + explicit Pybind11Tensor(py::handle obj) + : object_{py::reinterpret_borrow(obj)}, + shape_{object_.attr("shape").cast()}, + strides_{object_.attr("stride")().cast()}, + dtype_{TensorFromPybind11Handle(object_).dtype()}, + device_{DeviceFromPybind11Handle(object_)} {} + + static Pybind11Tensor Empty(const Tensor::Shape& shape, DataType dtype, + Device device) { + auto torch = py::module_::import("torch"); + auto strides = ContiguousStrides(shape); + auto object = torch.attr("empty_strided")( + shape, strides, py::arg("dtype") = TorchDType(dtype), + py::arg("device") = TorchDevice(device)); + + return Pybind11Tensor{object}; + } + + void* data() const { + return reinterpret_cast( + object_.attr("data_ptr")().cast()); + } + + const Tensor::Shape& shape() const { return shape_; } + + const Tensor::Strides& strides() const { return strides_; } + + DataType dtype() const { return dtype_; } + + Device device() const { return device_; } + + const py::object& object() const { return object_; } + + private: + py::object object_; + + Tensor::Shape shape_; + + Tensor::Strides strides_; + + DataType dtype_; + + Device device_; +}; + } // namespace infini::ops #endif diff --git a/src/torch/ops/gemm/gemm.h b/src/torch/ops/gemm/gemm.h index 4fd22ff36..ae252fdbd 100644 --- a/src/torch/ops/gemm/gemm.h +++ b/src/torch/ops/gemm/gemm.h @@ -12,6 +12,13 @@ class Operator : public Gemm { std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c); + Operator(const Tensor a, const Tensor b, Tensor c) + : Operator{a, b, std::nullopt, 0.0F, std::nullopt, std::nullopt, c} {} + + Operator(const Tensor a, const Tensor b, std::optional alpha, + std::optional beta, Tensor c) + : Operator{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} + using Gemm::operator(); void operator()(const Tensor a, const Tensor b, std::optional alpha, diff --git a/tests/test_add.py b/tests/test_add.py index e2266c30d..23234076e 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -86,6 +86,44 @@ def test_add( ) +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "shape, input_strides, other_strides", + ( + ((13, 4), None, None), + ((13, 4), (0, 1), None), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-7, 1e-7), + (torch.float16, 1e-3, 1e-3), + ), +) +def test_add_returning( + shape, + input_strides, + other_strides, + implementation_index, + dtype, + device, + rtol, + atol, +): + input = randn_strided(shape, input_strides, dtype=dtype, device=device) + other = randn_strided(shape, other_strides, dtype=dtype, device=device) + + return Payload( + lambda *args: _add_returning(*args, implementation_index=implementation_index), + torch.add, + (input, other), + {}, + rtol=rtol, + atol=atol, + ) + + def _add(input, other, out, implementation_index=0): infini.ops.add( input, @@ -98,6 +136,15 @@ def _add(input, other, out, implementation_index=0): return out +def _add_returning(input, other, implementation_index=0): + return infini.ops.add( + input, + other, + stream=get_stream(input.device), + implementation_index=implementation_index, + ) + + def _torch_add(input, other, out): if input.dtype in _UINT_DTYPES: input = input.to(torch.int64) diff --git a/tests/test_gemm.py b/tests/test_gemm.py index 224390d15..b9fe752e0 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -110,6 +110,64 @@ def test_gemm( ) +@pytest.mark.auto_act_and_assert +@pytest.mark.parametrize( + "a_shape, b_shape, a_strides, b_strides", + ( + ((4, 64), (64, 32), None, None), + ((6, 2048), (2048, 2560), (2048, 1), (1, 2048)), + ), +) +@pytest.mark.parametrize( + ("dtype", "rtol", "atol"), + ( + (torch.float32, 1e-3, 1e-3), + (torch.float16, 1e-2, 1e-2), + ), +) +def test_gemm_returning( + a_shape, + b_shape, + a_strides, + b_strides, + implementation_index, + dtype, + device, + rtol, + atol, +): + if implementation_index == 1 and dtype == torch.float16: + pytest.skip("cuBLASLt half-precision exceeds current tolerances") + + if implementation_index == 2 and device == "cpu" and dtype == torch.float16: + pytest.skip("ATen CPU `addmm` does not support half-precision") + + if ( + device == "cuda" + and dtype == torch.float16 + and infini.ops.Gemm.active_implementation_indices("iluvatar") + ): + pytest.skip("Iluvatar GEMM reports fp16 execution failures") + + if implementation_index == 2 and device == "npu": + pytest.skip( + "Gemm impl=2 on Ascend is a torch-fallback stub without an " + "instantiated specialization" + ) + + a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) + b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) + + return Payload( + lambda *args: _gemm_returning(*args, implementation_index=implementation_index), + torch.matmul, + (a, b), + {}, + rtol=rtol, + atol=atol, + ) + + def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): infini.ops.gemm( a, @@ -126,6 +184,15 @@ def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): return c +def _gemm_returning(a, b, implementation_index=0): + return infini.ops.gemm( + a, + b, + stream=get_stream(a.device), + implementation_index=implementation_index, + ) + + def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None): if trans_a: a = a.transpose(-2, -1) diff --git a/tests/test_generate_wrappers.py b/tests/test_generate_wrappers.py index 931677553..f5d1f5be7 100644 --- a/tests/test_generate_wrappers.py +++ b/tests/test_generate_wrappers.py @@ -176,6 +176,89 @@ class Mul { assert 'py::arg("implementation_index") = py::none()' in text +def test_pybind_generates_returning_call_when_make_return_value_exists( + monkeypatch, tmp_path +): + module = _load_generator_module() + base_header = tmp_path / "add.h" + base_header.write_text( + """ +class Add { + public: + virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; + + template + static auto MakeReturnValue(const TensorLike& input, const TensorLike& other) { + return TensorLike::Empty(input.shape(), input.dtype(), input.device()); + } +}; +""" + ) + monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header) + + operator = module._Operator( + "add", + constructors=[], + calls=[ + module._ParsedFunction( + [ + module._ParsedArgument("const Tensor", "input"), + module._ParsedArgument("const Tensor", "other"), + module._ParsedArgument("Tensor", "out"), + ] + ) + ], + ) + + text = module._generate_pybind11(operator) + + assert ( + 'm.def("add", [](py::object input, py::object other, ' + "std::uintptr_t stream, std::optional implementation_index)" + ) in text + assert "Pybind11Tensor infini_input{input};" in text + assert "Pybind11Tensor infini_other{other};" in text + assert "auto out = Self::MakeReturnValue(infini_input, infini_other);" in text + assert ( + "generated_dispatch::CallAdd(handle, config, Tensor{infini_input}, " + "Tensor{infini_other}, Tensor{out});" + ) in text + assert "return out.object();" in text + + +def test_pybind_skips_returning_call_without_make_return_value(monkeypatch, tmp_path): + module = _load_generator_module() + base_header = tmp_path / "mul.h" + base_header.write_text( + """ +class Mul { + public: + virtual void operator()(const Tensor input, const Tensor other, Tensor out) const = 0; +}; +""" + ) + monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header) + + operator = module._Operator( + "mul", + constructors=[], + calls=[ + module._ParsedFunction( + [ + module._ParsedArgument("const Tensor", "input"), + module._ParsedArgument("const Tensor", "other"), + module._ParsedArgument("Tensor", "out"), + ] + ) + ], + ) + + text = module._generate_pybind11(operator) + + assert "Self::MakeReturnValue" not in text + assert text.count('m.def("mul"') == 1 + + def test_normalize_op_allowlist_accepts_spaces_and_commas(): module = _load_generator_module() From a23471975438f4501589d9f49c898bebce7192d0 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang <45955067+voltjia@users.noreply.github.com> Date: Tue, 30 Jun 2026 11:14:18 +0800 Subject: [PATCH 2/3] fix: align pybind error messages --- src/pybind11_utils.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pybind11_utils.h b/src/pybind11_utils.h index d8e0503cb..4f619c279 100644 --- a/src/pybind11_utils.h +++ b/src/pybind11_utils.h @@ -61,7 +61,7 @@ inline std::string DataTypeToTorchName(DataType dtype) { case DataType::kFloat64: return "float64"; default: - throw py::value_error("Unsupported dtype for PyTorch tensor creation"); + throw py::value_error("unsupported `dtype` for PyTorch tensor creation"); } } @@ -70,7 +70,7 @@ inline py::object TorchDType(DataType dtype) { auto name = DataTypeToTorchName(dtype); if (!py::hasattr(torch, name.c_str())) { - throw py::value_error("Current PyTorch build does not expose dtype `" + + throw py::value_error("current PyTorch build does not expose `dtype` `" + name + "`"); } @@ -105,7 +105,7 @@ inline std::string TorchDeviceTypeName(Device::Type type) { case Device::Type::kQy: return std::string{detail::TorchDeviceName::kValue}; default: - throw py::value_error("Unsupported device for PyTorch tensor creation"); + throw py::value_error("unsupported `device` for PyTorch tensor creation"); } } From 766e3d6ac77071341eb57f24babe4f70b8ffd0e7 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 30 Jun 2026 03:34:14 +0000 Subject: [PATCH 3/3] feat: expose returning add python call --- scripts/generate_wrappers.py | 7 +++- src/base/gemm.h | 5 ++- src/torch/ops/gemm/gemm.h | 7 ---- tests/test_add.py | 46 +++------------------- tests/test_gemm.py | 67 --------------------------------- tests/test_generate_wrappers.py | 38 +++++++++++++++++++ 6 files changed, 53 insertions(+), 117 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 390919e71..0716111f7 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -37,6 +37,11 @@ _OP_NAMESPACE_PREFIXES = ("special", "linalg", "fft") +# Keep Python returning overloads explicitly opt-in while the API is being +# rolled out. Some C++ returning-capable ops have extra semantics that need +# separate Python review before exposure. +_PY_RETURNING_OPS = frozenset({"add"}) + def _op_namespace_parts(op_name): parts = op_name.split("_") @@ -737,7 +742,7 @@ def _overload_order_key(node): inits = "\n".join(_generate_init(constructor) for constructor in constructors) calls = "\n".join(_generate_call(operator.name, call) for call in operator_calls) returning_calls = [] - if _has_make_return_value(operator.name): + if operator.name in _PY_RETURNING_OPS and _has_make_return_value(operator.name): returning_calls = sorted( (call for call in operator_calls if _is_returning_candidate(call)), key=lambda call: len(_args_without_stream(call)), diff --git a/src/base/gemm.h b/src/base/gemm.h index 787644482..c0b0fdc35 100644 --- a/src/base/gemm.h +++ b/src/base/gemm.h @@ -37,7 +37,7 @@ class Gemm : public Operator { } Gemm(const Tensor a, const Tensor b, Tensor c) - : Gemm{a, b, std::nullopt, 0.0F, std::nullopt, std::nullopt, c} {} + : Gemm{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {} virtual void operator()(const Tensor a, const Tensor b, std::optional alpha, std::optional beta, @@ -45,7 +45,8 @@ class Gemm : public Operator { std::optional trans_b, Tensor c) const = 0; virtual void operator()(const Tensor a, const Tensor b, Tensor c) const { - return operator()(a, b, std::nullopt, 0.0F, std::nullopt, std::nullopt, c); + return operator()(a, b, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, c); } virtual void operator()(const Tensor a, const Tensor b, diff --git a/src/torch/ops/gemm/gemm.h b/src/torch/ops/gemm/gemm.h index ae252fdbd..4fd22ff36 100644 --- a/src/torch/ops/gemm/gemm.h +++ b/src/torch/ops/gemm/gemm.h @@ -12,13 +12,6 @@ class Operator : public Gemm { std::optional beta, std::optional trans_a, std::optional trans_b, Tensor c); - Operator(const Tensor a, const Tensor b, Tensor c) - : Operator{a, b, std::nullopt, 0.0F, std::nullopt, std::nullopt, c} {} - - Operator(const Tensor a, const Tensor b, std::optional alpha, - std::optional beta, Tensor c) - : Operator{a, b, alpha, beta, std::nullopt, std::nullopt, c} {} - using Gemm::operator(); void operator()(const Tensor a, const Tensor b, std::optional alpha, diff --git a/tests/test_add.py b/tests/test_add.py index 23234076e..cba60f899 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -75,9 +75,13 @@ def test_add( other = randn_strided(shape, other_strides, dtype=dtype, device=device) out = empty_strided(shape, out_strides, dtype=dtype, device=device) + call = _add + + if input_strides is None and other_strides is None and out_strides is None: + call = _add_returning return Payload( - lambda *args: _add(*args, implementation_index=implementation_index), + lambda *args: call(*args, implementation_index=implementation_index), _torch_add, (input, other, out), {}, @@ -86,44 +90,6 @@ def test_add( ) -@pytest.mark.auto_act_and_assert -@pytest.mark.parametrize( - "shape, input_strides, other_strides", - ( - ((13, 4), None, None), - ((13, 4), (0, 1), None), - ), -) -@pytest.mark.parametrize( - ("dtype", "rtol", "atol"), - ( - (torch.float32, 1e-7, 1e-7), - (torch.float16, 1e-3, 1e-3), - ), -) -def test_add_returning( - shape, - input_strides, - other_strides, - implementation_index, - dtype, - device, - rtol, - atol, -): - input = randn_strided(shape, input_strides, dtype=dtype, device=device) - other = randn_strided(shape, other_strides, dtype=dtype, device=device) - - return Payload( - lambda *args: _add_returning(*args, implementation_index=implementation_index), - torch.add, - (input, other), - {}, - rtol=rtol, - atol=atol, - ) - - def _add(input, other, out, implementation_index=0): infini.ops.add( input, @@ -136,7 +102,7 @@ def _add(input, other, out, implementation_index=0): return out -def _add_returning(input, other, implementation_index=0): +def _add_returning(input, other, out, implementation_index=0): return infini.ops.add( input, other, diff --git a/tests/test_gemm.py b/tests/test_gemm.py index b9fe752e0..224390d15 100644 --- a/tests/test_gemm.py +++ b/tests/test_gemm.py @@ -110,64 +110,6 @@ def test_gemm( ) -@pytest.mark.auto_act_and_assert -@pytest.mark.parametrize( - "a_shape, b_shape, a_strides, b_strides", - ( - ((4, 64), (64, 32), None, None), - ((6, 2048), (2048, 2560), (2048, 1), (1, 2048)), - ), -) -@pytest.mark.parametrize( - ("dtype", "rtol", "atol"), - ( - (torch.float32, 1e-3, 1e-3), - (torch.float16, 1e-2, 1e-2), - ), -) -def test_gemm_returning( - a_shape, - b_shape, - a_strides, - b_strides, - implementation_index, - dtype, - device, - rtol, - atol, -): - if implementation_index == 1 and dtype == torch.float16: - pytest.skip("cuBLASLt half-precision exceeds current tolerances") - - if implementation_index == 2 and device == "cpu" and dtype == torch.float16: - pytest.skip("ATen CPU `addmm` does not support half-precision") - - if ( - device == "cuda" - and dtype == torch.float16 - and infini.ops.Gemm.active_implementation_indices("iluvatar") - ): - pytest.skip("Iluvatar GEMM reports fp16 execution failures") - - if implementation_index == 2 and device == "npu": - pytest.skip( - "Gemm impl=2 on Ascend is a torch-fallback stub without an " - "instantiated specialization" - ) - - a = randn_strided(a_shape, a_strides, dtype=dtype, device=device) - b = randn_strided(b_shape, b_strides, dtype=dtype, device=device) - - return Payload( - lambda *args: _gemm_returning(*args, implementation_index=implementation_index), - torch.matmul, - (a, b), - {}, - rtol=rtol, - atol=atol, - ) - - def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): infini.ops.gemm( a, @@ -184,15 +126,6 @@ def _gemm(a, b, alpha, beta, trans_a, trans_b, c, implementation_index=0): return c -def _gemm_returning(a, b, implementation_index=0): - return infini.ops.gemm( - a, - b, - stream=get_stream(a.device), - implementation_index=implementation_index, - ) - - def _torch_gemm(a, b, alpha=1.0, beta=1.0, trans_a=False, trans_b=False, c=None): if trans_a: a = a.transpose(-2, -1) diff --git a/tests/test_generate_wrappers.py b/tests/test_generate_wrappers.py index f5d1f5be7..a8cf4997f 100644 --- a/tests/test_generate_wrappers.py +++ b/tests/test_generate_wrappers.py @@ -226,6 +226,44 @@ class Add { assert "return out.object();" in text +def test_pybind_skips_returning_call_for_unenabled_op(monkeypatch, tmp_path): + module = _load_generator_module() + base_header = tmp_path / "gemm.h" + base_header.write_text( + """ +class Gemm { + public: + virtual void operator()(const Tensor a, const Tensor b, Tensor c) const = 0; + + template + static auto MakeReturnValue(const TensorLike& a, const TensorLike& b) { + return TensorLike::Empty(a.shape(), a.dtype(), a.device()); + } +}; +""" + ) + monkeypatch.setattr(module, "_find_base_header", lambda op_name: base_header) + + operator = module._Operator( + "gemm", + constructors=[], + calls=[ + module._ParsedFunction( + [ + module._ParsedArgument("const Tensor", "a"), + module._ParsedArgument("const Tensor", "b"), + module._ParsedArgument("Tensor", "c"), + ] + ) + ], + ) + + text = module._generate_pybind11(operator) + + assert "Self::MakeReturnValue" not in text + assert text.count('m.def("gemm"') == 1 + + def test_pybind_skips_returning_call_without_make_return_value(monkeypatch, tmp_path): module = _load_generator_module() base_header = tmp_path / "mul.h"