Skip to content

Commit de7afc3

Browse files
authored
fix lua inlining by avoiding breaking function references (#1168)
1 parent 1d240ee commit de7afc3

File tree

6 files changed

+436
-40
lines changed

6 files changed

+436
-40
lines changed

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/RunArgs.java

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,6 @@ public RunArgs(String... args) {
132132

133133
addOptionWithArg("functionSplitLimit", "The maximum number of operations in a function before it is split by the function splitter (used for compiletime functions)",
134134
s -> functionSplitLimit = Integer.parseInt(s, 10));
135-
136135
optionPrettyPrint = addOption("prettyPrint", "Pretty print the input file, or all sub-directory if the given path is: '...'");
137136

138137
nextArg:
@@ -149,6 +148,20 @@ public RunArgs(String... args) {
149148
o.isSet = true;
150149
continue nextArg;
151150
} else if ((o.argHandler != null && isDoubleArg(a, o))) {
151+
String value = a.substring(a.indexOf(" ") + 1).trim();
152+
if (value.isEmpty()) {
153+
throw new RuntimeException("Missing value for option: -" + o.name);
154+
}
155+
o.argHandler.accept(value);
156+
o.isSet = true;
157+
continue nextArg;
158+
} else if (o.argHandler != null && isEqualsArg(a, o)) {
159+
String value = a.substring(a.indexOf("=") + 1).trim();
160+
if (value.isEmpty()) {
161+
throw new RuntimeException("Missing value for option: -" + o.name);
162+
}
163+
o.argHandler.accept(value);
164+
o.isSet = true;
152165
continue nextArg;
153166
}
154167
}
@@ -170,6 +183,10 @@ private boolean isDoubleArg(String arg, RunOption option) {
170183
return (arg.contains(" ") && ("-" + option.name).equals(arg.substring(0, arg.indexOf(" "))));
171184
}
172185

186+
private boolean isEqualsArg(String arg, RunOption option) {
187+
return arg.startsWith("-" + option.name + "=");
188+
}
189+
173190
private RunOption addOption(String name, String descr) {
174191
RunOption opt = new RunOption(name, descr);
175192
options.add(opt);

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/GlobalsInliner.java

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import java.util.stream.Collectors;
1717

1818
public class GlobalsInliner implements OptimizerPass {
19-
2019
public int optimize(ImTranslator trans) {
2120
int obsoleteCount = 0;
2221
ImProg prog = trans.getImProg();
@@ -44,7 +43,7 @@ public int optimize(ImTranslator trans) {
4443
ImVarWrite obs = null;
4544
for (ImVarWrite write : v.attrWrites()) {
4645
ImFunction func = write.getNearestFunc();
47-
if (isInInit(func)) {
46+
if (isInInitGlobals(func)) {
4847
right = write.getRight();
4948
obs = write;
5049
break;
@@ -67,19 +66,21 @@ public int optimize(ImTranslator trans) {
6766
List<ImVarWrite> initWrites = new ArrayList<>();
6867
for (ImVarWrite imVarWrite : v.attrWrites()) {
6968
ImFunction nearestFunc = imVarWrite.getNearestFunc();
70-
if (isInInit(nearestFunc)) {
69+
if (isInInitGlobals(nearestFunc)) {
7170
initWrites.add(imVarWrite);
7271
}
7372
}
7473
if (initWrites.size() == 1) {
7574
if(v.getType() instanceof ImSimpleType) {
76-
ImExpr write = v.attrWrites().iterator().next().getRight();
75+
ImVarWrite initWrite = initWrites.get(0);
76+
ImExpr write = initWrite.getRight();
7777
try {
7878
ImExpr defaultValue = ImHelper.defaultValueForType((ImSimpleType) v.getType());
7979
boolean isDefault = defaultValue.structuralEquals(write);
8080
if (isDefault) {
81-
// Assignment is default value and can be removed
82-
v.attrWrites().iterator().next().replaceBy(ImHelper.nullExpr());
81+
// Only remove the init write when it assigns the default value.
82+
// Never touch non-init writes here.
83+
initWrite.replaceBy(ImHelper.nullExpr());
8384
}
8485
} catch (Exception e) {
8586
throw new CompileError(write.attrTrace().attrErrorPos(),
@@ -140,9 +141,8 @@ public String getName() {
140141
}
141142

142143

143-
private static boolean isInInit(ImFunction func) {
144-
return func != null && (func.getName().startsWith("init_") || func.getName().equals("main") || func.getName().startsWith("InitTrig_")
145-
|| func.getName().equals("initGlobals"));
144+
private static boolean isInInitGlobals(ImFunction func) {
145+
return func != null && func.getName().equals("initGlobals");
146146
}
147147

148148
}

de.peeeq.wurstscript/src/main/java/de/peeeq/wurstscript/translation/imoptimizer/ImInliner.java

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ public class ImInliner {
2828
private final Map<ImFunction, Integer> callCounts = Maps.newLinkedHashMap();
2929
private final Map<ImFunction, Integer> funcSizes = Maps.newLinkedHashMap();
3030
private final Set<ImFunction> done = Sets.newLinkedHashSet();
31+
private final Map<ImFunction, Boolean> containsFuncRefCache = Maps.newLinkedHashMap();
3132
private final double inlineTreshold = 50;
3233

3334
static {
@@ -49,8 +50,7 @@ public void doInlining() {
4950
}
5051

5152
private void inlineFunctions() {
52-
53-
for (ImFunction f : ImHelper.calculateFunctionsOfProg(prog)) {
53+
for (ImFunction f : sortedFunctions(ImHelper.calculateFunctionsOfProg(prog))) {
5454
inlineFunctions(f);
5555
}
5656
}
@@ -61,7 +61,7 @@ private void inlineFunctions(ImFunction f) {
6161
}
6262
done.add(f);
6363
// first inline functions called from this function
64-
for (ImFunction called : translator.getCalledFunctions().get(f)) {
64+
for (ImFunction called : sortedFunctions(translator.getCalledFunctions().get(f))) {
6565
inlineFunctions(called);
6666
}
6767
boolean[] changed = new boolean[]{false};
@@ -73,10 +73,10 @@ private ImFunction inlineFunctions(ImFunction f, Element parent, int parentI, El
7373
if (e instanceof ImFunctionCall) {
7474
ImFunctionCall call = (ImFunctionCall) e;
7575
ImFunction called = call.getFunc();
76-
boolean canInline = f != called && shouldInline(call, called);
76+
boolean canInline = f != called && shouldInline(f, call, called);
7777
if (LOG_INLINER) {
7878
String msg = "[INLINER] caller=" + f.getName() + " callee=" + called.getName() + " decision=" + (canInline ? "inline" : "keep") +
79-
(canInline ? "" : " reason=" + skipReason(call, called));
79+
(canInline ? "" : " reason=" + skipReason(f, call, called));
8080
WLogger.info(msg);
8181
System.out.println(msg);
8282
}
@@ -109,13 +109,19 @@ private ImFunction inlineFunctions(ImFunction f, Element parent, int parentI, El
109109
return null;
110110
}
111111

112-
private String skipReason(ImFunctionCall call, ImFunction f) {
112+
private String skipReason(ImFunction caller, ImFunctionCall call, ImFunction f) {
113113
if (f.isNative()) {
114114
return "native";
115115
}
116116
if (call.getCallType() == CallType.EXECUTE) {
117117
return "execute_call";
118118
}
119+
if (translator.isLuaTarget() && !maxOneReturn(f)) {
120+
return "lua_multi_return_inline_disabled";
121+
}
122+
if (translator.isLuaTarget() && containsFuncRef(f)) {
123+
return "lua_callback_funcref_barrier";
124+
}
119125
if (!inlinableFunctions.contains(f)) {
120126
return "not_in_inlinable_set";
121127
}
@@ -276,12 +282,18 @@ private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar)
276282
}
277283

278284
private void rateInlinableFunctions() {
279-
for (Map.Entry<ImFunction, ImFunction> edge : translator.getCalledFunctions().entries()) {
285+
List<Map.Entry<ImFunction, ImFunction>> edges = new ArrayList<>(translator.getCalledFunctions().entries());
286+
edges.sort((a, b) -> {
287+
int c = functionSortKey(a.getKey()).compareTo(functionSortKey(b.getKey()));
288+
if (c != 0) return c;
289+
return functionSortKey(a.getValue()).compareTo(functionSortKey(b.getValue()));
290+
});
291+
for (Map.Entry<ImFunction, ImFunction> edge : edges) {
280292
// For bloat control we need how often a function is used (incoming edges),
281293
// not how many calls it performs itself (outgoing edges).
282294
incCallCount(edge.getValue());
283295
}
284-
for (ImFunction f : inlinableFunctions) {
296+
for (ImFunction f : sortedFunctions(inlinableFunctions)) {
285297
int size = estimateSize(f);
286298
funcSizes.put(f, size);
287299
}
@@ -322,10 +334,25 @@ private int getFuncSize(ImFunction f) {
322334
}
323335
}
324336

325-
private boolean shouldInline(ImFunctionCall call, ImFunction f) {
337+
private boolean shouldInline(ImFunction caller, ImFunctionCall call, ImFunction f) {
326338
if (f.isNative() || call.getCallType() == CallType.EXECUTE) {
327339
return false;
328340
}
341+
if (translator.isLuaTarget() && !maxOneReturn(f)) {
342+
// Conservative safety: Lua inliner multi-return rewriting is not yet fully robust
343+
// across all lowered patterns. Keep call semantics intact for now.
344+
return false;
345+
}
346+
if (translator.isLuaTarget() && containsFuncRef(f)) {
347+
// Functions that build callback refs are lowered with Lua-specific wrappers/xpcall.
348+
// Keeping them as standalone calls avoids callback context/vararg scope breakage.
349+
return false;
350+
}
351+
if (isLuaTypeCastingCompatFunction(f)) {
352+
// In Lua these compat wrappers are rewritten to object index helpers.
353+
// If they are inlined beforehand, old TypeCasting bodies leak through.
354+
return false;
355+
}
329356

330357
double threshold = inlineTreshold;
331358
for (ImExpr arg : call.getArguments()) {
@@ -362,6 +389,31 @@ private boolean containsCallTo(ImFunction f, Element e) {
362389
return false;
363390
}
364391

392+
private boolean containsFuncRef(ImFunction f) {
393+
if (f == null) {
394+
return false;
395+
}
396+
Boolean cached = containsFuncRefCache.get(f);
397+
if (cached != null) {
398+
return cached;
399+
}
400+
boolean result = containsFuncRef(f.getBody());
401+
containsFuncRefCache.put(f, result);
402+
return result;
403+
}
404+
405+
private boolean containsFuncRef(Element e) {
406+
if (e instanceof ImFuncRef) {
407+
return true;
408+
}
409+
for (int i = 0; i < e.size(); i++) {
410+
if (containsFuncRef(e.get(i))) {
411+
return true;
412+
}
413+
}
414+
return false;
415+
}
416+
365417
private int estimateSize(ImFunction f) {
366418
int[] r = new int[]{0};
367419
estimateSize(f.getBody(), r);
@@ -390,24 +442,46 @@ private int getCallCount(ImFunction f) {
390442
}
391443

392444
private void collectInlinableFunctions() {
393-
for (ImFunction f : ImHelper.calculateFunctionsOfProg(prog)) {
445+
for (ImFunction f : sortedFunctions(ImHelper.calculateFunctionsOfProg(prog))) {
394446
if (isInlineCandidate(f)) {
395447
inlinableFunctions.add(f);
396448
}
397449
}
398450
// Some call targets can survive in the call graph but not in prog/classes lists.
399-
for (ImFunction f : translator.getCalledFunctions().values()) {
451+
for (ImFunction f : sortedFunctions(translator.getCalledFunctions().values())) {
400452
if (isInlineCandidate(f)) {
401453
inlinableFunctions.add(f);
402454
}
403455
}
404456
}
405457

458+
private List<ImFunction> sortedFunctions(Collection<ImFunction> functions) {
459+
List<ImFunction> r = new ArrayList<>(functions);
460+
r.sort(Comparator.comparing(this::functionSortKey));
461+
return r;
462+
}
463+
464+
private String functionSortKey(ImFunction f) {
465+
if (f == null) {
466+
return "";
467+
}
468+
StringBuilder sb = new StringBuilder();
469+
sb.append(f.getName()).append("|");
470+
sb.append(f.getReturnType()).append("|");
471+
for (ImVar p : f.getParameters()) {
472+
sb.append(p.getType()).append(",");
473+
}
474+
return sb.toString();
475+
}
476+
406477
private boolean isInlineCandidate(ImFunction f) {
407478
if (f.hasFlag(FunctionFlagEnum.IS_COMPILETIME_NATIVE) || f.hasFlag(FunctionFlagEnum.IS_NATIVE)) {
408479
// do not inline natives
409480
return false;
410481
}
482+
if (isLuaTypeCastingCompatFunction(f)) {
483+
return false;
484+
}
411485
if (f == translator.getGlobalInitFunc()) {
412486
return false;
413487
}
@@ -419,6 +493,20 @@ private boolean isInlineCandidate(ImFunction f) {
419493
return true;
420494
}
421495

496+
private boolean isLuaTypeCastingCompatFunction(ImFunction f) {
497+
if (!translator.isLuaTarget() || f == null) {
498+
return false;
499+
}
500+
de.peeeq.wurstscript.ast.Element trace = f.attrTrace();
501+
if (trace instanceof de.peeeq.wurstscript.ast.FuncDef fd
502+
&& fd.attrNearestPackage() instanceof de.peeeq.wurstscript.ast.WPackage p
503+
&& "TypeCasting".equals(p.getName())) {
504+
String name = fd.getName();
505+
return name.endsWith("FromIndex") || name.endsWith("ToIndex");
506+
}
507+
return false;
508+
}
509+
422510
private boolean maxOneReturn(ImFunction f) {
423511
return maxOneReturn(f.getBody());
424512
}

0 commit comments

Comments
 (0)