Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package de.peeeq.wurstscript.translation.lua.translation;

import de.peeeq.datastructures.UnionFind;
import de.peeeq.wurstscript.ast.ClassDef;
import de.peeeq.wurstscript.ast.Element;
import de.peeeq.wurstscript.ast.FuncDef;
import de.peeeq.wurstscript.ast.NameDef;
import de.peeeq.wurstscript.ast.WParameter;
import de.peeeq.wurstscript.ast.WPackage;
import de.peeeq.wurstscript.jassIm.*;
import de.peeeq.wurstscript.luaAst.*;
Expand All @@ -15,6 +17,9 @@
import de.peeeq.wurstscript.utils.Utils;
import org.jetbrains.annotations.NotNull;

import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Stream;
Expand Down Expand Up @@ -88,6 +93,8 @@ public class LuaTranslator {
"leaderboardFromIndex", "multiboardFromIndex", "trackableFromIndex", "lightningFromIndex",
"ubersplatFromIndex", "framehandleFromIndex", "oskeytypeFromIndex"
);
private static final boolean DEBUG_LUA_DISPATCH = "1".equals(System.getenv("WURST_DEBUG_LUA_DISPATCH"))
|| Boolean.getBoolean("wurst.debug.lua.dispatch");

final ImProg prog;
final LuaCompilationUnit luaModel;
Expand Down Expand Up @@ -432,10 +439,20 @@ private void normalizeMethodNames() {
}
}

