Skip to content

Commit 81ac6be

Browse files
committed
fix(flow): guard correlated return-overload reentry
Correlated @return_overload narrowing could recursively re-enter itself for the same target return slot while the first inference was still in flight. The failure mode is: - get_type_at_flow(target) reaches a condition node - condition narrowing sees that the discriminant is a sibling return value - narrow_var_from_return_overload_condition tries to recover the target's antecedent and uncorrelated fallback types via get_type_at_flow(target, ...) - that predecessor walk can hit another condition on the same discriminant before the first correlated query has unwound - the same target VarRefId is correlated again, repeating the cycle until indexing appears hung or the stack overflows This is not just extra work. The algorithm was missing an in-flight recursion break for correlated condition analysis, so the same unresolved question could be asked again before any cache entry existed. Add a small correlated_condition_guard keyed by VarRefId in LuaInferCache and use it only around narrow_var_from_return_overload_condition. Re-entry now returns Continue, letting the existing flow walk fall back to non-correlated inference instead of recursively opening the same correlated query. The regression test is a focused Lua block with a single pick() call and many repeated if-not-ok guards. Without this fix that block overflows the test stack. With the guard it builds a semantic model and the rest of the return-overload suite still passes.
1 parent 3e0bb16 commit 81ac6be

3 files changed

Lines changed: 125 additions & 61 deletions

File tree

