diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index effc0787..8a47c1f7 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -169,12 +169,13 @@ def _generate_arguments(node): return ", ".join(args) op_name = operator.name + pascal_case_op_name = _snake_to_pascal(op_name) def _generate_init(constructor): constructor_params = _generate_params(constructor) - return f""" .def(py::init([]({constructor_params}) {{ - return std::unique_ptr{{static_cast(Self::Make({_generate_arguments(constructor)}).release())}}; + Config config; + return std::unique_ptr{{static_cast(generated_dispatch::Make{pascal_case_op_name}(config, {_generate_arguments(constructor)}).release())}}; }}))""" def _generate_py_args(node): @@ -194,7 +195,6 @@ def _generate_py_args(node): def _generate_call(op_name, call, method=True): call_params = _generate_params(call) call_args = _generate_arguments(call) - if not method: params = ( f"{call_params}, std::uintptr_t stream, std::size_t implementation_index" @@ -212,12 +212,12 @@ def _generate_call(op_name, call, method=True): f" }}\n" f" Config config;\n" f" config.set_implementation_index(implementation_index);\n" - f" return Self::Call(handle, config, {call_args});\n" + f" return generated_dispatch::Call{pascal_case_op_name}(handle, config, {call_args});\n" f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);' ) return f""" .def("__call__", [](const Self& self, {call_params}) {{ - return static_cast&>(self)({call_args}); + return generated_dispatch::Invoke{pascal_case_op_name}(self, {call_args}); }})""" inits = "\n".join( @@ -228,8 +228,6 @@ def _generate_call(op_name, call, method=True): _generate_call(operator.name, call, method=False) for call in operator.calls ) - pascal_case_op_name = _snake_to_pascal(op_name) - return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_ #define INFINI_OPS_BINDINGS_{op_name.upper()}_H_ @@ -238,8 +236,8 @@ def _generate_call(op_name, call, method=True): #include "base/{op_name}.h" #include "config.h" +#include "generated/bindings/generated_dispatch.h" #include "handle.h" -#include "operator.h" #include "pybind11_utils.h" namespace py = pybind11; @@ -253,9 +251,9 @@ def _generate_call(op_name, call, method=True): {inits} {calls} .def_static("active_implementation_indices", [](const std::string& device) {{ - return Self::active_implementation_indices(DeviceTypeFromString(device)); + return generated_dispatch::ActiveImplementationIndicesFor{pascal_case_op_name}(DeviceTypeFromString(device)); }}) - .def_static("clear_cache", &Self::clear_cache); + .def_static("clear_cache", &generated_dispatch::ClearCacheFor{pascal_case_op_name}); {callers} }} @@ -440,6 +438,166 @@ def _generate_tensor_caster(name, is_data=False): return _generate_source(operator), _generate_header(operator) +def _generate_generated_dispatch(operators, ops, devices): + def _generate_params(node): + return ", ".join( + f"{arg.type.spelling} {arg.spelling}" + for arg in node.get_arguments() + if arg.spelling != "stream" + ) + + def _generate_arguments(node): + return ", ".join( + arg.spelling for arg in node.get_arguments() if arg.spelling != "stream" + ) + + def _append_optional_args(prefix, args): + if args: + return f"{prefix}, {args}" + + return prefix + + def _append_optional_params(prefix, params): + if params: + return f"{prefix}, {params}" + + return prefix + + header_base_includes = "\n".join( + f'#include "base/{operator.name}.h"' for operator in operators + ) + header_device_includes = "\n".join( + f'#include "{path}"' for path in _device_marker_headers(devices) + ) + impl_includes = "\n".join( + f'#include "{impl_path}"' + for impl_paths in ops.values() + for impl_path in impl_paths + ) + + declarations = [] + definitions = [] + + for operator in operators: + pascal_case_op_name = _snake_to_pascal(operator.name) + + declarations.append( + f"std::vector ActiveImplementationIndicesFor" + f"{pascal_case_op_name}(Device::Type dev_type);" + ) + definitions.append( + f"""std::vector ActiveImplementationIndicesFor{pascal_case_op_name}(Device::Type dev_type) {{ + return Operator<{pascal_case_op_name}>::active_implementation_indices(dev_type); +}}""" + ) + + declarations.append(f"void ClearCacheFor{pascal_case_op_name}();") + definitions.append( + f"""void ClearCacheFor{pascal_case_op_name}() {{ + Operator<{pascal_case_op_name}>::clear_cache(); +}}""" + ) + + for constructor in operator.constructors: + params = _generate_params(constructor) + args = _generate_arguments(constructor) + + declarations.append( + f"std::unique_ptr> " + f"Make{pascal_case_op_name}(" + f"{_append_optional_params('const Config& config', params)});" + ) + definitions.append( + f"""std::unique_ptr> Make{pascal_case_op_name}({_append_optional_params("const Config& config", params)}) {{ + return Operator<{pascal_case_op_name}>::Make({_append_optional_args("config", args)}); +}}""" + ) + + for call in operator.calls: + params = _generate_params(call) + args = _generate_arguments(call) + + declarations.append( + f"void Invoke{pascal_case_op_name}(const " + f"{_append_optional_params(f'{pascal_case_op_name}& op', params)});" + ) + definitions.append( + f"""void Invoke{pascal_case_op_name}(const {_append_optional_params(f"{pascal_case_op_name}& op", params)}) {{ + return static_cast&>(op)({args}); +}}""" + ) + + declarations.append( + f"void Call{pascal_case_op_name}(const Handle& handle, " + f"{_append_optional_params('const Config& config', params)});" + ) + definitions.append( + f"""void Call{pascal_case_op_name}(const Handle& handle, {_append_optional_params("const Config& config", params)}) {{ + return Operator<{pascal_case_op_name}>::Call({_append_optional_args("handle, config", args)}); +}}""" + ) + + header = f"""#ifndef INFINI_OPS_GENERATED_BINDINGS_GENERATED_DISPATCH_H_ +#define INFINI_OPS_GENERATED_BINDINGS_GENERATED_DISPATCH_H_ + +#include +#include +#include +#include + +#include "config.h" +#include "device.h" +#include "handle.h" +#include "operator.h" + +{header_device_includes} + +{header_base_includes} + +namespace infini::ops::generated_dispatch {{ + +{chr(10).join(declarations)} + +}} // namespace infini::ops::generated_dispatch + +#endif +""" + + source = f"""#include "generated_dispatch.h" + +// clang-format off +{impl_includes} +// clang-format on + +namespace infini::ops::generated_dispatch {{ + +{chr(10).join(definitions)} + +}} // namespace infini::ops::generated_dispatch +""" + + return header, source + + +def _device_marker_headers(devices): + paths = { + "cpu": "native/cpu/device_.h", + "nvidia": "native/cuda/nvidia/device_.h", + "cambricon": "native/cambricon/device_.h", + "ascend": "native/ascend/device_.h", + "metax": "native/cuda/metax/device_.h", + "moore": "native/cuda/moore/device_.h", + "iluvatar": "native/cuda/iluvatar/device_.h", + } + + return [paths[device] for device in devices if device in paths] + + +def _generate_binding_source(op_name): + return f"""#include "{op_name}.h" +""" + + def _snake_to_pascal(snake_str): return "".join(word.capitalize() for word in snake_str.split("_")) @@ -489,9 +647,11 @@ def _get_all_ops(devices, with_torch=False): args = parser.parse_args() - _BINDINGS_DIR.mkdir(parents=True, exist_ok=True) - _GENERATED_SRC_DIR.mkdir(parents=True, exist_ok=True) - _INCLUDE_DIR.mkdir(parents=True, exist_ok=True) + for directory in (_BINDINGS_DIR, _GENERATED_SRC_DIR, _INCLUDE_DIR): + if directory.exists(): + shutil.rmtree(directory) + + directory.mkdir(parents=True) ops_json = pathlib.Path("ops.json") @@ -500,47 +660,48 @@ def _get_all_ops(devices, with_torch=False): else: ops = _get_all_ops(args.devices, with_torch=args.with_torch) - header_paths = [] bind_func_names = [] + operators = [] for op_name, impl_paths in ops.items(): extractor = _OperatorExtractor() operator = extractor(op_name) + operators.append(operator) source_path = _GENERATED_SRC_DIR / op_name header_name = f"{op_name}.h" bind_func_name = f"Bind{_snake_to_pascal(op_name)}" (_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator)) + (_BINDINGS_DIR / f"{op_name}.cc").write_text(_generate_binding_source(op_name)) legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths) source_path.mkdir(exist_ok=True) (_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source) (_INCLUDE_DIR / header_name).write_text(legacy_c_header) - header_paths.append(header_name) bind_func_names.append(bind_func_name) - impl_includes = "\n".join( - f'#include "{impl_path}"' - for impl_paths in ops.values() - for impl_path in impl_paths + dispatch_header, dispatch_source = _generate_generated_dispatch( + operators, ops, args.devices + ) + (_BINDINGS_DIR / "generated_dispatch.h").write_text(dispatch_header) + (_BINDINGS_DIR / "generated_dispatch.cc").write_text(dispatch_source) + + bind_func_declarations = "\n".join( + f"void {bind_func_name}(pybind11::module& m);" + for bind_func_name in bind_func_names ) - op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths) bind_func_calls = "\n".join( f"{bind_func_name}(m);" for bind_func_name in bind_func_names ) (_BINDINGS_DIR / "ops.cc").write_text(f"""#include -// clang-format off -{impl_includes} -// clang-format on - -{op_includes} - namespace infini::ops {{ +{bind_func_declarations} + PYBIND11_MODULE(ops, m) {{ {textwrap.indent(bind_func_calls, _INDENTATION)} }} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ce888b4b..35d63ff8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -340,7 +340,8 @@ if(GENERATE_PYTHON_BINDINGS) message(STATUS "Generating wrappers - done") endif() - set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc") + file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS + "${PROJECT_SOURCE_DIR}/generated/bindings/*.cc") # TODO: There might be a better solution. if(WITH_NVIDIA OR WITH_ILUVATAR)