From 3012ad64e40194a103bb87d9e9c9571f5c5149ca Mon Sep 17 00:00:00 2001 From: Chase Cooper Date: Fri, 5 Jun 2026 00:58:08 -0400 Subject: [PATCH] Add proactive stack overflow detection via MaxCallDepth + tests Solves #168 --- src/Lua/Exceptions.cs | 2 +- src/Lua/LuaState.cs | 7 +++ src/Lua/Runtime/LuaVirtualMachine.cs | 11 ++++ src/Lua/Standard/BasicLibrary.cs | 2 + tests/Lua.Tests/StackOverflowTests.cs | 87 +++++++++++++++++++++++++++ 5 files changed, 108 insertions(+), 1 deletion(-) create mode 100644 tests/Lua.Tests/StackOverflowTests.cs diff --git a/src/Lua/Exceptions.cs b/src/Lua/Exceptions.cs index 4cb26291..4e9cffa5 100644 --- a/src/Lua/Exceptions.cs +++ b/src/Lua/Exceptions.cs @@ -99,7 +99,7 @@ static string GetMessageWithNearToken(string message, string? nearToken) public class LuaUndumpException(string message) : Exception(message); -class LuaStackOverflowException() : Exception("stack overflow") +public class LuaStackOverflowException() : Exception("stack overflow") { public override string ToString() { diff --git a/src/Lua/LuaState.cs b/src/Lua/LuaState.cs index f94db100..8f7616f6 100644 --- a/src/Lua/LuaState.cs +++ b/src/Lua/LuaState.cs @@ -197,6 +197,8 @@ public void Release() public LuaTable PreloadModules => GlobalState.PreloadModules; public LuaState MainThread => GlobalState.MainThread; + public int MaxCallDepth { get; set; } = 100_000; + public ILuaModuleLoader? ModuleLoader { get => GlobalState.ModuleLoader; @@ -285,6 +287,11 @@ int callerInstructionIndex [MethodImpl(MethodImplOptions.AggressiveInlining)] internal void PushCallStackFrame(in CallStackFrame frame) { + if (CallStackFrameCount >= MaxCallDepth) + { + throw new LuaStackOverflowException(); + } + CurrentException?.BuildOrGet(); CurrentException = null; ref var callStack = ref CoreData!.CallStack; diff --git a/src/Lua/Runtime/LuaVirtualMachine.cs b/src/Lua/Runtime/LuaVirtualMachine.cs index 49a13c60..8b7ef115 100644 --- a/src/Lua/Runtime/LuaVirtualMachine.cs +++ b/src/Lua/Runtime/LuaVirtualMachine.cs @@ -1419,6 +1419,11 @@ static bool Call(VirtualMachineExecutionContext context, out bool doRestart) var newFrame = func.CreateNewFrame(context, newBase, RA, variableArgumentCount); + if (!RuntimeHelpers.TryEnsureSufficientExecutionStack()) + { + throw new LuaStackOverflowException(); + } + state.PushCallStackFrame(newFrame); if (state.CallOrReturnHookMask.Value != 0 && !context.State.IsInHook) { @@ -1619,6 +1624,12 @@ static bool TailCall(VirtualMachineExecutionContext context, out bool doRestart) ); newBase = context.FrameBase + variableArgumentCount; stack.PopUntil(newBase + argumentCount); + + if (!RuntimeHelpers.TryEnsureSufficientExecutionStack()) + { + throw new LuaStackOverflowException(); + } + var lastFrame = state.GetCurrentFrame(); state.LastPc = context.Pc; state.LastCallerFunction = lastFrame.Function; diff --git a/src/Lua/Standard/BasicLibrary.cs b/src/Lua/Standard/BasicLibrary.cs index c57a0aa6..a659397d 100644 --- a/src/Lua/Standard/BasicLibrary.cs +++ b/src/Lua/Standard/BasicLibrary.cs @@ -367,6 +367,8 @@ CancellationToken cancellationToken throw; case OperationCanceledException: throw new LuaCanceledException(context.State, cancellationToken, ex); + case LuaStackOverflowException: + return context.Return(false, ex.Message); case LuaRuntimeException luaEx: { if ( diff --git a/tests/Lua.Tests/StackOverflowTests.cs b/tests/Lua.Tests/StackOverflowTests.cs new file mode 100644 index 00000000..36f936b2 --- /dev/null +++ b/tests/Lua.Tests/StackOverflowTests.cs @@ -0,0 +1,87 @@ +using Lua.Standard; + +namespace Lua.Tests; + +public class StackOverflowTests +{ + [Test] + public void ExceedingMaxCallDepth_ThrowsLuaRuntimeException() + { + var state = LuaState.Create(); + state.MaxCallDepth = 5; + + var ex = Assert.ThrowsAsync(async () => + await state.DoStringAsync( + """ + local function f() + f() + end + f() + """ + ).AsTask() + ); + + Assert.That(ex!.InnerException, Is.TypeOf()); + } + + [Test] + public async Task ExceedingMaxCallDepth_WithPCall_ReturnsErrorMessage() + { + var state = LuaState.Create(); + state.OpenStandardLibraries(); + state.MaxCallDepth = 5; + + var result = await state.DoStringAsync( + """ + local function f() + f() + end + local ok, msg = pcall(f) + return ok, msg + """ + ); + + Assert.That(result, Has.Length.EqualTo(2)); + Assert.That(result[0], Is.EqualTo(new LuaValue(false))); + Assert.That(result[1].Read(), Does.Contain("stack overflow")); + } + + [Test] + public async Task UnderMaxCallLimit_Succeeds() + { + var state = LuaState.Create(); + state.MaxCallDepth = 100; + + var result = await state.DoStringAsync( + """ + local function f(n) + if n <= 0 then return 'done' end + return f(n - 1) + end + return f(10) + """ + ); + + Assert.That(result, Has.Length.EqualTo(1)); + Assert.That(result[0], Is.EqualTo(new LuaValue("done"))); + } + + [Test] + public async Task DefaultMaxCallDepth_AllowsDeepRecursion() + { + var state = LuaState.Create(); + + var result = await state.DoStringAsync( + """ + local function f(n) + if n <= 0 then return n end + return f(n - 1) + end + return f(1000) + """ + ); + + Assert.That(result, Has.Length.EqualTo(1)); + Assert.That(result[0], Is.EqualTo(new LuaValue(0))); + } +}