Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/5331.added.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Introspection: emit base classes.
15 changes: 12 additions & 3 deletions pyo3-introspection/src/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,12 @@ fn convert_members<'a>(
Chunk::Class {
name,
id,
bases,
decorators,
} => classes.push(convert_class(
id,
name,
bases,
decorators,
chunks_by_id,
chunks_by_parent,
Expand Down Expand Up @@ -186,6 +188,7 @@ fn convert_members<'a>(
fn convert_class(
id: &str,
name: &str,
bases: &[ChunkTypeHint],
decorators: &[ChunkTypeHint],
chunks_by_id: &HashMap<&str, &Chunk>,
chunks_by_parent: &HashMap<&str, Vec<&Chunk>>,
Expand All @@ -205,16 +208,20 @@ fn convert_class(
);
Ok(Class {
name: name.into(),
bases: bases
.iter()
.map(convert_python_identifier)
.collect::<Result<_>>()?,
methods,
attributes,
decorators: decorators
.iter()
.map(convert_decorator)
.map(convert_python_identifier)
.collect::<Result<_>>()?,
})
}

fn convert_decorator(decorator: &ChunkTypeHint) -> Result<PythonIdentifier> {
fn convert_python_identifier(decorator: &ChunkTypeHint) -> Result<PythonIdentifier> {
match convert_type_hint(decorator) {
TypeHint::Plain(id) => Ok(PythonIdentifier {
module: None,
Expand All @@ -240,7 +247,7 @@ fn convert_function(
name: name.into(),
decorators: decorators
.iter()
.map(convert_decorator)
.map(convert_python_identifier)
.collect::<Result<_>>()?,
arguments: Arguments {
positional_only_arguments: arguments.posonlyargs.iter().map(convert_argument).collect(),
Expand Down Expand Up @@ -462,6 +469,8 @@ enum Chunk {
id: String,
name: String,
#[serde(default)]
bases: Vec<ChunkTypeHint>,
#[serde(default)]
decorators: Vec<ChunkTypeHint>,
},
Function {
Expand Down
1 change: 1 addition & 0 deletions pyo3-introspection/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct Module {
#[derive(Debug, Eq, PartialEq, Clone, Hash)]
pub struct Class {
pub name: String,
pub bases: Vec<PythonIdentifier>,
pub methods: Vec<Function>,
pub attributes: Vec<Attribute>,
/// decorator like 'typing.final'
Expand Down
17 changes: 17 additions & 0 deletions pyo3-introspection/src/stubs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ fn class_stubs(class: &Class, imports: &Imports) -> String {
}
buffer.push_str("class ");
buffer.push_str(&class.name);
if !class.bases.is_empty() {
buffer.push('(');
for (i, base) in class.bases.iter().enumerate() {
if i > 0 {
buffer.push_str(", ");
}
imports.serialize_identifier(base, &mut buffer);
}
buffer.push(')');
}
buffer.push(':');
if class.methods.is_empty() && class.attributes.is_empty() {
buffer.push_str(" ...");
Expand Down Expand Up @@ -441,6 +451,9 @@ impl ElementsUsedInAnnotations {
}

fn walk_class(&mut self, class: &Class) {
for base in &class.bases {
self.walk_identifier(base);
}
for decorator in &class.decorators {
self.walk_identifier(decorator);
}
Expand Down Expand Up @@ -667,6 +680,10 @@ mod tests {
modules: Vec::new(),
classes: vec![Class {
name: "A".into(),
bases: vec![PythonIdentifier {
module: Some("builtins".into()),
name: "dict".into(),
}],
methods: Vec::new(),
attributes: Vec::new(),
decorators: vec![PythonIdentifier {
Expand Down
15 changes: 14 additions & 1 deletion pyo3-macros-backend/src/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use std::hash::{Hash, Hasher};
use std::mem::take;
use std::sync::atomic::{AtomicUsize, Ordering};
use syn::visit_mut::{visit_type_mut, VisitMut};
use syn::{Attribute, Ident, Lifetime, ReturnType, Type, TypePath};
use syn::{Attribute, Ident, Lifetime, Path, ReturnType, Type, TypePath};

static GLOBAL_COUNTER_FOR_UNIQUE_NAMES: AtomicUsize = AtomicUsize::new(0);

Expand Down Expand Up @@ -89,6 +89,7 @@ pub fn class_introspection_code(
pyo3_crate_path: &PyO3CratePath,
ident: &Ident,
name: &str,
extends: Option<&Path>,
is_final: bool,
) -> TokenStream {
let mut desc = HashMap::from([
Expand All @@ -99,6 +100,12 @@ pub fn class_introspection_code(
),
("name", IntrospectionNode::String(name.into())),
]);
if let Some(extends) = extends {
desc.insert(
"bases",
IntrospectionNode::List(vec![IntrospectionNode::BaseType(extends).into()]),
);
}
if is_final {
desc.insert(
"decorators",
Expand Down Expand Up @@ -355,6 +362,7 @@ enum IntrospectionNode<'a> {
IntrospectionId(Option<Cow<'a, Type>>),
InputType(Type),
OutputType { rust_type: Type, is_final: bool },
BaseType(&'a Path),
ConstantType(PythonIdentifier),
Map(HashMap<&'static str, IntrospectionNode<'a>>),
List(Vec<AttributedIntrospectionNode<'a>>),
Expand Down Expand Up @@ -411,6 +419,11 @@ impl IntrospectionNode<'_> {
}
content.push_tokens(serialize_type_hint(annotation, pyo3_crate_path));
}
Self::BaseType(rust_type) => {
let annotation =
quote! { <#rust_type as #pyo3_crate_path::type_object::PyTypeInfo>::TYPE_HINT };
content.push_tokens(serialize_type_hint(annotation, pyo3_crate_path));
}
Self::ConstantType(hint) => {
let name = &hint.name;
let annotation = if let Some(module) = &hint.module {
Expand Down
3 changes: 2 additions & 1 deletion pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ impl FieldPyO3Options {
}
}

fn get_class_python_name<'a>(cls: &'a syn::Ident, args: &'a PyClassArgs) -> Cow<'a, syn::Ident> {
fn get_class_python_name<'a>(cls: &'a Ident, args: &'a PyClassArgs) -> Cow<'a, Ident> {
args.options
.name
.as_ref()
Expand Down Expand Up @@ -2687,6 +2687,7 @@ impl<'a> PyClassImplsBuilder<'a> {
pyo3_path,
ident,
&name,
self.attr.options.extends.as_ref().map(|attr| &attr.value),
self.attr.options.subclass.is_none(),
);
let introspection_id = introspection_id_const();
Expand Down
34 changes: 34 additions & 0 deletions pytests/src/subclassing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use pyo3::prelude::*;
#[pymodule(gil_used = false)]
pub mod subclassing {
use pyo3::prelude::*;
#[cfg(not(Py_LIMITED_API))]
Comment thread
Tpt marked this conversation as resolved.
Outdated
use pyo3::types::PyDict;

#[pyclass(subclass)]
pub struct Subclassable {}
Expand All @@ -20,4 +22,36 @@ pub mod subclassing {
"Subclassable"
}
}

#[pyclass(extends = Subclassable)]
pub struct Subclass {}

#[pymethods]
impl Subclass {
#[new]
fn new() -> (Self, Subclassable) {
(Subclass {}, Subclassable::new())
}

fn __str__(&self) -> &'static str {
"Subclass"
}
}

#[cfg(not(Py_LIMITED_API))]
Comment thread
Tpt marked this conversation as resolved.
Outdated
#[pyclass(extends = PyDict)]
pub struct SubDict {}

#[cfg(not(Py_LIMITED_API))]
Comment thread
Tpt marked this conversation as resolved.
Outdated
#[pymethods]
impl SubDict {
#[new]
fn new() -> Self {
Self {}
}

fn __str__(&self) -> &'static str {
"SubDict"
}
}
}
12 changes: 12 additions & 0 deletions pytests/stubs/subclassing.pyi
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
from typing import final

@final
class SubDict(dict):
def __new__(cls, /) -> SubDict: ...
def __str__(self, /) -> str: ...

@final
class Subclass(Subclassable):
def __new__(cls, /) -> Subclass: ...
def __str__(self, /) -> str: ...

class Subclassable:
def __new__(cls, /) -> Subclassable: ...
def __str__(self, /) -> str: ...
Comment thread
davidhewitt marked this conversation as resolved.
10 changes: 8 additions & 2 deletions pytests/tests/test_subclassing.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from pyo3_pytests.subclassing import Subclassable
from pyo3_pytests.subclassing import Subclassable, Subclass


class SomeSubClass(Subclassable):
def __str__(self):
return "SomeSubclass"


def test_subclassing():
def test_python_subclassing():
a = SomeSubClass()
assert str(a) == "SomeSubclass"
assert type(a) is SomeSubClass


def test_rust_subclassing():
a = Subclass()
assert str(a) == "Subclass"
assert type(a) is Subclass
Loading