Skip to content

Commit 8cf3756

Browse files
committed
feat: Support direct construction of tvm_ffi.Function from callables
Add `__init__` to the Cython `Function` class so users can write: fadd = tvm_ffi.Function(add) instead of going through `tvm_ffi.convert_to_tvm_func`. The constructor: - Validates that the argument is callable - Copies the handle (with incref) if the input is already a Function - Otherwise converts via `_convert_to_ffi_func_handle` Also adds the corresponding type stub in `core.pyi` and a test in `test_function.py`. Re-lands the original change reverted in #406.
1 parent 934f2d1 commit 8cf3756

3 files changed

Lines changed: 31 additions & 0 deletions

File tree

python/tvm_ffi/core.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ class DLTensorTestWrapper:
184184
def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...
185185

186186
class Function(Object):
187+
def __init__(self, func: Callable[..., Any]) -> None: ...
187188
@property
188189
def release_gil(self) -> bool: ...
189190
@release_gil.setter

python/tvm_ffi/cython/function.pxi

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,27 @@ cdef class Function(Object):
889889
def __cinit__(self) -> None:
890890
self.c_release_gil = _RELEASE_GIL_BY_DEFAULT
891891

892+
def __init__(self, func: Callable[..., Any]) -> None:
893+
"""Initialize a Function from a Python callable.
894+
895+
This constructor allows creating a `tvm_ffi.Function` directly
896+
from a Python function or another `tvm_ffi.Function` instance.
897+
898+
Parameters
899+
----------
900+
func : Callable[..., Any]
901+
The Python callable to wrap.
902+
"""
903+
cdef TVMFFIObjectHandle chandle = NULL
904+
if not callable(func):
905+
raise TypeError(f"func must be callable, got {type(func)}")
906+
if isinstance(func, Function):
907+
chandle = (<Object>func).chandle
908+
TVMFFIObjectIncRef(chandle)
909+
else:
910+
_convert_to_ffi_func_handle(func, &chandle)
911+
self.chandle = chandle
912+
892913
property release_gil:
893914
"""Whether calls release the Python GIL while executing."""
894915

tests/python/test_function.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,15 @@ def fapply(f: Any, *args: Any) -> Any:
153153
assert fapply(add, 1, 3.3) == 4.3
154154

155155

156+
def test_pyfunc_init() -> None:
157+
def add(a: int, b: int) -> int:
158+
return a + b
159+
160+
fadd = tvm_ffi.Function(add)
161+
assert isinstance(fadd, tvm_ffi.Function)
162+
assert fadd(1, 2) == 3
163+
164+
156165
def test_global_func() -> None:
157166
@tvm_ffi.register_global_func("mytest.echo")
158167
def echo(x: Any) -> Any:

0 commit comments

Comments
 (0)