Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -309,31 +309,19 @@ private void analyzeComponent(List<Node> scc, Map<Node, Knowledge> knowledge) {
} else {
Value newValue = null;

// Try constant folding first
ImExpr foldedExpr = tryConstantFold(right, newOut);
if (foldedExpr != null && foldedExpr != right) {
// We successfully folded to a constant
newValue = Value.tryValue(foldedExpr);
if (newValue != null) {
// Replace the RHS with the folded constant in the AST
right.replaceBy(foldedExpr);
}
}

// If no folding happened, try regular value propagation
if (newValue == null) {
if (right instanceof ImConst) {
newValue = Value.tryValue(right);
} else if (right instanceof ImVarAccess) {
ImVar varRight = ((ImVarAccess) right).getVar();
if(newOut.containsKey(varRight)) {
newValue = newOut.get(varRight).getOrNull();
} else {
newValue = Value.tryValue(right);
}
} else if(right instanceof ImTupleExpr) {
// Constant folding is intentionally centralized in SimpleRewrites.
// This pass performs propagation only to keep fold semantics in one place.
if (right instanceof ImConst) {
newValue = Value.tryValue(right);
} else if (right instanceof ImVarAccess) {
ImVar varRight = ((ImVarAccess) right).getVar();
if(newOut.containsKey(varRight)) {
newValue = newOut.get(varRight).getOrNull();
} else {
newValue = Value.tryValue(right);
}
} else if(right instanceof ImTupleExpr) {
newValue = Value.tryValue(right);
}

if (newValue == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,6 @@ private String skipReason(ImFunction caller, ImFunctionCall call, ImFunction f)
if (call.getCallType() == CallType.EXECUTE) {
return "execute_call";
}
if (translator.isLuaTarget() && !maxOneReturn(f)) {
return "lua_multi_return_inline_disabled";
}
if (translator.isLuaTarget() && containsFuncRef(f)) {
return "lua_callback_funcref_barrier";
}
Expand Down Expand Up @@ -241,14 +238,14 @@ public void visit(ImFunctionCall called) {
private ImStmts rewriteForEarlyReturns(ImStmts body, ImVar doneVar, ImVar retVar) {
ImStmts rewritten = JassIm.ImStmts();
for (ImStmt s : body) {
ImStmt transformed = rewriteStmtForEarlyReturn(s, doneVar, retVar);
ImStmts transformed = rewriteStmtForEarlyReturn(s, doneVar, retVar);
ImExpr notDone = JassIm.ImOperatorCall(de.peeeq.wurstscript.WurstOperator.NOT, JassIm.ImExprs(JassIm.ImVarAccess(doneVar)));
rewritten.add(JassIm.ImIf(s.attrTrace(), notDone, JassIm.ImStmts(transformed), JassIm.ImStmts()));
rewritten.add(JassIm.ImIf(s.attrTrace(), notDone, transformed, JassIm.ImStmts()));
}
return rewritten;
}

private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) {
private ImStmts rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar) {
if (s instanceof ImReturn) {
ImReturn r = (ImReturn) s;
ImStmts b = JassIm.ImStmts();
Expand All @@ -258,27 +255,27 @@ private ImStmt rewriteStmtForEarlyReturn(ImStmt s, ImVar doneVar, ImVar retVar)
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(retVar), rv));
}
b.add(JassIm.ImSet(r.getTrace(), JassIm.ImVarAccess(doneVar), JassIm.ImBoolVal(true)));
return ImHelper.statementExprVoid(b);
return b;
} else if (s instanceof ImIf) {
ImIf imIf = (ImIf) s;
ImStmts thenBlock = rewriteForEarlyReturns(imIf.getThenBlock().copy(), doneVar, retVar);
ImStmts elseBlock = rewriteForEarlyReturns(imIf.getElseBlock().copy(), doneVar, retVar);
return JassIm.ImIf(imIf.getTrace(), imIf.getCondition().copy(), thenBlock, elseBlock);
return JassIm.ImStmts(JassIm.ImIf(imIf.getTrace(), imIf.getCondition().copy(), thenBlock, elseBlock));
} else if (s instanceof ImLoop) {
ImLoop l = (ImLoop) s;
ImStmts loopBody = JassIm.ImStmts();
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
return JassIm.ImLoop(l.getTrace(), loopBody);
return JassIm.ImStmts(JassIm.ImLoop(l.getTrace(), loopBody));
} else if (s instanceof ImVarargLoop) {
ImVarargLoop l = (ImVarargLoop) s;
ImStmts loopBody = JassIm.ImStmts();
loopBody.add(JassIm.ImExitwhen(l.getTrace(), JassIm.ImVarAccess(doneVar)));
loopBody.addAll(rewriteForEarlyReturns(l.getBody().copy(), doneVar, retVar).removeAll());
return JassIm.ImVarargLoop(l.getTrace(), loopBody, l.getLoopVar());
return JassIm.ImStmts(JassIm.ImVarargLoop(l.getTrace(), loopBody, l.getLoopVar()));
}
// Keep tree ownership valid when rewrapping statements into new blocks.
return s.copy();
return JassIm.ImStmts(s.copy());
}

private void rateInlinableFunctions() {
Expand Down Expand Up @@ -338,11 +335,6 @@ private boolean shouldInline(ImFunction caller, ImFunctionCall call, ImFunction
if (f.isNative() || call.getCallType() == CallType.EXECUTE) {
return false;
}
if (translator.isLuaTarget() && !maxOneReturn(f)) {
// Conservative safety: Lua inliner multi-return rewriting is not yet fully robust
// across all lowered patterns. Keep call semantics intact for now.
return false;
}
if (translator.isLuaTarget() && containsFuncRef(f)) {
// Functions that build callback refs are lowered with Lua-specific wrappers/xpcall.
// Keeping them as standalone calls avoids callback context/vararg scope breakage.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ public class LuaTranslator {

final ImProg prog;
final LuaCompilationUnit luaModel;
private final LuaStatements deferredMainInit = LuaAst.LuaStatements();
private final Set<String> usedNames = new HashSet<>(Arrays.asList(
// reserved function names
"print", "tostring", "error",
Expand Down Expand Up @@ -180,7 +181,7 @@ public LuaMethod initFor(ImMethod a) {
GetAForB<ImClass, LuaVariable> luaClassVar = new GetAForB<ImClass, LuaVariable>() {
@Override
public LuaVariable initFor(ImClass a) {
return LuaAst.LuaVariable(uniqueName(a.getName()), LuaAst.LuaNoExpr());
return LuaAst.LuaVariable(uniqueName(a.getName()), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields()));
}
};

Expand Down Expand Up @@ -321,9 +322,10 @@ public LuaCompilationUnit translate() {
initClassTables(c);
}

emitExperimentalHashtableLeakGuards();
prependDeferredMainInitToMain();
cleanStatements();
enforceLuaLocalLimits();
emitExperimentalHashtableLeakGuards();

return luaModel;
}
Expand Down Expand Up @@ -353,12 +355,39 @@ private void ensureWurstContextCallbackHelpers() {
}

private void emitExperimentalHashtableLeakGuards() {
luaModel.add(LuaAst.LuaLiteral("-- Wurst experimental Lua assertion guards: raw WC3 hashtable natives must not be called."));
deferMainInit(LuaAst.LuaLiteral("-- Wurst experimental Lua assertion guards: raw WC3 hashtable natives must not be called."));
deferMainInit(LuaAst.LuaLiteral("do"));
deferMainInit(LuaAst.LuaLiteral(" local __wurst_guard_ok = pcall(function()"));
for (String nativeName : allHashtableNativeNames()) {
luaModel.add(LuaAst.LuaLiteral("if " + nativeName + " ~= nil then " + nativeName
deferMainInit(LuaAst.LuaLiteral(" if " + nativeName + " ~= nil then " + nativeName
+ " = function(...) error(\"Wurst Lua assertion failed: unexpected call to native " + nativeName
+ ". Expected __wurst_" + nativeName + ".\") end end"));
}
deferMainInit(LuaAst.LuaLiteral(" end)"));
deferMainInit(LuaAst.LuaLiteral(" if not __wurst_guard_ok then"));
deferMainInit(LuaAst.LuaLiteral(" -- Some Lua runtimes lock native globals. Compile-time leak checks stay authoritative."));
deferMainInit(LuaAst.LuaLiteral(" end"));
deferMainInit(LuaAst.LuaLiteral("end"));
}

private void deferMainInit(LuaStatement statement) {
deferredMainInit.add(statement);
}

private void prependDeferredMainInitToMain() {
if (deferredMainInit.isEmpty()) {
return;
}
ImFunction mainIm = imTr.getMainFunc();
if (mainIm == null) {
return;
}
LuaFunction mainLua = luaFunc.getFor(mainIm);
LuaStatements mainBody = mainLua.getBody();
for (int i = deferredMainInit.size() - 1; i >= 0; i--) {
LuaStatement stmt = deferredMainInit.remove(i);
mainBody.add(0, stmt);
}
}

public static void assertNoLeakedHashtableNativeCalls(String luaCode) {
Expand Down Expand Up @@ -534,15 +563,17 @@ private void createInstanceOfFunction() {

private void createObjectIndexFunctions() {
String vName = "__wurst_objectIndexMap";
LuaVariable v = LuaAst.LuaVariable(vName, LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0"))
)));
LuaVariable v = LuaAst.LuaVariable(vName, LuaAst.LuaExprNull());
luaModel.add(v);

LuaVariable im = LuaAst.LuaVariable("__wurst_number_wrapper_map", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(v), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0"))
)));
))));

LuaVariable im = LuaAst.LuaVariable("__wurst_number_wrapper_map", LuaAst.LuaExprNull());
luaModel.add(im);
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(im), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0"))
))));

