Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 123 additions & 0 deletions python/tvm_ffi/dataclasses/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from __future__ import annotations

import copy as copy_module
import functools
from dataclasses import MISSING
from typing import Any, Callable, Type, TypeVar, cast
Expand Down Expand Up @@ -122,6 +123,39 @@ def _get_all_fields(type_info: TypeInfo) -> list[TypeField]:
return fields


def _classify_fields_for_copy(
type_info: TypeInfo,
) -> tuple[list[str], list[str], list[str]]:
"""Classify fields for copy/replace operations.

Returns:
Tuple of (ffi_arg_order, init_fields, non_init_fields):
- ffi_arg_order: Fields passed to FFI constructor
- init_fields: Fields with init=True (replaceable)
- non_init_fields: Fields with init=False

"""
fields = _get_all_fields(type_info)
ffi_arg_order: list[str] = []
init_fields: list[str] = []
non_init_fields: list[str] = []

for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
dataclass_field = field.dataclass_field

if dataclass_field.init:
init_fields.append(field.name)
ffi_arg_order.append(field.name)
elif dataclass_field.default_factory is not MISSING:
ffi_arg_order.append(field.name)
else:
non_init_fields.append(field.name)
Comment thread
guan404ming marked this conversation as resolved.

return ffi_arg_order, init_fields, non_init_fields


def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]:
"""Generate a ``__repr__`` method for the dataclass.

Expand Down Expand Up @@ -243,3 +277,92 @@ def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
exec(source, exec_globals, namespace)
__init__ = namespace["__init__"]
return __init__


def method_copy(_type_cls: type, type_info: TypeInfo) -> Callable[..., Any]:
"""Generate a ``__copy__`` method for the dataclass (shallow copy).

The generated method creates a shallow copy by calling the FFI constructor
directly with the current field values (bypassing custom Python __init__).
"""
ffi_arg_order, _, non_init_fields = _classify_fields_for_copy(type_info)

body_lines: list[str] = []
if ffi_arg_order:
ffi_args = ", ".join(f"self.{name}" for name in ffi_arg_order)
body_lines.append(f"new_obj = type(self).__c_ffi_init__({ffi_args})")
else:
body_lines.append("new_obj = type(self).__c_ffi_init__()")
for name in non_init_fields:
body_lines.append(f"new_obj.{name} = self.{name}")
body_lines.append("return new_obj")

source_lines = ["def __copy__(self):"]
source_lines.extend(f" {line}" for line in body_lines)
source = "\n".join(source_lines)

namespace: dict[str, Any] = {}
exec(source, {}, namespace)
return namespace["__copy__"]


def method_deepcopy(_type_cls: type, type_info: TypeInfo) -> Callable[..., Any]:
"""Generate a ``__deepcopy__`` method for the dataclass.

The generated method creates a deep copy using copy.deepcopy for field values,
handling circular references via the memo dictionary.
"""
ffi_arg_order, _, non_init_fields = _classify_fields_for_copy(type_info)

body_lines: list[str] = []
if ffi_arg_order:
ffi_args = ", ".join(f"_copy_deepcopy(self.{name}, memo)" for name in ffi_arg_order)
body_lines.append(f"new_obj = type(self).__c_ffi_init__({ffi_args})")
else:
body_lines.append("new_obj = type(self).__c_ffi_init__()")
body_lines.append("memo[id(self)] = new_obj")
for name in non_init_fields:
body_lines.append(f"new_obj.{name} = _copy_deepcopy(self.{name}, memo)")
body_lines.append("return new_obj")

source_lines = ["def __deepcopy__(self, memo):"]
source_lines.extend(f" {line}" for line in body_lines)
source = "\n".join(source_lines)

exec_globals: dict[str, Any] = {"_copy_deepcopy": copy_module.deepcopy}
namespace: dict[str, Any] = {}
exec(source, exec_globals, namespace)
return namespace["__deepcopy__"]


def method_replace(_type_cls: type, type_info: TypeInfo) -> Callable[..., Any]:
"""Generate a ``__replace__`` method for the dataclass.

