Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 145 additions & 38 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@

_OP_NAMESPACE_PREFIXES = ("special", "linalg", "fft")

# Keep Python returning overloads explicitly opt-in while the API is being
# rolled out. Some C++ returning-capable ops have extra semantics that need
# separate Python review before exposure.
_PY_RETURNING_OPS = frozenset({"add"})


def _op_namespace_parts(op_name):
parts = op_name.split("_")
Expand Down Expand Up @@ -454,9 +459,16 @@ def _find_tensor_params(op_name):
return params


def _has_make_return_value(op_name):
source = _strip_cpp_comments(_find_base_header(op_name).read_text())

return re.search(r"\bMakeReturnValue\s*\(", source) is not None


def _generate_pybind11(operator):
optional_tensor_params = _find_optional_tensor_params(operator.name)
optional_non_tensor_params = _find_optional_non_tensor_params(operator.name)
tensor_params = _find_tensor_params(operator.name)
vector_tensor_params = _find_vector_tensor_params(operator.name)
vector_int64_params = _find_vector_int64_params(operator.name)

Expand All @@ -483,70 +495,91 @@ def _is_vector_tensor(arg):
def _is_vector_int64(arg):
return arg.spelling in vector_int64_params

def _generate_params(node):
parts = []
def _is_tensor(arg):
if arg.spelling in optional_non_tensor_params:
return False

for arg in node.get_arguments():
if arg.spelling == "stream":
continue
if arg.spelling in tensor_params:
return True

return "Tensor" in arg.type.spelling

def _is_plain_tensor(arg):
return (
not _is_optional_tensor(arg)
and not _is_vector_tensor(arg)
and _is_tensor(arg)
)

def _args_without_stream(node):
return [arg for arg in node.get_arguments() if arg.spelling != "stream"]

def _generate_params_for_args(args):
parts = []

for arg in args:
if _is_optional_tensor(arg):
parts.append(f"std::optional<py::object> {arg.spelling}")
elif _is_vector_tensor(arg):
parts.append(f"std::vector<py::object> {arg.spelling}")
elif _is_vector_int64(arg):
parts.append(f"const std::vector<int64_t> {arg.spelling}")
elif _is_tensor(arg):
parts.append(f"py::object {arg.spelling}")
else:
param = arg.type.spelling.replace("const Tensor", "py::object").replace(
"Tensor", "py::object"
)
parts.append(f"{param} {arg.spelling}")
parts.append(f"{arg.type.spelling} {arg.spelling}")

return ", ".join(parts)

def _generate_arguments(node):
args = []
def _generate_params(node):
return _generate_params_for_args(_args_without_stream(node))

for arg in node.get_arguments():
if arg.spelling == "stream":
continue
def _generate_arguments_for_args(args):
generated_args = []

for arg in args:
if _is_optional_tensor(arg):
args.append(f"OptionalTensorFromPybind11Handle({arg.spelling})")
generated_args.append(
f"OptionalTensorFromPybind11Handle({arg.spelling})"
)
elif _is_vector_tensor(arg):
args.append(f"VectorTensorFromPybind11Handle({arg.spelling})")
elif "Tensor" in arg.type.spelling:
args.append(f"TensorFromPybind11Handle({arg.spelling})")
generated_args.append(f"VectorTensorFromPybind11Handle({arg.spelling})")
elif _is_tensor(arg):
generated_args.append(f"TensorFromPybind11Handle({arg.spelling})")
else:
args.append(arg.spelling)
generated_args.append(arg.spelling)

return ", ".join(args)
return ", ".join(generated_args)

def _generate_arguments(node):
return _generate_arguments_for_args(_args_without_stream(node))

op_name = operator.name
op_type = _op_cpp_type(op_name)
symbol_name = _op_symbol_name(op_name)

def _first_tensor_arg(node):
for arg in node.get_arguments():
if arg.spelling == "stream":
continue
def _first_tensor_arg_for_args(args):
for arg in args:
if _is_optional_tensor(arg):
continue
if _is_vector_tensor(arg):
return f"{arg.spelling}.at(0)"
if "Tensor" in arg.type.spelling:
if _is_tensor(arg):
return arg.spelling
return None

def _default_impl_index_expr(node):
first_tensor = _first_tensor_arg(node)
def _default_impl_index_expr_for_args(args):
first_tensor = _first_tensor_arg_for_args(args)
if first_tensor is None:
return "0"
return (
f"DefaultImplementationIndexFor{symbol_name}("
f"DeviceFromPybind11Handle({first_tensor}).type())"
)

def _default_impl_index_expr(node):
return _default_impl_index_expr_for_args(_args_without_stream(node))

