Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 187 additions & 26 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,12 +169,13 @@ def _generate_arguments(node):
return ", ".join(args)

op_name = operator.name
pascal_case_op_name = _snake_to_pascal(op_name)

def _generate_init(constructor):
constructor_params = _generate_params(constructor)

return f""" .def(py::init([]({constructor_params}) {{
return std::unique_ptr<Self>{{static_cast<Self*>(Self::Make({_generate_arguments(constructor)}).release())}};
Config config;
return std::unique_ptr<Self>{{static_cast<Self*>(generated_dispatch::Make{pascal_case_op_name}(config, {_generate_arguments(constructor)}).release())}};
}}))"""

def _generate_py_args(node):
Expand All @@ -194,7 +195,6 @@ def _generate_py_args(node):
def _generate_call(op_name, call, method=True):
call_params = _generate_params(call)
call_args = _generate_arguments(call)

if not method:
params = (
f"{call_params}, std::uintptr_t stream, std::size_t implementation_index"
Expand All @@ -212,12 +212,12 @@ def _generate_call(op_name, call, method=True):
f" }}\n"
f" Config config;\n"
f" config.set_implementation_index(implementation_index);\n"
f" return Self::Call(handle, config, {call_args});\n"
f" return generated_dispatch::Call{pascal_case_op_name}(handle, config, {call_args});\n"
f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = 0);'
)

return f""" .def("__call__", [](const Self& self, {call_params}) {{
return static_cast<const Operator<Self>&>(self)({call_args});
return generated_dispatch::Invoke{pascal_case_op_name}(self, {call_args});
}})"""

inits = "\n".join(
Expand All @@ -228,8 +228,6 @@ def _generate_call(op_name, call, method=True):
_generate_call(operator.name, call, method=False) for call in operator.calls
)

pascal_case_op_name = _snake_to_pascal(op_name)

return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_
#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_

Expand All @@ -238,8 +236,8 @@ def _generate_call(op_name, call, method=True):

#include "base/{op_name}.h"
#include "config.h"
#include "generated/bindings/generated_dispatch.h"
#include "handle.h"
#include "operator.h"
#include "pybind11_utils.h"

namespace py = pybind11;
Expand All @@ -253,9 +251,9 @@ def _generate_call(op_name, call, method=True):
{inits}
{calls}
.def_static("active_implementation_indices", [](const std::string& device) {{
return Self::active_implementation_indices(DeviceTypeFromString(device));
return generated_dispatch::ActiveImplementationIndicesFor{pascal_case_op_name}(DeviceTypeFromString(device));
}})
.def_static("clear_cache", &Self::clear_cache);
.def_static("clear_cache", &generated_dispatch::ClearCacheFor{pascal_case_op_name});

{callers}
}}
Expand Down Expand Up @@ -440,6 +438,166 @@ def _generate_tensor_caster(name, is_data=False):
return _generate_source(operator), _generate_header(operator)


def _generate_generated_dispatch(operators, ops, devices):
def _generate_params(node):
return ", ".join(
f"{arg.type.spelling} {arg.spelling}"
for arg in node.get_arguments()
if arg.spelling != "stream"
)

def _generate_arguments(node):
return ", ".join(
arg.spelling for arg in node.get_arguments() if arg.spelling != "stream"
)

def _append_optional_args(prefix, args):
if args:
return f"{prefix}, {args}"

return prefix

def _append_optional_params(prefix, params):
if params:
return f"{prefix}, {params}"

return prefix

header_base_includes = "\n".join(
f'#include "base/{operator.name}.h"' for operator in operators
)
header_device_includes = "\n".join(
f'#include "{path}"' for path in _device_marker_headers(devices)
)
impl_includes = "\n".join(
f'#include "{impl_path}"'
for impl_paths in ops.values()
for impl_path in impl_paths
)

declarations = []
definitions = []

for operator in operators:
pascal_case_op_name = _snake_to_pascal(operator.name)

declarations.append(
f"std::vector<std::size_t> ActiveImplementationIndicesFor"
f"{pascal_case_op_name}(Device::Type dev_type);"
)
definitions.append(
f"""std::vector<std::size_t> ActiveImplementationIndicesFor{pascal_case_op_name}(Device::Type dev_type) {{
return Operator<{pascal_case_op_name}>::active_implementation_indices(dev_type);
}}"""
)

declarations.append(f"void ClearCacheFor{pascal_case_op_name}();")
definitions.append(
f"""void ClearCacheFor{pascal_case_op_name}() {{
Operator<{pascal_case_op_name}>::clear_cache();
}}"""
)

for constructor in operator.constructors:
params = _generate_params(constructor)
args = _generate_arguments(constructor)

declarations.append(
f"std::unique_ptr<Operator<{pascal_case_op_name}>> "
f"Make{pascal_case_op_name}("
f"{_append_optional_params('const Config& config', params)});"
)
definitions.append(
f"""std::unique_ptr<Operator<{pascal_case_op_name}>> Make{pascal_case_op_name}({_append_optional_params("const Config& config", params)}) {{
return Operator<{pascal_case_op_name}>::Make({_append_optional_args("config", args)});
}}"""
)

for call in operator.calls:
params = _generate_params(call)
args = _generate_arguments(call)