The generated method returns a new instance with specified fields replaced.
Only fields with init=True can be changed. Fields with init=False are copied unchanged.
"""
ffi_arg_order, init_fields, non_init_fields = _classify_fields_for_copy(type_info)

body_lines: list[str] = []
body_lines.append("for key in changes:")
body_lines.append(" if key not in _valid_fields:")
body_lines.append(
" raise TypeError(f\"__replace__() got an unexpected keyword argument '{key}'\")"
)
if ffi_arg_order:
ffi_args = ", ".join(f"changes.get('{name}', self.{name})" for name in ffi_arg_order)
body_lines.append(f"new_obj = type(self).__c_ffi_init__({ffi_args})")
else:
body_lines.append("new_obj = type(self).__c_ffi_init__()")
for name in non_init_fields:
body_lines.append(f"new_obj.{name} = self.{name}")
body_lines.append("return new_obj")

source_lines = ["def __replace__(self, **changes):"]
source_lines.extend(f" {line}" for line in body_lines)
source = "\n".join(source_lines)

exec_globals: dict[str, Any] = {"_valid_fields": frozenset(init_fields)}
namespace: dict[str, Any] = {}
exec(source, exec_globals, namespace)
return namespace["__replace__"]
20 changes: 19 additions & 1 deletion python/tvm_ffi/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,28 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else None
fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else None

# Generate copy methods unless user has defined them
fn_copy = None
fn_deepcopy = None
fn_replace = None
if "__copy__" not in super_type_cls.__dict__:
fn_copy = _utils.method_copy(super_type_cls, type_info)
if "__deepcopy__" not in super_type_cls.__dict__:
fn_deepcopy = _utils.method_deepcopy(super_type_cls, type_info)
if "__replace__" not in super_type_cls.__dict__:
fn_replace = _utils.method_replace(super_type_cls, type_info)

type_cls: Type[_InputClsType] = _utils.type_info_to_cls( # noqa: UP006
type_info=type_info,
cls=super_type_cls,
methods={"__init__": fn_init, "__repr__": fn_repr},
methods={
"__init__": fn_init,
"__repr__": fn_repr,
"__copy__": fn_copy,
"__deepcopy__": fn_deepcopy,
"__replace__": fn_replace,
},
)
_set_type_cls(type_info, type_cls)
return type_cls
Expand Down
84 changes: 84 additions & 0 deletions tests/python/test_dataclasses_c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import inspect
from dataclasses import MISSING

Expand Down Expand Up @@ -184,3 +185,86 @@ def test_field_kw_only_with_default() -> None:

def test_kw_only_sentinel_exists() -> None:
assert isinstance(KW_ONLY, _KW_ONLY_TYPE)


def test_cxx_class_copy() -> None:
obj = _TestCxxClassDerivedDerived(
v_i64=1, v_i32=2, v_f64=3.0, v_f32=4.0, v_str="hello", v_bool=True
)
obj_copy = copy.copy(obj)

assert obj_copy.v_i64 == 1
assert obj_copy.v_i32 == 2
assert obj_copy.v_f64 == 3.0
assert obj_copy.v_f32 == 4.0
assert obj_copy.v_str == "hello"
assert obj_copy.v_bool is True
assert obj is not obj_copy


def test_cxx_class_deepcopy() -> None:
obj = _TestCxxClassDerivedDerived(
v_i64=1, v_i32=2, v_f64=3.0, v_f32=4.0, v_str="hello", v_bool=True
)
obj_copy = copy.deepcopy(obj)

assert obj_copy.v_i64 == 1
assert obj_copy.v_i32 == 2
assert obj_copy.v_f64 == 3.0
assert obj_copy.v_f32 == 4.0
assert obj_copy.v_str == "hello"
assert obj_copy.v_bool is True
assert obj is not obj_copy


def test_cxx_class_replace() -> None:
obj = _TestCxxClassDerivedDerived(
v_i64=1, v_i32=2, v_f64=3.0, v_f32=4.0, v_str="hello", v_bool=True
)
obj_new = obj.__replace__(v_i64=100, v_str="world") # type: ignore[attr-defined]

assert obj_new.v_i64 == 100
assert obj_new.v_i32 == 2
assert obj_new.v_f64 == 3.0
assert obj_new.v_f32 == 4.0
assert obj_new.v_str == "world"
assert obj_new.v_bool is True
assert obj is not obj_new


def test_cxx_class_replace_invalid_field() -> None:
obj = _TestCxxClassDerived(v_i64=123, v_i32=456, v_f64=4.0, v_f32=8.0)

with pytest.raises(TypeError, match="unexpected keyword argument"):
obj.__replace__(nonexistent_field=42) # type: ignore[attr-defined]


def test_cxx_class_copy_init_false_field() -> None:
obj = _TestCxxInitSubset(required_field=42)
obj.optional_field = 100 # Modify the init=False field

obj_copy = copy.copy(obj)

assert obj_copy.required_field == 42
assert obj_copy.optional_field == 100
assert obj_copy.note == "py-default"
assert obj is not obj_copy


def test_cxx_class_deepcopy_init_false_field() -> None:
obj = _TestCxxInitSubset(required_field=42)
obj.optional_field = 100 # Modify the init=False field

obj_copy = copy.deepcopy(obj)

assert obj_copy.required_field == 42
assert obj_copy.optional_field == 100
assert obj_copy.note == "py-default"
assert obj is not obj_copy


def test_cxx_class_replace_rejects_init_false_field() -> None:
obj = _TestCxxInitSubset(required_field=42)

with pytest.raises(TypeError, match="unexpected keyword argument"):
obj.__replace__(optional_field=100) # type: ignore[attr-defined]