diff --git a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs index 5cbb83f2b..059241716 100644 --- a/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs +++ b/crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs @@ -141,4 +141,42 @@ mod test { assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string")); assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string")); } + + #[test] + fn test_return_overload_infers_callable_union_member() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): integer + + ---@type FnA | FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } + + #[test] + fn test_return_overload_selects_callable_intersection_member() { + let mut ws = VirtualWorkspace::new_with_init_std_lib(); + + ws.def( + r#" + ---@alias FnA fun(x: integer): integer + ---@alias FnB fun(x: string): boolean + + ---@type FnA & FnB + local run + + _, a = pcall(run, 1) + "#, + ); + + assert_eq!(ws.expr_ty("a"), ws.ty("integer|string")); + } } diff --git a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs index d0a6b24c9..2a4c551cf 100644 --- a/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs +++ b/crates/emmylua_code_analysis/src/semantic/generic/instantiate_type/instantiate_func_generic.rs @@ -22,6 +22,7 @@ use crate::{ }, infer::InferFailReason, infer_expr, + overload_resolve::resolve_signature_by_args, }, }; use crate::{ @@ -156,10 +157,116 @@ pub fn infer_callable_return_from_remaining_args( return Ok(None); } - let Some(callable) = as_doc_function_type(context.db, callable_type)? else { + let mut overloads = Vec::new(); + collect_callable_overloads(context.db, callable_type, &mut overloads)?; + if overloads.is_empty() { return Ok(None); + } + + let db = context.db; + + // Fall back to the union of all candidate returns when args cannot narrow the callable. + let fallback_return = || { + LuaType::from_vec( + overloads + .iter() + .map(|callable| { + let mut callable_tpls = HashSet::new(); + callable.visit_type(&mut |ty| { + if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty + { + callable_tpls.insert(generic_tpl.get_tpl_id()); + } + }); + + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); + infer_return_from_callable(db, callable, &callable_substitutor) + }) + .collect(), + ) }; + let call_arg_types = match infer_expr_list_types(db, context.cache, arg_exprs, None, infer_expr) + { + Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), + Err(_) => return Ok(Some(fallback_return())), + }; + if call_arg_types.is_empty() { + return Ok(None); + } + + let instantiated_overloads = overloads + .iter() + .map(|callable| instantiate_callable_from_arg_types(context, callable, &call_arg_types)) + .collect::>(); + + Ok(Some( + resolve_signature_by_args(db, &instantiated_overloads, &call_arg_types, false, None) + .map(|callable| callable.get_ret().clone()) + .unwrap_or_else(|_| fallback_return()), + )) +} + +fn collect_callable_overloads( + db: &DbIndex, + callable_type: &LuaType, + overloads: &mut Vec>, +) -> Result<(), InferFailReason> { + // TODO: Distinguish callable union vs intersection semantics here instead of flattening both + // into one overload-candidate pool. Keep in sync with `infer_union` / `infer_intersection`. + match callable_type { + LuaType::DocFunction(doc_func) => overloads.push(doc_func.clone()), + LuaType::Signature(sig_id) => { + let signature = db + .get_signature_index() + .get(sig_id) + .ok_or(InferFailReason::None)?; + overloads.extend(signature.overloads.iter().cloned()); + overloads.push(signature.to_doc_func_type()); + } + LuaType::Ref(type_id) | LuaType::Def(type_id) => { + if let Some(origin_type) = db + .get_type_index() + .get_type_decl(type_id) + .ok_or(InferFailReason::None)? + .get_alias_origin(db, None) + { + collect_callable_overloads(db, &origin_type, overloads)?; + } + } + LuaType::Generic(generic) => { + let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec()); + if let Some(origin_type) = db + .get_type_index() + .get_type_decl(&generic.get_base_type_id()) + .ok_or(InferFailReason::None)? + .get_alias_origin(db, Some(&substitutor)) + { + collect_callable_overloads(db, &origin_type, overloads)?; + } + } + LuaType::Union(union) => { + for member in union.into_vec() { + collect_callable_overloads(db, &member, overloads)?; + } + } + LuaType::Intersection(intersection) => { + for member in intersection.get_types() { + collect_callable_overloads(db, member, overloads)?; + } + } + _ => {} + } + + Ok(()) +} + +fn instantiate_callable_from_arg_types( + context: &mut TplContext, + callable: &Arc, + call_arg_types: &[LuaType], +) -> Arc { let mut callable_tpls = HashSet::new(); callable.visit_type(&mut |ty| { if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty { @@ -167,20 +274,7 @@ pub fn infer_callable_return_from_remaining_args( } }); if callable_tpls.is_empty() { - return Ok(Some(callable.get_ret().clone())); - } - - let mut callable_substitutor = TypeSubstitutor::new(); - callable_substitutor.add_need_infer_tpls(callable_tpls); - let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor); - - let call_arg_types = - match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) { - Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::>(), - Err(_) => return Ok(Some(fallback_return)), - }; - if call_arg_types.is_empty() { - return Ok(None); + return callable.clone(); } let callable_param_types = callable @@ -188,28 +282,20 @@ pub fn infer_callable_return_from_remaining_args( .iter() .map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown)) .collect::>(); - + let mut callable_substitutor = TypeSubstitutor::new(); + callable_substitutor.add_need_infer_tpls(callable_tpls); let mut callable_context = TplContext { db: context.db, cache: context.cache, substitutor: &mut callable_substitutor, call_expr: context.call_expr.clone(), }; - if tpl_pattern_match_args( - &mut callable_context, - &callable_param_types, - &call_arg_types, - ) - .is_err() - { - return Ok(Some(fallback_return)); - } + let _ = tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types); - Ok(Some(infer_return_from_callable( - context.db, - &callable, - &callable_substitutor, - ))) + match instantiate_doc_function(context.db, callable, &callable_substitutor) { + LuaType::DocFunction(func) => func, + _ => callable.clone(), + } } fn infer_generic_types_from_call( diff --git a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs index bb8344e41..0b1b69dc4 100644 --- a/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs +++ b/crates/emmylua_code_analysis/src/semantic/overload_resolve/mod.rs @@ -16,7 +16,7 @@ use super::{ infer::{InferCallFuncResult, InferFailReason}, }; -use resolve_signature_by_args::resolve_signature_by_args; +pub(crate) use resolve_signature_by_args::resolve_signature_by_args; pub fn resolve_signature( db: &DbIndex,