crates/emmylua_code_analysis/src/compilation/test/return_overload_flow_test.rs

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
mod test {
33
use crate::{DiagnosticCode, VirtualWorkspace};
44

5+
const STACKED_CORRELATED_GUARDS: usize = 180;
6+
57
#[test]
68
fn test_return_overload_narrow_after_not() {
79
let mut ws = VirtualWorkspace::new();
@@ -541,4 +543,46 @@ mod test {
541543
assert!(after_guard.contains("integer"));
542544
assert!(!after_guard.contains("string"));
543545
}
546+
547+
#[test]
548+
fn test_return_overload_stacked_same_discriminant_guards_build_semantic_model() {
549+
let mut ws = VirtualWorkspace::new();
550+
let repeated_guards =
551+
"if not ok then error(result) end\n".repeat(STACKED_CORRELATED_GUARDS);
552+
let block = format!(
553+
r#"
554+
---@generic T, E
555+
---@param ok boolean
556+
---@param success T
557+
---@param failure E
558+
---@return boolean
559+
---@return T|E
560+
---@return_overload true, T
561+
---@return_overload false, E
562+
local function pick(ok, success, failure)
563+
if ok then
564+
return true, success
565+
end
566+
return false, failure
567+
end
568+
569+
local cond ---@type boolean
570+
local ok, result = pick(cond, 1, "error")
571+
572+
{repeated_guards}
573+
narrowed = result
574+
"#,
575+
);
576+
577+
let file_id = ws.def(&block);
578+
579+
assert!(
580+
ws.analysis
581+
.compilation
582+
.get_semantic_model(file_id)
583+
.is_some(),
584+
"expected semantic model for stacked correlated-guard repro"
585+
);
586+
assert_eq!(ws.expr_ty("narrowed"), ws.ty("integer"));
587+
}
544588
}

crates/emmylua_code_analysis/src/semantic/cache/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ pub struct LuaInferCache {
2323
pub call_cache:
2424
HashMap<(LuaSyntaxId, Option<usize>, LuaType), CacheEntry<Arc<LuaFunctionType>>>,
2525
pub flow_node_cache: HashMap<(VarRefId, FlowId), CacheEntry<LuaType>>,
26+
pub correlated_condition_guard: HashSet<VarRefId>,
2627
pub index_ref_origin_type_cache: HashMap<VarRefId, CacheEntry<LuaType>>,
2728
pub expr_var_ref_id_cache: HashMap<LuaSyntaxId, VarRefId>,
2829
pub narrow_by_literal_stop_position_cache: HashSet<LuaSyntaxId>,
@@ -36,6 +37,7 @@ impl LuaInferCache {
3637
expr_cache: HashMap::new(),
3738
call_cache: HashMap::new(),
3839
flow_node_cache: HashMap::new(),
40+
correlated_condition_guard: HashSet::new(),
3941
index_ref_origin_type_cache: HashMap::new(),
4042
expr_var_ref_id_cache: HashMap::new(),
4143
narrow_by_literal_stop_position_cache: HashSet::new(),
@@ -58,6 +60,7 @@ impl LuaInferCache {
5860
self.expr_cache.clear();
5961
self.call_cache.clear();
6062
self.flow_node_cache.clear();
63+
self.correlated_condition_guard.clear();
6164
self.index_ref_origin_type_cache.clear();
6265
self.expr_var_ref_id_cache.clear();
6366
}

crates/emmylua_code_analysis/src/semantic/infer/narrow/condition_flow/correlated_flow.rs

Lines changed: 78 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use std::collections::HashSet;
1+
use std::{
2+
collections::HashSet,
3+
panic::{AssertUnwindSafe, catch_unwind, resume_unwind},
4+
};
25

36
use emmylua_parser::{LuaAstPtr, LuaCallExpr, LuaChunk};
47

@@ -23,76 +26,90 @@ pub(in crate::semantic::infer::narrow::condition_flow) fn narrow_var_from_return
2326
condition_position: rowan::TextSize,
2427
narrowed_discriminant_type: &LuaType,
2528
) -> Result<ResultTypeOrContinue, InferFailReason> {
26-
let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else {
27-
return Ok(ResultTypeOrContinue::Continue);
28-
};
29-
if !tree.has_decl_multi_return_refs(&discriminant_decl_id)
30-
|| !tree.has_decl_multi_return_refs(&target_decl_id)
31-
{
29+
let guard_key = var_ref_id.clone();
30+
if !cache.correlated_condition_guard.insert(guard_key.clone()) {
3231
return Ok(ResultTypeOrContinue::Continue);
3332
}
3433

35-
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
36-
let search_root_flow_ids = tree.get_decl_multi_return_search_roots(
37-
&discriminant_decl_id,
38-
&target_decl_id,
39-
condition_position,
40-
antecedent_flow_id,
41-
);
42-
let mut matching_target_types = Vec::new();
43-
let mut uncorrelated_target_types = Vec::new();
44-
for search_root_flow_id in search_root_flow_ids {
45-
let (root_matching_target_types, root_uncorrelated_target_type) =
46-
collect_correlated_types_from_search_root(
47-
db,
48-
tree,
49-
cache,
50-
root,
51-
var_ref_id,
52-
discriminant_decl_id,
53-
target_decl_id,
54-
condition_position,
55-
search_root_flow_id,
56-
narrowed_discriminant_type,
57-
)?;
58-
matching_target_types.extend(root_matching_target_types);
59-
if let Some(root_uncorrelated_target_type) = root_uncorrelated_target_type {
60-
uncorrelated_target_types.push(root_uncorrelated_target_type);
34+
let result = catch_unwind(AssertUnwindSafe(|| {
35+
let Some(target_decl_id) = var_ref_id.get_decl_id_ref() else {
36+
return Ok(ResultTypeOrContinue::Continue);
37+
};
38+
if !tree.has_decl_multi_return_refs(&discriminant_decl_id)
39+
|| !tree.has_decl_multi_return_refs(&target_decl_id)
40+
{
41+
return Ok(ResultTypeOrContinue::Continue);
6142
}
62-
}
6343

64-
if matching_target_types.is_empty() {
65-
return Ok(ResultTypeOrContinue::Continue);
66-
}
44+
let antecedent_flow_id = get_single_antecedent(tree, flow_node)?;
45+
let search_root_flow_ids = tree.get_decl_multi_return_search_roots(
46+
&discriminant_decl_id,
47+
&target_decl_id,
48+
condition_position,
49+
antecedent_flow_id,
50+
);
51+
let mut matching_target_types = Vec::new();
52+
let mut uncorrelated_target_types = Vec::new();
53+
for search_root_flow_id in search_root_flow_ids {
54+
let (root_matching_target_types, root_uncorrelated_target_type) =
55+
collect_correlated_types_from_search_root(
56+
db,
57+
tree,
58+
cache,
59+
root,
60+
var_ref_id,
61+
discriminant_decl_id,
62+
target_decl_id,
63+
condition_position,
64+
search_root_flow_id,
65+
narrowed_discriminant_type,
66+
)?;
67+
matching_target_types.extend(root_matching_target_types);
68+
if let Some(root_uncorrelated_target_type) = root_uncorrelated_target_type {
69+
uncorrelated_target_types.push(root_uncorrelated_target_type);
70+
}
71+
}
6772

68-
let matching_target_type = LuaType::from_vec(matching_target_types);
69-
let antecedent_type = get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
70-
let narrowed_correlated_type =
71-
TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type);
72-
if narrowed_correlated_type.is_never() {
73-
return Ok(ResultTypeOrContinue::Continue);
74-
}
73+
if matching_target_types.is_empty() {
74+
return Ok(ResultTypeOrContinue::Continue);
75+
}
7576

76-
if uncorrelated_target_types.is_empty() {
77-
return Ok(if narrowed_correlated_type == antecedent_type {
78-
ResultTypeOrContinue::Continue
77+
let matching_target_type = LuaType::from_vec(matching_target_types);
78+
let antecedent_type =
79+
get_type_at_flow(db, tree, cache, root, var_ref_id, antecedent_flow_id)?;
80+
let narrowed_correlated_type =
81+
TypeOps::Intersect.apply(db, &antecedent_type, &matching_target_type);
82+
if narrowed_correlated_type.is_never() {
83+
return Ok(ResultTypeOrContinue::Continue);
84+
}
85+
86+
if uncorrelated_target_types.is_empty() {
87+
return Ok(if narrowed_correlated_type == antecedent_type {
88+
ResultTypeOrContinue::Continue
89+
} else {
90+
ResultTypeOrContinue::Result(narrowed_correlated_type)
91+
});
92+
}
93+
94+
let uncorrelated_target_type = LuaType::from_vec(uncorrelated_target_types);
95+
let merged_type = if uncorrelated_target_type.is_never() {
96+
narrowed_correlated_type
7997
} else {
80-
ResultTypeOrContinue::Result(narrowed_correlated_type)
81-
});
82-
}
98+
LuaType::from_vec(vec![narrowed_correlated_type, uncorrelated_target_type])
99+
};
83100

84-
let uncorrelated_target_type = LuaType::from_vec(uncorrelated_target_types);
85-
let merged_type = if uncorrelated_target_type.is_never() {
86-
narrowed_correlated_type
87-
} else {
88-
LuaType::from_vec(vec![narrowed_correlated_type, uncorrelated_target_type])
89-
};
101+
Ok(if merged_type == antecedent_type {
102+
ResultTypeOrContinue::Continue
103+
} else {
104+
ResultTypeOrContinue::Result(merged_type)
105+
})
106+
}));
90107

91-
Ok(if merged_type == antecedent_type {
92-
ResultTypeOrContinue::Continue
93-
} else {
94-
ResultTypeOrContinue::Result(merged_type)
95-
})
108+
cache.correlated_condition_guard.remove(&guard_key);
109+
match result {
110+
Ok(result) => result,
111+
Err(payload) => resume_unwind(payload),
112+
}
96113
}
97114

98115
#[allow(clippy::too_many_arguments)]

0 commit comments

Comments
 (0)