From 856417d84e3687fa2ba2eed71b4007926ff10fdd Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sat, 21 Mar 2026 12:22:34 +0000 Subject: [PATCH 1/5] test(flow): characterize loop post-flow boundaries --- .../src/compilation/test/flow.rs | 105 ++++++++++++++++++ 1 file changed, 105 insertions(+) diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 32228bae1..d45c1acc9 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -627,6 +627,111 @@ end assert_eq!(b, LuaType::String); } + #[test] + fn test_while_loop_post_flow_keeps_incoming_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local condition ---@type boolean + local value ---@type string? + + while condition do + value = "loop" + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + + #[test] + fn test_repeat_loop_post_flow_keeps_body_assignment() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local condition ---@type boolean + local value ---@type string? + + repeat + value = "loop" + until condition + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string")); + } + + #[test] + fn test_numeric_for_loop_post_flow_keeps_incoming_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local value ---@type string? + + for i = 1, 3 do + value = "loop" + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + + #[test] + fn test_for_in_loop_post_flow_keeps_incoming_type_after_break() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + local value ---@type string? + + for _, _value in ipairs({ "loop" }) do + value = "loop" + break + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + + #[test] + fn test_nested_while_loop_post_flow_keeps_incoming_type() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local outer_condition ---@type boolean + local inner_condition ---@type boolean + local value ---@type string? + + while outer_condition do + while inner_condition do + value = "loop" + break + end + + break + end + + after_loop = value + "#, + ); + + assert_eq!(ws.expr_ty("after_loop"), ws.ty("string?")); + } + #[test] fn test_issue_347() { let mut ws = VirtualWorkspace::new(); From 2be6c506d12a3e412e7c221e1b97c27f97268d42 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Sun, 22 Mar 2026 00:26:49 +0000 Subject: [PATCH 2/5] fix(flow): apply deferred narrows after antecedent resolution When walking backward through flow, collect condition narrows as pending actions instead of applying them while recursively querying antecedent types. The eager path mixed narrowing with antecedent resolution, so stacked guards could re-enter flow inference from inside condition evaluation and build deep recursive chains across repeated truthiness, type-guard, signature-cast, and correlated return-overload checks. Resolve the antecedent type first, then apply the pending narrows in reverse order. That keeps narrowing as a post-pass over an already-resolved input, avoids re-entering the same condition chain while answering the current flow query, and lets same-variable self/member guards wait until earlier guards have narrowed the receiver enough for reliable lookup. Key the flow cache by whether condition narrowing is enabled, and separate assignment source lookup from condition application. Reuse a narrowed source only when the RHS preserves that precision; broader RHS expressions fall back to the antecedent type with condition narrowing disabled so reassignment clears stale branch narrows, while exact literals and compatible partial table/object rewrites still preserve useful narrowing. Add regression coverage for stacked guards, correlated overload joins, pcall aliases, and assign/return diagnostics. --- .../src/compilation/test/flow.rs | 837 ++++++++++++++++++ .../src/compilation/test/pcall_test.rs | 34 + .../test/return_overload_flow_test.rs | 608 +++++++++++++ .../test/assign_type_mismatch_test.rs | 60 ++ .../test/return_type_mismatch_test.rs | 152 ++++ .../src/semantic/cache/mod.rs | 2 +- .../narrow/condition_flow/binary_flow.rs | 257 +++--- .../infer/narrow/condition_flow/call_flow.rs | 334 +++---- .../narrow/condition_flow/correlated_flow.rs | 285 ++++-- .../infer/narrow/condition_flow/index_flow.rs | 107 +-- .../infer/narrow/condition_flow/mod.rs | 315 +++++-- .../infer/narrow/get_type_at_cast_flow.rs | 4 +- .../semantic/infer/narrow/get_type_at_flow.rs | 473 ++++++---- .../src/semantic/infer/narrow/mod.rs | 15 +- .../semantic/infer/narrow/narrow_type/mod.rs | 3 + 15 files changed, 2768 insertions(+), 718 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index d45c1acc9..a9003de65 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -2,6 +2,8 @@ mod test { use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; + const STACKED_TYPE_GUARDS: usize = 180; + #[test] fn test_closure_return() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -101,6 +103,540 @@ mod test { )); } + #[test] + fn test_stacked_same_var_type_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if type(value) ~= 'string' then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + local value ---@type string|integer|boolean + + {repeated_guards} + local narrowed ---@type string + narrowed = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-variable type guard repro" + ); + assert!(ws.check_code_for(DiagnosticCode::AssignTypeMismatch, &block)); + } + + #[test] + fn test_stacked_same_var_call_type_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not instance_of(value, 'string') then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local value ---@type string|integer|boolean + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-variable call type guard repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_stacked_same_var_call_type_guard_eq_false_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if instance_of(value, 'string') == false then return end\n" + .repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local value ---@type string|integer|boolean + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked binary call type guard repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + + #[test] + fn test_branch_join_keeps_union_when_only_one_side_narrows() { + let mut ws = VirtualWorkspace::new(); + let block = r#" + local cond ---@type boolean + local value ---@type string|integer + + if cond then + if type(value) ~= 'string' then + return + end + end + + after_join = value + "#; + + let file_id = ws.def(block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for branch join merge-safety repro" + ); + assert_eq!(ws.expr_ty("after_join"), ws.ty("string|integer")); + } + + #[test] + fn test_stacked_same_field_truthiness_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if not value.foo then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class HasFoo + ---@field foo string + + ---@class NoFoo + ---@field bar integer + + local value ---@type HasFoo|NoFoo + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-field truthiness repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("HasFoo")); + } + + #[test] + fn test_stacked_return_cast_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not is_player(creature) then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + + ---@return boolean + ---@return_cast creature Player else Monster + local function is_player(creature) + return true + end + + local creature ---@type Creature + + {repeated_guards} + after_guard = creature + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked return-cast repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_stacked_return_cast_self_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not creature:is_player() then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + local creature = {{}} + + ---@return boolean + ---@return_cast self Player else Monster + function creature:is_player() + return true + end + + {repeated_guards} + after_guard = creature + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked self return-cast repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_pending_replay_order_uses_type_guard_before_self_return_cast_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = {} + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local branch ---@type boolean + local creature = branch and checker or false + + if type(creature) ~= "table" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_pending_replay_order_with_three_guards_before_self_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class PlayerA + + ---@class MonsterA + + ---@class PlayerB + + ---@class MonsterB + + local checker_a = { + kind = "checker_a", + } + + ---@return boolean + ---@return_cast self PlayerA else MonsterA + function checker_a:is_player() + return true + end + + local checker_b = { + kind = "checker_b", + } + + ---@return boolean + ---@return_cast self PlayerB else MonsterB + function checker_b:is_player() + return true + end + + local allow_false ---@type boolean + local choose_a ---@type boolean + local creature = allow_false and false or (choose_a and checker_a or checker_b) + + if type(creature) ~= "table" then + return + end + + if creature.kind ~= "checker_a" then + return + end + + if creature:is_player() == false then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("PlayerA")); + } + + #[test] + fn test_return_cast_self_guard_uses_prior_narrowing_for_method_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + } + + local branch ---@type boolean + local creature = branch and checker or monster + + if creature.kind ~= "checker" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_return_cast_self_guard_without_prior_method_lookup_narrowing_does_not_apply() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + } + + local branch ---@type boolean + local creature = branch and checker or monster + before_guard = creature + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.expr_ty("before_guard")); + } + + #[test] + fn test_return_cast_self_guard_with_multiple_method_candidates_uses_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class PlayerA + + ---@class MonsterA + + ---@class PlayerB + + ---@class MonsterB + + local checker_a = { + kind = "checker_a", + } + + ---@return boolean + ---@return_cast self PlayerA else MonsterA + function checker_a:is_player() + return true + end + + local checker_b = { + kind = "checker_b", + } + + ---@return boolean + ---@return_cast self PlayerB else MonsterB + function checker_b:is_player() + return true + end + + local branch ---@type boolean + local creature = branch and checker_a or checker_b + + if creature.kind ~= "checker_a" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("PlayerA")); + } + + #[test] + fn test_return_cast_self_guard_with_non_callable_member_uses_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + is_player = false, + } + + local branch ---@type boolean + local creature = branch and checker or monster + + if creature.kind ~= "checker" then + return + end + + if not creature:is_player() then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + + #[test] + fn test_return_cast_self_guard_eq_false_uses_prior_narrowing_for_method_lookup() { + let mut ws = VirtualWorkspace::new(); + ws.def( + r#" + ---@class Player + + ---@class Monster + + local checker = { + kind = "checker", + } + + ---@return boolean + ---@return_cast self Player else Monster + function checker:is_player() + return true + end + + local monster = { + kind = "monster", + is_player = false, + } + + local branch ---@type boolean + local creature = branch and checker or monster + + if creature.kind ~= "checker" then + return + end + + if creature:is_player() == false then + return + end + + after_guard = creature + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); + } + #[test] fn test_issue_100() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -1892,4 +2428,305 @@ _2 = a[1] let type_str = ws.humanize_type_detailed(e_ty); assert_eq!(type_str, "(A|B|table)"); } + + #[test] + fn test_assignment_from_wider_single_return_call_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param ok boolean + ---@return Foo|Bar + local function pick(ok) + if ok then + return { kind = "foo", a = 1 } + end + + return { kind = "bar", b = 2 } + end + + local ok ---@type boolean + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = pick(ok) + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo|Bar")); + } + + #[test] + fn test_assignment_after_pending_return_cast_guard_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Creature + + ---@class Player: Creature + + ---@class Monster: Creature + + ---@param creature Creature + ---@return boolean + ---@return_cast creature Player else Monster + local function is_player(creature) + return true + end + + local creature ---@type Creature + local next_creature ---@type Creature + + if not is_player(creature) then + return + end + + before_assign = creature + creature = next_creature + after_assign = creature + "#, + ); + + assert_eq!(ws.expr_ty("before_assign"), ws.ty("Player")); + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Creature")); + } + + #[test] + fn test_assignment_after_binary_call_guard_eq_false_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T + ---@param inst any + ---@param type `T` + ---@return TypeGuard + local function instance_of(inst, type) + return true + end + + local value ---@type string|integer|boolean + local next_value ---@type string|integer|boolean + + if instance_of(value, 'string') == false then + return + end + + before_assign = value + value = next_value + after_assign = value + "#, + ); + + assert_eq!(ws.expr_ty("before_assign"), ws.ty("string")); + assert_eq!(ws.expr_ty("after_assign"), ws.ty("string|integer|boolean")); + } + + #[test] + fn test_assignment_after_mixed_eager_and_pending_guards_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Player + ---@field kind "player" + + ---@class Monster + ---@field kind "monster" + + ---@param creature Player|Monster + ---@return boolean + ---@return_cast creature Player else Monster + local function is_player(creature) + return true + end + + local creature ---@type Player|Monster + local next_creature ---@type Player|Monster + + if creature.kind ~= "player" then + return + end + + if not is_player(creature) then + return + end + + before_assign = creature + creature = next_creature + after_assign = creature + "#, + ); + + assert_eq!(ws.expr_ty("before_assign"), ws.ty("Player")); + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Player|Monster")); + } + + #[test] + fn test_assignment_from_nullable_union_keeps_rhs_members() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string? + local y ---@type number? + + if x then + x = y + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("number?")); + } + + #[test] + fn test_assignment_from_partially_overlapping_union_keeps_rhs_members() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string|number + local y ---@type integer|string + + if x == 1 then + x = y + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("integer|string")); + } + + #[test] + fn test_partial_table_reassignment_preserves_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = {} + x.kind = "foo" + x.a = 1 + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo")); + } + + #[test] + fn test_partial_table_reassignment_with_discriminant_preserves_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = { kind = "foo" } + x.a = 1 + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo")); + } + + #[test] + fn test_exact_string_reassignment_preserves_literal_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type string|number + + if x == 1 then + x = "a" + after_assign = x + end + "#, + ); + + let after_assign = ws.expr_ty("after_assign"); + assert_eq!(ws.humanize_type(after_assign), r#""a""#); + } + + #[test] + fn test_assignment_from_broad_string_drops_literal_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + local x ---@type "a"|boolean + local y ---@type string + + if x == "a" then + x = y + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("string")); + } + + #[test] + fn test_partial_table_reassignment_with_conflicting_discriminant_drops_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + local x ---@type Foo|Bar + + if x.kind == "foo" then + x = { kind = "bar" } + after_assign = x + end + "#, + ); + + assert_eq!(ws.expr_ty("after_assign"), ws.ty("Foo|Bar")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 5cbb83f2b..7700022c1 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -2,6 +2,8 @@ mod test { use crate::{DiagnosticCode, VirtualWorkspace}; + const STACKED_PCALL_ALIAS_GUARDS: usize = 180; + #[test] fn test_issue_263() { let mut ws = VirtualWorkspace::new_with_init_std_lib(); @@ -141,4 +143,36 @@ mod test { assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string")); assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); } + + #[test] + fn test_pcall_stacked_alias_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + let repeated_guards = + "if failed then error(result) end\n".repeat(STACKED_PCALL_ALIAS_GUARDS); + let block = format!( + r#" + ---@return integer + local function foo() + return 1 + end + + local ok, result = pcall(foo) + local failed = ok == false + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked pcall alias guard repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } } diff --git a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs index f0a74b505..e1fc57836 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs @@ -2,6 +2,8 @@ mod test { use crate::{DiagnosticCode, VirtualWorkspace}; + const STACKED_CORRELATED_GUARDS: usize = 180; + #[test] fn test_return_overload_narrow_after_not() { let mut ws = VirtualWorkspace::new(); @@ -541,4 +543,610 @@ mod test { assert!(after_guard.contains("integer")); assert!(!after_guard.contains("string")); } + + #[test] + fn test_return_overload_stacked_same_discriminant_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not ok then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked correlated-guard repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_stacked_eq_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if ok == false then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked correlated-eq repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_stacked_mixed_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if ok == false then error(result) end\nif not ok then error(result) end\n" + .repeat(STACKED_CORRELATED_GUARDS / 2); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "error") + + {repeated_guards} + narrowed = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked mixed correlated-guard repro" + ); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_stacked_noncorrelated_origin_guards_keep_extra_type() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = + "if not ok then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS); + let block = format!( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond, 1, "err") + + if branch then + result = false + end + + {repeated_guards} + after_guard = result + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked noncorrelated correlated-guard repro" + ); + let after_guard_ty = ws.expr_ty("after_guard"); + let after_guard = ws.humanize_type(after_guard_ty); + assert!(after_guard.contains("false")); + assert!(after_guard.contains("integer")); + assert!(!after_guard.contains("string")); + } + + #[test] + fn test_return_overload_uncorrelated_later_guard_keeps_prior_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local ok, result = pick(cond, 1, "err") + + if not ok then + error(result) + end + + ok = cond + + if not ok then + error(result) + end + + narrowed = result + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_unmatched_discriminant_call_keeps_target_wide() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return integer|string + ---@return_overload true, integer + ---@return_overload false, string + local function pick(ok) + if ok then + return true, 1 + end + return false, "err" + end + + ---@param ok boolean + ---@return boolean + local function bounce(ok) + return ok + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond) + + if branch then + ok = bounce(other) + end + + if not ok then + error(result) + end + + after_guard = result + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_unmatched_target_call_keeps_guard_union() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return boolean|table + ---@return_overload true, boolean + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, true + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("boolean|table|integer")); + } + + #[test] + fn test_return_overload_unmatched_target_root_then_truthiness_guard() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return false|table + ---@return_overload true, false + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, false + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + after_guard = result + + if result then + truthy = result + end + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("false|table|integer")); + assert_eq!(ws.expr_ty("truthy"), ws.ty("table|integer")); + } + + #[test] + fn test_return_overload_unmatched_target_root_then_type_guard() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return false|table + ---@return_overload true, false + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, false + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + after_guard = result + + if type(result) == "table" then + table_result = result + end + end + "#, + ); + + assert_eq!(ws.expr_ty("after_guard"), ws.ty("false|table|integer")); + let table_result = ws.expr_ty("table_result"); + assert_eq!(ws.humanize_type(table_result), "table"); + } + + #[test] + fn test_return_overload_post_guard_reassign_clears_mixed_root_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return boolean + ---@return false|table + ---@return_overload true, false + ---@return_overload false, table + local function pick_right(ok) + if ok then + return true, false + end + return false, {} + end + + local cond ---@type boolean + local other ---@type boolean + local branch ---@type boolean + local next_result ---@type string + local tag, result = pick_left(cond) + + if branch then + _, result = pick_right(other) + end + + if tag == "left_ok" then + before_reassign = result + result = next_result + after_reassign = result + end + "#, + ); + + assert_eq!(ws.expr_ty("before_reassign"), ws.ty("false|table|integer")); + assert_eq!(ws.expr_ty("after_reassign"), ws.ty("string")); + } + + #[test] + fn test_return_overload_reassign_from_fresh_call_ignores_prior_guard() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + local cond ---@type boolean + local branch ---@type boolean + local ok, result = pick(cond, 1, "err") + + if not ok then + error(result) + end + + if branch then + ok, result = pick(cond, "x", 2) + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_branch_reassign_to_different_call_preserves_matching_root_narrowing() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return "right_ok"|"right_err" + ---@return boolean|table + ---@return_overload "right_ok", boolean + ---@return_overload "right_err", table + local function pick_right(ok) + if ok then + return "right_ok", true + end + return "right_err", {} + end + + local cond ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + tag, result = pick_right(cond) + end + + at_join = result + + if tag == "left_ok" then + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("at_join"), ws.ty("boolean|table|integer|string")); + assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer")); + } + + #[test] + fn test_return_overload_branch_reassign_to_different_call_narrows_alternate_matching_root() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return "left_ok"|"left_err" + ---@return integer|string + ---@return_overload "left_ok", integer + ---@return_overload "left_err", string + local function pick_left(ok) + if ok then + return "left_ok", 1 + end + return "left_err", "err" + end + + ---@param ok boolean + ---@return "right_ok"|"right_err" + ---@return boolean|table + ---@return_overload "right_ok", boolean + ---@return_overload "right_err", table + local function pick_right(ok) + if ok then + return "right_ok", true + end + return "right_err", {} + end + + local cond ---@type boolean + local branch ---@type boolean + local tag, result = pick_left(cond) + + if branch then + tag, result = pick_right(cond) + end + + if tag == "right_ok" then + narrowed = result + end + "#, + ); + + assert_eq!(ws.expr_ty("narrowed"), ws.ty("boolean")); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs index fa8e9c209..f67193db2 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/assign_type_mismatch_test.rs @@ -1184,4 +1184,64 @@ return t "#, )); } + + #[test] + fn test_exact_string_reassignment_in_narrowed_branch_keeps_assign_literal() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + local x ---@type string|number + + if x == 1 then + x = "a" + + ---@type "a" + local y = x + end + "#, + )); + } + + #[test] + fn test_return_overload_mixed_guards_keep_assign_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::AssignTypeMismatch, + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + ---@param cond boolean + local function test(cond) + local ok, result = pick(cond, 1, "err") + + if ok == false then + error(result) + end + + if not ok then + error(result) + end + + ---@type integer + local narrowed = result + end + "#, + )); + } } diff --git a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs index 1fcb2438c..bdf6cd523 100644 --- a/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs +++ b/crates/emmylua_code_analysis/src/diagnostic/test/return_type_mismatch_test.rs @@ -131,6 +131,158 @@ mod tests { )); } + #[test] + fn test_discriminated_union_assignment_keeps_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param x Foo|Bar + ---@return Foo + local function test(x) + if x.kind == "foo" then + x = { kind = "foo", a = 1 } + return x + end + + return { kind = "foo", a = 2 } + end + "# + )); + } + + #[test] + fn test_discriminated_union_partial_assignment_keeps_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param x Foo|Bar + ---@return Foo + local function test(x) + if x.kind == "foo" then + x = {} + x.kind = "foo" + x.a = 1 + return x + end + + return { kind = "foo", a = 2 } + end + "# + )); + } + + #[test] + fn test_discriminated_union_partial_literal_assignment_keeps_branch_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@class Foo + ---@field kind "foo" + ---@field a integer + + ---@class Bar + ---@field kind "bar" + ---@field b integer + + ---@param x Foo|Bar + ---@return Foo + local function test(x) + if x.kind == "foo" then + x = { kind = "foo" } + x.a = 1 + return x + end + + return { kind = "foo", a = 2 } + end + "# + )); + } + + #[test] + fn test_exact_string_reassignment_in_narrowed_branch_keeps_return_literal() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@param x string|number + ---@return "a" + local function test(x) + if x == 1 then + x = "a" + return x + end + + return "a" + end + "# + )); + } + + #[test] + fn test_return_overload_mixed_guards_keep_return_narrowing() { + let mut ws = VirtualWorkspace::new(); + + assert!(ws.check_code_for( + DiagnosticCode::ReturnTypeMismatch, + r#" + ---@generic T, E + ---@param ok boolean + ---@param success T + ---@param failure E + ---@return boolean + ---@return T|E + ---@return_overload true, T + ---@return_overload false, E + local function pick(ok, success, failure) + if ok then + return true, success + end + return false, failure + end + + ---@param cond boolean + ---@return integer + local function test(cond) + local ok, result = pick(cond, 1, "err") + + if ok == false then + error(result) + end + + if not ok then + error(result) + end + + return result + end + "# + )); + } + #[test] fn test_variadic_return_type_mismatch() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs index a11a2f6d4..8958d49a1 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs @@ -22,7 +22,7 @@ pub struct LuaInferCache { pub expr_cache: HashMap>, pub call_cache: HashMap<(LuaSyntaxId, Option, LuaType), CacheEntry>>, - pub flow_node_cache: HashMap<(VarRefId, FlowId), CacheEntry>, + pub(crate) flow_node_cache: HashMap<(VarRefId, FlowId, bool), CacheEntry>, pub index_ref_origin_type_cache: HashMap>, pub expr_var_ref_id_cache: HashMap, pub narrow_by_literal_stop_position_cache: HashSet, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index f8052cadf..f6bf3744d 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -12,8 +12,9 @@ use crate::{ narrow::{ ResultTypeOrContinue, condition_flow::{ - InferConditionFlow, always_literal_equal, call_flow::get_type_at_call_expr, - correlated_flow::narrow_var_from_return_overload_condition, + ConditionFlowAction, InferConditionFlow, PendingConditionNarrow, + always_literal_equal, call_flow::get_type_at_call_expr, + correlated_flow::prepare_var_from_return_overload_condition, }, get_single_antecedent, get_type_at_flow::get_type_at_flow, @@ -33,13 +34,13 @@ pub fn get_type_at_binary_expr( flow_node: &FlowNode, binary_expr: LuaBinaryExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(op_token) = binary_expr.get_op_token() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some((left_expr, right_expr)) = binary_expr.get_exprs() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match op_token.get_op() { @@ -76,7 +77,8 @@ pub fn get_type_at_binary_expr( right_expr, condition_flow, true, - ), + ) + .map(Into::into), BinaryOperator::OpGe => try_get_at_gt_or_ge_expr( db, tree, @@ -88,8 +90,9 @@ pub fn get_type_at_binary_expr( right_expr, condition_flow, false, - ), - _ => Ok(ResultTypeOrContinue::Continue), + ) + .map(Into::into), + _ => Ok(ConditionFlowAction::Continue), } } @@ -104,8 +107,8 @@ fn try_get_at_eq_or_neq_expr( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { - if let ResultTypeOrContinue::Result(result_type) = maybe_type_guard_binary( +) -> Result { + if let Some(action) = maybe_type_guard_binary_action( db, tree, cache, @@ -116,21 +119,7 @@ fn try_get_at_eq_or_neq_expr( right_expr.clone(), condition_flow, )? { - return Ok(ResultTypeOrContinue::Result(result_type)); - } - - if let ResultTypeOrContinue::Result(result_type) = maybe_field_literal_eq_narrow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - left_expr.clone(), - right_expr.clone(), - condition_flow, - )? { - return Ok(ResultTypeOrContinue::Result(result_type)); + return Ok(action); } let (left_expr, right_expr) = if !matches!( @@ -145,7 +134,21 @@ fn try_get_at_eq_or_neq_expr( (left_expr, right_expr) }; - maybe_var_eq_narrow( + if let Some(action) = maybe_field_literal_eq_action( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + left_expr.clone(), + right_expr.clone(), + condition_flow, + )? { + return Ok(action); + } + + get_var_eq_condition_action( db, tree, cache, @@ -196,7 +199,7 @@ fn try_get_at_gt_or_ge_expr( } let right_expr_type = infer_expr(db, cache, right_expr)?; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; match (&antecedent_type, &right_expr_type) { @@ -225,7 +228,7 @@ fn try_get_at_gt_or_ge_expr( } #[allow(clippy::too_many_arguments)] -fn maybe_type_guard_binary( +fn maybe_type_guard_binary_action( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -235,7 +238,7 @@ fn maybe_type_guard_binary( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result, InferFailReason> { let (candidate_expr, literal_expr) = match (left_expr, right_expr) { // If either side is a literal expression and the other side is a type guard call expression // (or ref), we can narrow it @@ -249,7 +252,7 @@ fn maybe_type_guard_binary( let (Some(candidate_expr), Some(LuaLiteralToken::String(literal_string))) = (candidate_expr, literal_expr.and_then(|e| e.get_literal())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let candidate_expr = match candidate_expr { @@ -266,53 +269,61 @@ fn maybe_type_guard_binary( LuaExpr::CallExpr(call_expr) if call_expr.is_type() => Some(call_expr), _ => None, }) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(narrow) = type_call_name_to_type(&literal_string.get_value()) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(arg) = type_guard_expr .get_args_list() .and_then(|arg_list| arg_list.get_args().next()) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, arg) else { // If we cannot find a reference declaration ID, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + if maybe_var_ref_id == *var_ref_id { + return Ok(Some(ConditionFlowAction::Pending( + PendingConditionNarrow::TypeGuard { + narrow, + condition_flow, + }, + ))); + } + + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, &maybe_var_ref_id, antecedent_flow_id)?; let narrowed_discriminant_type = match condition_flow { InferConditionFlow::TrueCondition => { - narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow) + narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow.clone()) } InferConditionFlow::FalseCondition => TypeOps::Remove.apply(db, &antecedent_type, &narrow), }; - if maybe_var_ref_id == *var_ref_id { - Ok(ResultTypeOrContinue::Result(narrowed_discriminant_type)) - } else { - let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { - return Ok(ResultTypeOrContinue::Continue); - }; - narrow_var_from_return_overload_condition( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - discriminant_decl_id, - type_guard_expr.get_position(), - &narrowed_discriminant_type, - ) - } + let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { + return Ok(None); + }; + + Ok(prepare_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + discriminant_decl_id, + type_guard_expr.get_position(), + &narrowed_discriminant_type, + )? + .map(PendingConditionNarrow::Correlated) + .map(ConditionFlowAction::Pending)) } /// Maps the string result of Lua's builtin `type()` call to the corresponding `LuaType`. @@ -335,12 +346,31 @@ fn narrow_eq_condition( antecedent_type: LuaType, right_expr_type: LuaType, condition_flow: InferConditionFlow, + allow_literal_equivalence: bool, ) -> LuaType { match condition_flow { InferConditionFlow::TrueCondition => { let left_maybe_type = TypeOps::Intersect.apply(db, &antecedent_type, &right_expr_type); if left_maybe_type.is_never() { + if allow_literal_equivalence { + let literal_matches = match &antecedent_type { + LuaType::Union(union) => union + .into_vec() + .into_iter() + .filter(|candidate| always_literal_equal(candidate, &right_expr_type)) + .collect::>(), + _ if always_literal_equal(&antecedent_type, &right_expr_type) => { + vec![antecedent_type.clone()] + } + _ => Vec::new(), + }; + + if !literal_matches.is_empty() { + return LuaType::from_vec(literal_matches); + } + } + antecedent_type } else { left_maybe_type @@ -353,7 +383,7 @@ fn narrow_eq_condition( } #[allow(clippy::too_many_arguments)] -fn maybe_var_eq_narrow( +fn get_var_eq_condition_action( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -363,28 +393,28 @@ fn maybe_var_eq_narrow( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { // only check left as need narrow match left_expr { LuaExpr::NameExpr(left_name_expr) => { let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::NameExpr(left_name_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let right_expr_type = infer_expr(db, cache, right_expr)?; if maybe_ref_id != *var_ref_id { let Some(discriminant_decl_id) = maybe_ref_id.get_decl_id_ref() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let antecedent_type = get_type_at_flow(db, tree, cache, root, &maybe_ref_id, antecedent_flow_id)?; let narrowed_discriminant_type = - narrow_eq_condition(db, antecedent_type, right_expr_type, condition_flow); - return narrow_var_from_return_overload_condition( + narrow_eq_condition(db, antecedent_type, right_expr_type, condition_flow, true); + return Ok(prepare_var_from_return_overload_condition( db, tree, cache, @@ -394,26 +424,32 @@ fn maybe_var_eq_narrow( discriminant_decl_id, left_name_expr.get_position(), &narrowed_discriminant_type, - ); + )? + .map(PendingConditionNarrow::Correlated) + .map(ConditionFlowAction::Pending) + .unwrap_or(ConditionFlowAction::Continue)); } - let left_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - let result_type = match condition_flow { InferConditionFlow::TrueCondition => { // self 是特殊的, 我们删除其 nil 类型 if var_ref_id.is_self_ref() && !right_expr_type.is_nil() { TypeOps::Remove.apply(db, &right_expr_type, &LuaType::Nil) } else { - narrow_eq_condition(db, left_type, right_expr_type, condition_flow) + return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + })); } } InferConditionFlow::FalseCondition => { - TypeOps::Remove.apply(db, &left_type, &right_expr_type) + return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + })); } }; - Ok(ResultTypeOrContinue::Result(result_type)) + Ok(ConditionFlowAction::Result(result_type)) } LuaExpr::CallExpr(left_call_expr) => { if let LuaExpr::LiteralExpr(literal_expr) = right_expr { @@ -425,72 +461,61 @@ fn maybe_var_eq_narrow( condition_flow.get_negated() }; - return get_type_at_call_expr( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - left_call_expr, - flow, - ); + return get_type_at_call_expr(db, cache, var_ref_id, left_call_expr, flow); } - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(ConditionFlowAction::Continue), } }; - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } LuaExpr::IndexExpr(left_index_expr) => { let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(left_index_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if maybe_ref_id != *var_ref_id { // If the reference declaration ID does not match, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } let right_expr_type = infer_expr(db, cache, right_expr)?; - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => right_expr_type, - InferConditionFlow::FalseCondition => { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - TypeOps::Remove.apply(db, &antecedent_type, &right_expr_type) - } - }; - Ok(ResultTypeOrContinue::Result(result_type)) + if condition_flow.is_false() { + return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + })); + } + + Ok(ConditionFlowAction::Result(right_expr_type)) } LuaExpr::UnaryExpr(unary_expr) => { let Some(op) = unary_expr.get_op_token() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match op.get_op() { UnaryOperator::OpLen => {} - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(ConditionFlowAction::Continue), }; let Some(expr) = unary_expr.get_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, expr) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if maybe_ref_id != *var_ref_id { // If the reference declaration ID does not match, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } let right_expr_type = infer_expr(db, cache, right_expr)?; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; match (&antecedent_type, &right_expr_type) { @@ -501,25 +526,25 @@ fn maybe_var_eq_narrow( if condition_flow.is_true() { let new_array_type = LuaArrayType::new(array_type.get_base().clone(), LuaArrayLen::Max(*i)); - return Ok(ResultTypeOrContinue::Result(LuaType::Array( + return Ok(ConditionFlowAction::Result(LuaType::Array( new_array_type.into(), ))); } } - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(ConditionFlowAction::Continue), } - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } _ => { // If the left expression is not a name or call expression, we cannot narrow it - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } } } #[allow(clippy::too_many_arguments)] -fn maybe_field_literal_eq_narrow( +fn maybe_field_literal_eq_action( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -529,7 +554,7 @@ fn maybe_field_literal_eq_narrow( left_expr: LuaExpr, right_expr: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result, InferFailReason> { // only check left as need narrow let syntax_id = left_expr.get_syntax_id(); let (index_expr, literal_expr) = match (left_expr, right_expr) { @@ -539,16 +564,16 @@ fn maybe_field_literal_eq_narrow( (LuaExpr::LiteralExpr(literal_expr), LuaExpr::IndexExpr(index_expr)) => { (index_expr, literal_expr) } - _ => return Ok(ResultTypeOrContinue::Continue), + _ => return Ok(None), }; let Some(prefix_expr) = index_expr.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, prefix_expr.clone()) else { // If we cannot find a reference declaration ID, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; if maybe_var_ref_id != *var_ref_id { @@ -557,18 +582,18 @@ fn maybe_field_literal_eq_narrow( .contains(&syntax_id) && var_ref_id.start_with(&maybe_var_ref_id) { - return Ok(ResultTypeOrContinue::Result(get_var_ref_type( + return Ok(Some(ConditionFlowAction::Result(get_var_ref_type( db, cache, var_ref_id, - )?)); + )?))); } - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let left_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; let LuaType::Union(union_type) = left_type else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; cache @@ -599,16 +624,18 @@ fn maybe_field_literal_eq_narrow( match condition_flow { InferConditionFlow::TrueCondition => { if let Some(i) = opt_result { - return Ok(ResultTypeOrContinue::Result(union_types[i].clone())); + return Ok(Some(ConditionFlowAction::Result(union_types[i].clone()))); } } InferConditionFlow::FalseCondition => { if let Some(i) = opt_result { union_types.remove(i); - return Ok(ResultTypeOrContinue::Result(LuaType::from_vec(union_types))); + return Ok(Some(ConditionFlowAction::Result(LuaType::from_vec( + union_types, + )))); } } } - Ok(ResultTypeOrContinue::Continue) + Ok(None) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index 8ad1b8b0e..b9e1d9793 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -1,61 +1,88 @@ use std::{ops::Deref, sync::Arc}; -use emmylua_parser::{LuaCallExpr, LuaChunk, LuaExpr}; +use emmylua_parser::{LuaCallExpr, LuaExpr, LuaIndexMemberExpr}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaAliasCallKind, LuaAliasCallType, - LuaFunctionType, LuaInferCache, LuaSignatureCast, LuaSignatureId, LuaType, TypeOps, - infer_call_expr_func, infer_expr, + DbIndex, InferFailReason, InferGuard, LuaAliasCallKind, LuaAliasCallType, LuaFunctionType, + LuaInferCache, LuaSignatureId, LuaType, infer_call_expr_func, infer_expr, semantic::infer::{ VarRefId, + infer_index::infer_member_by_member_key, narrow::{ - ResultTypeOrContinue, condition_flow::InferConditionFlow, get_single_antecedent, - get_type_at_cast_flow::cast_type, get_type_at_flow::get_type_at_flow, - narrow_false_or_nil, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id, + ResultTypeOrContinue, + condition_flow::{ConditionFlowAction, InferConditionFlow, PendingConditionNarrow}, + get_var_ref_type, narrow_false_or_nil, remove_false_or_nil, + var_ref_id::get_var_expr_var_ref_id, }, }, }; -#[allow(clippy::too_many_arguments)] pub fn get_type_at_call_expr( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(prefix_expr) = call_expr.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - let maybe_func = infer_expr(db, cache, prefix_expr.clone())?; + let maybe_func = if call_expr.is_colon_call() { + match &prefix_expr { + LuaExpr::IndexExpr(index_expr) => { + if let Some(self_expr) = index_expr.get_prefix_expr() + && let Some(self_var_ref_id) = get_var_expr_var_ref_id(db, cache, self_expr) + && self_var_ref_id == *var_ref_id + { + let self_type = get_var_ref_type(db, cache, var_ref_id)?; + let member_type = infer_member_by_member_key( + db, + cache, + &self_type, + LuaIndexMemberExpr::IndexExpr(index_expr.clone()), + &InferGuard::new(), + )?; + + if needs_antecedent_same_var_colon_lookup(&member_type) { + // Keep the dedicated pending case here: replay needs the antecedent type + // for member lookup itself, not just for applying a cast after lookup. + return Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::SameVarColonCall { + index: LuaIndexMemberExpr::IndexExpr(index_expr.clone()), + condition_flow, + }, + )); + } else { + member_type + } + } else { + infer_expr(db, cache, prefix_expr.clone())? + } + } + _ => infer_expr(db, cache, prefix_expr.clone())?, + } + } else { + infer_expr(db, cache, prefix_expr.clone())? + }; match maybe_func { LuaType::DocFunction(f) => { let return_type = f.get_ret(); match return_type { LuaType::TypeGuard(_) => get_type_at_call_expr_by_type_guard( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, f, condition_flow, ), - _ => { - // If the return type is not a type guard, we cannot infer the type cast. - Ok(ResultTypeOrContinue::Continue) - } + _ => Ok(ConditionFlowAction::Continue), } } LuaType::Signature(signature_id) => { let Some(signature) = db.get_signature_index().get(&signature_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let ret = signature.get_return_type(); @@ -63,99 +90,92 @@ pub fn get_type_at_call_expr( LuaType::TypeGuard(_) => { return get_type_at_call_expr_by_type_guard( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, signature.to_doc_func_type(), condition_flow, ); } LuaType::Call(call) => { - return get_type_at_call_expr_by_call( + return Ok(get_type_at_call_expr_by_call( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, &call, condition_flow, - ); + )? + .into()); } _ => {} } let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match signature_cast.name.as_str() { "self" => get_type_at_call_expr_by_signature_self( db, - tree, cache, - root, var_ref_id, - flow_node, prefix_expr, - signature_cast, signature_id, condition_flow, ), name => get_type_at_call_expr_by_signature_param_name( db, - tree, cache, - root, var_ref_id, - flow_node, call_expr, - signature_cast, signature_id, name, condition_flow, ), } } - _ => { - // If the prefix expression is not a function, we cannot infer the type cast. - Ok(ResultTypeOrContinue::Continue) - } + _ => Ok(ConditionFlowAction::Continue), } } -#[allow(clippy::too_many_arguments)] -fn get_type_at_call_expr_by_type_guard( +fn needs_antecedent_same_var_colon_lookup(member_type: &LuaType) -> bool { + let candidate_members = match member_type { + LuaType::Union(union_type) => union_type.into_vec(), + LuaType::MultiLineUnion(multi_union) => match multi_union.to_union() { + LuaType::Union(union_type) => union_type.into_vec(), + _ => return false, + }, + _ => return false, + }; + + candidate_members.len() > 1 + && candidate_members.iter().any(|ty| { + matches!( + ty, + LuaType::DocFunction(_) | LuaType::Signature(_) | LuaType::Call(_) + ) + }) +} + +fn get_type_guard_call_info( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, - var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, func_type: Arc, - condition_flow: InferConditionFlow, -) -> Result { +) -> Result, InferFailReason> { let Some(arg_list) = call_expr.get_args_list() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(first_arg) = arg_list.get_args().next() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; let Some(maybe_ref_id) = get_var_expr_var_ref_id(db, cache, first_arg) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; - if maybe_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); - } - let mut return_type = func_type.get_ret().clone(); if return_type.contain_tpl() { let call_expr_type = LuaType::DocFunction(func_type); @@ -171,142 +191,94 @@ fn get_type_at_call_expr_by_type_guard( return_type = inst_func.get_ret().clone(); } - let guard_type = match return_type { - LuaType::TypeGuard(guard) => guard.deref().clone(), - _ => return Ok(ResultTypeOrContinue::Continue), + let LuaType::TypeGuard(guard) = return_type else { + return Ok(None); }; - match condition_flow { - InferConditionFlow::TrueCondition => Ok(ResultTypeOrContinue::Result(guard_type)), - InferConditionFlow::FalseCondition => { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - Ok(ResultTypeOrContinue::Result(TypeOps::Remove.apply( - db, - &antecedent_type, - &guard_type, - ))) - } + Ok(Some((maybe_ref_id, guard.deref().clone()))) +} + +fn get_type_at_call_expr_by_type_guard( + db: &DbIndex, + cache: &mut LuaInferCache, + var_ref_id: &VarRefId, + call_expr: LuaCallExpr, + func_type: Arc, + condition_flow: InferConditionFlow, +) -> Result { + let Some((maybe_ref_id, guard_type)) = + get_type_guard_call_info(db, cache, call_expr, func_type)? + else { + return Ok(ConditionFlowAction::Continue); + }; + + if maybe_ref_id != *var_ref_id { + return Ok(ConditionFlowAction::Continue); } + + Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::TypeGuard { + narrow: guard_type, + condition_flow, + }, + )) } -#[allow(clippy::too_many_arguments)] fn get_type_at_call_expr_by_signature_self( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_prefix: LuaExpr, - signature_cast: &LuaSignatureCast, signature_id: LuaSignatureId, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let LuaExpr::IndexExpr(call_prefix_index) = call_prefix else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(self_expr) = call_prefix_index.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, self_expr) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if name_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let signature_root = syntax_tree.get_chunk_node(); - - // Choose the appropriate cast based on condition_flow and whether fallback exists - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => { - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - InferConditionFlow::FalseCondition => { - // Use fallback_cast if available, otherwise use the default behavior - if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { - let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - fallback_op_type, - antecedent_type.clone(), - InferConditionFlow::TrueCondition, // Apply fallback as force cast - )? - } else { - // Original behavior: remove the true type from antecedent - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - } - }; - - Ok(ResultTypeOrContinue::Result(result_type)) + Ok(get_signature_cast_pending(signature_id, condition_flow)) } #[allow(clippy::too_many_arguments)] fn get_type_at_call_expr_by_signature_param_name( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, - signature_cast: &LuaSignatureCast, signature_id: LuaSignatureId, name: &str, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let colon_call = call_expr.is_colon_call(); let Some(arg_list) = call_expr.get_args_list() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(signature) = db.get_signature_index().get(&signature_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(mut param_idx) = signature.find_param_idx(name) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let colon_define = signature.is_colon_define; match (colon_call, colon_define) { (true, false) => { if param_idx == 0 { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } param_idx -= 1; @@ -318,80 +290,34 @@ fn get_type_at_call_expr_by_signature_param_name( } let Some(expr) = arg_list.get_args().nth(param_idx) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, expr) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if name_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&signature_id.get_file_id()) else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let signature_root = syntax_tree.get_chunk_node(); - - // Choose the appropriate cast based on condition_flow and whether fallback exists - let result_type = match condition_flow { - InferConditionFlow::TrueCondition => { - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - InferConditionFlow::FalseCondition => { - // Use fallback_cast if available, otherwise use the default behavior - if let Some(fallback_cast_ptr) = &signature_cast.fallback_cast { - let Some(fallback_op_type) = fallback_cast_ptr.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - fallback_op_type, - antecedent_type.clone(), - InferConditionFlow::TrueCondition, // Apply fallback as force cast - )? - } else { - // Original behavior: remove the true type from antecedent - let Some(cast_op_type) = signature_cast.cast.to_node(&signature_root) else { - return Ok(ResultTypeOrContinue::Continue); - }; - cast_type( - db, - signature_id.get_file_id(), - cast_op_type, - antecedent_type, - condition_flow, - )? - } - } - }; + Ok(get_signature_cast_pending(signature_id, condition_flow)) +} - Ok(ResultTypeOrContinue::Result(result_type)) +fn get_signature_cast_pending( + signature_id: LuaSignatureId, + condition_flow: InferConditionFlow, +) -> ConditionFlowAction { + ConditionFlowAction::Pending(PendingConditionNarrow::SignatureCast { + signature_id, + condition_flow, + }) } -#[allow(unused, clippy::too_many_arguments)] fn get_type_at_call_expr_by_call( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, call_expr: LuaCallExpr, alias_call_type: &Arc, condition_flow: InferConditionFlow, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 8d0ddaac1..ad5e030c7 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -7,12 +7,101 @@ use crate::{ LuaInferCache, LuaType, TypeOps, infer_expr, instantiate_func_generic, semantic::infer::{ VarRefId, - narrow::{ResultTypeOrContinue, get_single_antecedent, get_type_at_flow::get_type_at_flow}, + narrow::{get_single_antecedent, get_type_at_flow::get_type_at_flow}, }, }; +#[derive(Debug, Clone)] +pub(in crate::semantic::infer::narrow) struct CorrelatedConditionNarrowing { + search_root_correlated_types: Vec, +} + +#[derive(Debug, Clone)] +struct SearchRootCorrelatedTypes { + matching_target_types: Vec, + uncorrelated_target_types: Vec, + deferred_known_call_target_types: Option>, +} + +impl CorrelatedConditionNarrowing { + pub(in crate::semantic::infer::narrow) fn apply( + self, + db: &DbIndex, + antecedent_type: LuaType, + ) -> LuaType { + let mut root_target_types = Vec::new(); + let mut found_matching_root = false; + for root_types in self.search_root_correlated_types { + let SearchRootCorrelatedTypes { + matching_target_types, + mut uncorrelated_target_types, + deferred_known_call_target_types, + } = root_types; + + let root_matching_target_type = if matching_target_types.is_empty() { + None + } else { + let matching_target_type = LuaType::from_vec(matching_target_types); + let narrowed_correlated_type = + TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); + if narrowed_correlated_type.is_never() { + None + } else { + found_matching_root = true; + Some(narrowed_correlated_type) + } + }; + + if let Some(known_call_target_types) = deferred_known_call_target_types { + let remaining_root_type = + if known_call_target_types.is_empty() && uncorrelated_target_types.is_empty() { + Some(antecedent_type.clone()) + } else { + subtract_correlated_candidate_types( + db, + antecedent_type.clone(), + &known_call_target_types, + ) + }; + if let Some(remaining_root_type) = remaining_root_type { + uncorrelated_target_types.push(remaining_root_type); + } + } + + let root_uncorrelated_target_type = (!uncorrelated_target_types.is_empty()) + .then(|| LuaType::from_vec(uncorrelated_target_types)); + + match (root_matching_target_type, root_uncorrelated_target_type) { + (Some(root_matching_target_type), Some(root_uncorrelated_target_type)) => { + root_target_types.push(LuaType::from_vec(vec![ + root_matching_target_type, + root_uncorrelated_target_type, + ])); + } + (Some(root_matching_target_type), None) => { + root_target_types.push(root_matching_target_type); + } + (None, Some(root_uncorrelated_target_type)) => { + root_target_types.push(root_uncorrelated_target_type); + } + (None, None) => {} + } + } + + if !found_matching_root { + return antecedent_type; + } + + if root_target_types.is_empty() { + antecedent_type + } else { + LuaType::from_vec(root_target_types) + } + } +} + #[allow(clippy::too_many_arguments)] -pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return_overload_condition( +pub(in crate::semantic::infer::narrow) fn prepare_var_from_return_overload_condition( db: &DbIndex, tree: &FlowTree, cache: &mut LuaInferCache, @@ -22,27 +111,27 @@ pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return discriminant_decl_id: LuaDeclId, condition_position: rowan::TextSize, narrowed_discriminant_type: &LuaType, -) -> Result { +) -> Result, InferFailReason> { let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); }; if !tree.has_decl_multi_return_refs(&discriminant_decl_id) || !tree.has_decl_multi_return_refs(&target_decl_id) { - return Ok(ResultTypeOrContinue::Continue); + return Ok(None); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let search_root_flow_ids = tree.get_decl_multi_return_search_roots( &discriminant_decl_id, &target_decl_id, condition_position, antecedent_flow_id, ); - let mut matching_target_types = Vec::new(); - let mut uncorrelated_target_types = Vec::new(); - for search_root_flow_id in search_root_flow_ids { - let (root_matching_target_types, root_uncorrelated_target_type) = + let root_correlated_types = search_root_flow_ids + .iter() + .copied() + .map(|search_root_flow_id| { collect_correlated_types_from_search_root( db, tree, @@ -52,47 +141,23 @@ pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return discriminant_decl_id, target_decl_id, condition_position, + antecedent_flow_id, search_root_flow_id, narrowed_discriminant_type, - )?; - matching_target_types.extend(root_matching_target_types); - if let Some(root_uncorrelated_target_type) = root_uncorrelated_target_type { - uncorrelated_target_types.push(root_uncorrelated_target_type); - } - } + ) + }) + .collect::, _>>()?; - if matching_target_types.is_empty() { - return Ok(ResultTypeOrContinue::Continue); - } - - let matching_target_type = LuaType::from_vec(matching_target_types); - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - let narrowed_correlated_type = - TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); - if narrowed_correlated_type.is_never() { - return Ok(ResultTypeOrContinue::Continue); - } - - if uncorrelated_target_types.is_empty() { - return Ok(if narrowed_correlated_type == antecedent_type { - ResultTypeOrContinue::Continue - } else { - ResultTypeOrContinue::Result(narrowed_correlated_type) - }); + if root_correlated_types + .iter() + .all(|root_types| root_types.matching_target_types.is_empty()) + { + return Ok(None); } - let uncorrelated_target_type = LuaType::from_vec(uncorrelated_target_types); - let merged_type = if uncorrelated_target_type.is_never() { - narrowed_correlated_type - } else { - LuaType::from_vec(vec![narrowed_correlated_type, uncorrelated_target_type]) - }; - - Ok(if merged_type == antecedent_type { - ResultTypeOrContinue::Continue - } else { - ResultTypeOrContinue::Result(merged_type) - }) + Ok(Some(CorrelatedConditionNarrowing { + search_root_correlated_types: root_correlated_types, + })) } #[allow(clippy::too_many_arguments)] @@ -105,9 +170,10 @@ fn collect_correlated_types_from_search_root( discriminant_decl_id: LuaDeclId, target_decl_id: LuaDeclId, condition_position: rowan::TextSize, + antecedent_flow_id: FlowId, search_root_flow_id: FlowId, narrowed_discriminant_type: &LuaType, -) -> Result<(Vec, Option), InferFailReason> { +) -> Result { let (discriminant_refs, discriminant_has_non_reference_origin) = tree .get_decl_multi_return_ref_summary_at( &discriminant_decl_id, @@ -119,17 +185,12 @@ fn collect_correlated_types_from_search_root( condition_position, search_root_flow_id, ); - if discriminant_refs.is_empty() || target_refs.is_empty() { - return Ok(( - Vec::new(), - get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id).ok(), - )); - } - let ( root_matching_target_types, root_correlated_candidate_types, - has_unmatched_correlated_origin, + root_unmatched_target_types, + has_unmatched_discriminant_origin, + has_opaque_target_origin, ) = collect_matching_correlated_types( db, cache, @@ -138,27 +199,48 @@ fn collect_correlated_types_from_search_root( &target_refs, narrowed_discriminant_type, )?; - if root_matching_target_types.is_empty() { - return Ok(( - Vec::new(), - get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id).ok(), - )); - } - let root_uncorrelated_target_type = if discriminant_has_non_reference_origin + let mut root_uncorrelated_target_types = root_unmatched_target_types; + let has_uncorrelated_origin = discriminant_has_non_reference_origin || target_has_non_reference_origin - || has_unmatched_correlated_origin - { - get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) - .ok() - .and_then(|root_type| { - subtract_correlated_candidate_types(db, root_type, &root_correlated_candidate_types) - }) - } else { - None - }; + || has_opaque_target_origin + || has_unmatched_discriminant_origin; + let correlated_candidate_types_is_empty = root_correlated_candidate_types.is_empty(); + let deferred_known_call_target_types = + if has_uncorrelated_origin && search_root_flow_id == antecedent_flow_id { + let mut known_call_target_types = root_correlated_candidate_types.clone(); + known_call_target_types.extend(root_uncorrelated_target_types.iter().cloned()); + Some(known_call_target_types) + } else { + None + }; + if has_uncorrelated_origin && deferred_known_call_target_types.is_none() { + if correlated_candidate_types_is_empty && root_uncorrelated_target_types.is_empty() { + if let Ok(root_type) = + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) + { + root_uncorrelated_target_types.push(root_type); + } + } else { + let mut known_call_target_types = root_correlated_candidate_types; + known_call_target_types.extend(root_uncorrelated_target_types.iter().cloned()); + if let Some(remaining_root_type) = + get_type_at_flow(db, tree, cache, root, var_ref_id, search_root_flow_id) + .ok() + .and_then(|root_type| { + subtract_correlated_candidate_types(db, root_type, &known_call_target_types) + }) + { + root_uncorrelated_target_types.push(remaining_root_type); + } + } + } - Ok((root_matching_target_types, root_uncorrelated_target_type)) + Ok(SearchRootCorrelatedTypes { + matching_target_types: root_matching_target_types, + uncorrelated_target_types: root_uncorrelated_target_types, + deferred_known_call_target_types, + }) } fn subtract_correlated_candidate_types( @@ -195,9 +277,10 @@ fn collect_matching_correlated_types( discriminant_refs: &[crate::DeclMultiReturnRef], target_refs: &[crate::DeclMultiReturnRef], narrowed_discriminant_type: &LuaType, -) -> Result<(Vec, Vec, bool), InferFailReason> { +) -> Result<(Vec, Vec, Vec, bool, bool), InferFailReason> { let mut matching_target_types = Vec::new(); let mut correlated_candidate_types = Vec::new(); + let mut unmatched_target_types = Vec::new(); let mut correlated_discriminant_call_expr_ids = HashSet::new(); let mut correlated_target_call_expr_ids = HashSet::new(); @@ -211,7 +294,7 @@ fn collect_matching_correlated_types( continue; } - let overload_rows = instantiate_return_overload_rows(db, cache, call_expr, signature); + let overload_rows = instantiate_return_rows(db, cache, call_expr, signature); let discriminant_call_expr_id = discriminant_ref.call_expr.get_syntax_id(); for target_ref in target_refs { @@ -243,15 +326,35 @@ fn collect_matching_correlated_types( } } - let has_unmatched_correlated_origin = discriminant_refs.iter().any(|discriminant_ref| { + let mut has_opaque_target_origin = false; + for target_ref in target_refs { + if correlated_target_call_expr_ids.contains(&target_ref.call_expr.get_syntax_id()) { + continue; + } + + let Some((call_expr, signature)) = + infer_signature_for_call_ptr(db, cache, root, &target_ref.call_expr)? + else { + has_opaque_target_origin = true; + continue; + }; + let return_rows = instantiate_return_rows(db, cache, call_expr, signature); + unmatched_target_types.extend( + return_rows.iter().map(|row| { + crate::LuaSignature::get_overload_row_slot(row, target_ref.return_index) + }), + ); + } + + let has_unmatched_discriminant_origin = discriminant_refs.iter().any(|discriminant_ref| { !correlated_discriminant_call_expr_ids.contains(&discriminant_ref.call_expr.get_syntax_id()) - }) || target_refs.iter().any(|target_ref| { - !correlated_target_call_expr_ids.contains(&target_ref.call_expr.get_syntax_id()) }); Ok(( matching_target_types, correlated_candidate_types, - has_unmatched_correlated_origin, + unmatched_target_types, + has_unmatched_discriminant_origin, + has_opaque_target_origin, )) } @@ -278,12 +381,34 @@ fn infer_signature_for_call_ptr<'a>( Ok(Some((call_expr, signature))) } -fn instantiate_return_overload_rows( +fn instantiate_return_rows( db: &DbIndex, cache: &mut LuaInferCache, call_expr: LuaCallExpr, signature: &crate::LuaSignature, ) -> Vec> { + if signature.return_overloads.is_empty() { + let return_type = signature.get_return_type(); + let instantiated_return_type = if return_type.contain_tpl() { + let func = LuaFunctionType::new( + signature.async_state, + signature.is_colon_define, + signature.is_vararg, + signature.get_type_params(), + return_type.clone(), + ); + match instantiate_func_generic(db, cache, &func, call_expr) { + Ok(instantiated) => instantiated.get_ret().clone(), + Err(_) => return_type, + } + } else { + return_type + }; + return vec![crate::LuaSignature::return_type_to_row( + instantiated_return_type, + )]; + } + let mut rows = Vec::with_capacity(signature.return_overloads.len()); for overload in &signature.return_overloads { let type_refs = &overload.type_refs; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs index ef1738954..9c080c3f0 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs @@ -1,123 +1,52 @@ -use emmylua_parser::{LuaChunk, LuaExpr, LuaIndexExpr, LuaIndexMemberExpr}; +use emmylua_parser::{LuaExpr, LuaIndexExpr, LuaIndexMemberExpr}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaInferCache, LuaType, TypeOps, + DbIndex, InferFailReason, LuaInferCache, semantic::infer::{ VarRefId, - infer_index::infer_member_by_member_key, narrow::{ - ResultTypeOrContinue, condition_flow::InferConditionFlow, get_single_antecedent, - get_type_at_flow::get_type_at_flow, narrow_false_or_nil, remove_false_or_nil, + condition_flow::{ConditionFlowAction, InferConditionFlow, PendingConditionNarrow}, var_ref_id::get_var_expr_var_ref_id, }, }, }; -#[allow(clippy::too_many_arguments)] pub fn get_type_at_index_expr( db: &DbIndex, - tree: &FlowTree, cache: &mut LuaInferCache, - root: &LuaChunk, var_ref_id: &VarRefId, - flow_node: &FlowNode, index_expr: LuaIndexExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::IndexExpr(index_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - if name_var_ref_id != *var_ref_id { - return maybe_field_exist_narrow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - index_expr, - condition_flow, - ); + if name_var_ref_id == *var_ref_id { + return Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::Truthiness(condition_flow), + )); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let result_type = match condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), - }; - - Ok(ResultTypeOrContinue::Result(result_type)) -} - -#[allow(clippy::too_many_arguments)] -fn maybe_field_exist_narrow( - db: &DbIndex, - tree: &FlowTree, - cache: &mut LuaInferCache, - root: &LuaChunk, - var_ref_id: &VarRefId, - flow_node: &FlowNode, - index_expr: LuaIndexExpr, - condition_flow: InferConditionFlow, -) -> Result { let Some(prefix_expr) = index_expr.get_prefix_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(maybe_var_ref_id) = get_var_expr_var_ref_id(db, cache, prefix_expr.clone()) else { // If we cannot find a reference declaration ID, we cannot narrow it - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if maybe_var_ref_id != *var_ref_id { - return Ok(ResultTypeOrContinue::Continue); - } - - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let left_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - let LuaType::Union(union_type) = &left_type else { - return Ok(ResultTypeOrContinue::Continue); - }; - - let index = LuaIndexMemberExpr::IndexExpr(index_expr); - let mut result = vec![]; - let union_types = union_type.into_vec(); - for sub_type in &union_types { - let member_type = match infer_member_by_member_key( - db, - cache, - sub_type, - index.clone(), - &InferGuard::new(), - ) { - Ok(member_type) => member_type, - Err(_) => continue, // If we cannot infer the member type, skip this type - }; - // donot use always true - if !member_type.is_always_falsy() { - result.push(sub_type.clone()); - } + return Ok(ConditionFlowAction::Continue); } - match condition_flow { - InferConditionFlow::TrueCondition => { - if !result.is_empty() { - return Ok(ResultTypeOrContinue::Result(LuaType::from_vec(result))); - } - } - InferConditionFlow::FalseCondition => { - if !result.is_empty() { - let target = LuaType::from_vec(result); - let t = TypeOps::Remove.apply(db, &left_type, &target); - return Ok(ResultTypeOrContinue::Result(t)); - } - } - } - - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::FieldTruthy { + index: LuaIndexMemberExpr::IndexExpr(index_expr), + condition_flow, + }, + )) } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index 1a28ecb92..d65521404 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -1,24 +1,31 @@ mod binary_flow; mod call_flow; -mod correlated_flow; +pub(in crate::semantic::infer::narrow) mod correlated_flow; mod index_flow; -use self::correlated_flow::narrow_var_from_return_overload_condition; -use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator}; +use self::{ + binary_flow::get_type_at_binary_expr, + correlated_flow::{CorrelatedConditionNarrowing, prepare_var_from_return_overload_condition}, +}; +use emmylua_parser::{ + LuaAstNode, LuaChunk, LuaExpr, LuaIndexMemberExpr, LuaNameExpr, LuaUnaryExpr, UnaryOperator, +}; use crate::{ - DbIndex, FlowNode, FlowTree, InferFailReason, LuaInferCache, LuaType, + DbIndex, FlowNode, FlowTree, InferFailReason, InferGuard, LuaInferCache, LuaSignatureCast, + LuaSignatureId, LuaType, semantic::infer::{ VarRefId, + infer_index::infer_member_by_member_key, narrow::{ ResultTypeOrContinue, condition_flow::{ - binary_flow::get_type_at_binary_expr, call_flow::get_type_at_call_expr, - index_flow::get_type_at_index_expr, + call_flow::get_type_at_call_expr, index_flow::get_type_at_index_expr, }, get_single_antecedent, + get_type_at_cast_flow::cast_type, get_type_at_flow::get_type_at_flow, - narrow_false_or_nil, remove_false_or_nil, + narrow_down_type, narrow_false_or_nil, remove_false_or_nil, var_ref_id::get_var_expr_var_ref_id, }, }, @@ -48,6 +55,227 @@ impl InferConditionFlow { } } +#[derive(Debug)] +pub(in crate::semantic::infer::narrow) enum ConditionFlowAction { + Continue, + Result(LuaType), + Pending(PendingConditionNarrow), +} + +impl From for ConditionFlowAction { + fn from(result_or_continue: ResultTypeOrContinue) -> Self { + match result_or_continue { + ResultTypeOrContinue::Continue => ConditionFlowAction::Continue, + ResultTypeOrContinue::Result(result_type) => ConditionFlowAction::Result(result_type), + } + } +} + +#[derive(Debug, Clone)] +pub(in crate::semantic::infer::narrow) enum PendingConditionNarrow { + Truthiness(InferConditionFlow), + FieldTruthy { + index: LuaIndexMemberExpr, + condition_flow: InferConditionFlow, + }, + SameVarColonCall { + index: LuaIndexMemberExpr, + condition_flow: InferConditionFlow, + }, + SignatureCast { + signature_id: LuaSignatureId, + condition_flow: InferConditionFlow, + }, + Eq { + right_expr_type: LuaType, + condition_flow: InferConditionFlow, + }, + TypeGuard { + narrow: LuaType, + condition_flow: InferConditionFlow, + }, + Correlated(CorrelatedConditionNarrowing), +} + +impl PendingConditionNarrow { + pub(in crate::semantic::infer::narrow) fn apply( + self, + db: &DbIndex, + cache: &mut LuaInferCache, + antecedent_type: LuaType, + ) -> LuaType { + match self { + PendingConditionNarrow::Truthiness(condition_flow) => match condition_flow { + InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), + InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), + }, + PendingConditionNarrow::FieldTruthy { + index, + condition_flow, + } => { + let LuaType::Union(union_type) = &antecedent_type else { + return antecedent_type; + }; + + let union_types = union_type.into_vec(); + let mut result = vec![]; + for sub_type in &union_types { + let member_type = match infer_member_by_member_key( + db, + cache, + sub_type, + index.clone(), + &InferGuard::new(), + ) { + Ok(member_type) => member_type, + Err(_) => continue, + }; + + if !member_type.is_always_falsy() { + result.push(sub_type.clone()); + } + } + + if result.is_empty() { + antecedent_type + } else { + match condition_flow { + InferConditionFlow::TrueCondition => LuaType::from_vec(result), + InferConditionFlow::FalseCondition => { + let target = LuaType::from_vec(result); + crate::TypeOps::Remove.apply(db, &antecedent_type, &target) + } + } + } + } + PendingConditionNarrow::SameVarColonCall { + index, + condition_flow, + } => { + let Ok(member_type) = infer_member_by_member_key( + db, + cache, + &antecedent_type, + index, + &InferGuard::new(), + ) else { + return antecedent_type; + }; + + let LuaType::Signature(signature_id) = member_type else { + return antecedent_type; + }; + + let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) + else { + return antecedent_type; + }; + + if signature_cast.name != "self" { + return antecedent_type; + } + + apply_signature_cast( + db, + antecedent_type, + signature_id, + signature_cast, + condition_flow, + ) + } + PendingConditionNarrow::SignatureCast { + signature_id, + condition_flow, + } => { + let Some(signature_cast) = db.get_flow_index().get_signature_cast(&signature_id) + else { + return antecedent_type; + }; + + apply_signature_cast( + db, + antecedent_type, + signature_id, + signature_cast, + condition_flow, + ) + } + PendingConditionNarrow::Eq { + right_expr_type, + condition_flow, + } => match condition_flow { + InferConditionFlow::TrueCondition => { + let maybe_type = + crate::TypeOps::Intersect.apply(db, &antecedent_type, &right_expr_type); + if maybe_type.is_never() { + antecedent_type + } else { + maybe_type + } + } + InferConditionFlow::FalseCondition => { + crate::TypeOps::Remove.apply(db, &antecedent_type, &right_expr_type) + } + }, + PendingConditionNarrow::TypeGuard { + narrow, + condition_flow, + } => match condition_flow { + InferConditionFlow::TrueCondition => { + narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow) + } + InferConditionFlow::FalseCondition => { + crate::TypeOps::Remove.apply(db, &antecedent_type, &narrow) + } + }, + PendingConditionNarrow::Correlated(correlated_narrowing) => { + correlated_narrowing.apply(db, antecedent_type) + } + } + } +} + +fn apply_signature_cast( + db: &DbIndex, + antecedent_type: LuaType, + signature_id: LuaSignatureId, + signature_cast: &LuaSignatureCast, + condition_flow: InferConditionFlow, +) -> LuaType { + let file_id = signature_id.get_file_id(); + let Some(syntax_tree) = db.get_vfs().get_syntax_tree(&file_id) else { + return antecedent_type; + }; + let signature_root = syntax_tree.get_chunk_node(); + + let (cast_ptr, cast_flow) = match condition_flow { + InferConditionFlow::TrueCondition => (&signature_cast.cast, condition_flow), + InferConditionFlow::FalseCondition => ( + signature_cast + .fallback_cast + .as_ref() + .unwrap_or(&signature_cast.cast), + signature_cast + .fallback_cast + .as_ref() + .map(|_| InferConditionFlow::TrueCondition) + .unwrap_or(condition_flow), + ), + }; + let Some(cast_op_type) = cast_ptr.to_node(&signature_root) else { + return antecedent_type; + }; + + cast_type( + db, + file_id, + cast_op_type, + antecedent_type.clone(), + cast_flow, + ) + .unwrap_or(antecedent_type) +} + #[allow(clippy::too_many_arguments)] pub fn get_type_at_condition_flow( db: &DbIndex, @@ -58,7 +286,7 @@ pub fn get_type_at_condition_flow( flow_node: &FlowNode, condition: LuaExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { match condition { LuaExpr::NameExpr(name_expr) => get_type_at_name_expr( db, @@ -70,28 +298,14 @@ pub fn get_type_at_condition_flow( name_expr, condition_flow, ), - LuaExpr::CallExpr(call_expr) => get_type_at_call_expr( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - call_expr, - condition_flow, - ), - LuaExpr::IndexExpr(index_expr) => get_type_at_index_expr( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - index_expr, - condition_flow, - ), + LuaExpr::CallExpr(call_expr) => { + get_type_at_call_expr(db, cache, var_ref_id, call_expr, condition_flow) + } + LuaExpr::IndexExpr(index_expr) => { + get_type_at_index_expr(db, cache, var_ref_id, index_expr, condition_flow) + } LuaExpr::TableExpr(_) | LuaExpr::LiteralExpr(_) | LuaExpr::ClosureExpr(_) => { - Ok(ResultTypeOrContinue::Continue) + Ok(ConditionFlowAction::Continue) } LuaExpr::BinaryExpr(binary_expr) => get_type_at_binary_expr( db, @@ -115,7 +329,7 @@ pub fn get_type_at_condition_flow( ), LuaExpr::ParenExpr(paren_expr) => { let Some(inner_expr) = paren_expr.get_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; get_type_at_condition_flow( @@ -142,11 +356,11 @@ fn get_type_at_name_expr( flow_node: &FlowNode, name_expr: LuaNameExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(name_var_ref_id) = get_var_expr_var_ref_id(db, cache, LuaExpr::NameExpr(name_expr.clone())) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; if name_var_ref_id != *var_ref_id { @@ -162,15 +376,9 @@ fn get_type_at_name_expr( ); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - - let result_type = match condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), - }; - - Ok(ResultTypeOrContinue::Result(result_type)) + Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::Truthiness(condition_flow), + )) } #[allow(clippy::too_many_arguments)] @@ -183,14 +391,15 @@ fn get_type_at_name_ref( flow_node: &FlowNode, name_expr: LuaNameExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(decl_id) = db .get_reference_index() .get_var_reference_decl(&cache.get_file_id(), name_expr.get_range()) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_discriminant_type = get_type_at_flow( db, tree, @@ -204,7 +413,7 @@ fn get_type_at_name_ref( InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_discriminant_type), }; - if let ResultTypeOrContinue::Result(result_type) = narrow_var_from_return_overload_condition( + if let Some(correlated_narrowing) = prepare_var_from_return_overload_condition( db, tree, cache, @@ -215,15 +424,17 @@ fn get_type_at_name_ref( name_expr.get_position(), &narrowed_discriminant_type, )? { - return Ok(ResultTypeOrContinue::Result(result_type)); + return Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::Correlated(correlated_narrowing), + )); } let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(expr) = expr_ptr.to_node(root) else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; get_type_at_condition_flow( @@ -274,19 +485,19 @@ fn get_type_at_unary_flow( flow_node: &FlowNode, unary_expr: LuaUnaryExpr, condition_flow: InferConditionFlow, -) -> Result { +) -> Result { let Some(inner_expr) = unary_expr.get_expr() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; let Some(op) = unary_expr.get_op_token() else { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); }; match op.get_op() { UnaryOperator::OpNot => {} _ => { - return Ok(ResultTypeOrContinue::Continue); + return Ok(ConditionFlowAction::Continue); } } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs index 6aa14497a..17bf8193d 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_cast_flow.rs @@ -50,7 +50,7 @@ fn get_type_at_cast_expr( return Ok(ResultTypeOrContinue::Continue); } - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let mut antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; for cast_op_type in tag_cast.get_op_types() { @@ -74,7 +74,7 @@ fn get_type_at_inline_cast( flow_node: &FlowNode, tag_cast: LuaDocTagCast, ) -> Result { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let antecedent_flow_id = get_single_antecedent(flow_node)?; let mut antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; for cast_op_type in tag_cast.get_op_types() { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 66ccda945..4ddf3da9c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -2,17 +2,23 @@ use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaChunk, LuaExpr, LuaVarExpr}; use crate::{ CacheEntry, DbIndex, FlowId, FlowNode, FlowNodeKind, FlowTree, InferFailReason, LuaDeclId, - LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, infer_expr, - semantic::infer::{ - InferResult, VarRefId, infer_expr_list_value_type_at, - narrow::{ - ResultTypeOrContinue, - condition_flow::{InferConditionFlow, get_type_at_condition_flow}, - get_multi_antecedents, get_single_antecedent, - get_type_at_cast_flow::get_type_at_cast_flow, - get_var_ref_type, narrow_down_type, - var_ref_id::get_var_expr_var_ref_id, + LuaInferCache, LuaMemberId, LuaSignatureId, LuaType, TypeOps, check_type_compact, infer_expr, + semantic::{ + infer::{ + InferResult, VarRefId, infer_expr_list_value_type_at, + narrow::{ + ResultTypeOrContinue, + condition_flow::{ + ConditionFlowAction, InferConditionFlow, PendingConditionNarrow, + get_type_at_condition_flow, + }, + get_multi_antecedents, get_single_antecedent, + get_type_at_cast_flow::get_type_at_cast_flow, + get_var_ref_type, narrow_down_type, + var_ref_id::get_var_expr_var_ref_id, + }, }, + member::find_members, }, }; @@ -24,171 +30,292 @@ pub fn get_type_at_flow( var_ref_id: &VarRefId, flow_id: FlowId, ) -> InferResult { - let key = (var_ref_id.clone(), flow_id); - if let Some(cache_entry) = cache.flow_node_cache.get(&key) - && let CacheEntry::Cache(narrow_type) = cache_entry - { - return Ok(narrow_type.clone()); + get_type_at_flow_internal(db, tree, cache, root, var_ref_id, flow_id, true) +} + +fn can_reuse_narrowed_assignment_source( + db: &DbIndex, + narrowed_source_type: &LuaType, + expr_type: &LuaType, +) -> bool { + if matches!(expr_type, LuaType::TableConst(_) | LuaType::Object(_)) { + return is_partial_assignment_expr_compatible(db, narrowed_source_type, expr_type); } - let result_type; - let mut antecedent_flow_id = flow_id; - loop { - let flow_node = tree - .get_flow_node(antecedent_flow_id) - .ok_or(InferFailReason::None)?; - - match &flow_node.kind { - FlowNodeKind::Start | FlowNodeKind::Unreachable => { - result_type = get_var_ref_type(db, cache, var_ref_id)?; - break; - } - FlowNodeKind::LoopLabel | FlowNodeKind::Break | FlowNodeKind::Return => { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + if !is_exact_assignment_expr_type(expr_type) { + return false; + } + + match narrow_down_type(db, narrowed_source_type.clone(), expr_type.clone(), None) { + Some(narrowed_expr_type) => narrowed_expr_type == *expr_type, + None => true, + } +} + +fn preserves_assignment_expr_type(typ: &LuaType) -> bool { + matches!(typ, LuaType::TableConst(_) | LuaType::Object(_)) || is_exact_assignment_expr_type(typ) +} + +fn is_partial_assignment_expr_compatible( + db: &DbIndex, + source_type: &LuaType, + expr_type: &LuaType, +) -> bool { + if check_type_compact(db, source_type, expr_type).is_ok() { + return true; + } + + // Only preserve branch narrowing for concrete partial table/object literals. + // Broader RHS expressions can carry hidden state the current flow/type model cannot represent + // without wider semantic changes. + if !matches!(expr_type, LuaType::TableConst(_) | LuaType::Object(_)) { + return false; + } + + let expr_members = find_members(db, expr_type).unwrap_or_default(); + + if expr_members.is_empty() { + return true; + } + + let Some(source_members) = find_members(db, source_type) else { + return false; + }; + + expr_members.into_iter().all(|expr_member| { + match source_members + .iter() + .find(|source_member| source_member.key == expr_member.key) + { + Some(source_member) => { + is_partial_assignment_expr_compatible(db, &source_member.typ, &expr_member.typ) } - FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => { - let multi_antecedents = get_multi_antecedents(tree, flow_node)?; - - let mut branch_result_type = LuaType::Unknown; - for &flow_id in &multi_antecedents { - let branch_type = get_type_at_flow(db, tree, cache, root, var_ref_id, flow_id)?; - branch_result_type = - TypeOps::Union.apply(db, &branch_result_type, &branch_type); + None => true, + } + }) +} + +fn is_exact_assignment_expr_type(typ: &LuaType) -> bool { + match typ { + LuaType::Nil | LuaType::DocBooleanConst(_) => true, + typ if typ.is_const() => !matches!(typ, LuaType::TableConst(_)), + LuaType::Union(union) => union.into_vec().iter().all(is_exact_assignment_expr_type), + LuaType::MultiLineUnion(multi_union) => { + is_exact_assignment_expr_type(&multi_union.to_union()) + } + LuaType::TypeGuard(inner) => is_exact_assignment_expr_type(inner), + _ => false, + } +} + +fn get_type_at_flow_internal( + db: &DbIndex, + tree: &FlowTree, + cache: &mut LuaInferCache, + root: &LuaChunk, + var_ref_id: &VarRefId, + flow_id: FlowId, + use_condition_narrowing: bool, +) -> InferResult { + let key = (var_ref_id.clone(), flow_id, use_condition_narrowing); + if let Some(cache_entry) = cache.flow_node_cache.get(&key) { + return match cache_entry { + CacheEntry::Cache(narrow_type) => Ok::(narrow_type.clone()), + CacheEntry::Ready => Err(InferFailReason::RecursiveInfer), + }; + } + + cache.flow_node_cache.insert(key.clone(), CacheEntry::Ready); + + let result = (|| { + let result_type; + let mut antecedent_flow_id = flow_id; + let mut pending_condition_narrows: Vec = Vec::new(); + loop { + let flow_node = tree + .get_flow_node(antecedent_flow_id) + .ok_or(InferFailReason::None)?; + + match &flow_node.kind { + FlowNodeKind::Start | FlowNodeKind::Unreachable => { + result_type = get_var_ref_type(db, cache, var_ref_id)?; + break; } - result_type = branch_result_type; - break; - } - FlowNodeKind::DeclPosition(position) => { - if *position <= var_ref_id.get_position() { - match get_var_ref_type(db, cache, var_ref_id) { - Ok(var_type) => { - result_type = var_type; - break; - } - Err(err) => { - // 尝试推断声明位置的类型, 如果发生错误则返回初始错误, 否则返回当前推断错误 - if let Some(init_type) = - try_infer_decl_initializer_type(db, cache, root, var_ref_id)? - { - result_type = init_type; + FlowNodeKind::LoopLabel | FlowNodeKind::Break | FlowNodeKind::Return => { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } + FlowNodeKind::BranchLabel | FlowNodeKind::NamedLabel(_) => { + let multi_antecedents = get_multi_antecedents(tree, flow_node)?; + + let mut branch_result_type = LuaType::Unknown; + for &flow_id in &multi_antecedents { + let branch_type = get_type_at_flow_internal( + db, + tree, + cache, + root, + var_ref_id, + flow_id, + use_condition_narrowing, + )?; + branch_result_type = + TypeOps::Union.apply(db, &branch_result_type, &branch_type); + } + result_type = branch_result_type; + break; + } + FlowNodeKind::DeclPosition(position) => { + if *position <= var_ref_id.get_position() { + match get_var_ref_type(db, cache, var_ref_id) { + Ok(var_type) => { + result_type = var_type; break; } - - return Err(err); + Err(err) => { + // 尝试推断声明位置的类型, 如果发生错误则返回初始错误, 否则返回当前推断错误 + if let Some(init_type) = + try_infer_decl_initializer_type(db, cache, root, var_ref_id)? + { + result_type = init_type; + break; + } + + return Err(err); + } } + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; } - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; } - } - FlowNodeKind::Assignment(assign_ptr) => { - let assign_stat = assign_ptr.to_node(root).ok_or(InferFailReason::None)?; - let result_or_continue = get_type_at_assign_stat( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - assign_stat, - )?; - - if let ResultTypeOrContinue::Result(assign_type) = result_or_continue { - result_type = assign_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::Assignment(assign_ptr) => { + let assign_stat = assign_ptr.to_node(root).ok_or(InferFailReason::None)?; + let result_or_continue = get_type_at_assign_stat( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + assign_stat, + )?; + + if let ResultTypeOrContinue::Result(assign_type) = result_or_continue { + result_type = assign_type; + break; + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } } - } - FlowNodeKind::ImplFunc(func_ptr) => { - let func_stat = func_ptr.to_node(root).ok_or(InferFailReason::None)?; - let Some(func_name) = func_stat.get_func_name() else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - continue; - }; - - let Some(ref_id) = get_var_expr_var_ref_id(db, cache, func_name.to_expr()) else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - continue; - }; - - if ref_id == *var_ref_id { - let Some(closure) = func_stat.get_closure() else { - return Err(InferFailReason::None); + FlowNodeKind::ImplFunc(func_ptr) => { + let func_stat = func_ptr.to_node(root).ok_or(InferFailReason::None)?; + let Some(func_name) = func_stat.get_func_name() else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + continue; }; - result_type = LuaType::Signature(LuaSignatureId::from_closure( - cache.get_file_id(), - &closure, - )); - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + let Some(ref_id) = get_var_expr_var_ref_id(db, cache, func_name.to_expr()) + else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + continue; + }; + + if ref_id == *var_ref_id { + let Some(closure) = func_stat.get_closure() else { + return Err(InferFailReason::None); + }; + + result_type = LuaType::Signature(LuaSignatureId::from_closure( + cache.get_file_id(), + &closure, + )); + break; + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } } - } - FlowNodeKind::TrueCondition(condition_ptr) => { - let condition = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; - let result_or_continue = get_type_at_condition_flow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - condition, - InferConditionFlow::TrueCondition, - )?; - - if let ResultTypeOrContinue::Result(condition_type) = result_or_continue { - result_type = condition_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::TrueCondition(condition_ptr) + | FlowNodeKind::FalseCondition(condition_ptr) => { + if !use_condition_narrowing { + antecedent_flow_id = get_single_antecedent(flow_node)?; + continue; + } + + let condition_flow = + if matches!(&flow_node.kind, FlowNodeKind::TrueCondition(_)) { + InferConditionFlow::TrueCondition + } else { + InferConditionFlow::FalseCondition + }; + let condition = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; + match get_type_at_condition_flow( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + condition, + condition_flow, + )? { + ConditionFlowAction::Pending(pending_condition_narrow) => { + pending_condition_narrows.push(pending_condition_narrow); + antecedent_flow_id = get_single_antecedent(flow_node)?; + } + ConditionFlowAction::Result(condition_type) => { + result_type = condition_type; + break; + } + ConditionFlowAction::Continue => { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } + } } - } - FlowNodeKind::FalseCondition(condition_ptr) => { - let condition = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; - let result_or_continue = get_type_at_condition_flow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - condition, - InferConditionFlow::FalseCondition, - )?; - - if let ResultTypeOrContinue::Result(condition_type) = result_or_continue { - result_type = condition_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::ForIStat(_) => { + // todo check for `for i = 1, 10 do end` + antecedent_flow_id = get_single_antecedent(flow_node)?; } - } - FlowNodeKind::ForIStat(_) => { - // todo check for `for i = 1, 10 do end` - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - } - FlowNodeKind::TagCast(cast_ast_ptr) => { - let tag_cast = cast_ast_ptr.to_node(root).ok_or(InferFailReason::None)?; - let cast_or_continue = - get_type_at_cast_flow(db, tree, cache, root, var_ref_id, flow_node, tag_cast)?; - - if let ResultTypeOrContinue::Result(cast_type) = cast_or_continue { - result_type = cast_type; - break; - } else { - antecedent_flow_id = get_single_antecedent(tree, flow_node)?; + FlowNodeKind::TagCast(cast_ast_ptr) => { + let tag_cast = cast_ast_ptr.to_node(root).ok_or(InferFailReason::None)?; + let cast_or_continue = get_type_at_cast_flow( + db, tree, cache, root, var_ref_id, flow_node, tag_cast, + )?; + + if let ResultTypeOrContinue::Result(cast_type) = cast_or_continue { + result_type = cast_type; + break; + } else { + antecedent_flow_id = get_single_antecedent(flow_node)?; + } } } } + + let result_type = if use_condition_narrowing { + pending_condition_narrows.into_iter().rev().fold( + result_type, + |result_type, pending_condition_narrow| { + pending_condition_narrow.apply(db, cache, result_type) + }, + ) + } else { + result_type + }; + + Ok(result_type) + })(); + + match &result { + Ok(result_type) => { + cache + .flow_node_cache + .insert(key, CacheEntry::Cache(result_type.clone())); + } + Err(_) => { + cache.flow_node_cache.remove(&key); + } } - cache - .flow_node_cache - .insert(key, CacheEntry::Cache(result_type.clone())); - Ok(result_type) + result } fn get_type_at_assign_stat( @@ -231,12 +358,30 @@ fn get_type_at_assign_stat( return Ok(ResultTypeOrContinue::Continue); }; - let source_type = if let Some(explicit) = explicit_var_type.clone() { - explicit - } else { - let antecedent_flow_id = get_single_antecedent(tree, flow_node)?; - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)? - }; + let (source_type, reuse_source_narrowing) = + if let Some(explicit) = explicit_var_type.clone() { + (explicit, true) + } else { + let antecedent_flow_id = get_single_antecedent(flow_node)?; + let narrowed_source_type = + get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; + if can_reuse_narrowed_assignment_source(db, &narrowed_source_type, &expr_type) { + (narrowed_source_type, true) + } else { + ( + get_type_at_flow_internal( + db, + tree, + cache, + root, + var_ref_id, + antecedent_flow_id, + false, + )?, + false, + ) + } + }; let narrowed = if source_type == LuaType::Nil { None @@ -252,7 +397,11 @@ fn get_type_at_assign_stat( narrow_down_type(db, source_type.clone(), expr_type.clone(), declared) }; - let result_type = narrowed.unwrap_or(explicit_var_type.unwrap_or(expr_type)); + let result_type = if reuse_source_narrowing || preserves_assignment_expr_type(&expr_type) { + narrowed.unwrap_or_else(|| explicit_var_type.unwrap_or_else(|| expr_type.clone())) + } else { + expr_type + }; return Ok(ResultTypeOrContinue::Result(result_type)); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs index b78b19c0c..afe3cc7c0 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs @@ -71,22 +71,11 @@ fn get_var_ref_type(db: &DbIndex, cache: &mut LuaInferCache, var_ref_id: &VarRef } } -fn get_single_antecedent(tree: &FlowTree, flow: &FlowNode) -> Result { +fn get_single_antecedent(flow: &FlowNode) -> Result { match &flow.antecedent { Some(antecedent) => match antecedent { FlowAntecedent::Single(id) => Ok(*id), - FlowAntecedent::Multiple(multi_id) => { - let multi_flow = tree - .get_multi_antecedents(*multi_id) - .ok_or(InferFailReason::None)?; - if !multi_flow.is_empty() { - // If there are multiple antecedents, we need to handle them separately - // For now, we just return the first one - Ok(multi_flow[0]) - } else { - Err(InferFailReason::None) - } - } + FlowAntecedent::Multiple(_) => Err(InferFailReason::None), }, None => Err(InferFailReason::None), } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs index 44e85e141..f6d62ea97 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/narrow_type/mod.rs @@ -207,6 +207,9 @@ pub fn narrow_down_type( .into_iter() .filter_map(|t| narrow_down_type(db, real_source_ref.clone(), t, declared.clone())) .collect::>(); + if source_types.is_empty() { + return None; + } let mut result_type = LuaType::Unknown; for source_type in source_types { result_type = TypeOps::Union.apply(db, &result_type, &source_type); From e66b98ce9118b4dc4a6e5710414cc71427f6b88c Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Thu, 26 Mar 2026 14:55:21 +0000 Subject: [PATCH 3/5] fix(flow): skip redundant assignment flow walk --- .../src/compilation/test/flow.rs | 31 +++++++++++++++ .../test/return_overload_flow_test.rs | 38 +++++++++++++++++++ .../semantic/infer/narrow/get_type_at_flow.rs | 25 +++++++++--- 3 files changed, 89 insertions(+), 5 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index a9003de65..5afd80f46 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -3,6 +3,7 @@ mod test { use crate::{DiagnosticCode, LuaType, VirtualWorkspace}; const STACKED_TYPE_GUARDS: usize = 180; + const LARGE_LINEAR_ASSIGNMENT_STEPS: usize = 2048; #[test] fn test_closure_return() { @@ -332,6 +333,36 @@ mod test { assert_eq!(ws.expr_ty("after_guard"), ws.ty("Player")); } + #[test] + fn test_large_linear_assignment_file_builds_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let mut block = String::from( + r#" + local value ---@type integer + value = 1 + + "#, + ); + + for i in 0..LARGE_LINEAR_ASSIGNMENT_STEPS { + block.push_str(&format!("local alias_{i} = value\n")); + block.push_str(&format!("value = alias_{i}\n")); + } + block.push_str("after_assign = value\n"); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for large linear assignment stress case" + ); + let after_assign = ws.expr_ty("after_assign"); + assert_eq!(ws.humanize_type(after_assign), "integer"); + } + #[test] fn test_pending_replay_order_uses_type_guard_before_self_return_cast_lookup() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs index e1fc57836..0cf6b74a8 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs @@ -39,6 +39,44 @@ mod test { assert_eq!(ws.expr_ty("a"), ws.ty("integer")); } + #[test] + fn test_return_overload_narrow_tracks_multiple_targets_from_same_call() { + let mut ws = VirtualWorkspace::new(); + + ws.def( + r#" + ---@param ok boolean + ---@return boolean + ---@return integer|string + ---@return string|boolean + ---@return_overload true, integer, string + ---@return_overload false, string, boolean + local function pick(ok) + if ok then + return true, 1, "value" + end + return false, "error", false + end + + local cond ---@type boolean + local ok, result, extra = pick(cond) + + if ok then + a = result + b = extra + else + c = result + d = extra + end + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer")); + assert_eq!(ws.expr_ty("b"), ws.ty("string")); + assert_eq!(ws.expr_ty("c"), ws.ty("string")); + assert_eq!(ws.expr_ty("d"), ws.ty("boolean")); + } + #[test] fn test_return_overload_reassign_clears_multi_return_mapping() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 4ddf3da9c..e1ecbeb30 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -363,11 +363,7 @@ fn get_type_at_assign_stat( (explicit, true) } else { let antecedent_flow_id = get_single_antecedent(flow_node)?; - let narrowed_source_type = - get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; - if can_reuse_narrowed_assignment_source(db, &narrowed_source_type, &expr_type) { - (narrowed_source_type, true) - } else { + if !preserves_assignment_expr_type(&expr_type) { ( get_type_at_flow_internal( db, @@ -380,6 +376,25 @@ fn get_type_at_assign_stat( )?, false, ) + } else { + let narrowed_source_type = + get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?; + if can_reuse_narrowed_assignment_source(db, &narrowed_source_type, &expr_type) { + (narrowed_source_type, true) + } else { + ( + get_type_at_flow_internal( + db, + tree, + cache, + root, + var_ref_id, + antecedent_flow_id, + false, + )?, + false, + ) + } } }; From f8171c44d93cc37fa184fcb8540ff01fa497b199 Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Thu, 26 Mar 2026 17:02:53 +0000 Subject: [PATCH 4/5] fix(flow): cache repeated condition actions Add a dedicated condition-flow cache keyed by variable reference, antecedent flow, and branch polarity so stacked guards can reuse earlier results instead of re-entering the same narrowing path recursively. Only build correlated return-overload narrows when both the discriminant and target participate in multi-return tracking. That keeps ordinary same-variable truthiness and equality guards on the direct narrowing path while preserving correlated narrowing for real return-overload cases. Add a regression test for repeated `if not value then return end` guards so the semantic model still builds and the narrowed type remains `string`. --- .../src/compilation/test/flow.rs | 25 +++++++ .../src/semantic/cache/mod.rs | 10 ++- .../src/semantic/infer/mod.rs | 1 + .../narrow/condition_flow/binary_flow.rs | 30 +++++++-- .../narrow/condition_flow/correlated_flow.rs | 2 +- .../infer/narrow/condition_flow/mod.rs | 67 ++++++++++--------- .../semantic/infer/narrow/get_type_at_flow.rs | 55 ++++++++++++--- .../src/semantic/infer/narrow/mod.rs | 1 + 8 files changed, 141 insertions(+), 50 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/flow.rs b/crates/emmylua_code_analysis/src/compilation/test/flow.rs index 5afd80f46..7ae876e62 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/flow.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/flow.rs @@ -131,6 +131,31 @@ mod test { assert!(ws.check_code_for(DiagnosticCode::AssignTypeMismatch, &block)); } + #[test] + fn test_stacked_same_var_truthiness_guards_build_semantic_model() { + let mut ws = VirtualWorkspace::new(); + let repeated_guards = "if not value then return end\n".repeat(STACKED_TYPE_GUARDS); + let block = format!( + r#" + local value ---@type string? + + {repeated_guards} + after_guard = value + "#, + ); + + let file_id = ws.def(&block); + + assert!( + ws.analysis + .compilation + .get_semantic_model(file_id) + .is_some(), + "expected semantic model for stacked same-variable truthiness repro" + ); + assert_eq!(ws.expr_ty("after_guard"), ws.ty("string")); + } + #[test] fn test_stacked_same_var_call_type_guards_build_semantic_model() { let mut ws = VirtualWorkspace::new(); diff --git a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs index 8958d49a1..f7ee7d8d0 100644 --- a/crates/emmylua_code_analysis/src/semantic/cache/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/cache/mod.rs @@ -7,7 +7,11 @@ use std::{ sync::Arc, }; -use crate::{FileId, FlowId, LuaFunctionType, db_index::LuaType, semantic::infer::VarRefId}; +use crate::{ + FileId, FlowId, LuaFunctionType, + db_index::LuaType, + semantic::infer::{ConditionFlowAction, VarRefId}, +}; #[derive(Debug)] pub enum CacheEntry { @@ -23,6 +27,8 @@ pub struct LuaInferCache { pub call_cache: HashMap<(LuaSyntaxId, Option, LuaType), CacheEntry>>, pub(crate) flow_node_cache: HashMap<(VarRefId, FlowId, bool), CacheEntry>, + pub(in crate::semantic) condition_flow_cache: + HashMap<(VarRefId, FlowId, bool), CacheEntry>, pub index_ref_origin_type_cache: HashMap>, pub expr_var_ref_id_cache: HashMap, pub narrow_by_literal_stop_position_cache: HashSet, @@ -36,6 +42,7 @@ impl LuaInferCache { expr_cache: HashMap::new(), call_cache: HashMap::new(), flow_node_cache: HashMap::new(), + condition_flow_cache: HashMap::new(), index_ref_origin_type_cache: HashMap::new(), expr_var_ref_id_cache: HashMap::new(), narrow_by_literal_stop_position_cache: HashSet::new(), @@ -58,6 +65,7 @@ impl LuaInferCache { self.expr_cache.clear(); self.call_cache.clear(); self.flow_node_cache.clear(); + self.condition_flow_cache.clear(); self.index_ref_origin_type_cache.clear(); self.expr_var_ref_id_cache.clear(); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs index e30b8477c..2036d2c9c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/mod.rs @@ -26,6 +26,7 @@ pub use infer_name::{find_self_decl_or_member_id, infer_param}; use infer_table::infer_table_expr; pub use infer_table::{infer_table_field_value_should_be, infer_table_should_be}; use infer_unary::infer_unary_expr; +pub(in crate::semantic) use narrow::ConditionFlowAction; pub use narrow::VarRefId; use rowan::TextRange; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index f6bf3744d..19da18aa3 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -297,6 +297,18 @@ fn maybe_type_guard_binary_action( ))); } + let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { + return Ok(None); + }; + let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { + return Ok(None); + }; + if !tree.has_decl_multi_return_refs(&discriminant_decl_id) + || !tree.has_decl_multi_return_refs(&target_decl_id) + { + return Ok(None); + } + let antecedent_flow_id = get_single_antecedent(flow_node)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, &maybe_var_ref_id, antecedent_flow_id)?; @@ -307,10 +319,6 @@ fn maybe_type_guard_binary_action( InferConditionFlow::FalseCondition => TypeOps::Remove.apply(db, &antecedent_type, &narrow), }; - let Some(discriminant_decl_id) = maybe_var_ref_id.get_decl_id_ref() else { - return Ok(None); - }; - Ok(prepare_var_from_return_overload_condition( db, tree, @@ -403,13 +411,20 @@ fn get_var_eq_condition_action( return Ok(ConditionFlowAction::Continue); }; - let antecedent_flow_id = get_single_antecedent(flow_node)?; - let right_expr_type = infer_expr(db, cache, right_expr)?; - if maybe_ref_id != *var_ref_id { let Some(discriminant_decl_id) = maybe_ref_id.get_decl_id_ref() else { return Ok(ConditionFlowAction::Continue); }; + let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else { + return Ok(ConditionFlowAction::Continue); + }; + if !tree.has_decl_multi_return_refs(&discriminant_decl_id) + || !tree.has_decl_multi_return_refs(&target_decl_id) + { + return Ok(ConditionFlowAction::Continue); + } + let antecedent_flow_id = get_single_antecedent(flow_node)?; + let right_expr_type = infer_expr(db, cache, right_expr)?; let antecedent_type = get_type_at_flow(db, tree, cache, root, &maybe_ref_id, antecedent_flow_id)?; let narrowed_discriminant_type = @@ -430,6 +445,7 @@ fn get_var_eq_condition_action( .unwrap_or(ConditionFlowAction::Continue)); } + let right_expr_type = infer_expr(db, cache, right_expr)?; let result_type = match condition_flow { InferConditionFlow::TrueCondition => { // self 是特殊的, 我们删除其 nil 类型 diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index ad5e030c7..6120c8a18 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -12,7 +12,7 @@ use crate::{ }; #[derive(Debug, Clone)] -pub(in crate::semantic::infer::narrow) struct CorrelatedConditionNarrowing { +pub(in crate::semantic) struct CorrelatedConditionNarrowing { search_root_correlated_types: Vec, } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index d65521404..8c33698ba 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -55,8 +55,8 @@ impl InferConditionFlow { } } -#[derive(Debug)] -pub(in crate::semantic::infer::narrow) enum ConditionFlowAction { +#[derive(Debug, Clone)] +pub(in crate::semantic) enum ConditionFlowAction { Continue, Result(LuaType), Pending(PendingConditionNarrow), @@ -72,7 +72,7 @@ impl From for ConditionFlowAction { } #[derive(Debug, Clone)] -pub(in crate::semantic::infer::narrow) enum PendingConditionNarrow { +pub(in crate::semantic) enum PendingConditionNarrow { Truthiness(InferConditionFlow), FieldTruthy { index: LuaIndexMemberExpr, @@ -399,34 +399,41 @@ fn get_type_at_name_ref( return Ok(ConditionFlowAction::Continue); }; - let antecedent_flow_id = get_single_antecedent(flow_node)?; - let antecedent_discriminant_type = get_type_at_flow( - db, - tree, - cache, - root, - &VarRefId::VarRef(decl_id), - antecedent_flow_id, - )?; - let narrowed_discriminant_type = match condition_flow { - InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_discriminant_type), - InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_discriminant_type), - }; + if let Some(target_decl_id) = var_ref_id.get_decl_id_ref() + && tree.has_decl_multi_return_refs(&decl_id) + && tree.has_decl_multi_return_refs(&target_decl_id) + { + let antecedent_flow_id = get_single_antecedent(flow_node)?; + let antecedent_discriminant_type = get_type_at_flow( + db, + tree, + cache, + root, + &VarRefId::VarRef(decl_id), + antecedent_flow_id, + )?; + let narrowed_discriminant_type = match condition_flow { + InferConditionFlow::FalseCondition => { + narrow_false_or_nil(db, antecedent_discriminant_type) + } + InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_discriminant_type), + }; - if let Some(correlated_narrowing) = prepare_var_from_return_overload_condition( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - decl_id, - name_expr.get_position(), - &narrowed_discriminant_type, - )? { - return Ok(ConditionFlowAction::Pending( - PendingConditionNarrow::Correlated(correlated_narrowing), - )); + if let Some(correlated_narrowing) = prepare_var_from_return_overload_condition( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + decl_id, + name_expr.get_position(), + &narrowed_discriminant_type, + )? { + return Ok(ConditionFlowAction::Pending( + PendingConditionNarrow::Correlated(correlated_narrowing), + )); + } } let Some(expr_ptr) = tree.get_decl_ref_expr(&decl_id) else { diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index e1ecbeb30..5d179baae 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -246,17 +246,50 @@ fn get_type_at_flow_internal( } else { InferConditionFlow::FalseCondition }; - let condition = condition_ptr.to_node(root).ok_or(InferFailReason::None)?; - match get_type_at_condition_flow( - db, - tree, - cache, - root, - var_ref_id, - flow_node, - condition, - condition_flow, - )? { + let condition_key = ( + var_ref_id.clone(), + antecedent_flow_id, + matches!(condition_flow, InferConditionFlow::TrueCondition), + ); + let condition_action = { + if let Some(cache_entry) = cache.condition_flow_cache.get(&condition_key) { + match cache_entry { + CacheEntry::Cache(action) => { + Ok::(action.clone()) + } + CacheEntry::Ready => Err(InferFailReason::RecursiveInfer), + } + } else { + let condition = + condition_ptr.to_node(root).ok_or(InferFailReason::None)?; + cache + .condition_flow_cache + .insert(condition_key.clone(), CacheEntry::Ready); + let result = get_type_at_condition_flow( + db, + tree, + cache, + root, + var_ref_id, + flow_node, + condition, + condition_flow, + ); + match &result { + Ok(action) => { + cache + .condition_flow_cache + .insert(condition_key, CacheEntry::Cache(action.clone())); + } + Err(_) => { + cache.condition_flow_cache.remove(&condition_key); + } + } + result + } + }?; + + match condition_action { ConditionFlowAction::Pending(pending_condition_narrow) => { pending_condition_narrows.push(pending_condition_narrow); antecedent_flow_id = get_single_antecedent(flow_node)?; diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs index afe3cc7c0..4bd3de91d 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/mod.rs @@ -12,6 +12,7 @@ use crate::{ infer_name::{find_decl_member_type, infer_global_type}, }, }; +pub(in crate::semantic) use condition_flow::ConditionFlowAction; use emmylua_parser::{LuaAstNode, LuaChunk, LuaExpr}; pub use get_type_at_cast_flow::get_type_at_call_expr_inline_cast; pub use narrow_type::{narrow_down_type, narrow_false_or_nil, remove_false_or_nil}; From 6c4883d65c8f71f8736eeefdb9976c61afac6eeb Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Fri, 27 Mar 2026 11:55:39 +0000 Subject: [PATCH 5/5] refactor(flow): share deferred narrowing action data Store deferred condition-flow narrows behind Rc so cached ConditionFlowAction values can be cloned cheaply instead of recursively copying potentially deep PendingConditionNarrow trees. Also make pending and correlated replay borrow their payloads during flow reconstruction, which lets get_type_at_flow reuse the same cached action data across cache hits and replay passes while keeping the indirection introduced for enum size control. --- .../narrow/condition_flow/binary_flow.rs | 12 ++-- .../infer/narrow/condition_flow/call_flow.rs | 6 +- .../narrow/condition_flow/correlated_flow.rs | 72 ++++++++++--------- .../infer/narrow/condition_flow/index_flow.rs | 4 +- .../infer/narrow/condition_flow/mod.rs | 45 +++++++----- .../semantic/infer/narrow/get_type_at_flow.rs | 4 +- 6 files changed, 79 insertions(+), 64 deletions(-) diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs index 19da18aa3..a570878af 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/binary_flow.rs @@ -289,7 +289,7 @@ fn maybe_type_guard_binary_action( }; if maybe_var_ref_id == *var_ref_id { - return Ok(Some(ConditionFlowAction::Pending( + return Ok(Some(ConditionFlowAction::pending( PendingConditionNarrow::TypeGuard { narrow, condition_flow, @@ -331,7 +331,7 @@ fn maybe_type_guard_binary_action( &narrowed_discriminant_type, )? .map(PendingConditionNarrow::Correlated) - .map(ConditionFlowAction::Pending)) + .map(ConditionFlowAction::pending)) } /// Maps the string result of Lua's builtin `type()` call to the corresponding `LuaType`. @@ -441,7 +441,7 @@ fn get_var_eq_condition_action( &narrowed_discriminant_type, )? .map(PendingConditionNarrow::Correlated) - .map(ConditionFlowAction::Pending) + .map(ConditionFlowAction::pending) .unwrap_or(ConditionFlowAction::Continue)); } @@ -452,14 +452,14 @@ fn get_var_eq_condition_action( if var_ref_id.is_self_ref() && !right_expr_type.is_nil() { TypeOps::Remove.apply(db, &right_expr_type, &LuaType::Nil) } else { - return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + return Ok(ConditionFlowAction::pending(PendingConditionNarrow::Eq { right_expr_type, condition_flow, })); } } InferConditionFlow::FalseCondition => { - return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + return Ok(ConditionFlowAction::pending(PendingConditionNarrow::Eq { right_expr_type, condition_flow, })); @@ -499,7 +499,7 @@ fn get_var_eq_condition_action( let right_expr_type = infer_expr(db, cache, right_expr)?; if condition_flow.is_false() { - return Ok(ConditionFlowAction::Pending(PendingConditionNarrow::Eq { + return Ok(ConditionFlowAction::pending(PendingConditionNarrow::Eq { right_expr_type, condition_flow, })); diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs index b9e1d9793..0e6cdb03c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/call_flow.rs @@ -47,7 +47,7 @@ pub fn get_type_at_call_expr( if needs_antecedent_same_var_colon_lookup(&member_type) { // Keep the dedicated pending case here: replay needs the antecedent type // for member lookup itself, not just for applying a cast after lookup. - return Ok(ConditionFlowAction::Pending( + return Ok(ConditionFlowAction::pending( PendingConditionNarrow::SameVarColonCall { index: LuaIndexMemberExpr::IndexExpr(index_expr.clone()), condition_flow, @@ -216,7 +216,7 @@ fn get_type_at_call_expr_by_type_guard( return Ok(ConditionFlowAction::Continue); } - Ok(ConditionFlowAction::Pending( + Ok(ConditionFlowAction::pending( PendingConditionNarrow::TypeGuard { narrow: guard_type, condition_flow, @@ -308,7 +308,7 @@ fn get_signature_cast_pending( signature_id: LuaSignatureId, condition_flow: InferConditionFlow, ) -> ConditionFlowAction { - ConditionFlowAction::Pending(PendingConditionNarrow::SignatureCast { + ConditionFlowAction::pending(PendingConditionNarrow::SignatureCast { signature_id, condition_flow, }) diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs index 6120c8a18..5b9b88a4c 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs @@ -4,7 +4,7 @@ use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk}; use crate::{ DbIndex, FlowId, FlowNode, FlowTree, InferFailReason, LuaDeclId, LuaFunctionType, - LuaInferCache, LuaType, TypeOps, infer_expr, instantiate_func_generic, + LuaInferCache, LuaSignature, LuaType, TypeOps, infer_expr, instantiate_func_generic, semantic::infer::{ VarRefId, narrow::{get_single_antecedent, get_type_at_flow::get_type_at_flow}, @@ -23,25 +23,33 @@ struct SearchRootCorrelatedTypes { deferred_known_call_target_types: Option>, } +#[derive(Debug)] +struct CollectedCorrelatedTypes { + matching_target_types: Vec, + correlated_candidate_types: Vec, + unmatched_target_types: Vec, + has_unmatched_discriminant_origin: bool, + has_opaque_target_origin: bool, +} + impl CorrelatedConditionNarrowing { pub(in crate::semantic::infer::narrow) fn apply( - self, + &self, db: &DbIndex, antecedent_type: LuaType, ) -> LuaType { let mut root_target_types = Vec::new(); let mut found_matching_root = false; - for root_types in self.search_root_correlated_types { - let SearchRootCorrelatedTypes { - matching_target_types, - mut uncorrelated_target_types, - deferred_known_call_target_types, - } = root_types; + for root_types in &self.search_root_correlated_types { + let matching_target_types = &root_types.matching_target_types; + let mut uncorrelated_target_types = root_types.uncorrelated_target_types.clone(); + let deferred_known_call_target_types = + root_types.deferred_known_call_target_types.as_deref(); let root_matching_target_type = if matching_target_types.is_empty() { None } else { - let matching_target_type = LuaType::from_vec(matching_target_types); + let matching_target_type = LuaType::from_vec(matching_target_types.clone()); let narrowed_correlated_type = TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type); if narrowed_correlated_type.is_never() { @@ -185,13 +193,13 @@ fn collect_correlated_types_from_search_root( condition_position, search_root_flow_id, ); - let ( - root_matching_target_types, - root_correlated_candidate_types, - root_unmatched_target_types, + let CollectedCorrelatedTypes { + matching_target_types: root_matching_target_types, + correlated_candidate_types: root_correlated_candidate_types, + unmatched_target_types: root_unmatched_target_types, has_unmatched_discriminant_origin, has_opaque_target_origin, - ) = collect_matching_correlated_types( + } = collect_matching_correlated_types( db, cache, root, @@ -277,7 +285,7 @@ fn collect_matching_correlated_types( discriminant_refs: &[crate::DeclMultiReturnRef], target_refs: &[crate::DeclMultiReturnRef], narrowed_discriminant_type: &LuaType, -) -> Result<(Vec, Vec, Vec, bool, bool), InferFailReason> { +) -> Result { let mut matching_target_types = Vec::new(); let mut correlated_candidate_types = Vec::new(); let mut unmatched_target_types = Vec::new(); @@ -304,18 +312,16 @@ fn collect_matching_correlated_types( correlated_discriminant_call_expr_ids.insert(discriminant_call_expr_id); correlated_target_call_expr_ids.insert(target_ref.call_expr.get_syntax_id()); correlated_candidate_types.extend(overload_rows.iter().map(|overload| { - crate::LuaSignature::get_overload_row_slot(overload, target_ref.return_index) + LuaSignature::get_overload_row_slot(overload, target_ref.return_index) })); matching_target_types.extend(overload_rows.iter().filter_map(|overload| { - let discriminant_type = crate::LuaSignature::get_overload_row_slot( - overload, - discriminant_ref.return_index, - ); + let discriminant_type = + LuaSignature::get_overload_row_slot(overload, discriminant_ref.return_index); if !TypeOps::Intersect .apply(db, &discriminant_type, narrowed_discriminant_type) .is_never() { - return Some(crate::LuaSignature::get_overload_row_slot( + return Some(LuaSignature::get_overload_row_slot( overload, target_ref.return_index, )); @@ -340,22 +346,22 @@ fn collect_matching_correlated_types( }; let return_rows = instantiate_return_rows(db, cache, call_expr, signature); unmatched_target_types.extend( - return_rows.iter().map(|row| { - crate::LuaSignature::get_overload_row_slot(row, target_ref.return_index) - }), + return_rows + .iter() + .map(|row| LuaSignature::get_overload_row_slot(row, target_ref.return_index)), ); } let has_unmatched_discriminant_origin = discriminant_refs.iter().any(|discriminant_ref| { !correlated_discriminant_call_expr_ids.contains(&discriminant_ref.call_expr.get_syntax_id()) }); - Ok(( + Ok(CollectedCorrelatedTypes { matching_target_types, correlated_candidate_types, unmatched_target_types, has_unmatched_discriminant_origin, has_opaque_target_origin, - )) + }) } fn infer_signature_for_call_ptr<'a>( @@ -363,7 +369,7 @@ fn infer_signature_for_call_ptr<'a>( cache: &mut LuaInferCache, root: &LuaChunk, call_expr_ptr: &LuaAstPtr, -) -> Result, InferFailReason> { +) -> Result, InferFailReason> { let Some(call_expr) = call_expr_ptr.to_node(root) else { return Ok(None); }; @@ -385,7 +391,7 @@ fn instantiate_return_rows( db: &DbIndex, cache: &mut LuaInferCache, call_expr: LuaCallExpr, - signature: &crate::LuaSignature, + signature: &LuaSignature, ) -> Vec> { if signature.return_overloads.is_empty() { let return_type = signature.get_return_type(); @@ -404,15 +410,13 @@ fn instantiate_return_rows( } else { return_type }; - return vec![crate::LuaSignature::return_type_to_row( - instantiated_return_type, - )]; + return vec![LuaSignature::return_type_to_row(instantiated_return_type)]; } let mut rows = Vec::with_capacity(signature.return_overloads.len()); for overload in &signature.return_overloads { let type_refs = &overload.type_refs; - let overload_return_type = crate::LuaSignature::row_to_return_type(type_refs.to_vec()); + let overload_return_type = LuaSignature::row_to_return_type(type_refs.to_vec()); let instantiated_return_type = if overload_return_type.contain_tpl() { let overload_func = LuaFunctionType::new( signature.async_state, @@ -429,9 +433,7 @@ fn instantiate_return_rows( overload_return_type }; - rows.push(crate::LuaSignature::return_type_to_row( - instantiated_return_type, - )); + rows.push(LuaSignature::return_type_to_row(instantiated_return_type)); } rows diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs index 9c080c3f0..a2d7eb54a 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/index_flow.rs @@ -25,7 +25,7 @@ pub fn get_type_at_index_expr( }; if name_var_ref_id == *var_ref_id { - return Ok(ConditionFlowAction::Pending( + return Ok(ConditionFlowAction::pending( PendingConditionNarrow::Truthiness(condition_flow), )); } @@ -43,7 +43,7 @@ pub fn get_type_at_index_expr( return Ok(ConditionFlowAction::Continue); } - Ok(ConditionFlowAction::Pending( + Ok(ConditionFlowAction::pending( PendingConditionNarrow::FieldTruthy { index: LuaIndexMemberExpr::IndexExpr(index_expr), condition_flow, diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs index 8c33698ba..4e51cb355 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/mod.rs @@ -3,6 +3,8 @@ mod call_flow; pub(in crate::semantic::infer::narrow) mod correlated_flow; mod index_flow; +use std::rc::Rc; + use self::{ binary_flow::get_type_at_binary_expr, correlated_flow::{CorrelatedConditionNarrowing, prepare_var_from_return_overload_condition}, @@ -59,7 +61,7 @@ impl InferConditionFlow { pub(in crate::semantic) enum ConditionFlowAction { Continue, Result(LuaType), - Pending(PendingConditionNarrow), + Pending(Rc), } impl From for ConditionFlowAction { @@ -71,6 +73,14 @@ impl From for ConditionFlowAction { } } +impl ConditionFlowAction { + pub(in crate::semantic::infer::narrow) fn pending( + pending_condition_narrow: PendingConditionNarrow, + ) -> Self { + ConditionFlowAction::Pending(Rc::new(pending_condition_narrow)) + } +} + #[derive(Debug, Clone)] pub(in crate::semantic) enum PendingConditionNarrow { Truthiness(InferConditionFlow), @@ -99,13 +109,13 @@ pub(in crate::semantic) enum PendingConditionNarrow { impl PendingConditionNarrow { pub(in crate::semantic::infer::narrow) fn apply( - self, + &self, db: &DbIndex, cache: &mut LuaInferCache, antecedent_type: LuaType, ) -> LuaType { match self { - PendingConditionNarrow::Truthiness(condition_flow) => match condition_flow { + PendingConditionNarrow::Truthiness(condition_flow) => match condition_flow.clone() { InferConditionFlow::FalseCondition => narrow_false_or_nil(db, antecedent_type), InferConditionFlow::TrueCondition => remove_false_or_nil(antecedent_type), }, @@ -139,7 +149,7 @@ impl PendingConditionNarrow { if result.is_empty() { antecedent_type } else { - match condition_flow { + match condition_flow.clone() { InferConditionFlow::TrueCondition => LuaType::from_vec(result), InferConditionFlow::FalseCondition => { let target = LuaType::from_vec(result); @@ -156,7 +166,7 @@ impl PendingConditionNarrow { db, cache, &antecedent_type, - index, + index.clone(), &InferGuard::new(), ) else { return antecedent_type; @@ -178,9 +188,9 @@ impl PendingConditionNarrow { apply_signature_cast( db, antecedent_type, - signature_id, + signature_id.clone(), signature_cast, - condition_flow, + condition_flow.clone(), ) } PendingConditionNarrow::SignatureCast { @@ -195,18 +205,18 @@ impl PendingConditionNarrow { apply_signature_cast( db, antecedent_type, - signature_id, + signature_id.clone(), signature_cast, - condition_flow, + condition_flow.clone(), ) } PendingConditionNarrow::Eq { right_expr_type, condition_flow, - } => match condition_flow { + } => match condition_flow.clone() { InferConditionFlow::TrueCondition => { let maybe_type = - crate::TypeOps::Intersect.apply(db, &antecedent_type, &right_expr_type); + crate::TypeOps::Intersect.apply(db, &antecedent_type, right_expr_type); if maybe_type.is_never() { antecedent_type } else { @@ -214,18 +224,19 @@ impl PendingConditionNarrow { } } InferConditionFlow::FalseCondition => { - crate::TypeOps::Remove.apply(db, &antecedent_type, &right_expr_type) + crate::TypeOps::Remove.apply(db, &antecedent_type, right_expr_type) } }, PendingConditionNarrow::TypeGuard { narrow, condition_flow, - } => match condition_flow { + } => match condition_flow.clone() { InferConditionFlow::TrueCondition => { - narrow_down_type(db, antecedent_type, narrow.clone(), None).unwrap_or(narrow) + narrow_down_type(db, antecedent_type, narrow.clone(), None) + .unwrap_or_else(|| narrow.clone()) } InferConditionFlow::FalseCondition => { - crate::TypeOps::Remove.apply(db, &antecedent_type, &narrow) + crate::TypeOps::Remove.apply(db, &antecedent_type, narrow) } }, PendingConditionNarrow::Correlated(correlated_narrowing) => { @@ -376,7 +387,7 @@ fn get_type_at_name_expr( ); } - Ok(ConditionFlowAction::Pending( + Ok(ConditionFlowAction::pending( PendingConditionNarrow::Truthiness(condition_flow), )) } @@ -430,7 +441,7 @@ fn get_type_at_name_ref( name_expr.get_position(), &narrowed_discriminant_type, )? { - return Ok(ConditionFlowAction::Pending( + return Ok(ConditionFlowAction::pending( PendingConditionNarrow::Correlated(correlated_narrowing), )); } diff --git a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs index 5d179baae..416d942c9 100644 --- a/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs +++ b/crates/emmylua_code_analysis/src/semantic/infer/narrow/get_type_at_flow.rs @@ -1,3 +1,5 @@ +use std::rc::Rc; + use emmylua_parser::{LuaAssignStat, LuaAstNode, LuaChunk, LuaExpr, LuaVarExpr}; use crate::{ @@ -130,7 +132,7 @@ fn get_type_at_flow_internal( let result = (|| { let result_type; let mut antecedent_flow_id = flow_id; - let mut pending_condition_narrows: Vec = Vec::new(); + let mut pending_condition_narrows: Vec> = Vec::new(); loop { let flow_node = tree .get_flow_node(antecedent_flow_id)