From 5958a1ef5aefea4263f873b1b52b0fc25976f9fd Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Tue, 12 May 2026 21:43:24 +0800 Subject: [PATCH 1/3] feat: generate dispatch entrypoints for bindings --- scripts/generate_wrappers.py | 229 +++++++++++++++++++++++++++++++---- src/CMakeLists.txt | 3 +- 2 files changed, 207 insertions(+), 25 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index effc0787..8d855ce6 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -169,12 +169,21 @@ def _generate_arguments(node): return ", ".join(args) op_name = operator.name + pascal_case_op_name = _snake_to_pascal(op_name) + + def _overload_suffix(nodes, node): + if len(nodes) == 1: + return "" + + return str(nodes.index(node)) def _generate_init(constructor): constructor_params = _generate_params(constructor) + suffix = _overload_suffix(operator.constructors, constructor) return f""" .def(py::init([]({constructor_params}) {{ - return std::unique_ptr{{static_cast(Self::Make({_generate_arguments(constructor)}).release())}}; + Config config; + return std::unique_ptr{{static_cast(generated_dispatch::Make{pascal_case_op_name}{suffix}(config, {_generate_arguments(constructor)}).release())}}; }}))""" def _generate_py_args(node): @@ -194,6 +203,7 @@ def _generate_py_args(node): def _generate_call(op_name, call, method=True): call_params = _generate_params(call) call_args = _generate_arguments(call) + suffix = _overload_suffix(operator.calls, call) if not method: params = ( @@ -212,12 +222,12 @@ def _generate_call(op_name, call, method=True): f" }}\n" f" Config config;\n" f" config.set_implementation_index(implementation_index);\n" - f" return Self::Call(handle, config, {call_args});\n" + f" return generated_dispatch::Call{pascal_case_op_name}{suffix}(handle, config, {call_args});\n" f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);' ) return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return static_cast&>(self)({call_args}); + return generated_dispatch::Invoke{pascal_case_op_name}{suffix}(self, {call_args}); }})""" inits = "\n".join( @@ -228,8 +238,6 @@ def _generate_call(op_name, call, method=True): _generate_call(operator.name, call, method=False) for call in operator.calls ) - pascal_case_op_name = _snake_to_pascal(op_name) - return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ @@ -238,8 +246,8 @@ def _generate_call(op_name, call, method=True): #include "base/{op_name}.h" #include "config.h" +#include "generated/bindings/generated_dispatch.h" #include "handle.h" -#include "operator.h" #include "pybind11_utils.h" namespace py = pybind11; @@ -253,9 +261,9 @@ def _generate_call(op_name, call, method=True): {inits} {calls} .def_static("active_implementation_indices", [](const std::string& device) {{ - return Self::active_implementation_indices(DeviceTypeFromString(device)); + return generated_dispatch::ActiveImplementationIndices{pascal_case_op_name}(DeviceTypeFromString(device)); }}) - .def_static("clear_cache", &Self::clear_cache); + .def_static("clear_cache", &generated_dispatch::ClearCache{pascal_case_op_name}); {callers} }} @@ -440,6 +448,176 @@ def _generate_tensor_caster(name, is_data=False): return _generate_source(operator), _generate_header(operator) +def _generate_generated_dispatch(operators, ops, devices): + def _generate_params(node): + return ", ".join( + f"{arg.type.spelling} {arg.spelling}" + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + def _generate_arguments(node): + return ", ".join( + arg.spelling + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + def _append_optional_args(prefix, args): + if args: + return f"{prefix}, {args}" + + return prefix + + def _append_optional_params(prefix, params): + if params: + return f"{prefix}, {params}" + + return prefix + + def _overload_suffix(nodes, node): + if len(nodes) == 1: + return "" + + return str(nodes.index(node)) + + header_base_includes = "\n".join( + f'#include "base/{operator.name}.h"' for operator in operators + ) + header_device_includes = "\n".join( + f'#include "{path}"' for path in _device_marker_headers(devices) + ) + impl_includes = "\n".join( + f'#include "{impl_path}"' + for impl_paths in ops.values() + for impl_path in impl_paths + ) + + declarations = [] + definitions = [] + + for operator in operators: + pascal_case_op_name = _snake_to_pascal(operator.name) + + declarations.append( + f"std::vector ActiveImplementationIndices" + f"{pascal_case_op_name}(Device::Type dev_type);" + ) + definitions.append( + f"""std::vector ActiveImplementationIndices{pascal_case_op_name}(Device::Type dev_type) {{ + return Operator<{pascal_case_op_name}>::active_implementation_indices(dev_type); +}}""" + ) + + declarations.append(f"void ClearCache{pascal_case_op_name}();") + definitions.append( + f"""void ClearCache{pascal_case_op_name}() {{ + Operator<{pascal_case_op_name}>::clear_cache(); +}}""" + ) + + for constructor in operator.constructors: + suffix = _overload_suffix(operator.constructors, constructor) + params = _generate_params(constructor) + args = _generate_arguments(constructor) + + declarations.append( + f"std::unique_ptr> " + f"Make{pascal_case_op_name}{suffix}(" + f"{_append_optional_params('const Config& config', params)});" + ) + definitions.append( + f"""std::unique_ptr> Make{pascal_case_op_name}{suffix}({_append_optional_params("const Config& config", params)}) {{ + return Operator<{pascal_case_op_name}>::Make({_append_optional_args("config", args)}); +}}""" + ) + + for call in operator.calls: + suffix = _overload_suffix(operator.calls, call) + params = _generate_params(call) + args = _generate_arguments(call) + + declarations.append( + f"void Invoke{pascal_case_op_name}{suffix}(const " + f"{_append_optional_params(f'{pascal_case_op_name}& op', params)});" + ) + definitions.append( + f"""void Invoke{pascal_case_op_name}{suffix}(const {_append_optional_params(f'{pascal_case_op_name}& op', params)}) {{ + return static_cast&>(op)({args}); +}}""" + ) + + declarations.append( + f"void Call{pascal_case_op_name}{suffix}(const Handle& handle, " + f"{_append_optional_params('const Config& config', params)});" + ) + definitions.append( + f"""void Call{pascal_case_op_name}{suffix}(const Handle& handle, {_append_optional_params("const Config& config", params)}) {{ + return Operator<{pascal_case_op_name}>::Call({_append_optional_args("handle, config", args)}); +}}""" + ) + + header = f"""#ifndef INFINI_OPS_GENERATED_BINDINGS_GENERATED_DISPATCH_H_ +#define INFINI_OPS_GENERATED_BINDINGS_GENERATED_DISPATCH_H_ + +#include +#include +#include +#include + +#include "config.h" +#include "device.h" +#include "handle.h" +#include "operator.h" + +{header_device_includes} + +{header_base_includes} + +namespace infini::ops::generated_dispatch {{ + +{chr(10).join(declarations)} + +}} // namespace infini::ops::generated_dispatch + +#endif +""" + + source = f"""#include "generated_dispatch.h" + +// clang-format off +{impl_includes} +// clang-format on + +namespace infini::ops::generated_dispatch {{ + +{chr(10).join(definitions)} + +}} // namespace infini::ops::generated_dispatch +""" + + return header, source + + +def _device_marker_headers(devices): + paths = { + "cpu": "native/cpu/device_.h", + "nvidia": "native/cuda/nvidia/device_.h", + "cambricon": "native/cambricon/device_.h", + "ascend": "native/ascend/device_.h", + "metax": "native/cuda/metax/device_.h", + "moore": "native/cuda/moore/device_.h", + "iluvatar": "native/cuda/iluvatar/device_.h", + } + + return [paths[device] for device in devices if device in paths] + + +def _generate_binding_source(op_name): + return f"""#include "{op_name}.h" +""" + + def _snake_to_pascal(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) @@ -489,9 +667,11 @@ def _get_all_ops(devices, with_torch=False): args = parser.parse_args() - _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) - _GENERATED_SRC_DIR.mkdir(parents=True, exist_ok=True) - _INCLUDE_DIR.mkdir(parents=True, exist_ok=True) + for directory in (_BINDINGS_DIR, _GENERATED_SRC_DIR, _INCLUDE_DIR): + if directory.exists(): + shutil.rmtree(directory) + + directory.mkdir(parents=True) ops_json = pathlib.Path("ops.json") @@ -500,47 +680,48 @@ def _get_all_ops(devices, with_torch=False): else: ops = _get_all_ops(args.devices, with_torch=args.with_torch) - header_paths = [] bind_func_names = [] + operators = [] for op_name, impl_paths in ops.items(): extractor = _OperatorExtractor() operator = extractor(op_name) + operators.append(operator) source_path = _GENERATED_SRC_DIR / op_name header_name = f"{op_name}.h" bind_func_name = f"Bind{_snake_to_pascal(op_name)}" (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + (_BINDINGS_DIR / f"{op_name}.cc").write_text(_generate_binding_source(op_name)) legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) - header_paths.append(header_name) bind_func_names.append(bind_func_name) - impl_includes = "\n".join( - f'#include "{impl_path}"' - for impl_paths in ops.values() - for impl_path in impl_paths + dispatch_header, dispatch_source = _generate_generated_dispatch( + operators, ops, args.devices + ) + (_BINDINGS_DIR / "generated_dispatch.h").write_text(dispatch_header) + (_BINDINGS_DIR / "generated_dispatch.cc").write_text(dispatch_source) + + bind_func_declarations = "\n".join( + f"void {bind_func_name}(pybind11::module& m);" + for bind_func_name in bind_func_names ) - op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths) bind_func_calls = "\n".join( f"{bind_func_name}(m);" for bind_func_name in bind_func_names ) (_BINDINGS_DIR / "ops.cc").write_text(f"""#include -// clang-format off -{impl_includes} -// clang-format on - -{op_includes} - namespace infini::ops {{ +{bind_func_declarations} + PYBIND11_MODULE(ops, m) {{ {textwrap.indent(bind_func_calls, _INDENTATION)} }} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ce888b4b..35d63ff8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -340,7 +340,8 @@ if(GENERATE_PYTHON_BINDINGS) message(STATUS "Generating wrappers - done") endif() - set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS + "${PROJECT_SOURCE_DIR}/generated/bindings/*.cc") # TODO: There might be a better solution. if(WITH_NVIDIA OR WITH_ILUVATAR) From 9e749f3b7e958e4c2475f26a6eac12099f7a3682 Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 13 May 2026 10:51:51 +0800 Subject: [PATCH 2/3] refactor: clarify generated dispatch names --- scripts/generate_wrappers.py | 38 ++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index 8d855ce6..acffd071 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -172,7 +172,15 @@ def _generate_arguments(node): pascal_case_op_name = _snake_to_pascal(op_name) def _overload_suffix(nodes, node): - if len(nodes) == 1: + def _signature(item): + return tuple( + arg.type.spelling + for arg in item.get_arguments() + if arg.spelling != "stream" + ) + + signature = _signature(node) + if [_signature(item) for item in nodes].count(signature) == 1: return "" return str(nodes.index(node)) @@ -261,9 +269,9 @@ def _generate_call(op_name, call, method=True): {inits} {calls} .def_static("active_implementation_indices", [](const std::string& device) {{ - return generated_dispatch::ActiveImplementationIndices{pascal_case_op_name}(DeviceTypeFromString(device)); + return generated_dispatch::ActiveImplementationIndicesFor{pascal_case_op_name}(DeviceTypeFromString(device)); }}) - .def_static("clear_cache", &generated_dispatch::ClearCache{pascal_case_op_name}); + .def_static("clear_cache", &generated_dispatch::ClearCacheFor{pascal_case_op_name}); {callers} }} @@ -458,9 +466,7 @@ def _generate_params(node): def _generate_arguments(node): return ", ".join( - arg.spelling - for arg in node.get_arguments() - if arg.spelling != "stream" + arg.spelling for arg in node.get_arguments() if arg.spelling != "stream" ) def _append_optional_args(prefix, args): @@ -476,7 +482,15 @@ def _append_optional_params(prefix, params): return prefix def _overload_suffix(nodes, node): - if len(nodes) == 1: + def _signature(item): + return tuple( + arg.type.spelling + for arg in item.get_arguments() + if arg.spelling != "stream" + ) + + signature = _signature(node) + if [_signature(item) for item in nodes].count(signature) == 1: return "" return str(nodes.index(node)) @@ -500,18 +514,18 @@ def _overload_suffix(nodes, node): pascal_case_op_name = _snake_to_pascal(operator.name) declarations.append( - f"std::vector ActiveImplementationIndices" + f"std::vector ActiveImplementationIndicesFor" f"{pascal_case_op_name}(Device::Type dev_type);" ) definitions.append( - f"""std::vector ActiveImplementationIndices{pascal_case_op_name}(Device::Type dev_type) {{ + f"""std::vector ActiveImplementationIndicesFor{pascal_case_op_name}(Device::Type dev_type) {{ return Operator<{pascal_case_op_name}>::active_implementation_indices(dev_type); }}""" ) - declarations.append(f"void ClearCache{pascal_case_op_name}();") + declarations.append(f"void ClearCacheFor{pascal_case_op_name}();") definitions.append( - f"""void ClearCache{pascal_case_op_name}() {{ + f"""void ClearCacheFor{pascal_case_op_name}() {{ Operator<{pascal_case_op_name}>::clear_cache(); }}""" ) @@ -542,7 +556,7 @@ def _overload_suffix(nodes, node): f"{_append_optional_params(f'{pascal_case_op_name}& op', params)});" ) definitions.append( - f"""void Invoke{pascal_case_op_name}{suffix}(const {_append_optional_params(f'{pascal_case_op_name}& op', params)}) {{ + f"""void Invoke{pascal_case_op_name}{suffix}(const {_append_optional_params(f"{pascal_case_op_name}& op", params)}) {{ return static_cast&>(op)({args}); }}""" ) From a00c8e51126bb7e5e74a9956aaf211f22374fc6e Mon Sep 17 00:00:00 2001 From: Jiacheng Huang Date: Wed, 13 May 2026 15:00:43 +0800 Subject: [PATCH 3/3] refactor: rely on overloads in generated dispatch --- scripts/generate_wrappers.py | 52 +++++++----------------------------- 1 file changed, 9 insertions(+), 43 deletions(-) diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index acffd071..8a47c1f7 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -171,27 +171,11 @@ def _generate_arguments(node): op_name = operator.name pascal_case_op_name = _snake_to_pascal(op_name) - def _overload_suffix(nodes, node): - def _signature(item): - return tuple( - arg.type.spelling - for arg in item.get_arguments() - if arg.spelling != "stream" - ) - - signature = _signature(node) - if [_signature(item) for item in nodes].count(signature) == 1: - return "" - - return str(nodes.index(node)) - def _generate_init(constructor): constructor_params = _generate_params(constructor) - suffix = _overload_suffix(operator.constructors, constructor) - return f""" .def(py::init([]({constructor_params}) {{ Config config; - return std::unique_ptr{{static_cast(generated_dispatch::Make{pascal_case_op_name}{suffix}(config, {_generate_arguments(constructor)}).release())}}; + return std::unique_ptr{{static_cast(generated_dispatch::Make{pascal_case_op_name}(config, {_generate_arguments(constructor)}).release())}}; }}))""" def _generate_py_args(node): @@ -211,8 +195,6 @@ def _generate_py_args(node): def _generate_call(op_name, call, method=True): call_params = _generate_params(call) call_args = _generate_arguments(call) - suffix = _overload_suffix(operator.calls, call) - if not method: params = ( f"{call_params}, std::uintptr_t stream, std::size_t implementation_index" @@ -230,12 +212,12 @@ def _generate_call(op_name, call, method=True): f" }}\n" f" Config config;\n" f" config.set_implementation_index(implementation_index);\n" - f" return generated_dispatch::Call{pascal_case_op_name}{suffix}(handle, config, {call_args});\n" + f" return generated_dispatch::Call{pascal_case_op_name}(handle, config, {call_args});\n" f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);' ) return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return generated_dispatch::Invoke{pascal_case_op_name}{suffix}(self, {call_args}); + return generated_dispatch::Invoke{pascal_case_op_name}(self, {call_args}); }})""" inits = "\n".join( @@ -481,20 +463,6 @@ def _append_optional_params(prefix, params): return prefix - def _overload_suffix(nodes, node): - def _signature(item): - return tuple( - arg.type.spelling - for arg in item.get_arguments() - if arg.spelling != "stream" - ) - - signature = _signature(node) - if [_signature(item) for item in nodes].count(signature) == 1: - return "" - - return str(nodes.index(node)) - header_base_includes = "\n".join( f'#include "base/{operator.name}.h"' for operator in operators ) @@ -531,42 +499,40 @@ def _signature(item): ) for constructor in operator.constructors: - suffix = _overload_suffix(operator.constructors, constructor) params = _generate_params(constructor) args = _generate_arguments(constructor) declarations.append( f"std::unique_ptr> " - f"Make{pascal_case_op_name}{suffix}(" + f"Make{pascal_case_op_name}(" f"{_append_optional_params('const Config& config', params)});" ) definitions.append( - f"""std::unique_ptr> Make{pascal_case_op_name}{suffix}({_append_optional_params("const Config& config", params)}) {{ + f"""std::unique_ptr> Make{pascal_case_op_name}({_append_optional_params("const Config& config", params)}) {{ return Operator<{pascal_case_op_name}>::Make({_append_optional_args("config", args)}); }}""" ) for call in operator.calls: - suffix = _overload_suffix(operator.calls, call) params = _generate_params(call) args = _generate_arguments(call) declarations.append( - f"void Invoke{pascal_case_op_name}{suffix}(const " + f"void Invoke{pascal_case_op_name}(const " f"{_append_optional_params(f'{pascal_case_op_name}& op', params)});" ) definitions.append( - f"""void Invoke{pascal_case_op_name}{suffix}(const {_append_optional_params(f"{pascal_case_op_name}& op", params)}) {{ + f"""void Invoke{pascal_case_op_name}(const {_append_optional_params(f"{pascal_case_op_name}& op", params)}) {{ return static_cast&>(op)({args}); }}""" ) declarations.append( - f"void Call{pascal_case_op_name}{suffix}(const Handle& handle, " + f"void Call{pascal_case_op_name}(const Handle& handle, " f"{_append_optional_params('const Config& config', params)});" ) definitions.append( - f"""void Call{pascal_case_op_name}{suffix}(const Handle& handle, {_append_optional_params("const Config& config", params)}) {{ + f"""void Call{pascal_case_op_name}(const Handle& handle, {_append_optional_params("const Config& config", params)}) {{ return Operator<{pascal_case_op_name}>::Call({_append_optional_args("handle, config", args)}); }}""" )