def _generate_init(constructor):
constructor_params = _generate_params(constructor)
default_impl_index = _default_impl_index_expr(constructor)
Expand All @@ -557,20 +590,20 @@ def _generate_init(constructor):
return std::unique_ptr<Self>{{static_cast<Self*>(generated_dispatch::Make{symbol_name}(config, {_generate_arguments(constructor)}).release())}};
}}))"""

def _generate_py_args(node):
def _generate_py_args_for_args(args):
parts = []

for arg in node.get_arguments():
if arg.spelling == "stream":
continue

for arg in args:
if _is_optional(arg):
parts.append(f'py::arg("{arg.spelling}") = py::none()')
else:
parts.append(f'py::arg("{arg.spelling}")')

return ", ".join(parts)

def _generate_py_args(node):
return _generate_py_args_for_args(_args_without_stream(node))

def _generate_call(op_name, call, method=True):
call_params = _generate_params(call)
call_args = _generate_arguments(call)
Expand Down Expand Up @@ -609,6 +642,74 @@ def _generate_call(op_name, call, method=True):
return generated_dispatch::Invoke{symbol_name}(op, {call_args});
}})"""

def _is_returning_candidate(call):
args = _args_without_stream(call)

if len(args) < 2 or not _is_plain_tensor(args[-1]):
return False

if not any(_is_plain_tensor(arg) for arg in args[:-1]):
return False

return not any(
_is_optional_tensor(arg) or _is_vector_tensor(arg) for arg in args[:-1]
)

def _returning_local_name(arg):
return f"infini_{arg.spelling}"

def _generate_returning_call(call):
args = _args_without_stream(call)
input_args = args[:-1]
call_params = _generate_params_for_args(input_args)
params = (
f"{call_params}, std::uintptr_t stream, "
"std::optional<std::size_t> implementation_index"
if call_params
else "std::uintptr_t stream, "
"std::optional<std::size_t> implementation_index"
)
py_args = _generate_py_args_for_args(input_args)
py_args_str = f"{py_args}, " if py_args else ""
default_impl_index = _default_impl_index_expr_for_args(input_args)
tensor_locals = "\n".join(
f" Pybind11Tensor {_returning_local_name(arg)}{{{arg.spelling}}};"
for arg in input_args
if _is_plain_tensor(arg)
)
make_args = ", ".join(
_returning_local_name(arg) if _is_plain_tensor(arg) else arg.spelling
for arg in input_args
)
dispatch_args = ", ".join(
f"Tensor{{{_returning_local_name(arg)}}}"
if _is_plain_tensor(arg)
else arg.spelling
for arg in input_args
)
dispatch_args = (
f"{dispatch_args}, Tensor{{out}}" if dispatch_args else "Tensor{out}"
)

if tensor_locals:
tensor_locals += "\n"

return (
f' m.def("{op_name}", []({params}) {{\n'
f" Handle handle;\n"
f" if (stream) {{\n"
f" handle.set_stream(reinterpret_cast<void*>(stream));\n"
f" }}\n"
f" Config config;\n"
f" config.set_implementation_index(\n"
f" implementation_index.value_or({default_impl_index}));\n"
f"{tensor_locals}"
f" auto out = Self::MakeReturnValue({make_args});\n"
f" generated_dispatch::Call{symbol_name}(handle, config, {dispatch_args});\n"
f" return out.object();\n"
f' }}, {py_args_str}py::kw_only(), py::arg("stream") = 0, py::arg("implementation_index") = py::none());'
)

def _overload_order_key(node):
"""Sort key that places more-specific overloads first.

Expand All @@ -630,11 +731,7 @@ def _overload_order_key(node):

total += 1

if (
_is_optional_tensor(arg)
or _is_vector_tensor(arg)
or "Tensor" in arg.type.spelling
):
if _is_optional_tensor(arg) or _is_vector_tensor(arg) or _is_tensor(arg):
object_like += 1

return (object_like, -total)
Expand All @@ -644,9 +741,19 @@ def _overload_order_key(node):

inits = "\n".join(_generate_init(constructor) for constructor in constructors)
calls = "\n".join(_generate_call(operator.name, call) for call in operator_calls)
callers = "\n".join(
returning_calls = []
if operator.name in _PY_RETURNING_OPS and _has_make_return_value(operator.name):
returning_calls = sorted(
(call for call in operator_calls if _is_returning_candidate(call)),
key=lambda call: len(_args_without_stream(call)),
)[:1]
returning_callers = "\n".join(
_generate_returning_call(call) for call in returning_calls
)
inplace_callers = "\n".join(
_generate_call(operator.name, call, method=False) for call in operator_calls
)
callers = "\n".join(filter(None, (returning_callers, inplace_callers)))

return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_
#define INFINI_OPS_BINDINGS_{op_name.upper()}_H_
Expand Down
Loading
Loading