From 7eb80a93e3d86dc0a119a4d382a9a51a7cada6a9 Mon Sep 17 00:00:00 2001 From: Frotty Date: Fri, 13 Mar 2026 00:39:16 +0100 Subject: [PATCH] fix lua inlining --- .../java/de/peeeq/wurstscript/RunArgs.java | 19 ++- .../imoptimizer/GlobalsInliner.java | 18 +-- .../translation/imoptimizer/ImInliner.java | 110 +++++++++++-- .../lua/translation/LuaTranslator.java | 147 ++++++++++++++--- .../tests/LuaTranslationTests.java | 148 ++++++++++++++++++ .../wurstscript/tests/OptimizerTests.java | 34 ++++ 6 files changed, 436 insertions(+), 40 deletions(-) diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/RunArgs.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/RunArgs.java index ad1d65893..9566d251f 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/RunArgs.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/RunArgs.java @@ -132,7 +132,6 @@ public RunArgs(String... args) { addOptionWithArg("functionSplitLimit", "The maximum number of operations in a function before it is split by the function splitter (used for compiletime functions)", s -> functionSplitLimit = Integer.parseInt(s, 10)); - optionPrettyPrint = addOption("prettyPrint", "Pretty print the input file, or all sub-directory if the given path is: '...'"); nextArg: @@ -149,6 +148,20 @@ public RunArgs(String... args) { o.isSet = true; continue nextArg; } else if ((o.argHandler != null && isDoubleArg(a, o))) { + String value = a.substring(a.indexOf(" ") + 1).trim(); + if (value.isEmpty()) { + throw new RuntimeException("Missing value for option: -" + o.name); + } + o.argHandler.accept(value); + o.isSet = true; + continue nextArg; + } else if (o.argHandler != null && isEqualsArg(a, o)) { + String value = a.substring(a.indexOf("=") + 1).trim(); + if (value.isEmpty()) { + throw new RuntimeException("Missing value for option: -" + o.name); + } + o.argHandler.accept(value); + o.isSet = true; continue nextArg; } } @@ -170,6 +183,10 @@ private boolean isDoubleArg(String arg, RunOption option) { return (arg.contains(" ") && ("-" + option.name).equals(arg.substring(0, arg.indexOf(" ")))); } + private boolean isEqualsArg(String arg, RunOption option) { + return arg.startsWith("-" + option.name + "="); + } + private RunOption addOption(String name, String descr) { RunOption opt = new RunOption(name, descr); options.add(opt); diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/GlobalsInliner.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/GlobalsInliner.java index add8abad4..bc83c1afd 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/GlobalsInliner.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/GlobalsInliner.java @@ -16,7 +16,6 @@ import java.util.stream.Collectors; public class GlobalsInliner implements OptimizerPass { - public int optimize(ImTranslator trans) { int obsoleteCount = 0; ImProg prog = trans.getImProg(); @@ -44,7 +43,7 @@ public int optimize(ImTranslator trans) { ImVarWrite obs = null; for (ImVarWrite write : v.attrWrites()) { ImFunction func = write.getNearestFunc(); - if (isInInit(func)) { + if (isInInitGlobals(func)) { right = write.getRight(); obs = write; break; @@ -67,19 +66,21 @@ public int optimize(ImTranslator trans) { List initWrites = new ArrayList<>(); for (ImVarWrite imVarWrite : v.attrWrites()) { ImFunction nearestFunc = imVarWrite.getNearestFunc(); - if (isInInit(nearestFunc)) { + if (isInInitGlobals(nearestFunc)) { initWrites.add(imVarWrite); } } if (initWrites.size() == 1) { if(v.getType() instanceof ImSimpleType) { - ImExpr write = v.attrWrites().iterator().next().getRight(); + ImVarWrite initWrite = initWrites.get(0); + ImExpr write = initWrite.getRight(); try { ImExpr defaultValue = ImHelper.defaultValueForType((ImSimpleType) v.getType()); boolean isDefault = defaultValue.structuralEquals(write); if (isDefault) { - // Assignment is default value and can be removed - v.attrWrites().iterator().next().replaceBy(ImHelper.nullExpr()); + // Only remove the init write when it assigns the default value. + // Never touch non-init writes here. + initWrite.replaceBy(ImHelper.nullExpr()); } } catch (Exception e) { throw new CompileError(write.attrTrace().attrErrorPos(), @@ -140,9 +141,8 @@ public String getName() { } - private static boolean isInInit(ImFunction func) { - return func != null && (func.getName().startsWith("init_") || func.getName().equals("main") || func.getName().startsWith("InitTrig_") - || func.getName().equals("initGlobals")); + private static boolean isInInitGlobals(ImFunction func) { + return func != null && func.getName().equals("initGlobals"); } } diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java index 8eae0c9de..424c951b2 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java @@ -28,6 +28,7 @@ public class ImInliner { private final Map callCounts = Maps.newLinkedHashMap(); private final Map funcSizes = Maps.newLinkedHashMap(); private final Set done = Sets.newLinkedHashSet(); + private final Map containsFuncRefCache = Maps.newLinkedHashMap(); private final double inlineTreshold = 50; static { @@ -49,8 +50,7 @@ public void doInlining() { } private void inlineFunctions() { - - for (ImFunction f : ImHelper.calculateFunctionsOfProg(prog)) { + for (ImFunction f : sortedFunctions(ImHelper.calculateFunctionsOfProg(prog))) { inlineFunctions(f); } } @@ -61,7 +61,7 @@ private void inlineFunctions(ImFunction f) { } done.add(f); // first inline functions called from this function - for (ImFunction called : translator.getCalledFunctions().get(f)) { + for (ImFunction called : sortedFunctions(translator.getCalledFunctions().get(f))) { inlineFunctions(called); } boolean[] changed = new boolean[]{false}; @@ -73,10 +73,10 @@ private ImFunction inlineFunctions(ImFunction f, Element parent, int parentI, El if (e instanceof ImFunctionCall) { ImFunctionCall call = (ImFunctionCall) e; ImFunction called = call.getFunc(); - boolean canInline = f != called && shouldInline(call, called); + boolean canInline = f != called && shouldInline(f, call, called); if (LOG_INLINER) { String msg = "[INLINER] caller=" + f.getName() + " callee=" + called.getName() + " decision=" + (canInline ? "inline" : "keep") + - (canInline ? "" : " reason=" + skipReason(call, called)); + (canInline ? "" : " reason=" + skipReason(f, call, called)); WLogger.info(msg); System.out.println(msg); } @@ -109,13 +109,19 @@ private ImFunction inlineFunctions(ImFunction f, Element parent, int parentI, El return null; } - private String skipReason(ImFunctionCall call, ImFunction f) { + private String skipReason(ImFunction caller, ImFunctionCall call, ImFunction f) { if (f.isNative()) { return "native"; } if (call.getCallType() == CallType.EXECUTE) { return "execute_call"; } + if (translator.isLuaTarget() && !maxOneReturn(f)) { + return "lua_multi_return_inline_disabled"; + } + if (translator.isLuaTarget() && containsFuncRef(f)) { + return "lua_callback_funcref_barrier"; + } if (!inlinableFunctions.contains(f)) { return "not_in_inlinable_set"; } @@ -276,12 +282,18 @@ private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) } private void rateInlinableFunctions() { - for (Map.Entry edge : translator.getCalledFunctions().entries()) { + List> edges = new ArrayList<>(translator.getCalledFunctions().entries()); + edges.sort((a, b) -> { + int c = functionSortKey(a.getKey()).compareTo(functionSortKey(b.getKey())); + if (c != 0) return c; + return functionSortKey(a.getValue()).compareTo(functionSortKey(b.getValue())); + }); + for (Map.Entry edge : edges) { // For bloat control we need how often a function is used (incoming edges), // not how many calls it performs itself (outgoing edges). incCallCount(edge.getValue()); } - for (ImFunction f : inlinableFunctions) { + for (ImFunction f : sortedFunctions(inlinableFunctions)) { int size = estimateSize(f); funcSizes.put(f, size); } @@ -322,10 +334,25 @@ private int getFuncSize(ImFunction f) { } } - private boolean shouldInline(ImFunctionCall call, ImFunction f) { + private boolean shouldInline(ImFunction caller, ImFunctionCall call, ImFunction f) { if (f.isNative() || call.getCallType() == CallType.EXECUTE) { return false; } + if (translator.isLuaTarget() && !maxOneReturn(f)) { + // Conservative safety: Lua inliner multi-return rewriting is not yet fully robust + // across all lowered patterns. Keep call semantics intact for now. + return false; + } + if (translator.isLuaTarget() && containsFuncRef(f)) { + // Functions that build callback refs are lowered with Lua-specific wrappers/xpcall. + // Keeping them as standalone calls avoids callback context/vararg scope breakage. + return false; + } + if (isLuaTypeCastingCompatFunction(f)) { + // In Lua these compat wrappers are rewritten to object index helpers. + // If they are inlined beforehand, old TypeCasting bodies leak through. + return false; + } double threshold = inlineTreshold; for (ImExpr arg : call.getArguments()) { @@ -362,6 +389,31 @@ private boolean containsCallTo(ImFunction f, Element e) { return false; } + private boolean containsFuncRef(ImFunction f) { + if (f == null) { + return false; + } + Boolean cached = containsFuncRefCache.get(f); + if (cached != null) { + return cached; + } + boolean result = containsFuncRef(f.getBody()); + containsFuncRefCache.put(f, result); + return result; + } + + private boolean containsFuncRef(Element e) { + if (e instanceof ImFuncRef) { + return true; + } + for (int i = 0; i < e.size(); i++) { + if (containsFuncRef(e.get(i))) { + return true; + } + } + return false; + } + private int estimateSize(ImFunction f) { int[] r = new int[]{0}; estimateSize(f.getBody(), r); @@ -390,24 +442,46 @@ private int getCallCount(ImFunction f) { } private void collectInlinableFunctions() { - for (ImFunction f : ImHelper.calculateFunctionsOfProg(prog)) { + for (ImFunction f : sortedFunctions(ImHelper.calculateFunctionsOfProg(prog))) { if (isInlineCandidate(f)) { inlinableFunctions.add(f); } } // Some call targets can survive in the call graph but not in prog/classes lists. - for (ImFunction f : translator.getCalledFunctions().values()) { + for (ImFunction f : sortedFunctions(translator.getCalledFunctions().values())) { if (isInlineCandidate(f)) { inlinableFunctions.add(f); } } } + private List sortedFunctions(Collection functions) { + List r = new ArrayList<>(functions); + r.sort(Comparator.comparing(this::functionSortKey)); + return r; + } + + private String functionSortKey(ImFunction f) { + if (f == null) { + return ""; + } + StringBuilder sb = new StringBuilder(); + sb.append(f.getName()).append("|"); + sb.append(f.getReturnType()).append("|"); + for (ImVar p : f.getParameters()) { + sb.append(p.getType()).append(","); + } + return sb.toString(); + } + private boolean isInlineCandidate(ImFunction f) { if (f.hasFlag(FunctionFlagEnum.IS_COMPILETIME_NATIVE) || f.hasFlag(FunctionFlagEnum.IS_NATIVE)) { // do not inline natives return false; } + if (isLuaTypeCastingCompatFunction(f)) { + return false; + } if (f == translator.getGlobalInitFunc()) { return false; } @@ -419,6 +493,20 @@ private boolean isInlineCandidate(ImFunction f) { return true; } + private boolean isLuaTypeCastingCompatFunction(ImFunction f) { + if (!translator.isLuaTarget() || f == null) { + return false; + } + de.peeeq.wurstscript.ast.Element trace = f.attrTrace(); + if (trace instanceof de.peeeq.wurstscript.ast.FuncDef fd + && fd.attrNearestPackage() instanceof de.peeeq.wurstscript.ast.WPackage p + && "TypeCasting".equals(p.getName())) { + String name = fd.getName(); + return name.endsWith("FromIndex") || name.endsWith("ToIndex"); + } + return false; + } + private boolean maxOneReturn(ImFunction f) { return maxOneReturn(f.getBody()); } diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java index c833d7a8a..af1d31642 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java @@ -1,6 +1,7 @@ package de.peeeq.wurstscript.translation.lua.translation; import de.peeeq.datastructures.UnionFind; +import de.peeeq.wurstscript.WLogger; import de.peeeq.wurstscript.ast.ClassDef; import de.peeeq.wurstscript.ast.Element; import de.peeeq.wurstscript.ast.FuncDef; @@ -95,6 +96,8 @@ public class LuaTranslator { ); private static final boolean DEBUG_LUA_DISPATCH = "1".equals(System.getenv("WURST_DEBUG_LUA_DISPATCH")) || Boolean.getBoolean("wurst.debug.lua.dispatch"); + private static final boolean DEBUG_LUA_LOCALS = "1".equals(System.getenv("WURST_DEBUG_LUA_LOCALS")) + || Boolean.getBoolean("wurst.debug.lua.locals"); final ImProg prog; final LuaCompilationUnit luaModel; @@ -319,6 +322,7 @@ public LuaCompilationUnit translate() { } cleanStatements(); + enforceLuaLocalLimits(); emitExperimentalHashtableLeakGuards(); return luaModel; @@ -767,7 +771,8 @@ private void translateFunc(ImFunction f) { // translate body: translateStatements(lf.getBody(), f.getBody()); - spillLocalsIntoTableIfNeeded(lf, functionLocals); + // local-limit enforcement is done after final statement cleanup, + // because cleanup and later rewrites can still introduce locals. } if (f.isExtern() || f.isNative()) { @@ -828,43 +833,147 @@ private boolean rewriteTypeCastingCompatFunction(ImFunction f, LuaFunction lf) { return false; } - private void spillLocalsIntoTableIfNeeded(LuaFunction lf, List functionLocals) { - if (functionLocals.size() <= LUA_LOCALS_LIMIT) { + private void enforceLuaLocalLimits() { + luaModel.accept(new LuaModel.DefaultVisitor() { + @Override + public void visit(LuaFunction f) { + super.visit(f); + spillLocalsIntoTableIfNeeded(f.getName(), f.getParams(), f.getBody()); + } + + @Override + public void visit(LuaMethod m) { + super.visit(m); + spillLocalsIntoTableIfNeeded(m.getName(), m.getParams(), m.getBody()); + } + }); + } + + private void spillLocalsIntoTableIfNeeded(String functionName, LuaParams params, LuaStatements body) { + List scopeLocals = collectFunctionScopeLocals(body); + int localCount = params.size() + scopeLocals.size(); + if (DEBUG_LUA_LOCALS) { + WLogger.info("[LUA_LOCALS] function=" + functionName + " params=" + params.size() + + " locals=" + scopeLocals.size() + " total=" + localCount); + } + if (localCount < LUA_LOCALS_LIMIT || scopeLocals.isEmpty()) { return; } - Set localSet = new HashSet<>(functionLocals); - LuaVariable localsTable = LuaAst.LuaVariable(uniqueName("__wurst_locals"), - LuaAst.LuaTableConstructor(LuaAst.LuaTableFields())); + LuaVariable localsTable = findTopLevelLocalsTable(body); + if (localsTable == null) { + localsTable = LuaAst.LuaVariable(uniqueName("__wurst_locals"), + LuaAst.LuaTableConstructor(LuaAst.LuaTableFields())); + } + // Must be declared before any rewritten uses; otherwise accesses become global lookups. + if (!body.isEmpty() && body.get(0) != localsTable) { + body.remove(localsTable); + body.add(0, localsTable); + } + + Set localSet = new LinkedHashSet<>(); + for (LuaVariable v : scopeLocals) { + if (v != localsTable) { + localSet.add(v); + } + } + if (localSet.isEmpty()) { + return; + } + + if (DEBUG_LUA_LOCALS) { + WLogger.info("[LUA_LOCALS] spill function=" + functionName + " total=" + localCount + + " spilledLocals=" + localSet.size()); + } + + final LuaVariable tableVar = localsTable; + final Map localSlots = createLocalSlots(localSet); // Rewrite accesses first, then replace declarations with table init assignments. - lf.getBody().forEachElement(e -> { + forEachElementRec(body, e -> { if (e instanceof LuaExprVarAccess) { LuaExprVarAccess va = (LuaExprVarAccess) e; - if (localSet.contains(va.getVar())) { - LuaExpr tableRef = LuaAst.LuaExprVarAccess(localsTable); - LuaExpr key = LuaAst.LuaExprStringVal(va.getVar().getName()); + LuaVariable var = va.getVar(); + Integer slot = localSlots.get(var); + if (slot != null) { + LuaExpr tableRef = LuaAst.LuaExprVarAccess(tableVar); + LuaExpr key = LuaAst.LuaExprIntVal("" + slot); va.replaceBy(LuaAst.LuaExprArrayAccess(tableRef, LuaAst.LuaExprlist(key))); } } }); - List oldBody = new ArrayList<>(lf.getBody()); - lf.getBody().clear(); - lf.getBody().add(localsTable); + rewriteLocalDeclarationsToTableAssignments(body, localSet, localSlots, tableVar); + } - for (LuaStatement stmt : oldBody) { + private Map createLocalSlots(Set localSet) { + Map r = new LinkedHashMap<>(); + int i = 1; + for (LuaVariable v : localSet) { + r.put(v, i++); + } + return r; + } + + private LuaVariable findTopLevelLocalsTable(LuaStatements body) { + for (LuaStatement stmt : body) { + if (stmt instanceof LuaVariable) { + LuaVariable v = (LuaVariable) stmt; + if (v.getName().startsWith("__wurst_locals") && v.getInitialValue() instanceof LuaTableConstructor) { + return v; + } + } + } + return null; + } + + private List collectFunctionScopeLocals(LuaStatements body) { + List result = new ArrayList<>(); + collectFunctionScopeLocalsRec(body, result); + return result; + } + + private void rewriteLocalDeclarationsToTableAssignments(LuaStatements stmts, Set localSet, Map localSlots, LuaVariable tableVar) { + ListIterator it = stmts.listIterator(); + while (it.hasNext()) { + LuaStatement stmt = it.next(); if (stmt instanceof LuaVariable && localSet.contains(stmt)) { LuaVariable localDecl = (LuaVariable) stmt; - LuaExpr key = LuaAst.LuaExprStringVal(localDecl.getName()); - LuaExpr left = LuaAst.LuaExprArrayAccess(LuaAst.LuaExprVarAccess(localsTable), LuaAst.LuaExprlist(key)); - lf.getBody().add(LuaAst.LuaAssignment(left, ((LuaExpr) localDecl.getInitialValue()).copy())); - } else { - lf.getBody().add(stmt); + Integer slot = localSlots.get(localDecl); + if (slot == null) { + continue; + } + LuaExpr key = LuaAst.LuaExprIntVal("" + slot); + LuaExpr left = LuaAst.LuaExprArrayAccess(LuaAst.LuaExprVarAccess(tableVar), LuaAst.LuaExprlist(key)); + LuaExprOpt initVal = localDecl.getInitialValue(); + LuaExpr right = initVal instanceof LuaExpr ? (LuaExpr) initVal.copy() : LuaAst.LuaExprNull(); + it.set(LuaAst.LuaAssignment(left, right)); + } else if (stmt instanceof LuaIf) { + LuaIf luaIf = (LuaIf) stmt; + rewriteLocalDeclarationsToTableAssignments(luaIf.getThenStmts(), localSet, localSlots, tableVar); + rewriteLocalDeclarationsToTableAssignments(luaIf.getElseStmts(), localSet, localSlots, tableVar); + } else if (stmt instanceof LuaWhile) { + LuaWhile luaWhile = (LuaWhile) stmt; + rewriteLocalDeclarationsToTableAssignments(luaWhile.getBody(), localSet, localSlots, tableVar); } } } + private void collectFunctionScopeLocalsRec(de.peeeq.wurstscript.luaAst.Element e, List out) { + if (e instanceof LuaExprFunctionAbstraction || e instanceof LuaFunction || e instanceof LuaMethod) { + return; + } + if (e instanceof LuaVariable) { + out.add((LuaVariable) e); + } + e.forEachElement(child -> collectFunctionScopeLocalsRec(child, out)); + } + + private void forEachElementRec(de.peeeq.wurstscript.luaAst.Element root, java.util.function.Consumer action) { + action.accept(root); + root.forEachElement(child -> forEachElementRec(child, action)); + } + void translateStatements(List res, ImStmts stmts) { for (ImStmt s : stmts) { s.translateStmtToLua(res, this); diff --git a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java index 792451582..a454f01fa 100644 --- a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java +++ b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java @@ -797,8 +797,135 @@ public void largeFunctionSpillsLocalsIntoTableInLua() throws IOException { String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_largeFunctionSpillsLocalsIntoTableInLua.lua"), Charsets.UTF_8); assertTrue(compiled.contains("function huge(")); assertTrue(compiled.contains("__wurst_locals")); + assertContainsRegex(compiled, "__wurst_locals\\[[0-9]+\\]"); assertFalse(compiled.contains("local v0")); assertFalse(compiled.contains("local v209")); + assertFalse(compiled.contains("\nsum = ")); + assertFalse(compiled.contains("\nv0 = ")); + assertFalse(compiled.contains("takesInt(sum)")); + } + + @Test + public void inlinerDoesNotForceSpillWhenCallerStaysBelowLimit() throws IOException { + List lines = new ArrayList<>(); + lines.add("package Test"); + lines.add("native takesInt(int i)"); + lines.add("function small(int x) returns int"); + lines.add(" let a = x"); + lines.add(" let b = a + 1"); + lines.add(" let c = b + 1"); + lines.add(" let d = c + 1"); + lines.add(" return d"); + lines.add("function caller()"); + lines.add(" var sum = 0"); + for (int i = 0; i < 195; i++) { + lines.add(" let v" + i + " = " + i); + lines.add(" sum += v" + i); + } + lines.add(" sum += small(1)"); + lines.add(" takesInt(sum)"); + lines.add("init"); + lines.add(" caller()"); + + test().testLua(true).lines(lines.toArray(new String[0])); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_inlinerDoesNotForceSpillWhenCallerStaysBelowLimit.lua"), Charsets.UTF_8); + int start = compiled.indexOf("function caller("); + assertTrue("caller function not found in generated lua output", start >= 0); + int end = compiled.indexOf("\nend", start); + assertTrue("caller function end not found in generated lua output", end > start); + String callerBody = compiled.substring(start, end); + + assertFalse("caller should not spill locals into table in this shape", callerBody.contains("__wurst_locals")); + assertTrue("caller should keep direct call in this shape", callerBody.contains("small(1)")); + } + + @Test + public void spilledLocalsKeepNestedBlockInitializationsInLua() throws IOException { + List lines = new ArrayList<>(); + lines.add("package Test"); + lines.add("native takesInt(int i)"); + lines.add("function hugeNested(boolean b)"); + lines.add(" var sum = 0"); + lines.add(" if b"); + lines.add(" let inside = 7"); + lines.add(" sum += inside"); + for (int i = 0; i < 210; i++) { + lines.add(" let v" + i + " = " + i); + lines.add(" sum += v" + i); + } + lines.add(" takesInt(sum)"); + lines.add("init"); + lines.add(" hugeNested(true)"); + + test().testLua(true).lines(lines.toArray(new String[0])); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_spilledLocalsKeepNestedBlockInitializationsInLua.lua"), Charsets.UTF_8); + int start = compiled.indexOf("function hugeNested("); + assertTrue("hugeNested function not found in generated lua output", start >= 0); + int end = compiled.indexOf("\nend", start); + assertTrue("hugeNested function end not found in generated lua output", end > start); + String body = compiled.substring(start, end); + + assertTrue(body.contains("__wurst_locals")); + assertContainsRegex(body, "__wurst_locals\\[[0-9]+\\]\\s*=\\s*7"); + assertFalse("nested block local declaration should be rewritten", body.contains("local inside")); + } + + @Test + public void spilledLocalsDeclareTableBeforeFirstUseInLua() throws IOException { + List lines = new ArrayList<>(); + lines.add("package Test"); + lines.add("native takesInt(int i)"); + lines.add("function huge()"); + lines.add(" var sum = 0"); + for (int i = 0; i < 210; i++) { + lines.add(" let v" + i + " = " + i); + lines.add(" sum += v" + i); + } + lines.add(" takesInt(sum)"); + lines.add("init"); + lines.add(" huge()"); + + test().testLua(true).lines(lines.toArray(new String[0])); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_spilledLocalsDeclareTableBeforeFirstUseInLua.lua"), Charsets.UTF_8); + int start = compiled.indexOf("function huge("); + assertTrue("huge function not found in generated lua output", start >= 0); + int end = compiled.indexOf("\nend", start); + assertTrue("huge function end not found in generated lua output", end > start); + String body = compiled.substring(start, end); + + int declarationPos = body.indexOf("local __wurst_locals"); + int firstUsePos = body.indexOf("__wurst_locals["); + assertTrue("expected __wurst_locals declaration in spilled function body", declarationPos >= 0); + assertTrue("expected __wurst_locals use in spilled function body", firstUsePos >= 0); + assertTrue("__wurst_locals must be declared before first table access", declarationPos < firstUsePos); + } + + @Test + public void luaInlinerDoesNotInlineFunctionsWithMultipleReturns() throws IOException { + test().testLua(true).lines( + "package Test", + "native takesInt(int i)", + "function choose(boolean b, int x) returns int", + " if b", + " return x + 1", + " return x + 2", + "function caller()", + " let v = choose(true, 40)", + " takesInt(v)", + "init", + " caller()" + ); + + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_luaInlinerDoesNotInlineFunctionsWithMultipleReturns.lua"), Charsets.UTF_8); + int start = compiled.indexOf("function caller("); + assertTrue("caller function not found in generated lua output", start >= 0); + int end = compiled.indexOf("\nend", start); + assertTrue("caller function end not found in generated lua output", end > start); + String callerBody = compiled.substring(start, end); + + assertTrue("caller should keep a direct call to choose for Lua target", callerBody.contains("choose(true, 40)")); + assertFalse("caller should not contain multi-return inline control vars", callerBody.contains("inlineDone")); + assertFalse("caller should not contain multi-return inline return temp vars", callerBody.contains("inlineRet")); } @Test @@ -1046,6 +1173,27 @@ public void nestedForForceUsesRemappedHelpersInLua() throws IOException { assertTrue("expected at least two remapped __wurst_ForForce call sites for nested loops", count >= 2); } + @Test + public void luaInlinerKeepsCallbackFuncRefFunctionsAsCallBoundary() throws IOException { + test().testLua(true).withStdLib().lines( + "package Test", + "player picked", + "function pick(force f) returns player", + " picked = null", + " ForForce(f, () -> begin", + " picked = GetEnumPlayer()", + " end)", + " return picked", + "function caller(force f) returns player", + " return pick(f)", + "init", + " let f = CreateForce()", + " let p = caller(f)" + ); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_luaInlinerKeepsCallbackFuncRefFunctionsAsCallBoundary.lua"), Charsets.UTF_8); + assertContainsRegex(compiled, "function\\s+caller\\s*\\([^\\)]*\\)\\s+return\\s+pick\\([^\\)]*\\)\\s+end"); + } + @Test public void wurstGetEnumPlayerPrefersNativeBeforeOverride() throws IOException { test().testLua(true).withStdLib().lines( diff --git a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java index ca9a7b211..9f121c9dc 100644 --- a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java +++ b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/OptimizerTests.java @@ -97,6 +97,40 @@ public void test_inline_globals() { "endpackage"); } + @Test + public void globalsInlinerDoesNotRemoveNonInitDefaultWrite() { + test().executeProg().lines( + "package test", + " native testSuccess()", + " boolean g = false", + " @noinline function resetG()", + " g = false", + " function setG()", + " g = true", + " init", + " setG()", + " resetG()", + " if not g", + " testSuccess()" + ); + } + + @Test + public void globalsInlinerRespectsInitReadBeforeSingleWriteOrder() { + test().executeProg().lines( + "package test", + " native testSuccess()", + " int g = 0", + " boolean sawDefault = false", + " init", + " if g == 0", + " sawDefault = true", + " g = 5", + " if sawDefault and g == 5", + " testSuccess()" + ); + } + @Test public void test_nullsetter1() {