Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions pyrefly/lib/alt/class/class_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1468,6 +1468,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
..
} => {
let direct_annotation = annot.map(|a| self.get_idx(a).annotation.clone());
let mut annotation_flags = annot.and_then(|annot| {
self.extract_pydantic_field_from_annotation(annot, &metadata)
});
if metadata.is_protocol()
&& direct_annotation.is_none()
&& !is_dunder(name.as_str())
Expand All @@ -1486,8 +1489,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
&& let Some(dm) = metadata.dataclass_metadata()
&& let Expr::Call(call) = e
{
let flags = self.compute_dataclass_field_initialization(call, dm);
ClassFieldInitialization::ClassBody(flags.map(Box::new))
if let Some(mut flags) = self.compute_dataclass_field_initialization(call, dm) {
if let Some(annotation_flags) = annotation_flags.as_ref()
&& flags.strict.is_none()
{
flags.strict = annotation_flags.strict;
}
ClassFieldInitialization::ClassBody(Some(Box::new(flags)))
} else if let Some(mut flags) = annotation_flags.take() {
flags.default = Some(self.heap.mk_any_implicit());
ClassFieldInitialization::ClassBody(Some(Box::new(flags)))
} else {
ClassFieldInitialization::ClassBody(None)
}
} else if let Some(mut flags) = annotation_flags.take() {
flags.default = Some(self.heap.mk_any_implicit());
ClassFieldInitialization::ClassBody(Some(Box::new(flags)))
} else {
ClassFieldInitialization::ClassBody(None)
};
Expand Down
34 changes: 32 additions & 2 deletions pyrefly/lib/alt/class/pydantic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,18 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

pub fn is_pydantic_strict_type_alias(&self, ty: &Type) -> bool {
if let Type::TypeAlias(ta) = ty {
let alias = self.get_type_alias(ta);
if let Type::Annotated(_, metadata) = alias.as_type() {
return metadata
.iter()
.any(|metadata| self.is_pydantic_strict_metadata(metadata));
}
}
false
}

/// Helper function to find inherited keyword values from parent pydantic model metadata.
/// Only inherits from parents that are themselves pydantic models, not from arbitrary
/// dataclass parents whose config values (e.g. strict) may have different defaults.
Expand Down Expand Up @@ -563,6 +575,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
return None;
}
if let BindingAnnotation::AnnotateExpr(_, annotation_expr, _) = self.bindings().get(annot) {
let mut keywords = None;
let metadata_items = self.get_annotated_metadata(
annotation_expr,
TypeFormContext::ClassVarAnnotation,
Expand All @@ -571,11 +584,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
// Look through metadata items and find a Field(...) call, then extract its keywords
for metadata_item in &metadata_items {
if let Expr::Call(call) = metadata_item
&& let Some(keywords) = self.compute_dataclass_field_initialization(call, dm)
&& let Some(field_keywords) =
self.compute_dataclass_field_initialization(call, dm)
{
return Some(keywords);
keywords = Some(field_keywords);
break;
}
}
let errors = self.error_swallower();
let has_strict_metadata = metadata_items.iter().any(|metadata| {
self.is_pydantic_strict_metadata(&self.expr_infer(metadata, &errors))
}) || self
.is_pydantic_strict_type_alias(&self.expr_infer(annotation_expr, &errors));
if has_strict_metadata
&& keywords
.as_ref()
.is_none_or(|keywords| keywords.strict.is_none())
{
keywords
.get_or_insert_with(DataclassFieldKeywords::new)
.strict = Some(true);
}
return keywords;
}
None
}
Expand Down
14 changes: 1 addition & 13 deletions pyrefly/lib/alt/types/class_bases.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
}
}

fn is_type_alias_with_pydantic_strict_metadata(&self, ty: &Type) -> bool {
if let Type::TypeAlias(ta) = ty {
let alias = self.get_type_alias(ta);
if let Type::Annotated(_, metadata) = alias.as_type() {
return metadata
.iter()
.any(|metadata| self.is_pydantic_strict_metadata(metadata));
}
}
false
}

/// Get the untyped form (in other words, the instance type, after applying
/// any type arguments) for a base class.
///
Expand All @@ -229,7 +217,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> {
let range = base_expr.range();
let (inferred_ty, has_strict_from_infer) = self.base_class_expr_infer(base_expr, errors);
let has_pydantic_strict_metadata =
self.is_type_alias_with_pydantic_strict_metadata(&inferred_ty) || has_strict_from_infer;
self.is_pydantic_strict_type_alias(&inferred_ty) || has_strict_from_infer;
let ty = self.untype(inferred_ty, range, errors);
(
self.validate_type_form(ty, range, type_form_context, errors),
Expand Down
16 changes: 16 additions & 0 deletions pyrefly/lib/test/pydantic/strictness.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ Model(x='0', y='1') # E: Argument `Literal['1']` is not assignable to parameter
"#,
);

pydantic_testcase!(
test_strict_annotated_types,
r#"
from typing import Annotated
from pydantic import BaseModel, Strict, StrictInt

class Model(BaseModel):
x: StrictInt = 1
y: Annotated[int, Strict()]

Model(x=1, y=1)
Model(x='1', y=1) # E: Argument `Literal['1']` is not assignable to parameter `x`
Model(x=1, y='1') # E: Argument `Literal['1']` is not assignable to parameter `y`
"#,
);

pydantic_testcase!(
test_class_keyword,
r#"
Expand Down
Loading