From ca2c4d83a2b4e854ff7a19a751655735e9aee9da Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 12 May 2026 17:28:49 +0800 Subject: [PATCH] feat(bindings): support scalar default wrapper generation --- scripts/generate_wrappers.py | 32 ++++++++- tests/test_generate_wrappers.py | 114 ++++++++++++++++++++++++++++++++ 2 files changed, 144 insertions(+), 2 deletions(-) create mode 100644 tests/test_generate_wrappers.py diff --git a/scripts/generate_wrappers.py b/scripts/generate_wrappers.py index effc0787d..5469a0953 100644 --- a/scripts/generate_wrappers.py +++ b/scripts/generate_wrappers.py @@ -53,8 +53,10 @@ def _get_compilers(): system_include_flags = _get_system_include_flags() index = clang.cindex.Index.create() - args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags) - translation_unit = index.parse(f"src/base/{op_name}.h", args=args) + args = ("-std=c++17", "-x", "c++", "-I", str(_SRC_DIR)) + tuple( + system_include_flags + ) + translation_unit = index.parse(str(_BASE_DIR / f"{op_name}.h"), args=args) nodes = tuple(type(self)._find(translation_unit.cursor, op_name)) @@ -112,9 +114,31 @@ def _find_vector_tensor_params(op_name): return set(re.findall(r"std::vector\s+(\w+)", source)) +def _find_params_with_defaults(op_name): + """Return `{param_name: default_literal}` for scalar params with defaults. + + `libclang`'s cursor API does not expose defaults reliably, so we regex-scan + the source. Only used for plain scalar defaults such as + `bool pre_gathered = false`. + """ + source = (_BASE_DIR / f"{op_name}.h").read_text() + + mapping = {} + + for name, default in re.findall( + r"\b(?:bool|int(?:64_t|32_t|8_t|16_t)?|std::size_t|std::uint\w+_t|" + r"float|double)\s+(\w+)\s*=\s*([^,\)]+?)\s*(?:,|\))", + source, + ): + mapping[name] = default.strip() + + return mapping + + def _generate_pybind11(operator): optional_tensor_params = _find_optional_tensor_params(operator.name) vector_tensor_params = _find_vector_tensor_params(operator.name) + params_with_defaults = _find_params_with_defaults(operator.name) def _is_optional_tensor(arg): if arg.spelling in optional_tensor_params: @@ -186,6 +210,10 @@ def _generate_py_args(node): if _is_optional(arg): parts.append(f'py::arg("{arg.spelling}") = py::none()') + elif arg.spelling in params_with_defaults: + parts.append( + f'py::arg("{arg.spelling}") = {params_with_defaults[arg.spelling]}' + ) else: parts.append(f'py::arg("{arg.spelling}")') diff --git a/tests/test_generate_wrappers.py b/tests/test_generate_wrappers.py new file mode 100644 index 000000000..0531059fe --- /dev/null +++ b/tests/test_generate_wrappers.py @@ -0,0 +1,114 @@ +import functools +import importlib.util +from pathlib import Path + +import pytest + + +pytest.importorskip("clang.cindex") + + +@functools.lru_cache(maxsize=1) +def _load_generator(): + script = Path(__file__).parents[1] / "scripts" / "generate_wrappers.py" + spec = importlib.util.spec_from_file_location("generate_wrappers", script) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + return module + + +def _generate_binding(op_name, tmp_path, monkeypatch, source): + generator = _load_generator() + src_dir = tmp_path / "src" + base_dir = src_dir / "base" + base_dir.mkdir(parents=True) + (base_dir / f"{op_name}.h").write_text(source) + monkeypatch.setattr(generator, "_SRC_DIR", src_dir) + monkeypatch.setattr(generator, "_BASE_DIR", base_dir) + operator = generator._OperatorExtractor()(op_name) + + return generator._generate_pybind11(operator) + + +def test_mha_varlen_fwd_requires_out_binding(tmp_path, monkeypatch): + text = _generate_binding( + "mha_varlen_fwd", + tmp_path, + monkeypatch, + """ +#include +#include + +namespace infini::ops { + +struct Tensor {}; + +template +class Operator {}; + +class MhaVarlenFwd : public Operator { + public: + MhaVarlenFwd(const Tensor q, const Tensor k, const Tensor v, Tensor out, + const Tensor cu_seqlens_q, const Tensor cu_seqlens_k, + std::optional block_table, float softmax_scale, + bool is_causal, int64_t num_splits = 0) {} + + virtual void operator()(const Tensor q, const Tensor k, const Tensor v, + Tensor out, const Tensor cu_seqlens_q, + const Tensor cu_seqlens_k, + std::optional block_table, + float softmax_scale, bool is_causal, + int64_t num_splits = 0) const = 0; +}; + +} // namespace infini::ops +""", + ) + + assert 'py::arg("out"), py::arg("cu_seqlens_q")' in text + assert 'py::arg("num_splits") = 0' in text + assert 'py::arg("out") = py::none()' not in text + assert "std::optional out" not in text + assert "OptionalTensorFromPybind11Handle(out)" not in text + + +def test_mha_fwd_kvcache_requires_out_binding(tmp_path, monkeypatch): + text = _generate_binding( + "mha_fwd_kvcache", + tmp_path, + monkeypatch, + """ +#include +#include + +namespace infini::ops { + +struct Tensor {}; + +template +class Operator {}; + +class MhaFwdKvcache : public Operator { + public: + MhaFwdKvcache(const Tensor q, const Tensor kcache, const Tensor vcache, + std::optional k, std::optional v, Tensor out, + float softmax_scale, bool is_causal, + int64_t num_splits = 0) {} + + virtual void operator()(const Tensor q, const Tensor kcache, + const Tensor vcache, std::optional k, + std::optional v, Tensor out, + float softmax_scale, bool is_causal, + int64_t num_splits = 0) const = 0; +}; + +} // namespace infini::ops +""", + ) + + assert 'py::arg("out"), py::arg("softmax_scale")' in text + assert 'py::arg("num_splits") = 0' in text + assert 'py::arg("out") = py::none()' not in text + assert "std::optional out" not in text + assert "OptionalTensorFromPybind11Handle(out)" not in text