From be9caeb293f35e4ba89c6326eeecb55b480bec6f Mon Sep 17 00:00:00 2001 From: Lewis Russell Date: Fri, 20 Mar 2026 14:01:39 +0000 Subject: [PATCH] fix(generic): resolve higher-order callable returns by args Flatten callable unions and intersections into callable candidates, instantiate them from the remaining argument types, and select the matching return via overload resolution instead of unioning member returns directly. Add focused pcall regressions for callable union and intersection values. Callable union handling still mirrors overload-style resolution and is not semantically complete yet; leave that for follow-up work. --- .../src/compilation/test/pcall_test.rs | 38 +++++ .../instantiate_func_generic.rs | 146 ++++++++++++++---- .../src/semantic/overload_resolve/mod.rs | 2 +- 3 files changed, 155 insertions(+), 31 deletions(-) diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 5cbb83f2b..059241716 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -141,4 +141,42 @@ mod test { assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string")); assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); } + + #[test] + fn test_return_overload_infers_callable_union_member() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): integer + + ---@type FnA | FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_selects_callable_intersection_member() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): boolean + + ---@type FnA & FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index d0a6b24c9..2a4c551cf 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -22,6 +22,7 @@ use crate::{ }, infer::InferFailReason, infer_expr, + overload_resolve::resolve_signature_by_args, }, }; use crate::{ @@ -156,10 +157,116 @@ pub fn infer_callable_return_from_remaining_args( return Ok(None); } - let Some(callable) = as_doc_function_type(context.db, callable_type)? else { + let mut overloads = Vec::new(); + collect_callable_overloads(context.db, callable_type, &mut overloads)?; + if overloads.is_empty() { return Ok(None); + } + + let db = context.db; + + // Fall back to the union of all candidate returns when args cannot narrow the callable. + let fallback_return = || { + LuaType::from_vec( + overloads + .iter() + .map(|callable| { + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); + infer_return_from_callable(db, callable, &callable_substitutor) + }) + .collect(), + ) }; + let call_arg_types = match infer_expr_list_types(db, context.cache, arg_exprs, None, infer_expr) + { + Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), + Err(_) => return Ok(Some(fallback_return())), + }; + if call_arg_types.is_empty() { + return Ok(None); + } + + let instantiated_overloads = overloads + .iter() + .map(|callable| instantiate_callable_from_arg_types(context, callable, &call_arg_types)) + .collect::>(); + + Ok(Some( + resolve_signature_by_args(db, &instantiated_overloads, &call_arg_types, false, None) + .map(|callable| callable.get_ret().clone()) + .unwrap_or_else(|_| fallback_return()), + )) +} + +fn collect_callable_overloads( + db: &DbIndex, + callable_type: &LuaType, + overloads: &mut Vec>, +) -> Result<(), InferFailReason> { + // TODO: Distinguish callable union vs intersection semantics here instead of flattening both + // into one overload-candidate pool. Keep in sync with `infer_union` / `infer_intersection`. + match callable_type { + LuaType::DocFunction(doc_func) => overloads.push(doc_func.clone()), + LuaType::Signature(sig_id) => { + let signature = db + .get_signature_index() + .get(sig_id) + .ok_or(InferFailReason::None)?; + overloads.extend(signature.overloads.iter().cloned()); + overloads.push(signature.to_doc_func_type()); + } + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + if let Some(origin_type) = db + .get_type_index() + .get_type_decl(type_id) + .ok_or(InferFailReason::None)? + .get_alias_origin(db, None) + { + collect_callable_overloads(db, &origin_type, overloads)?; + } + } + LuaType::Generic(generic) => { + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + if let Some(origin_type) = db + .get_type_index() + .get_type_decl(&generic.get_base_type_id()) + .ok_or(InferFailReason::None)? + .get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overloads(db, &origin_type, overloads)?; + } + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overloads(db, &member, overloads)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overloads(db, member, overloads)?; + } + } + _ => {} + } + + Ok(()) +} + +fn instantiate_callable_from_arg_types( + context: &mut TplContext, + callable: &Arc, + call_arg_types: &[LuaType], +) -> Arc { let mut callable_tpls = HashSet::new(); callable.visit_type(&mut |ty| { if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { @@ -167,20 +274,7 @@ pub fn infer_callable_return_from_remaining_args( } }); if callable_tpls.is_empty() { - return Ok(Some(callable.get_ret().clone())); - } - - let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls); - let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor); - - let call_arg_types = - match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { - Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), - Err(_) => return Ok(Some(fallback_return)), - }; - if call_arg_types.is_empty() { - return Ok(None); + return callable.clone(); } let callable_param_types = callable @@ -188,28 +282,20 @@ pub fn infer_callable_return_from_remaining_args( .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); - + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); let mut callable_context = TplContext { db: context.db, cache: context.cache, substitutor: &mut callable_substitutor, call_expr: context.call_expr.clone(), }; - if tpl_pattern_match_args( - &mut callable_context, - &callable_param_types, - &call_arg_types, - ) - .is_err() - { - return Ok(Some(fallback_return)); - } + let _ = tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types); - Ok(Some(infer_return_from_callable( - context.db, - &callable, - &callable_substitutor, - ))) + match instantiate_doc_function(context.db, callable, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + } } fn infer_generic_types_from_call( diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index bb8344e41..0b1b69dc4 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -16,7 +16,7 @@ use super::{ infer::{InferCallFuncResult, InferFailReason}, }; -use resolve_signature_by_args::resolve_signature_by_args; +pub(crate) use resolve_signature_by_args::resolve_signature_by_args; pub fn resolve_signature( db: &DbIndex,