@@ -4,7 +4,7 @@ use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk};
44
55use crate :: {
66 DbIndex , FlowId , FlowNode , FlowTree , InferFailReason , LuaDeclId , LuaFunctionType ,
7- LuaInferCache , LuaType , TypeOps , infer_expr, instantiate_func_generic,
7+ LuaInferCache , LuaSignature , LuaType , TypeOps , infer_expr, instantiate_func_generic,
88 semantic:: infer:: {
99 VarRefId ,
1010 narrow:: { get_single_antecedent, get_type_at_flow:: get_type_at_flow} ,
@@ -23,25 +23,33 @@ struct SearchRootCorrelatedTypes {
2323 deferred_known_call_target_types : Option < Vec < LuaType > > ,
2424}
2525
26+ #[ derive( Debug ) ]
27+ struct CollectedCorrelatedTypes {
28+ matching_target_types : Vec < LuaType > ,
29+ correlated_candidate_types : Vec < LuaType > ,
30+ unmatched_target_types : Vec < LuaType > ,
31+ has_unmatched_discriminant_origin : bool ,
32+ has_opaque_target_origin : bool ,
33+ }
34+
2635impl CorrelatedConditionNarrowing {
2736 pub ( in crate :: semantic:: infer:: narrow) fn apply (
28- self ,
37+ & self ,
2938 db : & DbIndex ,
3039 antecedent_type : LuaType ,
3140 ) -> LuaType {
3241 let mut root_target_types = Vec :: new ( ) ;
3342 let mut found_matching_root = false ;
34- for root_types in self . search_root_correlated_types {
35- let SearchRootCorrelatedTypes {
36- matching_target_types,
37- mut uncorrelated_target_types,
38- deferred_known_call_target_types,
39- } = root_types;
43+ for root_types in & self . search_root_correlated_types {
44+ let matching_target_types = & root_types. matching_target_types ;
45+ let mut uncorrelated_target_types = root_types. uncorrelated_target_types . clone ( ) ;
46+ let deferred_known_call_target_types =
47+ root_types. deferred_known_call_target_types . as_deref ( ) ;
4048
4149 let root_matching_target_type = if matching_target_types. is_empty ( ) {
4250 None
4351 } else {
44- let matching_target_type = LuaType :: from_vec ( matching_target_types) ;
52+ let matching_target_type = LuaType :: from_vec ( matching_target_types. clone ( ) ) ;
4553 let narrowed_correlated_type =
4654 TypeOps :: Intersect . apply ( db, & antecedent_type, & matching_target_type) ;
4755 if narrowed_correlated_type. is_never ( ) {
@@ -185,13 +193,13 @@ fn collect_correlated_types_from_search_root(
185193 condition_position,
186194 search_root_flow_id,
187195 ) ;
188- let (
189- root_matching_target_types,
190- root_correlated_candidate_types,
191- root_unmatched_target_types,
196+ let CollectedCorrelatedTypes {
197+ matching_target_types : root_matching_target_types,
198+ correlated_candidate_types : root_correlated_candidate_types,
199+ unmatched_target_types : root_unmatched_target_types,
192200 has_unmatched_discriminant_origin,
193201 has_opaque_target_origin,
194- ) = collect_matching_correlated_types (
202+ } = collect_matching_correlated_types (
195203 db,
196204 cache,
197205 root,
@@ -277,7 +285,7 @@ fn collect_matching_correlated_types(
277285 discriminant_refs : & [ crate :: DeclMultiReturnRef ] ,
278286 target_refs : & [ crate :: DeclMultiReturnRef ] ,
279287 narrowed_discriminant_type : & LuaType ,
280- ) -> Result < ( Vec < LuaType > , Vec < LuaType > , Vec < LuaType > , bool , bool ) , InferFailReason > {
288+ ) -> Result < CollectedCorrelatedTypes , InferFailReason > {
281289 let mut matching_target_types = Vec :: new ( ) ;
282290 let mut correlated_candidate_types = Vec :: new ( ) ;
283291 let mut unmatched_target_types = Vec :: new ( ) ;
@@ -304,18 +312,16 @@ fn collect_matching_correlated_types(
304312 correlated_discriminant_call_expr_ids. insert ( discriminant_call_expr_id) ;
305313 correlated_target_call_expr_ids. insert ( target_ref. call_expr . get_syntax_id ( ) ) ;
306314 correlated_candidate_types. extend ( overload_rows. iter ( ) . map ( |overload| {
307- crate :: LuaSignature :: get_overload_row_slot ( overload, target_ref. return_index )
315+ LuaSignature :: get_overload_row_slot ( overload, target_ref. return_index )
308316 } ) ) ;
309317 matching_target_types. extend ( overload_rows. iter ( ) . filter_map ( |overload| {
310- let discriminant_type = crate :: LuaSignature :: get_overload_row_slot (
311- overload,
312- discriminant_ref. return_index ,
313- ) ;
318+ let discriminant_type =
319+ LuaSignature :: get_overload_row_slot ( overload, discriminant_ref. return_index ) ;
314320 if !TypeOps :: Intersect
315321 . apply ( db, & discriminant_type, narrowed_discriminant_type)
316322 . is_never ( )
317323 {
318- return Some ( crate :: LuaSignature :: get_overload_row_slot (
324+ return Some ( LuaSignature :: get_overload_row_slot (
319325 overload,
320326 target_ref. return_index ,
321327 ) ) ;
@@ -340,30 +346,30 @@ fn collect_matching_correlated_types(
340346 } ;
341347 let return_rows = instantiate_return_rows ( db, cache, call_expr, signature) ;
342348 unmatched_target_types. extend (
343- return_rows. iter ( ) . map ( |row| {
344- crate :: LuaSignature :: get_overload_row_slot ( row , target_ref . return_index )
345- } ) ,
349+ return_rows
350+ . iter ( )
351+ . map ( |row| LuaSignature :: get_overload_row_slot ( row , target_ref . return_index ) ) ,
346352 ) ;
347353 }
348354
349355 let has_unmatched_discriminant_origin = discriminant_refs. iter ( ) . any ( |discriminant_ref| {
350356 !correlated_discriminant_call_expr_ids. contains ( & discriminant_ref. call_expr . get_syntax_id ( ) )
351357 } ) ;
352- Ok ( (
358+ Ok ( CollectedCorrelatedTypes {
353359 matching_target_types,
354360 correlated_candidate_types,
355361 unmatched_target_types,
356362 has_unmatched_discriminant_origin,
357363 has_opaque_target_origin,
358- ) )
364+ } )
359365}
360366
361367fn infer_signature_for_call_ptr < ' a > (
362368 db : & ' a DbIndex ,
363369 cache : & mut LuaInferCache ,
364370 root : & LuaChunk ,
365371 call_expr_ptr : & LuaAstPtr < LuaCallExpr > ,
366- ) -> Result < Option < ( LuaCallExpr , & ' a crate :: LuaSignature ) > , InferFailReason > {
372+ ) -> Result < Option < ( LuaCallExpr , & ' a LuaSignature ) > , InferFailReason > {
367373 let Some ( call_expr) = call_expr_ptr. to_node ( root) else {
368374 return Ok ( None ) ;
369375 } ;
@@ -385,7 +391,7 @@ fn instantiate_return_rows(
385391 db : & DbIndex ,
386392 cache : & mut LuaInferCache ,
387393 call_expr : LuaCallExpr ,
388- signature : & crate :: LuaSignature ,
394+ signature : & LuaSignature ,
389395) -> Vec < Vec < LuaType > > {
390396 if signature. return_overloads . is_empty ( ) {
391397 let return_type = signature. get_return_type ( ) ;
@@ -404,15 +410,13 @@ fn instantiate_return_rows(
404410 } else {
405411 return_type
406412 } ;
407- return vec ! [ crate :: LuaSignature :: return_type_to_row(
408- instantiated_return_type,
409- ) ] ;
413+ return vec ! [ LuaSignature :: return_type_to_row( instantiated_return_type) ] ;
410414 }
411415
412416 let mut rows = Vec :: with_capacity ( signature. return_overloads . len ( ) ) ;
413417 for overload in & signature. return_overloads {
414418 let type_refs = & overload. type_refs ;
415- let overload_return_type = crate :: LuaSignature :: row_to_return_type ( type_refs. to_vec ( ) ) ;
419+ let overload_return_type = LuaSignature :: row_to_return_type ( type_refs. to_vec ( ) ) ;
416420 let instantiated_return_type = if overload_return_type. contain_tpl ( ) {
417421 let overload_func = LuaFunctionType :: new (
418422 signature. async_state ,
@@ -429,9 +433,7 @@ fn instantiate_return_rows(
429433 overload_return_type
430434 } ;
431435
432- rows. push ( crate :: LuaSignature :: return_type_to_row (
433- instantiated_return_type,
434- ) ) ;
436+ rows. push ( LuaSignature :: return_type_to_row ( instantiated_return_type) ) ;
435437 }
436438
437439 rows
0 commit comments