From 829b8fb88850ab4a87dedfc65ff68e777d37425a Mon Sep 17 00:00:00 2001 From: vaggelisd Date: Fri, 27 Feb 2026 11:37:59 +0200 Subject: [PATCH] Fix init_subclass running before ClassVar instatiation --- mypyc/irbuild/classdef.py | 5 + mypyc/lib-rt/CPy.h | 1 + mypyc/lib-rt/misc_ops.c | 13 ++- mypyc/primitives/misc_ops.py | 9 ++ mypyc/test-data/fixtures/ir.py | 1 + mypyc/test-data/irbuild-classes.test | 142 ++++++++++++++------------- mypyc/test-data/run-classes.test | 33 +++++++ 7 files changed, 133 insertions(+), 71 deletions(-) diff --git a/mypyc/irbuild/classdef.py b/mypyc/irbuild/classdef.py index d59355c33a500..afafd4063e384 100644 --- a/mypyc/irbuild/classdef.py +++ b/mypyc/irbuild/classdef.py @@ -80,6 +80,7 @@ import_op, not_implemented_op, py_calc_meta_op, + py_init_subclass_op, pytype_from_template_op, type_object_op, ) @@ -315,6 +316,10 @@ def add_attr(self, lvalue: NameExpr, stmt: AssignmentStmt) -> None: self.builder.init_final_static(lvalue, value, self.cdef.name) def finalize(self, ir: ClassIR) -> None: + # Call __init_subclass__ after class attributes have been set + if self.type_obj is not None: + self.builder.call_c(py_init_subclass_op, [self.type_obj], self.cdef.line) + attrs_with_defaults, default_assignments = find_attr_initializers( self.builder, self.cdef, self.skip_attr_default ) diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 10f1448a2dde9..c6d8e1b01eb5f 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -916,6 +916,7 @@ PyObject *CPyType_FromTemplate(PyObject *template_, PyObject *CPyType_FromTemplateWrapper(PyObject *template_, PyObject *orig_bases, PyObject *modname); +bool CPy_InitSubclass(PyObject *type); int CPyDataclass_SleightOfHand(PyObject *dataclass_dec, PyObject *tp, PyObject *dict, PyObject *annotations, PyObject *dataclass_type); diff --git a/mypyc/lib-rt/misc_ops.c b/mypyc/lib-rt/misc_ops.c index 64b4ff67b942d..733f7ad444876 100644 --- a/mypyc/lib-rt/misc_ops.c +++ b/mypyc/lib-rt/misc_ops.c @@ -303,9 +303,6 @@ PyObject *CPyType_FromTemplate(PyObject *template, if (PyObject_SetAttr((PyObject *)t, mypyc_interned_str.__module__, modname) < 0) goto error; - if (init_subclass((PyTypeObject *)t, NULL)) - goto error; - Py_XDECREF(dummy_class); // Unlike the tp_doc slots of most other object, a heap type's tp_doc @@ -338,6 +335,16 @@ PyObject *CPyType_FromTemplate(PyObject *template, return NULL; } +// Call __init_subclass__ on the appropriate base class of type. +// This is separated from CPyType_FromTemplate so that class attributes +// can be set before __init_subclass__ is called. +bool CPy_InitSubclass(PyObject *type) { + if (init_subclass((PyTypeObject *)type, NULL)) { + return false; + } + return true; +} + static int _CPy_UpdateObjFromDict(PyObject *obj, PyObject *dict) { Py_ssize_t pos = 0; diff --git a/mypyc/primitives/misc_ops.py b/mypyc/primitives/misc_ops.py index 01853341b8bf9..c617691f514b0 100644 --- a/mypyc/primitives/misc_ops.py +++ b/mypyc/primitives/misc_ops.py @@ -239,6 +239,15 @@ error_kind=ERR_MAGIC, ) +# Call __init_subclass__ on a type. Separated from CPyType_FromTemplate +# so that class attributes can be set before __init_subclass__ is called. +py_init_subclass_op = custom_op( + arg_types=[object_rprimitive], + return_type=bool_rprimitive, + c_function_name="CPy_InitSubclass", + error_kind=ERR_FALSE, +) + # Create a dataclass from an extension class. See # CPyDataclass_SleightOfHand for more docs. dataclass_sleight_of_hand = custom_op( diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index ee68f7b5a6110..131d1ef554a06 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -42,6 +42,7 @@ class object: __class__: type def __new__(cls) -> Self: pass def __init__(self) -> None: pass + def __init_subclass__(cls, **kwargs: object) -> None: pass def __eq__(self, x: object) -> bool: pass def __ne__(self, x: object) -> bool: pass def __str__(self) -> str: pass diff --git a/mypyc/test-data/irbuild-classes.test b/mypyc/test-data/irbuild-classes.test index eacb7413e4d09..c515d057f2a08 100644 --- a/mypyc/test-data/irbuild-classes.test +++ b/mypyc/test-data/irbuild-classes.test @@ -229,36 +229,39 @@ def __top_level__(): r35 :: str r36 :: i32 r37 :: bit - r38 :: object - r39 :: str - r40, r41 :: object - r42 :: str - r43 :: tuple - r44 :: i32 - r45 :: bit - r46 :: dict - r47 :: str - r48 :: i32 - r49 :: bit - r50, r51 :: object - r52 :: dict - r53 :: str - r54 :: object - r55 :: dict - r56 :: str - r57, r58 :: object - r59 :: tuple - r60 :: str - r61, r62 :: object - r63, r64 :: bool - r65, r66 :: str - r67 :: tuple - r68 :: i32 - r69 :: bit - r70 :: dict - r71 :: str - r72 :: i32 - r73 :: bit + r38 :: bool + r39 :: object + r40 :: str + r41, r42 :: object + r43 :: str + r44 :: tuple + r45 :: i32 + r46 :: bit + r47 :: dict + r48 :: str + r49 :: i32 + r50 :: bit + r51 :: bool + r52, r53 :: object + r54 :: dict + r55 :: str + r56 :: object + r57 :: dict + r58 :: str + r59, r60 :: object + r61 :: tuple + r62 :: str + r63, r64 :: object + r65, r66 :: bool + r67, r68 :: str + r69 :: tuple + r70 :: i32 + r71 :: bit + r72 :: dict + r73 :: str + r74 :: i32 + r75 :: bit + r76 :: bool L0: r0 = builtins :: module r1 = load_address _Py_NoneStruct @@ -306,44 +309,47 @@ L2: r35 = 'C' r36 = PyDict_SetItem(r34, r35, r27) r37 = r36 >= 0 :: signed - r38 = :: object - r39 = '__main__' - r40 = __main__.S_template :: type - r41 = CPyType_FromTemplate(r40, r38, r39) - r42 = '__mypyc_attrs__' - r43 = CPyTuple_LoadEmptyTupleConstant() - r44 = PyObject_SetAttr(r41, r42, r43) - r45 = r44 >= 0 :: signed - __main__.S = r41 :: type - r46 = __main__.globals :: static - r47 = 'S' - r48 = PyDict_SetItem(r46, r47, r41) - r49 = r48 >= 0 :: signed - r50 = __main__.C :: type - r51 = __main__.S :: type - r52 = __main__.globals :: static - r53 = 'Generic' - r54 = CPyDict_GetItem(r52, r53) - r55 = __main__.globals :: static - r56 = 'T' - r57 = CPyDict_GetItem(r55, r56) - r58 = PyObject_GetItem(r54, r57) - r59 = PyTuple_Pack(3, r50, r51, r58) - r60 = '__main__' - r61 = __main__.D_template :: type - r62 = CPyType_FromTemplate(r61, r59, r60) - r63 = D_trait_vtable_setup() - r64 = D_coroutine_setup(r62) - r65 = '__mypyc_attrs__' - r66 = '__dict__' - r67 = PyTuple_Pack(1, r66) - r68 = PyObject_SetAttr(r62, r65, r67) - r69 = r68 >= 0 :: signed - __main__.D = r62 :: type - r70 = __main__.globals :: static - r71 = 'D' - r72 = PyDict_SetItem(r70, r71, r62) - r73 = r72 >= 0 :: signed + r38 = CPy_InitSubclass(r27) + r39 = :: object + r40 = '__main__' + r41 = __main__.S_template :: type + r42 = CPyType_FromTemplate(r41, r39, r40) + r43 = '__mypyc_attrs__' + r44 = CPyTuple_LoadEmptyTupleConstant() + r45 = PyObject_SetAttr(r42, r43, r44) + r46 = r45 >= 0 :: signed + __main__.S = r42 :: type + r47 = __main__.globals :: static + r48 = 'S' + r49 = PyDict_SetItem(r47, r48, r42) + r50 = r49 >= 0 :: signed + r51 = CPy_InitSubclass(r42) + r52 = __main__.C :: type + r53 = __main__.S :: type + r54 = __main__.globals :: static + r55 = 'Generic' + r56 = CPyDict_GetItem(r54, r55) + r57 = __main__.globals :: static + r58 = 'T' + r59 = CPyDict_GetItem(r57, r58) + r60 = PyObject_GetItem(r56, r59) + r61 = PyTuple_Pack(3, r52, r53, r60) + r62 = '__main__' + r63 = __main__.D_template :: type + r64 = CPyType_FromTemplate(r63, r61, r62) + r65 = D_trait_vtable_setup() + r66 = D_coroutine_setup(r64) + r67 = '__mypyc_attrs__' + r68 = '__dict__' + r69 = PyTuple_Pack(1, r68) + r70 = PyObject_SetAttr(r64, r67, r69) + r71 = r70 >= 0 :: signed + __main__.D = r64 :: type + r72 = __main__.globals :: static + r73 = 'D' + r74 = PyDict_SetItem(r72, r73, r64) + r75 = r74 >= 0 :: signed + r76 = CPy_InitSubclass(r64) return 1 [case testIsInstance] diff --git a/mypyc/test-data/run-classes.test b/mypyc/test-data/run-classes.test index cd3a0bf349b71..1fdc8470a62bd 100644 --- a/mypyc/test-data/run-classes.test +++ b/mypyc/test-data/run-classes.test @@ -1113,6 +1113,39 @@ assert f() == 10 A.x = 200 assert f() == 200 +[case testInitSubclassWithClassVar] +from typing import ClassVar + +class Base: + name: ClassVar[str] = "base" + required: ClassVar[int] = -1 + + def __init_subclass__(cls, **kwargs: object) -> None: + cls.required = len(cls.name) + +class Child(Base): + name: ClassVar[str] = "child" + +class GrandChild(Child): + name: ClassVar[str] = "grandchild" + +class NoOverride(Base): + pass + +[file driver.py] +from native import Base, Child, GrandChild, NoOverride + +# __init_subclass__ should see the subclass's own ClassVar values +assert Child.name == "child" +assert Child.required == 5, f"expected 5, got {Child.required}" + +assert GrandChild.name == "grandchild" +assert GrandChild.required == 10, f"expected 10, got {GrandChild.required}" + +# No override should use inherited value +assert NoOverride.name == "base" +assert NoOverride.required == 4, f"expected 4, got {NoOverride.required}" + [case testDefaultVars] from typing import Optional class A: