diff --git a/python/pyfory/format/tests/test_infer.py b/python/pyfory/format/tests/test_infer.py index ea3a517c00..76f60b26b4 100644 --- a/python/pyfory/format/tests/test_infer.py +++ b/python/pyfory/format/tests/test_infer.py @@ -64,6 +64,25 @@ class X: assert result.type.id == TypeId.STRUCT +def test_infer_field_builtin_types_not_treated_as_struct(): + """Built-in types must NOT be routed to visit_customized (regression guard).""" + assert _infer_field("", int).type.id == TypeId.INT64 + assert _infer_field("", float).type.id == TypeId.FLOAT64 + assert _infer_field("", str).type.id == TypeId.STRING + assert _infer_field("", bytes).type.id == TypeId.BINARY + assert _infer_field("", bool).type.id == TypeId.BOOL + + +def test_infer_field_nested_custom_class(): + """Custom class nested inside a List should also be handled correctly.""" + + class Inner: + pass + + result = _infer_field("", List[Inner]) + assert result.type.id == TypeId.LIST + + def test_infer_class_schema(): schema = infer_schema(Foo) assert schema.num_fields == 7 diff --git a/python/pyfory/type_util.py b/python/pyfory/type_util.py index 2b24510272..4d995d75c3 100644 --- a/python/pyfory/type_util.py +++ b/python/pyfory/type_util.py @@ -249,11 +249,17 @@ def infer_field(field_name, type_, visitor: TypeVisitor, types_path=None): else: raise TypeError(f"Collection types should be {list, dict} instead of {type_}") else: - if is_function(origin) or not hasattr(origin, "__annotations__"): + if is_function(origin): return visitor.visit_other(field_name, type_, types_path=types_path) - else: + + if origin in (list, dict, set): + return visitor.visit_other(field_name, type_, types_path=types_path) + + if inspect.isclass(origin) and origin.__module__ not in ("builtins", "datetime"): return visitor.visit_customized(field_name, type_, types_path=types_path) + return visitor.visit_other(field_name, type_, types_path=types_path) + def is_function(func): return inspect.isfunction(func) or is_cython_function(func)