diff --git a/pyrefly/lib/alt/class/class_field.rs b/pyrefly/lib/alt/class/class_field.rs index 045f72959b..bf1741c107 100644 --- a/pyrefly/lib/alt/class/class_field.rs +++ b/pyrefly/lib/alt/class/class_field.rs @@ -1468,6 +1468,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .. } => { let direct_annotation = annot.map(|a| self.get_idx(a).annotation.clone()); + let mut annotation_flags = annot.and_then(|annot| { + self.extract_pydantic_field_from_annotation(annot, &metadata) + }); if metadata.is_protocol() && direct_annotation.is_none() && !is_dunder(name.as_str()) @@ -1486,8 +1489,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { && let Some(dm) = metadata.dataclass_metadata() && let Expr::Call(call) = e { - let flags = self.compute_dataclass_field_initialization(call, dm); - ClassFieldInitialization::ClassBody(flags.map(Box::new)) + if let Some(mut flags) = self.compute_dataclass_field_initialization(call, dm) { + if let Some(annotation_flags) = annotation_flags.as_ref() + && flags.strict.is_none() + { + flags.strict = annotation_flags.strict; + } + ClassFieldInitialization::ClassBody(Some(Box::new(flags))) + } else if let Some(mut flags) = annotation_flags.take() { + flags.default = Some(self.heap.mk_any_implicit()); + ClassFieldInitialization::ClassBody(Some(Box::new(flags))) + } else { + ClassFieldInitialization::ClassBody(None) + } + } else if let Some(mut flags) = annotation_flags.take() { + flags.default = Some(self.heap.mk_any_implicit()); + ClassFieldInitialization::ClassBody(Some(Box::new(flags))) } else { ClassFieldInitialization::ClassBody(None) }; diff --git a/pyrefly/lib/alt/class/pydantic.rs b/pyrefly/lib/alt/class/pydantic.rs index 786602b135..205361201c 100644 --- a/pyrefly/lib/alt/class/pydantic.rs +++ b/pyrefly/lib/alt/class/pydantic.rs @@ -181,6 +181,18 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } + pub fn is_pydantic_strict_type_alias(&self, ty: &Type) -> bool { + if let Type::TypeAlias(ta) = ty { + let alias = self.get_type_alias(ta); + if let Type::Annotated(_, metadata) = alias.as_type() { + return metadata + .iter() + .any(|metadata| self.is_pydantic_strict_metadata(metadata)); + } + } + false + } + /// Helper function to find inherited keyword values from parent pydantic model metadata. /// Only inherits from parents that are themselves pydantic models, not from arbitrary /// dataclass parents whose config values (e.g. strict) may have different defaults. @@ -563,6 +575,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { return None; } if let BindingAnnotation::AnnotateExpr(_, annotation_expr, _) = self.bindings().get(annot) { + let mut keywords = None; let metadata_items = self.get_annotated_metadata( annotation_expr, TypeFormContext::ClassVarAnnotation, @@ -571,11 +584,28 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { // Look through metadata items and find a Field(...) call, then extract its keywords for metadata_item in &metadata_items { if let Expr::Call(call) = metadata_item - && let Some(keywords) = self.compute_dataclass_field_initialization(call, dm) + && let Some(field_keywords) = + self.compute_dataclass_field_initialization(call, dm) { - return Some(keywords); + keywords = Some(field_keywords); + break; } } + let errors = self.error_swallower(); + let has_strict_metadata = metadata_items.iter().any(|metadata| { + self.is_pydantic_strict_metadata(&self.expr_infer(metadata, &errors)) + }) || self + .is_pydantic_strict_type_alias(&self.expr_infer(annotation_expr, &errors)); + if has_strict_metadata + && keywords + .as_ref() + .is_none_or(|keywords| keywords.strict.is_none()) + { + keywords + .get_or_insert_with(DataclassFieldKeywords::new) + .strict = Some(true); + } + return keywords; } None } diff --git a/pyrefly/lib/alt/types/class_bases.rs b/pyrefly/lib/alt/types/class_bases.rs index 22f4dffbd9..c5530be88f 100644 --- a/pyrefly/lib/alt/types/class_bases.rs +++ b/pyrefly/lib/alt/types/class_bases.rs @@ -203,18 +203,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } - fn is_type_alias_with_pydantic_strict_metadata(&self, ty: &Type) -> bool { - if let Type::TypeAlias(ta) = ty { - let alias = self.get_type_alias(ta); - if let Type::Annotated(_, metadata) = alias.as_type() { - return metadata - .iter() - .any(|metadata| self.is_pydantic_strict_metadata(metadata)); - } - } - false - } - /// Get the untyped form (in other words, the instance type, after applying /// any type arguments) for a base class. /// @@ -229,7 +217,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let range = base_expr.range(); let (inferred_ty, has_strict_from_infer) = self.base_class_expr_infer(base_expr, errors); let has_pydantic_strict_metadata = - self.is_type_alias_with_pydantic_strict_metadata(&inferred_ty) || has_strict_from_infer; + self.is_pydantic_strict_type_alias(&inferred_ty) || has_strict_from_infer; let ty = self.untype(inferred_ty, range, errors); ( self.validate_type_form(ty, range, type_form_context, errors), diff --git a/pyrefly/lib/test/pydantic/strictness.rs b/pyrefly/lib/test/pydantic/strictness.rs index e2d954af58..58ee3861bb 100644 --- a/pyrefly/lib/test/pydantic/strictness.rs +++ b/pyrefly/lib/test/pydantic/strictness.rs @@ -32,6 +32,22 @@ Model(x='0', y='1') # E: Argument `Literal['1']` is not assignable to parameter "#, ); +pydantic_testcase!( + test_strict_annotated_types, + r#" +from typing import Annotated +from pydantic import BaseModel, Strict, StrictInt + +class Model(BaseModel): + x: StrictInt = 1 + y: Annotated[int, Strict()] + +Model(x=1, y=1) +Model(x='1', y=1) # E: Argument `Literal['1']` is not assignable to parameter `x` +Model(x=1, y='1') # E: Argument `Literal['1']` is not assignable to parameter `y` + "#, +); + pydantic_testcase!( test_class_keyword, r#"