diff --git a/Distribution/LuaBridge/LuaBridge.h b/Distribution/LuaBridge/LuaBridge.h index f2e5eafc..faf9bea7 100644 --- a/Distribution/LuaBridge/LuaBridge.h +++ b/Distribution/LuaBridge/LuaBridge.h @@ -3069,6 +3069,25 @@ class IsContainer namespace luabridge { namespace detail { +template +std::error_code getNilBadArgError(lua_State* L, int index) +{ + const std::string message = std::string(typeName()) + " expected, got no value"; + +#if LUABRIDGE_HAS_EXCEPTIONS + if (index > 0 && LuaException::areExceptionsEnabled(L)) + { + lua_pushstring(L, message.c_str()); + LuaException::raise(L, makeErrorCode(ErrorCode::InvalidTypeCast)); + } +#endif + + if (index > 0) + luaL_argerror(L, lua_absindex(L, index), message.c_str()); + + return makeErrorCode(ErrorCode::InvalidTypeCast); +} + class Userdata { private: @@ -3078,17 +3097,80 @@ class Userdata return (void)classKey, static_cast(lua_touserdata(L, lua_absindex(L, index))); } - static Userdata* getClass(lua_State* L, - int index, - const void* registryConstKey, - const void* registryClassKey, - bool canBeConst) + static std::error_code getClassErrorCode(lua_State* L, const void* registryClassKey) + { + lua_rawgetp_x(L, LUA_REGISTRYINDEX, registryClassKey); + const bool classIsRegistered = lua_istable(L, -1); + lua_pop(L, 1); + + return makeErrorCode(classIsRegistered ? ErrorCode::InvalidTypeCast : ErrorCode::ClassNotRegistered); + } + + static std::error_code getBadArgError(lua_State* L, int index, const void* registryClassKey) + { + const int absIndex = lua_absindex(L, index); + + lua_rawgetp_x(L, LUA_REGISTRYINDEX, registryClassKey); + const bool classIsRegistered = lua_istable(L, -1); + + const char* expected = "unregistered class"; + if (classIsRegistered) + { + lua_rawgetp_x(L, -1, getTypeKey()); + if (lua_isstring(L, -1)) + expected = lua_tostring(L, -1); + lua_pop(L, 1); + } + + const char* got = nullptr; + if (lua_isuserdata(L, absIndex)) + { + lua_getmetatable(L, absIndex); + if (lua_istable(L, -1)) + { + lua_rawgetp_x(L, -1, getTypeKey()); + if (lua_isstring(L, -1)) + got = lua_tostring(L, -1); + + lua_pop(L, 1); + } + + lua_pop(L, 1); + } + + if (! got) + got = lua_typename(L, lua_type(L, absIndex)); + + lua_pop(L, 1); + + const auto errorCode = classIsRegistered ? ErrorCode::InvalidTypeCast : ErrorCode::ClassNotRegistered; + +#if LUABRIDGE_HAS_EXCEPTIONS + if (LuaException::areExceptionsEnabled(L)) + { + const std::string message = std::string(expected) + " expected, got " + got; + lua_pushstring(L, message.c_str()); + + LuaException::raise(L, makeErrorCode(errorCode)); + } +#endif + + return makeErrorCode(errorCode); + } + + static TypeResult getClass(lua_State* L, + int index, + const void* registryConstKey, + const void* registryClassKey, + bool canBeConst) { const int result = lua_getmetatable(L, index); if (result == 0 || !lua_istable(L, -1)) { - lua_rawgetp_x(L, LUA_REGISTRYINDEX, registryClassKey); - return throwBadArg(L, index); + if (result != 0) + lua_pop(L, 1); + + return getBadArgError(L, index, registryClassKey); } lua_rawgetp_x(L, -1, getConstKey()); @@ -3115,8 +3197,8 @@ class Userdata if (lua_isnil(L, -1)) { - lua_pop(L, 2); - return throwBadArg(L, index); + lua_pop(L, 3); + return getBadArgError(L, index, registryClassKey); } lua_remove(L, -2); @@ -3161,48 +3243,6 @@ class Userdata } - static Userdata* throwBadArg(lua_State* L, int index) - { - LUABRIDGE_ASSERT(lua_istable(L, -1) || lua_isnil(L, -1)); - - const char* expected = 0; - if (lua_isnil(L, -1)) - { - expected = "unregistered class"; - } - else - { - lua_rawgetp_x(L, -1, getTypeKey()); - expected = lua_tostring(L, -1); - lua_pop(L, 1); - } - - const char* got = nullptr; - if (lua_isuserdata(L, index)) - { - lua_getmetatable(L, index); - if (lua_istable(L, -1)) - { - lua_rawgetp_x(L, -1, getTypeKey()); - if (lua_isstring(L, -1)) - got = lua_tostring(L, -1); - - lua_pop(L, 1); - } - - lua_pop(L, 1); - } - - if (!got) - { - lua_pop(L, 1); - got = lua_typename(L, lua_type(L, index)); - } - - luaL_argerror(L, index, lua_pushfstring(L, "%s expected, got %s", expected, got)); - return nullptr; - } - public: virtual ~Userdata() {} @@ -3213,7 +3253,7 @@ class Userdata } template - static T* get(lua_State* L, int index, bool canBeConst) + static TypeResult get(lua_State* L, int index, bool canBeConst) { if (lua_isnil(L, index)) return nullptr; @@ -3237,11 +3277,11 @@ class Userdata lua_pop(L, 1); } - auto* clazz = getClass(L, absIndex, constId, classId, canBeConst); + auto clazz = getClass(L, absIndex, constId, classId, canBeConst); if (! clazz) - return nullptr; + return clazz.error(); - return static_cast(clazz->getPointer()); + return static_cast((*clazz)->getPointer()); } template @@ -3666,11 +3706,11 @@ struct StackHelper { using CastType = std::remove_const_t::Type>; - auto* result = Userdata::get(L, index, true); + auto result = Userdata::get(L, index, true); if (! result) - return makeErrorCode(ErrorCode::InvalidTypeCast); + return result.error(); - return ContainerTraits::construct(result); + return ContainerTraits::construct(*result); } }; @@ -3689,11 +3729,14 @@ struct StackHelper static TypeResult> get(lua_State* L, int index) { - auto* result = Userdata::get(L, index, true); + auto result = Userdata::get(L, index, true); if (! result) - return makeErrorCode(ErrorCode::InvalidTypeCast); + return result.error(); - return std::cref(*result); + if (*result == nullptr) + return getNilBadArgError(L, index); + + return std::cref(**result); } }; @@ -3710,11 +3753,11 @@ struct RefStackHelper static ReturnType get(lua_State* L, int index) { - auto* result = Userdata::get(L, index, true); + auto result = Userdata::get(L, index, true); if (! result) - return makeErrorCode(ErrorCode::InvalidTypeCast); + return result.error(); - return ContainerTraits::construct(result); + return ContainerTraits::construct(*result); } }; @@ -3735,11 +3778,14 @@ struct RefStackHelper static ReturnType get(lua_State* L, int index) { - auto* result = Userdata::get(L, index, true); + auto result = Userdata::get(L, index, true); if (! result) - return makeErrorCode(ErrorCode::InvalidTypeCast); + return result.error(); - return std::ref(*result); + if (*result == nullptr) + return getNilBadArgError(L, index); + + return std::ref(**result); } }; @@ -3750,11 +3796,11 @@ struct UserdataGetter static ReturnType get(lua_State* L, int index) { - auto* result = Userdata::get(L, index, true); + auto result = Userdata::get(L, index, true); if (! result) - return makeErrorCode(ErrorCode::InvalidTypeCast); + return result.error(); - return result; + return *result; } }; @@ -3839,7 +3885,14 @@ struct StackOpSelector static Result push(lua_State* L, const T* value) { return UserdataPtr::push(L, value); } - static ReturnType get(lua_State* L, int index) { return Userdata::get(L, index, true); } + static ReturnType get(lua_State* L, int index) + { + auto result = Userdata::get(L, index, true); + if (! result) + return result.error(); + + return *result; + } template static bool isInstance(lua_State* L, int index) { return Userdata::isInstance(L, index); } @@ -6874,6 +6927,9 @@ inline int newindex_metamethod_simple(lua_State* L) { const char* key = lua_tostring(L, 2); + if (! lua_istable(L, lua_upvalueindex(1))) + luaL_error(L, "no writable member '%s'", key); + lua_pushvalue(L, 2); lua_rawget(L, lua_upvalueindex(1)); @@ -6886,7 +6942,6 @@ inline int newindex_metamethod_simple(lua_State* L) } luaL_error(L, "no writable member '%s'", key); - return 0; } } else @@ -6896,6 +6951,9 @@ inline int newindex_metamethod_simple(lua_State* L) { const char* key = lua_tostring(L, 2); + if (! lua_istable(L, lua_upvalueindex(1))) + luaL_error(L, "no writable member '%s'", key); + lua_pushvalue(L, 2); lua_rawget(L, lua_upvalueindex(1)); @@ -6907,7 +6965,6 @@ inline int newindex_metamethod_simple(lua_State* L) } luaL_error(L, "no writable member '%s'", key); - return 0; } } @@ -7028,7 +7085,9 @@ struct property_getter { static int call(lua_State* L) { - C* c = Userdata::get(L, 1, true); + auto c = Userdata::get(L, 1, true); + if (! c) + raise_lua_error(L, "%s", c.error_cstr()); T C::** mp = static_cast(lua_touserdata(L, lua_upvalueindex(1))); @@ -7038,7 +7097,7 @@ struct property_getter try { #endif - result = Stack::push(L, c->**mp); + result = Stack::push(L, (*c)->**mp); #if LUABRIDGE_HAS_EXCEPTIONS } @@ -7099,7 +7158,9 @@ struct property_setter { static int call(lua_State* L) { - C* c = Userdata::get(L, 1, false); + auto c = Userdata::get(L, 1, false); + if (! c) + raise_lua_error(L, "%s", c.error_cstr()); T C::** mp = static_cast(lua_touserdata(L, lua_upvalueindex(1))); @@ -7111,7 +7172,7 @@ struct property_setter if (! result) raise_lua_error(L, "%s", result.error_cstr()); - c->** mp = std::move(*result); + (*c)->** mp = std::move(*result); #if LUABRIDGE_HAS_EXCEPTIONS } @@ -7282,12 +7343,14 @@ int invoke_member_function(lua_State* L) LUABRIDGE_ASSERT(isfulluserdata(L, lua_upvalueindex(1))); - T* ptr = Userdata::get(L, 1, false); + auto ptr = Userdata::get(L, 1, false); + if (! ptr) + raise_lua_error(L, "%s", ptr.error_cstr()); const F& func = *static_cast(lua_touserdata(L, lua_upvalueindex(1))); LUABRIDGE_ASSERT(func != nullptr); - return function::call(L, ptr, func); + return function::call(L, *ptr, func); } template @@ -7297,12 +7360,14 @@ int invoke_const_member_function(lua_State* L) LUABRIDGE_ASSERT(isfulluserdata(L, lua_upvalueindex(1))); - const T* ptr = Userdata::get(L, 1, true); + auto ptr = Userdata::get(L, 1, true); + if (! ptr) + raise_lua_error(L, "%s", ptr.error_cstr()); const F& func = *static_cast(lua_touserdata(L, lua_upvalueindex(1))); LUABRIDGE_ASSERT(func != nullptr); - return function::call(L, ptr, func); + return function::call(L, *ptr, func); } template @@ -7312,7 +7377,9 @@ int invoke_member_cfunction(lua_State* L) LUABRIDGE_ASSERT(isfulluserdata(L, lua_upvalueindex(1))); - T* t = Userdata::get(L, 1, false); + auto t = Userdata::get(L, 1, false); + if (! t) + raise_lua_error(L, "%s", t.error_cstr()); const F& func = *static_cast(lua_touserdata(L, lua_upvalueindex(1))); LUABRIDGE_ASSERT(func != nullptr); @@ -7321,7 +7388,7 @@ int invoke_member_cfunction(lua_State* L) try { #endif - return (t->*func)(L); + return ((*t)->*func)(L); #if LUABRIDGE_HAS_EXCEPTIONS } @@ -7341,7 +7408,9 @@ int invoke_const_member_cfunction(lua_State* L) LUABRIDGE_ASSERT(isfulluserdata(L, lua_upvalueindex(1))); - const T* t = Userdata::get(L, 1, true); + auto t = Userdata::get(L, 1, true); + if (! t) + raise_lua_error(L, "%s", t.error_cstr()); const F& func = *static_cast(lua_touserdata(L, lua_upvalueindex(1))); LUABRIDGE_ASSERT(func != nullptr); @@ -7350,7 +7419,7 @@ int invoke_const_member_cfunction(lua_State* L) try { #endif - return (t->*func)(L); + return ((*t)->*func)(L); #if LUABRIDGE_HAS_EXCEPTIONS } @@ -8192,11 +8261,11 @@ struct destructor_forwarder void operator()(lua_State* L) { - auto* value = Userdata::get(L, -1, false); - if (value == nullptr) - raise_lua_error(L, "invalid object destruction"); + auto value = Userdata::get(L, -1, false); + if (! value) + raise_lua_error(L, "%s", value.error_cstr()); - std::invoke(m_func, value); + std::invoke(m_func, *value); } private: @@ -8970,6 +9039,39 @@ class LuaRef : public LuaRefBase return LuaRef(*this).rawget(key); } + template + T unsafeRawgetField(const char* key) const + { +#if LUABRIDGE_SAFE_STACK_CHECKS + luaL_checkstack(m_L, 3, detail::error_lua_stack_overflow); +#endif + + push(m_L); + lua_pushstring(m_L, key); + lua_rawget(m_L, -2); + + auto result = Stack::get(m_L, -1); + lua_pop(m_L, 2); + + return result.value(); + } + + template + void unsafeRawsetField(const char* key, T&& value) const + { +#if LUABRIDGE_SAFE_STACK_CHECKS + luaL_checkstack(m_L, 3, detail::error_lua_stack_overflow); +#endif + + push(m_L); + lua_pushstring(m_L, key); + [[maybe_unused]] const auto pushed = Stack>::push(m_L, std::forward(value)); + LUABRIDGE_ASSERT(static_cast(pushed)); + + lua_rawset(m_L, -3); + lua_pop(m_L, 1); + } + private: int m_tableRef = LUA_NOREF; int m_keyRef = LUA_NOREF; @@ -9450,6 +9552,17 @@ namespace luabridge { namespace detail { +template +bool is_handler_valid(const F& f) noexcept +{ + if constexpr (std::is_pointer_v>) + return f != nullptr; + else if constexpr (std::is_constructible_v>) + return static_cast(f); + else + return true; +} + template struct IsTuple : std::false_type { @@ -9522,14 +9635,22 @@ TypeResult decodeCallResult(lua_State* L, int firstResultIndex, int numReturn template TypeResult callWithHandler(const Ref& object, F&& errorHandler, Args&&... args) { - static constexpr bool isValidHandler = !std::is_same_v, detail::remove_cvref_t>; + static_assert(std::is_same_v, detail::remove_cvref_t> || std::is_invocable_r_v); + + static constexpr bool isValidHandler = + !std::is_same_v, detail::remove_cvref_t>; lua_State* L = object.state(); const StackRestore stackRestore(L); const int initialTop = lua_gettop(L); + bool hasHandler = false; if constexpr (isValidHandler) - detail::push_function(L, std::forward(errorHandler), ""); + { + hasHandler = detail::is_handler_valid(errorHandler); + if (hasHandler) + detail::push_function(L, std::forward(errorHandler), ""); + } object.push(); @@ -9539,7 +9660,7 @@ TypeResult callWithHandler(const Ref& object, F&& errorHandler, Args&&... arg return result.error(); } - const int messageHandlerIndex = isValidHandler ? (initialTop + 1) : 0; + const int messageHandlerIndex = hasHandler ? (initialTop + 1) : 0; const int code = lua_pcall(L, sizeof...(Args), LUA_MULTRET, messageHandlerIndex); if (code != LUABRIDGE_LUA_OK) { @@ -9557,7 +9678,7 @@ TypeResult callWithHandler(const Ref& object, F&& errorHandler, Args&&... arg return ec; } - if constexpr (isValidHandler) + if (hasHandler) lua_remove(L, initialTop + 1); const int firstResultIndex = initialTop + 1; @@ -9918,8 +10039,6 @@ class Namespace : public detail::Registrar s = s + message; luaL_error(L, "%s", s.c_str()); - - return 0; } #endif diff --git a/Source/LuaBridge/detail/CFunctions.h b/Source/LuaBridge/detail/CFunctions.h index dcd98e57..6be15c58 100644 --- a/Source/LuaBridge/detail/CFunctions.h +++ b/Source/LuaBridge/detail/CFunctions.h @@ -947,6 +947,7 @@ inline int newindex_metamethod_simple(lua_State* L) } luaL_error(L, "no writable member '%s'", key); + return 0; } //================================================================================================= diff --git a/Source/LuaBridge/detail/Invoke.h b/Source/LuaBridge/detail/Invoke.h index 71645c4d..f0b64b56 100644 --- a/Source/LuaBridge/detail/Invoke.h +++ b/Source/LuaBridge/detail/Invoke.h @@ -20,6 +20,17 @@ namespace luabridge { //================================================================================================= namespace detail { +template +bool is_handler_valid(const F& f) noexcept +{ + if constexpr (std::is_pointer_v>) + return f != nullptr; + else if constexpr (std::is_constructible_v>) + return static_cast(f); + else + return true; +} + template struct IsTuple : std::false_type { @@ -96,14 +107,22 @@ TypeResult decodeCallResult(lua_State* L, int firstResultIndex, int numReturn template TypeResult callWithHandler(const Ref& object, F&& errorHandler, Args&&... args) { - static constexpr bool isValidHandler = !std::is_same_v, detail::remove_cvref_t>; + static_assert(std::is_same_v, detail::remove_cvref_t> || std::is_invocable_r_v); + + static constexpr bool isValidHandler = + !std::is_same_v, detail::remove_cvref_t>; lua_State* L = object.state(); const StackRestore stackRestore(L); const int initialTop = lua_gettop(L); + bool hasHandler = false; if constexpr (isValidHandler) - detail::push_function(L, std::forward(errorHandler), ""); + { + hasHandler = detail::is_handler_valid(errorHandler); + if (hasHandler) + detail::push_function(L, std::forward(errorHandler), ""); + } object.push(); @@ -113,7 +132,7 @@ TypeResult callWithHandler(const Ref& object, F&& errorHandler, Args&&... arg return result.error(); } - const int messageHandlerIndex = isValidHandler ? (initialTop + 1) : 0; + const int messageHandlerIndex = hasHandler ? (initialTop + 1) : 0; const int code = lua_pcall(L, sizeof...(Args), LUA_MULTRET, messageHandlerIndex); if (code != LUABRIDGE_LUA_OK) { @@ -131,7 +150,7 @@ TypeResult callWithHandler(const Ref& object, F&& errorHandler, Args&&... arg return ec; } - if constexpr (isValidHandler) + if (hasHandler) lua_remove(L, initialTop + 1); const int firstResultIndex = initialTop + 1; diff --git a/Tests/Source/LuaRefTests.cpp b/Tests/Source/LuaRefTests.cpp index 00d4913a..bccdbe00 100644 --- a/Tests/Source/LuaRefTests.cpp +++ b/Tests/Source/LuaRefTests.cpp @@ -533,6 +533,112 @@ TEST_F(LuaRefTests, CallableWithHandler) EXPECT_TRUE(errorMessage.find("we failed badly") != std::string::npos); } +TEST_F(LuaRefTests, CallableWithHandlerAsIntToBoolValuedFunction) +{ + runLua("function f(x) return x <= 1 end"); + auto f = luabridge::getGlobal(L, "f"); + EXPECT_TRUE(f.isCallable()); + + bool calledHandler = false; + std::string errorMessage; + auto handler = [&](lua_State*) -> int + { + calledHandler = true; + + if (auto msg = lua_tostring(L, 1)) + errorMessage = msg; + + return 0; + }; + + auto result = f.callWithHandler(handler, 2); + ASSERT_TRUE(result); + EXPECT_FALSE(calledHandler); + EXPECT_FALSE(result.value()); +} + +TEST_F(LuaRefTests, CallableWithStdFunction) +{ + runLua("function f(x) error('we failed ' .. x) end"); + auto f = luabridge::getGlobal(L, "f"); + EXPECT_TRUE(f.isCallable()); + + bool calledHandler = false; + std::string errorMessage; + auto handler = [&](lua_State*) -> int + { + calledHandler = true; + + if (auto msg = lua_tostring(L, 1)) + errorMessage = msg; + + return 0; + }; + + std::function pHandler = handler; + + EXPECT_FALSE(f.callWithHandler(pHandler, "badly")); + EXPECT_TRUE(calledHandler); + EXPECT_TRUE(errorMessage.find("we failed badly") != std::string::npos); +} + +TEST_F(LuaRefTests, CallableWithNullifiedStdFunction) +{ + runLua("function f(x) error('we failed ' .. x) end"); + auto f = luabridge::getGlobal(L, "f"); + EXPECT_TRUE(f.isCallable()); + + std::function pHandler = nullptr; + EXPECT_FALSE(f.callWithHandler(pHandler, "badly")); + +#if LUABRIDGE_HAS_EXCEPTIONS + EXPECT_ANY_THROW(f.callWithHandler(pHandler, "badly").throw_on_error()); +#endif +} + +TEST_F(LuaRefTests, CallableWithCFunction) +{ + runLua("function f(x) error('we failed ' .. x) end"); + auto f = luabridge::getGlobal(L, "f"); + EXPECT_TRUE(f.isCallable()); + + lua_CFunction pHandler = +[](lua_State* L) { return 0; }; + EXPECT_FALSE(f.callWithHandler(pHandler, "badly")); +} + +TEST_F(LuaRefTests, CallableWithNullCFunction) +{ + runLua("function f(x) error('we failed ' .. x) end"); + auto f = luabridge::getGlobal(L, "f"); + EXPECT_TRUE(f.isCallable()); + + lua_CFunction pHandler = nullptr; + EXPECT_FALSE(f.callWithHandler(pHandler, "badly")); + +#if LUABRIDGE_HAS_EXCEPTIONS + EXPECT_ANY_THROW(f.callWithHandler(pHandler, "badly").throw_on_error()); +#endif +} + +#if LUABRIDGE_HAS_EXCEPTIONS +TEST_F(LuaRefTests, CallableWithThrowingHandler) +{ + runLua("function f(x) error('we failed ' .. x) end"); + auto f = luabridge::getGlobal(L, "f"); + EXPECT_TRUE(f.isCallable()); + + bool calledHandler = false; + auto handler = [&](lua_State*) -> int + { + calledHandler = true; + return 0; + }; + + EXPECT_ANY_THROW(f.callWithHandler(handler, "badly").throw_on_error()); + EXPECT_TRUE(calledHandler); +} +#endif + TEST_F(LuaRefTests, CallableWrapper) { runLua("function sum(a, b) return a + b end");