Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 104 additions & 32 deletions Source/LuaBridge/detail/CFunctions.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <cstdint>
#include <optional>
#include <string>
#include <vector>

namespace luabridge {

Expand Down Expand Up @@ -947,6 +948,7 @@ inline int newindex_metamethod_simple(lua_State* L)
}

luaL_error(L, "no writable member '%s'", key);
return 0;
}

//=================================================================================================
Expand All @@ -958,7 +960,6 @@ inline int newindex_metamethod_simple(lua_State* L)
inline int read_only_error(lua_State* L)
{
raise_lua_error(L, "'%s' is read-only", lua_tostring(L, lua_upvalueindex(1)));

return 0;
}

Expand Down Expand Up @@ -1551,67 +1552,140 @@ int invoke_proxy_destructor(lua_State* L)
return 1;
}

//=================================================================================================
/**
* @brief C++ storage for a single overload entry: arity and optional type checker.
*
* Stored inside an OverloadSet which is kept as a Lua full userdata (auto-GC'd).
*/
struct OverloadEntry
{
using TypeChecker = bool (*)(lua_State*, int start);

int arity; // -1 for variadic (lua_CFunction): always attempt
TypeChecker checker; // nullptr for variadic: skip type pre-checking
};

/**
* @brief C++ storage for all overloads of a function.
*
* Stored as a Lua full userdata so it is GC'd automatically when the closure is collected.
* The actual function closures are stored separately in a flat Lua table (upvalue 2).
*/
struct OverloadSet
{
std::vector<OverloadEntry> entries;
};

/**
* @brief Check a single argument type, skipping lua_State* (auto-injected, not on the Lua stack).
*/
template <class T>
bool overload_check_one_arg(lua_State* L, int& idx)
{
if constexpr (std::is_pointer_v<T> &&
std::is_same_v<std::remove_const_t<std::remove_pointer_t<T>>, lua_State>)
{
return true; // lua_State* is auto-injected by LuaBridge, not a Lua-visible argument
}
else
{
return Stack<T>::isInstance(L, idx++);
}
}

template <class ArgsPack, std::size_t... I>
bool overload_check_args_impl(lua_State* L, int start, std::index_sequence<I...>)
{
int idx = start;
return (overload_check_one_arg<std::tuple_element_t<I, ArgsPack>>(L, idx) && ...);
}

template <class ArgsPack>
bool overload_check_args(lua_State* L, int start)
{
return overload_check_args_impl<ArgsPack>(L, start,
std::make_index_sequence<std::tuple_size_v<ArgsPack>>{});
}

/**
* @brief Type checker instantiable as an OverloadEntry::TypeChecker function pointer.
*
* Checks that the Lua stack arguments starting at @p start match the types in @tparam ArgsPack,
* using Stack<T>::isInstance without raising errors. lua_State* arguments are skipped.
*/
template <class ArgsPack>
bool overload_type_checker(lua_State* L, int start)
{
return overload_check_args<ArgsPack>(L, start);
}

//=================================================================================================
/**
* @brief lua_CFunction to resolve an invocation between several overloads.
*
* The list of overloads is in the first upvalue. The arguments of the function call are at the top of the Lua stack.
* upvalue[1] = OverloadSet full userdata — C++ vector of {arity, type_checker} per overload.
* upvalue[2] = flat Lua table {[1]=func1, [2]=func2, ...} — the actual function closures.
*
* Dispatch:
* 1. Arity check in C++ (no Lua call).
* 2. Type check via Stack<T>::isInstance in C++ (no pcall) — skips clearly mismatched overloads.
* 3. Only calls lua_pcall for type-matched candidates, eliminating failed pcalls for type mismatches.
*/
template <bool Member>
inline int try_overload_functions(lua_State* L)
{
const int nargs = lua_gettop(L);
const int effective_args = nargs - (Member ? 1 : 0);
const int start_arg = Member ? 2 : 1;

LUABRIDGE_ASSERT(isfulluserdata(L, lua_upvalueindex(1)));
auto* overload_set = align<OverloadSet>(lua_touserdata(L, lua_upvalueindex(1)));

// get the list of overloads
lua_pushvalue(L, lua_upvalueindex(1));
// push flat functions table (upvalue 2)
lua_pushvalue(L, lua_upvalueindex(2));
LUABRIDGE_ASSERT(lua_istable(L, -1));
const int idx_overloads = nargs + 1;
const int num_overloads = get_length(L, idx_overloads);
const int idx_funcs = nargs + 1;

// create table to hold error messages
lua_createtable(L, num_overloads, 0);
lua_createtable(L, static_cast<int>(overload_set->entries.size()), 0);
const int idx_errors = nargs + 2;
int nerrors = 0;

// iterate through table, snippet taken from Lua docs
lua_pushnil(L); // first key
while (lua_next(L, idx_overloads) != 0)
for (int i = 0; i < static_cast<int>(overload_set->entries.size()); ++i)
{
LUABRIDGE_ASSERT(lua_istable(L, -1));

// check matching arity
lua_rawgeti(L, -1, 1);
LUABRIDGE_ASSERT(lua_isnumber(L, -1));
const auto& entry = overload_set->entries[i];

const int overload_arity = static_cast<int>(lua_tointeger(L, -1));
if (overload_arity >= 0 && overload_arity != effective_args)
// fast arity check (C++, no Lua calls)
if (entry.arity >= 0 && entry.arity != effective_args)
{
// store error message and try next overload
lua_pushfstring(L, "Skipped overload #%d with unmatched arity of %d instead of %d", nerrors, overload_arity, effective_args);
lua_pushfstring(L, "Skipped overload #%d with unmatched arity of %d instead of %d", i, entry.arity, effective_args);
lua_rawseti(L, idx_errors, ++nerrors);

lua_pop(L, 2); // pop arity, value (table)
continue;
}

lua_pop(L, 1); // pop arity
// fast type check (C++, no pcall) — avoids expensive pcall for clearly mismatched types
if (entry.checker != nullptr && !entry.checker(L, start_arg))
{
lua_pushfstring(L, "Skipped overload #%d with unmatched argument types", i);
lua_rawseti(L, idx_errors, ++nerrors);
continue;
}

// push function
lua_pushnumber(L, 2);
lua_gettable(L, -2);
// O(1) function lookup from flat table
lua_rawgeti(L, idx_funcs, i + 1);
LUABRIDGE_ASSERT(lua_isfunction(L, -1));

// push arguments
for (int i = 1; i <= nargs; ++i)
lua_pushvalue(L, i);
for (int j = 1; j <= nargs; ++j)
lua_pushvalue(L, j);

// call f, this pops the function and its args, pushes result(s)
const int err = lua_pcall(L, nargs, LUA_MULTRET, 0);
if (err == LUABRIDGE_LUA_OK)
{
// calculate number of return values and return
return lua_gettop(L) - nargs - 4; // 4: overloads, errors, key, table
// 2 extra items on stack below results: idx_funcs, idx_errors
return lua_gettop(L) - nargs - 2;
}
else if (err == LUA_ERRRUN)
{
Expand All @@ -1620,10 +1694,8 @@ inline int try_overload_functions(lua_State* L)
}
else
{
return lua_error_x(L); // critical error: rethrow
return lua_error_x(L); // critical error: rethrow
}

lua_pop(L, 1); // pop value (table)
}

lua_Debug debug;
Expand Down
Loading
Loading