Skip to content

Commit cc97e05

Browse files
committed
feat: Restrict __slots__=() for subclasses of tvm_ffi.Object by default
1 parent 22f22e8 commit cc97e05

13 files changed

Lines changed: 218 additions & 122 deletions

File tree

docs/conf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ def _import_cls(cls_name: str) -> type | None:
329329
"__ffi_init__",
330330
"__from_extern_c__",
331331
"__from_mlir_packed_safe_call__",
332+
"_move",
333+
"__move_handle_from__",
334+
"__init_handle_by_constructor__",
332335
}
333336
# If a member method comes from one of these native types, hide it in the docs
334337
_py_native_classes: tuple[type, ...] = (

python/tvm_ffi/core.pyi

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ _TRACEBACK_TO_BACKTRACE_STR: Callable[[types.TracebackType | None], str] | None
3232
# DLPack protocol version (defined in tensor.pxi)
3333
__dlpack_version__: tuple[int, int]
3434

35-
class Object:
35+
class CObject:
3636
def __ctypes_handle__(self) -> Any: ...
3737
def __chandle__(self) -> int: ...
3838
def __reduce__(self) -> Any: ...
@@ -46,7 +46,9 @@ class Object:
4646
def __ffi_init__(self, *args: Any) -> None: ...
4747
def same_as(self, other: Any) -> bool: ...
4848
def _move(self) -> ObjectRValueRef: ...
49-
def __move_handle_from__(self, other: Object) -> None: ...
49+
def __move_handle_from__(self, other: CObject) -> None: ...
50+
51+
class Object(CObject): ...
5052

5153
class ObjectConvertible:
5254
def asobject(self) -> Object: ...

python/tvm_ffi/cython/base.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,7 @@ cdef extern from "tvm_ffi_python_helpers.h":
392392

393393

394394
cdef class ByteArrayArg:
395+
__slots__ = ()
395396
cdef TVMFFIByteArray cdata
396397
cdef object py_data
397398

python/tvm_ffi/cython/device.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ cdef class Device:
9191
assert str(dev) == "cuda:0"
9292
9393
"""
94+
__slots__ = ()
9495
cdef DLDevice cdevice
9596

9697
_DEVICE_TYPE_TO_NAME = {

python/tvm_ffi/cython/dtype.pxi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cdef class DataType:
7979
assert str(d) == "int32"
8080
8181
"""
82+
__slots__ = ()
8283
cdef DLDataType cdtype
8384

8485
def __init__(self, dtype_str: str) -> None:

python/tvm_ffi/cython/error.pxi

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ _WITH_APPEND_BACKTRACE: Optional[Callable[[BaseException, str], BaseException]]
2727
_TRACEBACK_TO_BACKTRACE_STR: Optional[Callable[[types.TracebackType | None], str]] = None
2828

2929

30-
cdef class Error(Object):
30+
cdef class Error(CObject):
3131
"""Base class for FFI errors.
3232
3333
An :class:`Error` is a lightweight wrapper around a concrete Python
@@ -43,6 +43,7 @@ cdef class Error(Object):
4343
Do not directly raise this object. Instead, use :py:meth:`py_error`
4444
to convert it to a Python exception and raise that.
4545
"""
46+
__slots__ = ()
4647

4748
def __init__(self, kind: str, message: str, backtrace: str):
4849
"""Construct an error wrapper.
@@ -66,7 +67,7 @@ cdef class Error(Object):
6667
)
6768
if ret != 0:
6869
raise MemoryError("Failed to create error object")
69-
(<Object>self).chandle = out
70+
(<CObject>self).chandle = out
7071

7172
def update_backtrace(self, backtrace: str) -> None:
7273
"""Replace the stored backtrace string with ``backtrace``.
@@ -107,7 +108,7 @@ cdef class Error(Object):
107108
cdef inline Error move_from_last_error():
108109
# raise last error
109110
error = Error.__new__(Error)
110-
TVMFFIErrorMoveFromRaised(&(<Object>error).chandle)
111+
TVMFFIErrorMoveFromRaised(&(<CObject>error).chandle)
111112
return error
112113

113114

python/tvm_ffi/cython/function.pxi

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ cdef int TVMFFIPyArgSetterTensor_(
131131
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
132132
PyObject* arg, TVMFFIAny* out
133133
) except -1:
134-
if (<Object>arg).chandle != NULL:
134+
if (<CObject>arg).chandle != NULL:
135135
out.type_index = kTVMFFITensor
136136
out.v_ptr = (<Tensor>arg).chandle
137137
else:
@@ -144,8 +144,8 @@ cdef int TVMFFIPyArgSetterObject_(
144144
TVMFFIPyArgSetter* handle, TVMFFIPyCallContext* ctx,
145145
PyObject* arg, TVMFFIAny* out
146146
) except -1:
147-
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
148-
out.v_ptr = (<Object>arg).chandle
147+
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
148+
out.v_ptr = (<CObject>arg).chandle
149149
return 0
150150

151151

@@ -312,7 +312,7 @@ cdef int TVMFFIPyArgSetterFFIObjectProtocol_(
312312
"""Setter for objects that implement the `__tvm_ffi_object__` protocol."""
313313
cdef object arg = <object>py_arg
314314
cdef TVMFFIObjectHandle temp_chandle
315-
cdef Object obj = arg.__tvm_ffi_object__()
315+
cdef CObject obj = arg.__tvm_ffi_object__()
316316
cdef long ref_count = Py_REFCNT(obj)
317317
temp_chandle = obj.chandle
318318
out.type_index = TVMFFIObjectGetTypeIndex(temp_chandle)
@@ -418,8 +418,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectStr_(
418418
# need to check if the arg is a large string returned from ffi
419419
if arg._tvm_ffi_cached_object is not None:
420420
arg = arg._tvm_ffi_cached_object
421-
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
422-
out.v_ptr = (<Object>arg).chandle
421+
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
422+
out.v_ptr = (<CObject>arg).chandle
423423
return 0
424424
return TVMFFIPyArgSetterStr_(handle, ctx, py_arg, out)
425425

@@ -457,8 +457,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectBytes_(
457457
# need to check if the arg is a large bytes returned from ffi
458458
if arg._tvm_ffi_cached_object is not None:
459459
arg = arg._tvm_ffi_cached_object
460-
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
461-
out.v_ptr = (<Object>arg).chandle
460+
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
461+
out.v_ptr = (<CObject>arg).chandle
462462
return 0
463463
return TVMFFIPyArgSetterBytes_(handle, ctx, py_arg, out)
464464

@@ -473,8 +473,8 @@ cdef int TVMFFIPyArgSetterPyNativeObjectGeneral_(
473473
raise ValueError(f"_tvm_ffi_cached_object is None for {type(arg)}")
474474
assert arg._tvm_ffi_cached_object is not None
475475
arg = arg._tvm_ffi_cached_object
476-
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
477-
out.v_ptr = (<Object>arg).chandle
476+
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
477+
out.v_ptr = (<CObject>arg).chandle
478478
return 0
479479

480480

@@ -507,7 +507,7 @@ cdef int TVMFFIPyArgSetterObjectRValueRef_(
507507
"""Setter for ObjectRValueRef"""
508508
cdef object arg = <object>py_arg
509509
out.type_index = kTVMFFIObjectRValueRef
510-
out.v_ptr = &((<Object>(arg.obj)).chandle)
510+
out.v_ptr = &((<CObject>(arg.obj)).chandle)
511511
return 0
512512

513513

@@ -532,8 +532,8 @@ cdef int TVMFFIPyArgSetterException_(
532532
"""Setter for Exception"""
533533
cdef object arg = <object>py_arg
534534
arg = _convert_to_ffi_error(arg)
535-
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
536-
out.v_ptr = (<Object>arg).chandle
535+
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
536+
out.v_ptr = (<CObject>arg).chandle
537537
TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
538538
return 0
539539

@@ -595,8 +595,8 @@ cdef int TVMFFIPyArgSetterObjectConvertible_(
595595
# recursively construct a new map
596596
cdef object arg = <object>py_arg
597597
arg = arg.asobject()
598-
out.type_index = TVMFFIObjectGetTypeIndex((<Object>arg).chandle)
599-
out.v_ptr = (<Object>arg).chandle
598+
out.type_index = TVMFFIObjectGetTypeIndex((<CObject>arg).chandle)
599+
out.v_ptr = (<CObject>arg).chandle
600600
TVMFFIPyPushTempPyObject(ctx, <PyObject*>arg)
601601

602602

@@ -727,7 +727,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
727727
if isinstance(arg, Tensor):
728728
out.func = TVMFFIPyArgSetterTensor_
729729
return 0
730-
if isinstance(arg, Object):
730+
if isinstance(arg, CObject):
731731
out.func = TVMFFIPyArgSetterObject_
732732
return 0
733733
if isinstance(arg, ObjectRValueRef):
@@ -857,7 +857,7 @@ cdef int TVMFFIPyArgSetterFactory_(PyObject* value, TVMFFIPyArgSetter* out) exce
857857
# ---------------------------------------------------------------------------------------------
858858
# Implementation of function calling
859859
# ---------------------------------------------------------------------------------------------
860-
cdef class Function(Object):
860+
cdef class Function(CObject):
861861
"""Callable wrapper around a TVM FFI function.
862862
863863
Instances are obtained by converting Python callables with
@@ -908,7 +908,7 @@ cdef class Function(Object):
908908
result.v_int64 = 0
909909
TVMFFIPyFuncCall(
910910
TVMFFIPyArgSetterFactory_,
911-
(<Object>self).chandle, <PyObject*>args,
911+
(<CObject>self).chandle, <PyObject*>args,
912912
&result,
913913
&c_api_ret_code,
914914
self.release_gil,
@@ -972,7 +972,7 @@ cdef class Function(Object):
972972

973973
CHECK_CALL(ret_code)
974974
func = Function.__new__(Function)
975-
(<Object>func).chandle = chandle
975+
(<CObject>func).chandle = chandle
976976
return func
977977

978978
@staticmethod
@@ -1026,7 +1026,7 @@ cdef class Function(Object):
10261026
TVMFFIPyMLIRPackedSafeCallDeleter(mlir_packed_safe_call)
10271027
CHECK_CALL(ret_code)
10281028
func = Function.__new__(Function)
1029-
(<Object>func).chandle = chandle
1029+
(<CObject>func).chandle = chandle
10301030
return func
10311031

10321032

@@ -1039,7 +1039,7 @@ def _register_global_func(name: str, pyfunc: Callable[..., Any] | Function, over
10391039
if not isinstance(pyfunc, Function):
10401040
pyfunc = _convert_to_ffi_func(pyfunc)
10411041

1042-
CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(), (<Object>pyfunc).chandle, ioverride))
1042+
CHECK_CALL(TVMFFIFunctionSetGlobal(name_arg.cptr(), (<CObject>pyfunc).chandle, ioverride))
10431043
return pyfunc
10441044

10451045

@@ -1050,7 +1050,7 @@ def _get_global_func(name: str, allow_missing: bool):
10501050
CHECK_CALL(TVMFFIFunctionGetGlobal(name_arg.cptr(), &chandle))
10511051
if chandle != NULL:
10521052
ret = Function.__new__(Function)
1053-
(<Object>ret).chandle = chandle
1053+
(<CObject>ret).chandle = chandle
10541054
return ret
10551055

10561056
if allow_missing:
@@ -1105,7 +1105,7 @@ def _convert_to_ffi_func(object pyfunc: Callable[..., Any]) -> Function:
11051105
cdef TVMFFIObjectHandle chandle
11061106
_convert_to_ffi_func_handle(pyfunc, &chandle)
11071107
ret = Function.__new__(Function)
1108-
(<Object>ret).chandle = chandle
1108+
(<CObject>ret).chandle = chandle
11091109
return ret
11101110

11111111

@@ -1127,7 +1127,7 @@ def _convert_to_opaque_object(object pyobject: Any) -> OpaquePyObject:
11271127
cdef TVMFFIObjectHandle chandle
11281128
_convert_to_opaque_object_handle(pyobject, &chandle)
11291129
ret = OpaquePyObject.__new__(OpaquePyObject)
1130-
(<Object>ret).chandle = chandle
1130+
(<CObject>ret).chandle = chandle
11311131
return ret
11321132

11331133

0 commit comments

Comments
 (0)