declarations.append(
f"void Invoke{pascal_case_op_name}(const "
f"{_append_optional_params(f'{pascal_case_op_name}& op', params)});"
)
definitions.append(
f"""void Invoke{pascal_case_op_name}(const {_append_optional_params(f"{pascal_case_op_name}& op", params)}) {{
return static_cast<const Operator<{pascal_case_op_name}>&>(op)({args});
}}"""
)

declarations.append(
f"void Call{pascal_case_op_name}(const Handle& handle, "
f"{_append_optional_params('const Config& config', params)});"
)
definitions.append(
f"""void Call{pascal_case_op_name}(const Handle& handle, {_append_optional_params("const Config& config", params)}) {{
return Operator<{pascal_case_op_name}>::Call({_append_optional_args("handle, config", args)});
}}"""
)

header = f"""#ifndef INFINI_OPS_GENERATED_BINDINGS_GENERATED_DISPATCH_H_
#define INFINI_OPS_GENERATED_BINDINGS_GENERATED_DISPATCH_H_

#include <cstddef>
#include <memory>
#include <optional>
#include <vector>

#include "config.h"
#include "device.h"
#include "handle.h"
#include "operator.h"

{header_device_includes}

{header_base_includes}

namespace infini::ops::generated_dispatch {{

{chr(10).join(declarations)}

}} // namespace infini::ops::generated_dispatch

#endif
"""

source = f"""#include "generated_dispatch.h"

// clang-format off
{impl_includes}
// clang-format on

namespace infini::ops::generated_dispatch {{

{chr(10).join(definitions)}

}} // namespace infini::ops::generated_dispatch
"""

return header, source


def _device_marker_headers(devices):
paths = {
"cpu": "native/cpu/device_.h",
"nvidia": "native/cuda/nvidia/device_.h",
"cambricon": "native/cambricon/device_.h",
"ascend": "native/ascend/device_.h",
"metax": "native/cuda/metax/device_.h",
"moore": "native/cuda/moore/device_.h",
"iluvatar": "native/cuda/iluvatar/device_.h",
}

return [paths[device] for device in devices if device in paths]


def _generate_binding_source(op_name):
return f"""#include "{op_name}.h"
"""


def _snake_to_pascal(snake_str):
return "".join(word.capitalize() for word in snake_str.split("_"))

Expand Down Expand Up @@ -489,9 +647,11 @@ def _get_all_ops(devices, with_torch=False):

args = parser.parse_args()

_BINDINGS_DIR.mkdir(parents=True, exist_ok=True)
_GENERATED_SRC_DIR.mkdir(parents=True, exist_ok=True)
_INCLUDE_DIR.mkdir(parents=True, exist_ok=True)
for directory in (_BINDINGS_DIR, _GENERATED_SRC_DIR, _INCLUDE_DIR):
if directory.exists():
shutil.rmtree(directory)

directory.mkdir(parents=True)

ops_json = pathlib.Path("ops.json")

Expand All @@ -500,47 +660,48 @@ def _get_all_ops(devices, with_torch=False):
else:
ops = _get_all_ops(args.devices, with_torch=args.with_torch)

header_paths = []
bind_func_names = []
operators = []

for op_name, impl_paths in ops.items():
extractor = _OperatorExtractor()
operator = extractor(op_name)
operators.append(operator)

source_path = _GENERATED_SRC_DIR / op_name
header_name = f"{op_name}.h"
bind_func_name = f"Bind{_snake_to_pascal(op_name)}"

(_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator))
(_BINDINGS_DIR / f"{op_name}.cc").write_text(_generate_binding_source(op_name))

legacy_c_source, legacy_c_header = _generate_legacy_c(operator, impl_paths)
source_path.mkdir(exist_ok=True)
(_GENERATED_SRC_DIR / op_name / "operator.cc").write_text(legacy_c_source)
(_INCLUDE_DIR / header_name).write_text(legacy_c_header)

header_paths.append(header_name)
bind_func_names.append(bind_func_name)

impl_includes = "\n".join(
f'#include "{impl_path}"'
for impl_paths in ops.values()
for impl_path in impl_paths
dispatch_header, dispatch_source = _generate_generated_dispatch(
operators, ops, args.devices
)
(_BINDINGS_DIR / "generated_dispatch.h").write_text(dispatch_header)
(_BINDINGS_DIR / "generated_dispatch.cc").write_text(dispatch_source)

bind_func_declarations = "\n".join(
f"void {bind_func_name}(pybind11::module& m);"
for bind_func_name in bind_func_names
)
op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths)
bind_func_calls = "\n".join(
f"{bind_func_name}(m);" for bind_func_name in bind_func_names
)

(_BINDINGS_DIR / "ops.cc").write_text(f"""#include <pybind11/pybind11.h>

// clang-format off
{impl_includes}
// clang-format on

{op_includes}

namespace infini::ops {{

{bind_func_declarations}

PYBIND11_MODULE(ops, m) {{
{textwrap.indent(bind_func_calls, _INDENTATION)}
}}
Expand Down
3 changes: 2 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ if(GENERATE_PYTHON_BINDINGS)
message(STATUS "Generating wrappers - done")
endif()

set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")
file(GLOB_RECURSE PYBIND11_SOURCES CONFIGURE_DEPENDS
"${PROJECT_SOURCE_DIR}/generated/bindings/*.cc")

# TODO: There might be a better solution.
if(WITH_NVIDIA OR WITH_ILUVATAR)
Expand Down
Loading