// give all related methods the same name
for (Map.Entry<ImMethod, Set<ImMethod>> entry : methodUnions.groups().entrySet()) {
String name = uniqueName(entry.getKey().getName());
for (ImMethod method : entry.getValue()) {
// give all related methods the same name in deterministic order
List<List<ImMethod>> groups = new ArrayList<>();
for (Set<ImMethod> group : methodUnions.groups().values()) {
List<ImMethod> sortedGroup = new ArrayList<>(group);
sortedGroup.sort(Comparator.comparing(this::methodSortKey));
groups.add(sortedGroup);
}
groups.sort(Comparator.comparing(g -> g.isEmpty() ? "" : methodSortKey(g.get(0))));
for (List<ImMethod> group : groups) {
if (group.isEmpty()) {
continue;
}
String name = uniqueName(group.get(0).getName());
for (ImMethod method : group) {
method.setName(name);
}
}
Expand Down Expand Up @@ -907,8 +924,7 @@ private void createClassInitFunction(ImClass c, LuaVariable classVar, LuaMethod
private void initClassTables(ImClass c) {
LuaVariable classVar = luaClassVar.getFor(c);
// create methods:
Set<String> methods = new HashSet<>();
createMethods(c, classVar, methods);
createMethods(c, classVar);

// set supertype metadata:
LuaTableFields superClasses = LuaAst.LuaTableFields();
Expand All @@ -929,25 +945,270 @@ private void initClassTables(ImClass c) {

}

private void createMethods(ImClass c, LuaVariable classVar, Set<String> methods) {
for (ImMethod method : c.getMethods()) {
if (methods.contains(method.getName())) {
private void createMethods(ImClass c, LuaVariable classVar) {
List<ImMethod> allMethods = collectMethodsInHierarchy(c);
Set<ImMethod> inHierarchy = new HashSet<>(allMethods);
UnionFind<ImMethod> unions = new UnionFind<>();
for (ImMethod method : allMethods) {
unions.find(method);
for (ImMethod subMethod : method.getSubMethods()) {
if (inHierarchy.contains(subMethod)) {
unions.union(method, subMethod);
}
}
}

Map<ImMethod, List<ImMethod>> groupedMethods = new HashMap<>();
for (ImMethod method : allMethods) {
ImMethod root = unions.find(method);
groupedMethods.computeIfAbsent(root, k -> new ArrayList<>()).add(method);
}

List<List<ImMethod>> groups = new ArrayList<>(groupedMethods.values());
groups.sort(Comparator.comparing(group -> group.isEmpty() ? "" : methodSortKey(group.get(0))));
Map<String, ImMethod> slotToImpl = new TreeMap<>();
for (List<ImMethod> groupMethods : groups) {
if (groupMethods == null || groupMethods.isEmpty()) {
continue;
}
methods.add(method.getName());
if (method.getIsAbstract()) {
groupMethods.sort(Comparator.comparing(this::methodSortKey));
ImMethod chosen = chooseBestImplementationForClass(c, groupMethods);
if (chosen == null || chosen.getIsAbstract() || chosen.getImplementation() == null) {
continue;
}
Set<String> slotNames = collectDispatchSlotNames(c, groupMethods);
for (String slotName : slotNames) {
ImMethod current = slotToImpl.get(slotName);
if (current == null || compareDispatchCandidates(c, chosen, current) < 0) {
slotToImpl.put(slotName, chosen);
}
}
String debugKey = groupMethods.get(0).getName();
debugDispatchGroup(c, debugKey, slotNames, groupMethods, chosen);
}
for (Map.Entry<String, ImMethod> e : slotToImpl.entrySet()) {
ImMethod impl = e.getValue();
if (impl == null || impl.getImplementation() == null) {
continue;
}
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
LuaAst.LuaExprVarAccess(classVar),
method.getName()),
LuaAst.LuaExprFuncRef(luaFunc.getFor(method.getImplementation()))
e.getKey()),
LuaAst.LuaExprFuncRef(luaFunc.getFor(impl.getImplementation()))
));
}
// also create links for inherited methods
}

private Set<String> collectDispatchSlotNames(ImClass receiverClass, List<ImMethod> groupMethods) {
Set<String> slotNames = new TreeSet<>();
Set<String> semanticNames = new TreeSet<>();
for (ImMethod m : groupMethods) {
if (m == null) {
continue;
}
String methodName = m.getName();
if (!methodName.isEmpty()) {
slotNames.add(methodName);
}
ImClass owner = m.attrClass();
String semanticName = semanticNameFromMethodName(methodName);
if (!semanticName.isEmpty()) {
semanticNames.add(semanticName);
}
if (owner != null && !semanticName.isEmpty()) {
slotNames.add(owner.getName() + "_" + semanticName);
}
}
if (receiverClass != null && !semanticNames.isEmpty()) {
Set<String> classNames = new TreeSet<>();
collectClassNamesInHierarchy(receiverClass, classNames, new HashSet<>());
for (String className : classNames) {
for (String semanticName : semanticNames) {
slotNames.add(className + "_" + semanticName);
}
}
}
return slotNames;
}

private void collectClassNamesInHierarchy(ImClass c, Set<String> out, Set<ImClass> visited) {
if (c == null || !visited.add(c)) {
return;
}
out.add(c.getName());
for (ImClassType sc : c.getSuperClasses()) {
createMethods(sc.getClassDef(), classVar, methods);
collectClassNamesInHierarchy(sc.getClassDef(), out, visited);
}
}

private List<ImMethod> collectMethodsInHierarchy(ImClass c) {
List<ImMethod> result = new ArrayList<>();
collectMethodsInHierarchy(c, result, new HashSet<>());
result.sort(Comparator.comparing(this::methodSortKey));
return result;
}

private void collectMethodsInHierarchy(ImClass c, List<ImMethod> out, Set<ImClass> visited) {
if (c == null || !visited.add(c)) {
return;
}
out.addAll(c.getMethods());
List<ImClassType> superClasses = new ArrayList<>(c.getSuperClasses());
superClasses.sort(Comparator.comparing(t -> classSortKey(t.getClassDef())));
for (ImClassType sc : superClasses) {
collectMethodsInHierarchy(sc.getClassDef(), out, visited);
}
}

private ImMethod chooseBestImplementationForClass(ImClass receiverClass, List<ImMethod> candidates) {
List<ImMethod> concrete = new ArrayList<>();
for (ImMethod m : candidates) {
if (!m.getIsAbstract() && m.getImplementation() != null) {
concrete.add(m);
}
}
if (concrete.isEmpty()) {
return null;
}
concrete.sort((a, b) -> compareDispatchCandidates(receiverClass, a, b));
return concrete.get(0);
}

private int compareDispatchCandidates(ImClass receiverClass, ImMethod a, ImMethod b) {
boolean aLocal = isImplementationFromClass(a, receiverClass);
boolean bLocal = isImplementationFromClass(b, receiverClass);
if (aLocal != bLocal) {
return aLocal ? -1 : 1;
}
int aDist = classDistance(receiverClass, a.attrClass());
int bDist = classDistance(receiverClass, b.attrClass());
if (aDist != bDist) {
return Integer.compare(aDist, bDist);
}
boolean aNoOp = isNoOpImplementation(a);
boolean bNoOp = isNoOpImplementation(b);
if (aNoOp != bNoOp) {
return aNoOp ? 1 : -1;
}
return methodSortKey(a).compareTo(methodSortKey(b));
}

private boolean isImplementationFromClass(ImMethod method, ImClass ownerClass) {
if (method == null || ownerClass == null || method.getImplementation() == null) {
return false;
}
return method.getImplementation().getName().startsWith(ownerClass.getName() + "_");
}

private boolean isNoOpImplementation(ImMethod method) {
return method != null
&& method.getImplementation() != null
&& method.getImplementation().getName().contains("NoOpState_");
}

private int classDistance(ImClass from, ImClass to) {
if (from == null || to == null) {
return Integer.MAX_VALUE;
}
if (from == to) {
return 0;
}
ArrayDeque<ImClass> queue = new ArrayDeque<>();
Map<ImClass, Integer> dist = new HashMap<>();
queue.add(from);
dist.put(from, 0);
while (!queue.isEmpty()) {
ImClass current = queue.removeFirst();
int currentDist = dist.get(current);
List<ImClassType> superClasses = new ArrayList<>(current.getSuperClasses());
superClasses.sort(Comparator.comparing(t -> classSortKey(t.getClassDef())));
for (ImClassType sc : superClasses) {
ImClass next = sc.getClassDef();
if (next == null || dist.containsKey(next)) {
continue;
}
int nextDist = currentDist + 1;
if (next == to) {
return nextDist;
}
dist.put(next, nextDist);
queue.add(next);
}
}
return Integer.MAX_VALUE;
}

private void debugDispatchGroup(ImClass receiverClass, String key, Set<String> slotNames, List<ImMethod> groupMethods, ImMethod chosen) {
if (!DEBUG_LUA_DISPATCH && !isSuspiciousGroup(slotNames, groupMethods, chosen)) {
return;
}
String chosenImpl = chosen != null && chosen.getImplementation() != null ? chosen.getImplementation().getName() : "null";
StringBuilder candidates = new StringBuilder();
List<ImMethod> sorted = new ArrayList<>(groupMethods);
sorted.sort(Comparator.comparing(this::methodSortKey));
for (ImMethod m : sorted) {
String impl = m.getImplementation() != null ? m.getImplementation().getName() : "null";
if (candidates.length() > 0) {
candidates.append("; ");
}
candidates.append(m.getName()).append("->").append(impl).append("@").append(classSortKey(m.attrClass()));
}
System.err.println("[LuaDispatch] class=" + classSortKey(receiverClass)
+ " key=" + key
+ " slots=" + slotNames
+ " chosen=" + chosenImpl
+ " candidates=[" + candidates + "]");
if (DEBUG_LUA_DISPATCH) {
String line = "[LuaDispatch] class=" + classSortKey(receiverClass)
+ " key=" + key
+ " slots=" + slotNames
+ " chosen=" + chosenImpl
+ " candidates=[" + candidates + "]"
+ System.lineSeparator();
try {
Files.writeString(Path.of("C:/Users/Frotty/Documents/GitHub/WurstScript/lua-dispatch-debug.log"),
line, StandardOpenOption.CREATE, StandardOpenOption.APPEND);
} catch (Exception ignored) {
}
}
}

private boolean isSuspiciousGroup(Set<String> slotNames, List<ImMethod> groupMethods, ImMethod chosen) {
if (slotNames.size() > 1) {
return true;
}
boolean hasNonNoOp = false;
for (ImMethod m : groupMethods) {
if (!isNoOpImplementation(m) && m.getImplementation() != null) {
hasNonNoOp = true;
break;
}
}
return hasNonNoOp && isNoOpImplementation(chosen);
}

private String methodSortKey(ImMethod m) {
String owner = classSortKey(m.attrClass());
String impl = m.getImplementation() != null ? m.getImplementation().getName() : "";
return owner + "|" + m.getName() + "|" + impl;
}

private String semanticNameFromMethodName(String methodName) {
if (methodName == null || methodName.isEmpty()) {
return "";
}
int lastUnderscore = methodName.lastIndexOf('_');
if (lastUnderscore >= 0 && lastUnderscore + 1 < methodName.length()) {
return methodName.substring(lastUnderscore + 1);
}
return methodName;
}

private String classSortKey(ImClass c) {
if (c == null) {
return "";
}
return c.getName();
}

@NotNull
Expand Down
Loading
Loading