From 32f737c5a0bc84334565c33cdab713753ff40376 Mon Sep 17 00:00:00 2001 From: Frotty Date: Thu, 12 Mar 2026 17:16:04 +0100 Subject: [PATCH] finally fixed lua dispatch --- .../lua/translation/LuaTranslator.java | 291 +++++++++++++++++- .../tests/LuaTranslationTests.java | 248 +++++++++++++++ 2 files changed, 524 insertions(+), 15 deletions(-) diff --git a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java index 26626d111..c833d7a8a 100644 --- a/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java +++ b/de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/lua/translation/LuaTranslator.java @@ -1,9 +1,11 @@ package de.peeeq.wurstscript.translation.lua.translation; import de.peeeq.datastructures.UnionFind; +import de.peeeq.wurstscript.ast.ClassDef; import de.peeeq.wurstscript.ast.Element; import de.peeeq.wurstscript.ast.FuncDef; import de.peeeq.wurstscript.ast.NameDef; +import de.peeeq.wurstscript.ast.WParameter; import de.peeeq.wurstscript.ast.WPackage; import de.peeeq.wurstscript.jassIm.*; import de.peeeq.wurstscript.luaAst.*; @@ -15,6 +17,9 @@ import de.peeeq.wurstscript.utils.Utils; import org.jetbrains.annotations.NotNull; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; import java.util.*; import java.util.regex.Pattern; import java.util.stream.Stream; @@ -88,6 +93,8 @@ public class LuaTranslator { "leaderboardFromIndex", "multiboardFromIndex", "trackableFromIndex", "lightningFromIndex", "ubersplatFromIndex", "framehandleFromIndex", "oskeytypeFromIndex" ); + private static final boolean DEBUG_LUA_DISPATCH = "1".equals(System.getenv("WURST_DEBUG_LUA_DISPATCH")) + || Boolean.getBoolean("wurst.debug.lua.dispatch"); final ImProg prog; final LuaCompilationUnit luaModel; @@ -432,10 +439,20 @@ private void normalizeMethodNames() { } } - // give all related methods the same name - for (Map.Entry> entry : methodUnions.groups().entrySet()) { - String name = uniqueName(entry.getKey().getName()); - for (ImMethod method : entry.getValue()) { + // give all related methods the same name in deterministic order + List> groups = new ArrayList<>(); + for (Set group : methodUnions.groups().values()) { + List sortedGroup = new ArrayList<>(group); + sortedGroup.sort(Comparator.comparing(this::methodSortKey)); + groups.add(sortedGroup); + } + groups.sort(Comparator.comparing(g -> g.isEmpty() ? "" : methodSortKey(g.get(0)))); + for (List group : groups) { + if (group.isEmpty()) { + continue; + } + String name = uniqueName(group.get(0).getName()); + for (ImMethod method : group) { method.setName(name); } } @@ -907,8 +924,7 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod private void initClassTables(ImClass c) { LuaVariable classVar = luaClassVar.getFor(c); // create methods: - Set methods = new HashSet<>(); - createMethods(c, classVar, methods); + createMethods(c, classVar); // set supertype metadata: LuaTableFields superClasses = LuaAst.LuaTableFields(); @@ -929,25 +945,270 @@ private void initClassTables(ImClass c) { } - private void createMethods(ImClass c, LuaVariable classVar, Set methods) { - for (ImMethod method : c.getMethods()) { - if (methods.contains(method.getName())) { + private void createMethods(ImClass c, LuaVariable classVar) { + List allMethods = collectMethodsInHierarchy(c); + Set inHierarchy = new HashSet<>(allMethods); + UnionFind unions = new UnionFind<>(); + for (ImMethod method : allMethods) { + unions.find(method); + for (ImMethod subMethod : method.getSubMethods()) { + if (inHierarchy.contains(subMethod)) { + unions.union(method, subMethod); + } + } + } + + Map> groupedMethods = new HashMap<>(); + for (ImMethod method : allMethods) { + ImMethod root = unions.find(method); + groupedMethods.computeIfAbsent(root, k -> new ArrayList<>()).add(method); + } + + List> groups = new ArrayList<>(groupedMethods.values()); + groups.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0)))); + Map slotToImpl = new TreeMap<>(); + for (List groupMethods : groups) { + if (groupMethods == null || groupMethods.isEmpty()) { continue; } - methods.add(method.getName()); - if (method.getIsAbstract()) { + groupMethods.sort(Comparator.comparing(this::methodSortKey)); + ImMethod chosen = chooseBestImplementationForClass(c, groupMethods); + if (chosen == null || chosen.getIsAbstract() || chosen.getImplementation() == null) { + continue; + } + Set slotNames = collectDispatchSlotNames(c, groupMethods); + for (String slotName : slotNames) { + ImMethod current = slotToImpl.get(slotName); + if (current == null || compareDispatchCandidates(c, chosen, current) < 0) { + slotToImpl.put(slotName, chosen); + } + } + String debugKey = groupMethods.get(0).getName(); + debugDispatchGroup(c, debugKey, slotNames, groupMethods, chosen); + } + for (Map.Entry e : slotToImpl.entrySet()) { + ImMethod impl = e.getValue(); + if (impl == null || impl.getImplementation() == null) { continue; } luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess( LuaAst.LuaExprVarAccess(classVar), - method.getName()), - LuaAst.LuaExprFuncRef(luaFunc.getFor(method.getImplementation())) + e.getKey()), + LuaAst.LuaExprFuncRef(luaFunc.getFor(impl.getImplementation())) )); } - // also create links for inherited methods + } + + private Set collectDispatchSlotNames(ImClass receiverClass, List groupMethods) { + Set slotNames = new TreeSet<>(); + Set semanticNames = new TreeSet<>(); + for (ImMethod m : groupMethods) { + if (m == null) { + continue; + } + String methodName = m.getName(); + if (!methodName.isEmpty()) { + slotNames.add(methodName); + } + ImClass owner = m.attrClass(); + String semanticName = semanticNameFromMethodName(methodName); + if (!semanticName.isEmpty()) { + semanticNames.add(semanticName); + } + if (owner != null && !semanticName.isEmpty()) { + slotNames.add(owner.getName() + "_" + semanticName); + } + } + if (receiverClass != null && !semanticNames.isEmpty()) { + Set classNames = new TreeSet<>(); + collectClassNamesInHierarchy(receiverClass, classNames, new HashSet<>()); + for (String className : classNames) { + for (String semanticName : semanticNames) { + slotNames.add(className + "_" + semanticName); + } + } + } + return slotNames; + } + + private void collectClassNamesInHierarchy(ImClass c, Set out, Set visited) { + if (c == null || !visited.add(c)) { + return; + } + out.add(c.getName()); for (ImClassType sc : c.getSuperClasses()) { - createMethods(sc.getClassDef(), classVar, methods); + collectClassNamesInHierarchy(sc.getClassDef(), out, visited); + } + } + + private List collectMethodsInHierarchy(ImClass c) { + List result = new ArrayList<>(); + collectMethodsInHierarchy(c, result, new HashSet<>()); + result.sort(Comparator.comparing(this::methodSortKey)); + return result; + } + + private void collectMethodsInHierarchy(ImClass c, List out, Set visited) { + if (c == null || !visited.add(c)) { + return; + } + out.addAll(c.getMethods()); + List superClasses = new ArrayList<>(c.getSuperClasses()); + superClasses.sort(Comparator.comparing(t -> classSortKey(t.getClassDef()))); + for (ImClassType sc : superClasses) { + collectMethodsInHierarchy(sc.getClassDef(), out, visited); + } + } + + private ImMethod chooseBestImplementationForClass(ImClass receiverClass, List candidates) { + List concrete = new ArrayList<>(); + for (ImMethod m : candidates) { + if (!m.getIsAbstract() && m.getImplementation() != null) { + concrete.add(m); + } + } + if (concrete.isEmpty()) { + return null; + } + concrete.sort((a, b) -> compareDispatchCandidates(receiverClass, a, b)); + return concrete.get(0); + } + + private int compareDispatchCandidates(ImClass receiverClass, ImMethod a, ImMethod b) { + boolean aLocal = isImplementationFromClass(a, receiverClass); + boolean bLocal = isImplementationFromClass(b, receiverClass); + if (aLocal != bLocal) { + return aLocal ? -1 : 1; + } + int aDist = classDistance(receiverClass, a.attrClass()); + int bDist = classDistance(receiverClass, b.attrClass()); + if (aDist != bDist) { + return Integer.compare(aDist, bDist); + } + boolean aNoOp = isNoOpImplementation(a); + boolean bNoOp = isNoOpImplementation(b); + if (aNoOp != bNoOp) { + return aNoOp ? 1 : -1; + } + return methodSortKey(a).compareTo(methodSortKey(b)); + } + + private boolean isImplementationFromClass(ImMethod method, ImClass ownerClass) { + if (method == null || ownerClass == null || method.getImplementation() == null) { + return false; + } + return method.getImplementation().getName().startsWith(ownerClass.getName() + "_"); + } + + private boolean isNoOpImplementation(ImMethod method) { + return method != null + && method.getImplementation() != null + && method.getImplementation().getName().contains("NoOpState_"); + } + + private int classDistance(ImClass from, ImClass to) { + if (from == null || to == null) { + return Integer.MAX_VALUE; + } + if (from == to) { + return 0; + } + ArrayDeque queue = new ArrayDeque<>(); + Map dist = new HashMap<>(); + queue.add(from); + dist.put(from, 0); + while (!queue.isEmpty()) { + ImClass current = queue.removeFirst(); + int currentDist = dist.get(current); + List superClasses = new ArrayList<>(current.getSuperClasses()); + superClasses.sort(Comparator.comparing(t -> classSortKey(t.getClassDef()))); + for (ImClassType sc : superClasses) { + ImClass next = sc.getClassDef(); + if (next == null || dist.containsKey(next)) { + continue; + } + int nextDist = currentDist + 1; + if (next == to) { + return nextDist; + } + dist.put(next, nextDist); + queue.add(next); + } + } + return Integer.MAX_VALUE; + } + + private void debugDispatchGroup(ImClass receiverClass, String key, Set slotNames, List groupMethods, ImMethod chosen) { + if (!DEBUG_LUA_DISPATCH && !isSuspiciousGroup(slotNames, groupMethods, chosen)) { + return; + } + String chosenImpl = chosen != null && chosen.getImplementation() != null ? chosen.getImplementation().getName() : "null"; + StringBuilder candidates = new StringBuilder(); + List sorted = new ArrayList<>(groupMethods); + sorted.sort(Comparator.comparing(this::methodSortKey)); + for (ImMethod m : sorted) { + String impl = m.getImplementation() != null ? m.getImplementation().getName() : "null"; + if (candidates.length() > 0) { + candidates.append("; "); + } + candidates.append(m.getName()).append("->").append(impl).append("@").append(classSortKey(m.attrClass())); + } + System.err.println("[LuaDispatch] class=" + classSortKey(receiverClass) + + " key=" + key + + " slots=" + slotNames + + " chosen=" + chosenImpl + + " candidates=[" + candidates + "]"); + if (DEBUG_LUA_DISPATCH) { + String line = "[LuaDispatch] class=" + classSortKey(receiverClass) + + " key=" + key + + " slots=" + slotNames + + " chosen=" + chosenImpl + + " candidates=[" + candidates + "]" + + System.lineSeparator(); + try { + Files.writeString(Path.of("C:/Users/Frotty/Documents/GitHub/WurstScript/lua-dispatch-debug.log"), + line, StandardOpenOption.CREATE, StandardOpenOption.APPEND); + } catch (Exception ignored) { + } + } + } + + private boolean isSuspiciousGroup(Set slotNames, List groupMethods, ImMethod chosen) { + if (slotNames.size() > 1) { + return true; + } + boolean hasNonNoOp = false; + for (ImMethod m : groupMethods) { + if (!isNoOpImplementation(m) && m.getImplementation() != null) { + hasNonNoOp = true; + break; + } + } + return hasNonNoOp && isNoOpImplementation(chosen); + } + + private String methodSortKey(ImMethod m) { + String owner = classSortKey(m.attrClass()); + String impl = m.getImplementation() != null ? m.getImplementation().getName() : ""; + return owner + "|" + m.getName() + "|" + impl; + } + + private String semanticNameFromMethodName(String methodName) { + if (methodName == null || methodName.isEmpty()) { + return ""; + } + int lastUnderscore = methodName.lastIndexOf('_'); + if (lastUnderscore >= 0 && lastUnderscore + 1 < methodName.length()) { + return methodName.substring(lastUnderscore + 1); + } + return methodName; + } + + private String classSortKey(ImClass c) { + if (c == null) { + return ""; } + return c.getName(); } @NotNull diff --git a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java index f58f32afb..792451582 100644 --- a/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java +++ b/de.peeeq.wurstscript/src/test/java/tests/wurstscript/tests/LuaTranslationTests.java @@ -329,6 +329,26 @@ public void methodFieldNameCollision() throws IOException { assertFunctionBodyContains(compiled, "Foo_Foo_size", "return this.Foo_size\n", false); } + @Test + public void overloadedMethodsDoNotAliasInLuaDispatchTables() throws IOException { + test().testLua(true).lines( + "package Test", + "class Writer", + " function write(int i)", + " skip", + " function write(string s)", + " skip", + "init", + " let w = new Writer()", + " w.write(1)", + " w.write(\"x\")" + ); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_overloadedMethodsDoNotAliasInLuaDispatchTables.lua"), Charsets.UTF_8); + assertTrue(compiled.contains("Writer.Writer_write = Writer_Writer_write")); + assertTrue(compiled.contains("Writer.Writer_write1 = Writer_Writer_write1")); + assertFalse(compiled.contains("Writer.Writer_write = Writer_Writer_write1")); + } + @Test public void mainAndConfigNamesFixed() throws IOException { test().testLua(true).lines( @@ -530,6 +550,234 @@ public void newGenericsStringFieldAssignmentRoundTripsInLua() throws IOException assertFunctionBodyContains(compiled, "testGenericStringField", "__wurst_stringFromIndex", false); } + @Test + public void genericOverrideChainBindsRootSlotToMostSpecificImplInLua() throws IOException { + test().testLua(true).compilationUnits(genericOverrideReproUnits()); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_genericOverrideChainBindsRootSlotToMostSpecificImplInLua.lua"), Charsets.UTF_8); + + Matcher slotMatcher = Pattern.compile("FSM_currentState:([A-Za-z0-9_]*_update)\\(").matcher(compiled); + assertTrue("Expected FSM to dispatch through a virtual *_update slot.", slotMatcher.find()); + String dispatchedSlot = slotMatcher.group(1); + + String[] states = {"FindBuilder", "PlanNextAction", "FindSpot", "BuildAtTarget", "QuickBuild", "RescueStrikeTarget"}; + for (String state : states) { + assertContainsRegex(compiled, state + "\\." + dispatchedSlot + "\\s*=\\s*" + state + "_" + state + "_update"); + assertDoesNotContainRegex(compiled, state + "\\." + dispatchedSlot + "\\s*=\\s*NoOpState_NoOpState_update"); + } + } + + @Test + public void luaOutputIsDeterministicForGenericOverrideSlots() throws IOException { + test().testLua(true).compilationUnits(genericOverrideReproUnits()); + String first = Files.toString(new File("test-output/lua/LuaTranslationTests_luaOutputIsDeterministicForGenericOverrideSlots.lua"), Charsets.UTF_8); + + GlobalCaches.clearAll(); + + test().testLua(true).compilationUnits(genericOverrideReproUnits()); + String second = Files.toString(new File("test-output/lua/LuaTranslationTests_luaOutputIsDeterministicForGenericOverrideSlots.lua"), Charsets.UTF_8); + + assertEquals(first, second); + } + + @Test + public void genericOverrideChainBindsGlobalStateSlotToMostSpecificImplInLua() throws IOException { + test().testLua(true).compilationUnits(genericOverrideGlobalStateReproUnits()); + String compiled = Files.toString(new File("test-output/lua/LuaTranslationTests_genericOverrideChainBindsGlobalStateSlotToMostSpecificImplInLua.lua"), Charsets.UTF_8); + + Matcher slotMatcher = Pattern.compile("FSM_globalState:([A-Za-z0-9_]*_update)\\(").matcher(compiled); + assertTrue("Expected FSM global state to dispatch through a virtual *_update slot.", slotMatcher.find()); + String dispatchedSlot = slotMatcher.group(1); + + assertContainsRegex(compiled, "GlobalCheckState\\." + dispatchedSlot + "\\s*=\\s*GlobalCheckState_GlobalCheckState_update"); + assertDoesNotContainRegex(compiled, "GlobalCheckState\\." + dispatchedSlot + "\\s*=\\s*NoOpState_NoOpState_update"); + } + + private CU[] genericOverrideReproUnits() { + return new CU[]{ + compilationUnit("fsmLib.wurst", + "package FSMReproLib", + "", + "public abstract class State", + " function enter(T owner)", + " function update(T owner, real dt)", + " function exit(T owner)", + "", + "public class NoOpState extends State", + " override function enter(T owner)", + " override function update(T owner, real dt)", + " override function exit(T owner)", + "", + "public class FSM", + " T owner", + " State currentState = null", + "", + " construct(T owner)", + " this.owner = owner", + "", + " function setInitialState(State st)", + " currentState = st", + " if currentState != null", + " currentState.enter(owner)", + "", + " function update(real dt)", + " if currentState != null", + " currentState.update(owner, dt)" + ), + compilationUnit("repro.wurst", + "package GenericOverrideSlotRepro", + "import FSMReproLib", + "", + "public class Owner", + " FSM fsm = new FSM(this)", + " int findBuilderTicks = 0", + " int planTicks = 0", + " int findSpotTicks = 0", + " int buildTicks = 0", + "", + "public constant globalCheckState = new GlobalCheckState()", + "public constant findBuilderState = new FindBuilder()", + "public constant planNextActionState = new PlanNextAction()", + "public constant findSpotState = new FindSpot()", + "public constant buildAtTargetState = new BuildAtTarget()", + "public constant quickBuildState = new QuickBuild()", + "public constant rescueStrikeTargetState = new RescueStrikeTarget()", + "", + "class GlobalCheckState extends NoOpState", + " override function update(Owner o, real dt)", + "", + "class FindBuilder extends NoOpState", + " override function enter(Owner o)", + " o.findBuilderTicks = 0", + " override function update(Owner o, real dt)", + " o.findBuilderTicks++", + "", + "class PlanNextAction extends NoOpState", + " override function update(Owner o, real dt)", + " o.planTicks++", + "", + "class FindSpot extends NoOpState", + " override function update(Owner o, real dt)", + " o.findSpotTicks++", + "", + "class BuildAtTarget extends NoOpState", + " override function update(Owner o, real dt)", + " o.buildTicks++", + + "class QuickBuild extends NoOpState", + " override function update(Owner o, real dt)", + " o.buildTicks++", + "", + "class RescueStrikeTarget extends NoOpState", + " override function update(Owner o, real dt)", + " o.buildTicks++", + "", + "function runOne(State st) returns int", + " let o = new Owner()", + " o.fsm.setInitialState(st)", + " for i = 0 to 4", + " o.fsm.update(0.1)", + " return o.findBuilderTicks", + "", + "init", + " runOne(findBuilderState)", + " runOne(planNextActionState)", + " runOne(findSpotState)", + " runOne(buildAtTargetState)", + " runOne(quickBuildState)", + " runOne(rescueStrikeTargetState)" + ) + }; + } + + private CU[] genericOverrideGlobalStateReproUnits() { + return new CU[]{ + compilationUnit("fsmLib.wurst", + "package FSMReproLibGlobal", + "", + "public abstract class State", + " function enter(T owner)", + " function update(T owner, real dt)", + " function exit(T owner)", + "", + "public class NoOpState extends State", + " override function enter(T owner)", + " override function update(T owner, real dt)", + " override function exit(T owner)", + "", + "public class FSM", + " T owner", + " State globalState = null", + " State currentState = null", + "", + " construct(T owner)", + " this.owner = owner", + "", + " function setInitialState(State st)", + " currentState = st", + " if currentState != null", + " currentState.enter(owner)", + "", + " function setGlobalState(State st)", + " globalState = st", + "", + " function update(real dt)", + " if globalState != null", + " globalState.update(owner, dt)", + " if currentState != null", + " currentState.update(owner, dt)" + ), + compilationUnit("repro.wurst", + "package GenericOverrideGlobalStateSlotRepro", + "import FSMReproLibGlobal", + "", + "public class Owner", + " FSM fsm = new FSM(this)", + " int globalTicks = 0", + " int findBuilderTicks = 0", + " int planTicks = 0", + " int findSpotTicks = 0", + " int buildTicks = 0", + "", + "public constant globalCheckState = new GlobalCheckState()", + "public constant findBuilderState = new FindBuilder()", + "public constant planNextActionState = new PlanNextAction()", + "public constant findSpotState = new FindSpot()", + "public constant buildAtTargetState = new BuildAtTarget()", + "public constant quickBuildState = new QuickBuild()", + "", + "class GlobalCheckState extends NoOpState", + " override function update(Owner o, real dt)", + " o.globalTicks++", + "", + "class FindBuilder extends NoOpState", + " override function update(Owner o, real dt)", + " o.findBuilderTicks++", + "", + "class PlanNextAction extends NoOpState", + " override function update(Owner o, real dt)", + " o.planTicks++", + "", + "class FindSpot extends NoOpState", + " override function update(Owner o, real dt)", + " o.findSpotTicks++", + "", + "class BuildAtTarget extends NoOpState", + " override function update(Owner o, real dt)", + " o.buildTicks++", + "", + "class QuickBuild extends NoOpState", + " override function update(Owner o, real dt)", + " o.buildTicks++", + "", + "init", + " let o = new Owner()", + " o.fsm.setGlobalState(globalCheckState)", + " o.fsm.setInitialState(buildAtTargetState)", + " o.fsm.update(0.1)" + ) + }; + } + @Test public void largeFunctionSpillsLocalsIntoTableInLua() throws IOException { List lines = new ArrayList<>();