{
String[] code = {
Expand Down Expand Up @@ -597,12 +628,13 @@ private void createObjectIndexFunctions() {
}

private void createStringIndexFunctions() {
LuaVariable map = LuaAst.LuaVariable("__wurst_string_index_map", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
LuaVariable map = LuaAst.LuaVariable("__wurst_string_index_map", LuaAst.LuaExprNull());
luaModel.add(map);
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(map), LuaAst.LuaTableConstructor(LuaAst.LuaTableFields(
LuaAst.LuaTableNamedField("counter", LuaAst.LuaExprIntVal("0")),
LuaAst.LuaTableNamedField("byString", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields())),
LuaAst.LuaTableNamedField("byIndex", LuaAst.LuaTableConstructor(LuaAst.LuaTableFields()))
)));
luaModel.add(map);
))));

{
String[] code = {
Expand Down Expand Up @@ -995,8 +1027,6 @@ private void translateClass(ImClass c) {

luaModel.add(initMethod);

classVar.setInitialValue(emptyTable());

// translate functions
for (ImFunction f : c.getFunctions()) {
translateFunc(f);
Expand Down Expand Up @@ -1038,14 +1068,14 @@ private void initClassTables(ImClass c) {
// set supertype metadata:
LuaTableFields superClasses = LuaAst.LuaTableFields();
collectSuperClasses(superClasses, c, new HashSet<>());
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
LuaAst.LuaExprVarAccess(classVar),
WURST_SUPERTYPES),
LuaAst.LuaTableConstructor(superClasses)
));

// set typeid metadata:
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
LuaAst.LuaExprVarAccess(classVar),
ExprTranslation.TYPE_ID),
LuaAst.LuaExprIntVal("" + prog.attrTypeId().get(c))
Expand Down Expand Up @@ -1100,7 +1130,7 @@ private void createMethods(ImClass c, LuaVariable classVar) {
if (impl == null || impl.getImplementation() == null) {
continue;
}
luaModel.add(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprFieldAccess(
LuaAst.LuaExprVarAccess(classVar),
e.getKey()),
LuaAst.LuaExprFuncRef(luaFunc.getFor(impl.getImplementation()))
Expand Down Expand Up @@ -1343,8 +1373,9 @@ private void translateGlobal(ImVar v) {
return;
}
LuaVariable lv = luaVar.getFor(v);
lv.setInitialValue(defaultValue(v.getType()));
lv.setInitialValue(LuaAst.LuaExprNull());
luaModel.add(lv);
deferMainInit(LuaAst.LuaAssignment(LuaAst.LuaExprVarAccess(lv), defaultValue(v.getType())));
}

private LuaExpr defaultValue(ImType type) {
Expand Down
Loading
Loading