Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions crates/emmylua_code_analysis/src/compilation/test/pcall_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,42 @@ mod test {
assert_eq!(ws.expr_ty("status"), ws.ty("boolean|string"));
assert_eq!(ws.expr_ty("payload"), ws.ty("integer|string"));
}

#[test]
fn test_return_overload_infers_callable_union_member() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: string): integer

---@type FnA | FnB
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(ws.expr_ty("a"), ws.ty("integer|string"));
}

#[test]
fn test_return_overload_selects_callable_intersection_member() {
let mut ws = VirtualWorkspace::new_with_init_std_lib();

ws.def(
r#"
---@alias FnA fun(x: integer): integer
---@alias FnB fun(x: string): boolean

---@type FnA & FnB
local run

_, a = pcall(run, 1)
"#,
);

assert_eq!(ws.expr_ty("a"), ws.ty("integer|string"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use crate::{
},
infer::InferFailReason,
infer_expr,
overload_resolve::resolve_signature_by_args,
},
};
use crate::{
Expand Down Expand Up @@ -156,60 +157,145 @@ pub fn infer_callable_return_from_remaining_args(
return Ok(None);
}

let Some(callable) = as_doc_function_type(context.db, callable_type)? else {
let mut overloads = Vec::new();
collect_callable_overloads(context.db, callable_type, &mut overloads)?;
if overloads.is_empty() {
return Ok(None);
}

let db = context.db;

// Fall back to the union of all candidate returns when args cannot narrow the callable.
let fallback_return = || {
LuaType::from_vec(
overloads
.iter()
.map(|callable| {
let mut callable_tpls = HashSet::new();
callable.visit_type(&mut |ty| {
if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty
{
callable_tpls.insert(generic_tpl.get_tpl_id());
}
});

let mut callable_substitutor = TypeSubstitutor::new();
callable_substitutor.add_need_infer_tpls(callable_tpls);
infer_return_from_callable(db, callable, &callable_substitutor)
})
.collect(),
)
};

let call_arg_types = match infer_expr_list_types(db, context.cache, arg_exprs, None, infer_expr)
{
Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::<Vec<_>>(),
Err(_) => return Ok(Some(fallback_return())),
};
if call_arg_types.is_empty() {
return Ok(None);
}

let instantiated_overloads = overloads
.iter()
.map(|callable| instantiate_callable_from_arg_types(context, callable, &call_arg_types))
.collect::<Vec<_>>();

Ok(Some(
resolve_signature_by_args(db, &instantiated_overloads, &call_arg_types, false, None)
.map(|callable| callable.get_ret().clone())
.unwrap_or_else(|_| fallback_return()),
))
}

fn collect_callable_overloads(
db: &DbIndex,
callable_type: &LuaType,
overloads: &mut Vec<Arc<LuaFunctionType>>,
) -> Result<(), InferFailReason> {
// TODO: Distinguish callable union vs intersection semantics here instead of flattening both
// into one overload-candidate pool. Keep in sync with `infer_union` / `infer_intersection`.
match callable_type {
LuaType::DocFunction(doc_func) => overloads.push(doc_func.clone()),
LuaType::Signature(sig_id) => {
let signature = db
.get_signature_index()
.get(sig_id)
.ok_or(InferFailReason::None)?;
overloads.extend(signature.overloads.iter().cloned());
overloads.push(signature.to_doc_func_type());
}
LuaType::Ref(type_id) | LuaType::Def(type_id) => {
if let Some(origin_type) = db
.get_type_index()
.get_type_decl(type_id)
.ok_or(InferFailReason::None)?
.get_alias_origin(db, None)
{
collect_callable_overloads(db, &origin_type, overloads)?;
}
}
LuaType::Generic(generic) => {
let substitutor = TypeSubstitutor::from_type_array(generic.get_params().to_vec());
if let Some(origin_type) = db
.get_type_index()
.get_type_decl(&generic.get_base_type_id())
.ok_or(InferFailReason::None)?
.get_alias_origin(db, Some(&substitutor))
{
collect_callable_overloads(db, &origin_type, overloads)?;
}
}
LuaType::Union(union) => {
for member in union.into_vec() {
collect_callable_overloads(db, &member, overloads)?;
}
}
LuaType::Intersection(intersection) => {
for member in intersection.get_types() {
collect_callable_overloads(db, member, overloads)?;
}
}
_ => {}
}

Ok(())
}

fn instantiate_callable_from_arg_types(
context: &mut TplContext,
callable: &Arc<LuaFunctionType>,
call_arg_types: &[LuaType],
) -> Arc<LuaFunctionType> {
let mut callable_tpls = HashSet::new();
callable.visit_type(&mut |ty| {
if let LuaType::TplRef(generic_tpl) | LuaType::ConstTplRef(generic_tpl) = ty {
callable_tpls.insert(generic_tpl.get_tpl_id());
}
});
if callable_tpls.is_empty() {
return Ok(Some(callable.get_ret().clone()));
}

let mut callable_substitutor = TypeSubstitutor::new();
callable_substitutor.add_need_infer_tpls(callable_tpls);
let fallback_return = infer_return_from_callable(context.db, &callable, &callable_substitutor);

let call_arg_types =
match infer_expr_list_types(context.db, context.cache, arg_exprs, None, infer_expr) {
Ok(types) => types.into_iter().map(|(ty, _)| ty).collect::<Vec<_>>(),
Err(_) => return Ok(Some(fallback_return)),
};
if call_arg_types.is_empty() {
return Ok(None);
return callable.clone();
}

let callable_param_types = callable
.get_params()
.iter()
.map(|(_, ty)| ty.clone().unwrap_or(LuaType::Unknown))
.collect::<Vec<_>>();

let mut callable_substitutor = TypeSubstitutor::new();
callable_substitutor.add_need_infer_tpls(callable_tpls);
let mut callable_context = TplContext {
db: context.db,
cache: context.cache,
substitutor: &mut callable_substitutor,
call_expr: context.call_expr.clone(),
};
if tpl_pattern_match_args(
&mut callable_context,
&callable_param_types,
&call_arg_types,
)
.is_err()
{
return Ok(Some(fallback_return));
}
let _ = tpl_pattern_match_args(&mut callable_context, &callable_param_types, call_arg_types);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why ignore the err


Ok(Some(infer_return_from_callable(
context.db,
&callable,
&callable_substitutor,
)))
match instantiate_doc_function(context.db, callable, &callable_substitutor) {
LuaType::DocFunction(func) => func,
_ => callable.clone(),
}
}

fn infer_generic_types_from_call(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use super::{
infer::{InferCallFuncResult, InferFailReason},
};

use resolve_signature_by_args::resolve_signature_by_args;
pub(crate) use resolve_signature_by_args::resolve_signature_by_args;

pub fn resolve_signature(
db: &DbIndex,
Expand Down
Loading