Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,10 @@ def _get_compilers():
system_include_flags = _get_system_include_flags()

index = clang.cindex.Index.create()
args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags)
translation_unit = index.parse(f"src/base/{op_name}.h", args=args)
args = ("-std=c++17", "-x", "c++", "-I", str(_SRC_DIR)) + tuple(
system_include_flags
)
translation_unit = index.parse(str(_BASE_DIR / f"{op_name}.h"), args=args)

nodes = tuple(type(self)._find(translation_unit.cursor, op_name))

Expand Down Expand Up @@ -112,9 +114,31 @@ def _find_vector_tensor_params(op_name):
return set(re.findall(r"std::vector<Tensor>\s+(\w+)", source))


def _find_params_with_defaults(op_name):
"""Return `{param_name: default_literal}` for scalar params with defaults.

`libclang`'s cursor API does not expose defaults reliably, so we regex-scan
the source. Only used for plain scalar defaults such as
`bool pre_gathered = false`.
"""
source = (_BASE_DIR / f"{op_name}.h").read_text()

mapping = {}

for name, default in re.findall(
r"\b(?:bool|int(?:64_t|32_t|8_t|16_t)?|std::size_t|std::uint\w+_t|"
r"float|double)\s+(\w+)\s*=\s*([^,\)]+?)\s*(?:,|\))",
source,
):
mapping[name] = default.strip()

return mapping


def _generate_pybind11(operator):
optional_tensor_params = _find_optional_tensor_params(operator.name)
vector_tensor_params = _find_vector_tensor_params(operator.name)
params_with_defaults = _find_params_with_defaults(operator.name)

def _is_optional_tensor(arg):
if arg.spelling in optional_tensor_params:
Expand Down Expand Up @@ -186,6 +210,10 @@ def _generate_py_args(node):

if _is_optional(arg):
parts.append(f'py::arg("{arg.spelling}") = py::none()')
elif arg.spelling in params_with_defaults:
parts.append(
f'py::arg("{arg.spelling}") = {params_with_defaults[arg.spelling]}'
)
else:
parts.append(f'py::arg("{arg.spelling}")')

Expand Down
114 changes: 114 additions & 0 deletions tests/test_generate_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import functools
import importlib.util
from pathlib import Path

import pytest


pytest.importorskip("clang.cindex")


@functools.lru_cache(maxsize=1)
def _load_generator():
script = Path(__file__).parents[1] / "scripts" / "generate_wrappers.py"
spec = importlib.util.spec_from_file_location("generate_wrappers", script)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)

return module


def _generate_binding(op_name, tmp_path, monkeypatch, source):
generator = _load_generator()
src_dir = tmp_path / "src"
base_dir = src_dir / "base"
base_dir.mkdir(parents=True)
(base_dir / f"{op_name}.h").write_text(source)
monkeypatch.setattr(generator, "_SRC_DIR", src_dir)
monkeypatch.setattr(generator, "_BASE_DIR", base_dir)
operator = generator._OperatorExtractor()(op_name)

return generator._generate_pybind11(operator)


def test_mha_varlen_fwd_requires_out_binding(tmp_path, monkeypatch):
text = _generate_binding(
"mha_varlen_fwd",
tmp_path,
monkeypatch,
"""
#include <cstdint>
#include <optional>

namespace infini::ops {

struct Tensor {};

template <typename T>
class Operator {};

class MhaVarlenFwd : public Operator<MhaVarlenFwd> {
public:
MhaVarlenFwd(const Tensor q, const Tensor k, const Tensor v, Tensor out,
const Tensor cu_seqlens_q, const Tensor cu_seqlens_k,
std::optional<Tensor> block_table, float softmax_scale,
bool is_causal, int64_t num_splits = 0) {}

virtual void operator()(const Tensor q, const Tensor k, const Tensor v,
Tensor out, const Tensor cu_seqlens_q,
const Tensor cu_seqlens_k,
std::optional<Tensor> block_table,
float softmax_scale, bool is_causal,
int64_t num_splits = 0) const = 0;
};

} // namespace infini::ops
""",
)

assert 'py::arg("out"), py::arg("cu_seqlens_q")' in text
assert 'py::arg("num_splits") = 0' in text
assert 'py::arg("out") = py::none()' not in text
assert "std::optional<py::object> out" not in text
assert "OptionalTensorFromPybind11Handle(out)" not in text


def test_mha_fwd_kvcache_requires_out_binding(tmp_path, monkeypatch):
text = _generate_binding(
"mha_fwd_kvcache",
tmp_path,
monkeypatch,
"""
#include <cstdint>
#include <optional>

namespace infini::ops {

struct Tensor {};

template <typename T>
class Operator {};

class MhaFwdKvcache : public Operator<MhaFwdKvcache> {
public:
MhaFwdKvcache(const Tensor q, const Tensor kcache, const Tensor vcache,
std::optional<Tensor> k, std::optional<Tensor> v, Tensor out,
float softmax_scale, bool is_causal,
int64_t num_splits = 0) {}

virtual void operator()(const Tensor q, const Tensor kcache,
const Tensor vcache, std::optional<Tensor> k,
std::optional<Tensor> v, Tensor out,
float softmax_scale, bool is_causal,
int64_t num_splits = 0) const = 0;
};

} // namespace infini::ops
""",
)

assert 'py::arg("out"), py::arg("softmax_scale")' in text
assert 'py::arg("num_splits") = 0' in text
assert 'py::arg("out") = py::none()' not in text
assert "std::optional<py::object> out" not in text
assert "OptionalTensorFromPybind11Handle(out)" not in text
Loading