@@ -204,6 +204,8 @@ pub enum FnType {
204204 FnNew ,
205205 /// Represents a pymethod annotated with both `#[new]` and `#[classmethod]` (in either order)
206206 FnNewClass ( Span ) ,
207+ /// Represents a pymethod annotated with `#[init]`, i.e. the `__init__` dunder.
208+ FnInit ( SelfType ) ,
207209 /// Represents a pymethod annotated with `#[classmethod]`, like a `@classmethod`
208210 FnClass ( Span ) ,
209211 /// Represents a pyfunction or a pymethod annotated with `#[staticmethod]`, like a `@staticmethod`
@@ -220,6 +222,7 @@ impl FnType {
220222 FnType :: Getter ( _)
221223 | FnType :: Setter ( _)
222224 | FnType :: Fn ( _)
225+ | FnType :: FnInit ( _)
223226 | FnType :: FnClass ( _)
224227 | FnType :: FnNewClass ( _)
225228 | FnType :: FnModule ( _) => true ,
@@ -231,6 +234,7 @@ impl FnType {
231234 match self {
232235 FnType :: Fn ( _)
233236 | FnType :: FnNew
237+ | FnType :: FnInit ( _)
234238 | FnType :: FnStatic
235239 | FnType :: FnClass ( _)
236240 | FnType :: FnNewClass ( _)
@@ -250,7 +254,7 @@ impl FnType {
250254 ) -> Option < TokenStream > {
251255 let Ctx { pyo3_path, .. } = ctx;
252256 match self {
253- FnType :: Getter ( st) | FnType :: Setter ( st) | FnType :: Fn ( st) => {
257+ FnType :: Getter ( st) | FnType :: Setter ( st) | FnType :: Fn ( st) | FnType :: FnInit ( st ) => {
254258 let mut receiver = st. receiver (
255259 cls. expect ( "no class given for Fn with a \" self\" receiver" ) ,
256260 error_mode,
@@ -378,6 +382,7 @@ pub enum CallingConvention {
378382 Varargs , // METH_VARARGS | METH_KEYWORDS
379383 Fastcall , // METH_FASTCALL | METH_KEYWORDS (not compatible with `abi3` feature before 3.10)
380384 TpNew , // special convention for tp_new
385+ TpInit , // special convention for tp_init
381386}
382387
383388impl CallingConvention {
@@ -476,10 +481,10 @@ impl<'a> FnSpec<'a> {
476481 FunctionSignature :: from_arguments ( arguments)
477482 } ;
478483
479- let convention = if matches ! ( fn_type, FnType :: FnNew | FnType :: FnNewClass ( _ ) ) {
480- CallingConvention :: TpNew
481- } else {
482- CallingConvention :: from_signature ( & signature)
484+ let convention = match fn_type {
485+ FnType :: FnNew | FnType :: FnNewClass ( _ ) => CallingConvention :: TpNew ,
486+ FnType :: FnInit ( _ ) => CallingConvention :: TpInit ,
487+ _ => CallingConvention :: from_signature ( & signature) ,
483488 } ;
484489
485490 Ok ( FnSpec {
@@ -524,11 +529,14 @@ impl<'a> FnSpec<'a> {
524529 . map ( |stripped| syn:: Ident :: new ( stripped, name. span ( ) ) )
525530 } ;
526531
527- let mut set_name_to_new = || {
528- if let Some ( name ) = & python_name {
529- bail_spanned ! ( name . span( ) => "`name` not allowed with `#[new ]`" ) ;
532+ let mut set_fn_name = |name | {
533+ if let Some ( ident ) = python_name {
534+ bail_spanned ! ( ident . span( ) => format! ( "`name` not allowed with `#[{name} ]`" ) ) ;
530535 }
531- * python_name = Some ( syn:: Ident :: new ( "__new__" , Span :: call_site ( ) ) ) ;
536+ * python_name = Some ( syn:: Ident :: new (
537+ format ! ( "__{name}__" ) . as_str ( ) ,
538+ Span :: call_site ( ) ,
539+ ) ) ;
532540 Ok ( ( ) )
533541 } ;
534542
@@ -539,14 +547,18 @@ impl<'a> FnSpec<'a> {
539547 [ MethodTypeAttribute :: StaticMethod ( _) ] => FnType :: FnStatic ,
540548 [ MethodTypeAttribute :: ClassAttribute ( _) ] => FnType :: ClassAttribute ,
541549 [ MethodTypeAttribute :: New ( _) ] => {
542- set_name_to_new ( ) ?;
550+ set_fn_name ( "new" ) ?;
543551 FnType :: FnNew
544552 }
545553 [ MethodTypeAttribute :: New ( _) , MethodTypeAttribute :: ClassMethod ( span) ]
546554 | [ MethodTypeAttribute :: ClassMethod ( span) , MethodTypeAttribute :: New ( _) ] => {
547- set_name_to_new ( ) ?;
555+ set_fn_name ( "new" ) ?;
548556 FnType :: FnNewClass ( * span)
549557 }
558+ [ MethodTypeAttribute :: Init ( _) ] => {
559+ set_fn_name ( "init" ) ?;
560+ FnType :: FnInit ( parse_receiver ( "expected receiver for `#[init]`" ) ?)
561+ }
550562 [ MethodTypeAttribute :: ClassMethod ( _) ] => {
551563 // Add a helpful hint if the classmethod doesn't look like a classmethod
552564 let span = match sig. inputs . first ( ) {
@@ -830,7 +842,6 @@ impl<'a> FnSpec<'a> {
830842 _kwargs: * mut #pyo3_path:: ffi:: PyObject
831843 ) -> #pyo3_path:: PyResult <* mut #pyo3_path:: ffi:: PyObject > {
832844 use #pyo3_path:: impl_:: callback:: IntoPyCallbackOutput ;
833- let function = #rust_name; // Shadow the function name to avoid #3017
834845 #arg_convert
835846 #init_holders
836847 let result = #call;
@@ -839,6 +850,29 @@ impl<'a> FnSpec<'a> {
839850 }
840851 }
841852 }
853+ CallingConvention :: TpInit => {
854+ let mut holders = Holders :: new ( ) ;
855+ let ( arg_convert, args) = impl_arg_params ( self , cls, false , & mut holders, ctx) ;
856+ let self_arg = self
857+ . tp
858+ . self_arg ( cls, ExtractErrorMode :: Raise , & mut holders, ctx) ;
859+ let call = quote_spanned ! { * output_span=> #rust_name( #self_arg #( #args) , * ) } ;
860+ let init_holders = holders. init_holders ( ctx) ;
861+ quote ! {
862+ unsafe fn #ident(
863+ py: #pyo3_path:: Python <' _>,
864+ _slf: * mut #pyo3_path:: ffi:: PyObject ,
865+ _args: * mut #pyo3_path:: ffi:: PyObject ,
866+ _kwargs: * mut #pyo3_path:: ffi:: PyObject
867+ ) -> #pyo3_path:: PyResult <:: std:: os:: raw:: c_int> {
868+ use #pyo3_path:: impl_:: callback:: IntoPyCallbackOutput ;
869+ #arg_convert
870+ #init_holders
871+ #call?;
872+ Ok ( 0 )
873+ }
874+ }
875+ }
842876 } )
843877 }
844878
@@ -917,6 +951,7 @@ impl<'a> FnSpec<'a> {
917951 )
918952 } ,
919953 CallingConvention :: TpNew => unreachable ! ( "tp_new cannot get a methoddef" ) ,
954+ CallingConvention :: TpInit => unreachable ! ( "tp_init cannot get a methoddef" ) ,
920955 }
921956 }
922957
@@ -934,7 +969,7 @@ impl<'a> FnSpec<'a> {
934969 let self_argument = match & self . tp {
935970 // Getters / Setters / ClassAttribute are not callables on the Python side
936971 FnType :: Getter ( _) | FnType :: Setter ( _) | FnType :: ClassAttribute => return None ,
937- FnType :: Fn ( _) => Some ( "self" ) ,
972+ FnType :: Fn ( _) | FnType :: FnInit ( _ ) => Some ( "self" ) ,
938973 FnType :: FnModule ( _) => Some ( "module" ) ,
939974 FnType :: FnClass ( _) | FnType :: FnNewClass ( _) => Some ( "cls" ) ,
940975 FnType :: FnStatic | FnType :: FnNew => None ,
@@ -950,6 +985,7 @@ impl<'a> FnSpec<'a> {
950985
951986enum MethodTypeAttribute {
952987 New ( Span ) ,
988+ Init ( Span ) ,
953989 ClassMethod ( Span ) ,
954990 StaticMethod ( Span ) ,
955991 Getter ( Span , Option < Ident > ) ,
@@ -961,6 +997,7 @@ impl MethodTypeAttribute {
961997 fn span ( & self ) -> Span {
962998 match self {
963999 MethodTypeAttribute :: New ( span)
1000+ | MethodTypeAttribute :: Init ( span)
9641001 | MethodTypeAttribute :: ClassMethod ( span)
9651002 | MethodTypeAttribute :: StaticMethod ( span)
9661003 | MethodTypeAttribute :: Getter ( span, _)
@@ -1018,6 +1055,9 @@ impl MethodTypeAttribute {
10181055 if path. is_ident ( "new" ) {
10191056 ensure_no_arguments ( meta, "new" ) ?;
10201057 Ok ( Some ( MethodTypeAttribute :: New ( path. span ( ) ) ) )
1058+ } else if path. is_ident ( "init" ) {
1059+ ensure_no_arguments ( meta, "init" ) ?;
1060+ Ok ( Some ( MethodTypeAttribute :: Init ( path. span ( ) ) ) )
10211061 } else if path. is_ident ( "classmethod" ) {
10221062 ensure_no_arguments ( meta, "classmethod" ) ?;
10231063 Ok ( Some ( MethodTypeAttribute :: ClassMethod ( path. span ( ) ) ) )
@@ -1043,6 +1083,7 @@ impl Display for MethodTypeAttribute {
10431083 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
10441084 match self {
10451085 MethodTypeAttribute :: New ( _) => "#[new]" . fmt ( f) ,
1086+ MethodTypeAttribute :: Init ( _) => "#[init]" . fmt ( f) ,
10461087 MethodTypeAttribute :: ClassMethod ( _) => "#[classmethod]" . fmt ( f) ,
10471088 MethodTypeAttribute :: StaticMethod ( _) => "#[staticmethod]" . fmt ( f) ,
10481089 MethodTypeAttribute :: Getter ( _, _) => "#[getter]" . fmt ( f) ,
0 commit comments