diff --git a/src/tools/wasm-reduce/wasm-reduce.cpp b/src/tools/wasm-reduce/wasm-reduce.cpp index f9cd7b64412..627ddfbc0a0 100644 --- a/src/tools/wasm-reduce/wasm-reduce.cpp +++ b/src/tools/wasm-reduce/wasm-reduce.cpp @@ -29,6 +29,7 @@ #include "ir/branch-utils.h" #include "ir/iteration.h" +#include "ir/localize.h" #include "ir/properties.h" #include "ir/utils.h" #include "pass.h" @@ -930,7 +931,8 @@ struct Reducer } std::cerr << "| try partition " << dd.partitionIndex() + 1 << " / " - << dd.partitionCount() << " (size " << dd.test.size() << ")\n"; + << dd.partitionCount() << " (size " << dd.test.size() << " / " + << dd.working.size() << ")\n"; Index removedSize = dd.working.size() - dd.test.size(); std::vector oldBodies(removedSize); @@ -982,66 +984,149 @@ struct Reducer } } - bool reduceFunctions() { - // try to remove functions - std::vector functionNames; - for (auto& func : module->functions) { - functionNames.push_back(func->name); + void reduceFunctions() { + std::cerr << "| try to remove functions\n"; + + // Find functions referenced from module code (i.e. global initializers). We + // will not attempt to remove these functions because we cannot generally + // replace their references with something valid. + // TODO: Look at how the function references are used. If they can be + // nullable, we can still consider deleting the functions. + struct UnremovableFinder : public PostWalker { + std::unordered_set unremovable; + void visitRefFunc(RefFunc* curr) { unremovable.insert(curr->func); } + }; + UnremovableFinder finder; + finder.walkModuleCode(module.get()); + + // Find the indices of functions we can consider removing or must not + // remove. + std::vector unremovableIndices; + std::vector initialCandidates; + initialCandidates.reserve(module->functions.size() - + finder.unremovable.size()); + for (Index i = 0; i < module->functions.size(); ++i) { + if (finder.unremovable.contains(module->functions[i]->name)) { + unremovableIndices.push_back(i); + } else { + initialCandidates.push_back(i); + } } - auto numFuncs = functionNames.size(); - if (numFuncs == 0) { - return false; + + if (initialCandidates.empty()) { + return; } - uint64_t skip = 1; - uint64_t maxSkip = 1; - // If we just removed some functions in the previous iteration, keep trying - // to remove more as this is one of the most efficient ways to reduce. - bool justReduced = true; - // Start from a new place each time. - size_t base = deterministicRandom(numFuncs); - std::cerr << "| try to remove functions (base: " << base - << ", decisionCounter: " << decisionCounter << ", numFuncs " - << numFuncs << ")\n"; - for (size_t x = 0; x < functionNames.size(); x++) { - size_t i = (base + x) % numFuncs; - if (!justReduced && functionsWeTriedToRemove.contains(functionNames[i]) && - !shouldTryToReduce(std::max((factor / 5) + 1, uint64_t(20000)))) { - continue; + + // Indices will change as we remove functions. Map the original indices to + // the present indices so we can use the original indices as stable + // identifiers. (Function names are not necessarily preserved through + // round-tripping.) + std::vector> currentIndices; + currentIndices.reserve(module->functions.size()); + for (Index i = 0; i < module->functions.size(); ++i) { + currentIndices.push_back(i); + } + + DeltaDebugger dd(std::move(initialCandidates)); + while (!dd.finished()) { + // Exit early if the test set size is less than the square root of the + // working set size. We don't want to waste time on very fine-grained + // partitions when we could switch to a different reduction strategy + // instead. + if (size_t sqrtRemaining = std::sqrt(dd.working.size()); + dd.test.size() > 0 && dd.test.size() < sqrtRemaining) { + break; } - std::vector names; - for (size_t j = 0; names.size() < skip && i + j < functionNames.size(); - j++) { - auto name = functionNames[i + j]; - if (module->getFunctionOrNull(name)) { - names.push_back(name); - functionsWeTriedToRemove.insert(name); + + std::cerr << "| try partition " << dd.partitionIndex() + 1 << " / " + << dd.partitionCount() << " (size " << dd.test.size() << " / " + << dd.working.size() << ")\n"; + + std::unordered_set keptIndices; + for (Index i : unremovableIndices) { + keptIndices.insert(*currentIndices[i]); + } + for (Index i : dd.test) { + keptIndices.insert(*currentIndices[i]); + } + + // Get the list of kept functions and the new index mapping we will have + // to use if this reduction works. + std::vector> newFuncs; + newFuncs.reserve(keptIndices.size()); + std::vector> newCurrentIndices; + newCurrentIndices.reserve(currentIndices.size()); + for (size_t i = 0; i < currentIndices.size(); ++i) { + if (auto currIndex = currentIndices[i]; + currIndex && keptIndices.contains(*currIndex)) { + newCurrentIndices.push_back(newFuncs.size()); + newFuncs.emplace_back(std::move(module->functions[*currIndex])); + } else { + newCurrentIndices.push_back(std::nullopt); } } - if (names.size() == 0) { - continue; + + module->functions = std::move(newFuncs); + module->updateFunctionsMap(); + + std::vector exportsToRemove; + for (auto& exp : module->exports) { + if (exp->kind == ExternalKind::Function && + !module->getFunctionOrNull(*exp->getInternalName())) { + exportsToRemove.push_back(exp->name); + } } - std::cerr << "| trying at i=" << i << " of size " << names.size() - << "\n"; - // Note that tryToRemoveFunctions() will reload the module if it fails, - // which means function names may change. - if (tryToRemoveFunctions(names)) { - noteReduction(names.size()); - // Subtract 1 since the loop increments us anyhow by one: we want to - // skip over the skipped functions, and not any more. - x += skip - 1; - skip = std::min(factor, 2 * skip); - maxSkip = std::max(skip, maxSkip); + for (auto expName : exportsToRemove) { + module->removeExport(expName); + } + + struct FunctionReplacer + : public WalkerPass> { + bool isFunctionParallel() override { return true; } + bool requiresNonNullableLocalFixups() override { return false; } + std::unique_ptr create() override { + return std::make_unique(); + }; + void visitCall(Call* curr) { + if (getModule()->getFunctionOrNull(curr->target)) { + return; + } + Builder builder(*getModule()); + auto* block = + ChildLocalizer(curr, getFunction(), *getModule(), getPassOptions()) + .getChildrenReplacement(); + auto* replacement = builder.replaceWithIdenticalType(curr); + if (replacement == curr) { + replacement = builder.makeUnreachable(); + } + block->list.push_back(replacement); + block->type = curr->type; + replaceCurrent(block); + } + void visitRefFunc(RefFunc* curr) { + if (getModule()->getFunctionOrNull(curr->func)) { + return; + } + Builder builder(*getModule()); + replaceCurrent( + builder.makeBlock({builder.makeUnreachable()}, curr->type)); + } + }; + PassRunner runner(module.get()); + runner.add(std::make_unique()); + runner.run(); + + assert(WasmValidator().validate( + *module, WasmValidator::Globally | WasmValidator::Quiet)); + if (writeAndTestReduction()) { + noteReduction(dd.working.size() - dd.test.size()); + currentIndices = std::move(newCurrentIndices); + dd.accept(); } else { - skip = std::max(skip / 2, uint64_t(1)); // or 1? - x += factor / 100; + loadWorking(); + dd.reject(); } } - // If maxSkip is 1 then we never reduced at all. If it is 2 then we did - // manage to reduce individual functions, but all our attempts at - // exponential growth failed. Only suggest doing a new iteration of this - // function if we did in fact manage to grow, which indicated there are lots - // of opportunities here, and it is worth focusing on this. - return maxSkip > 2; } void visitModule([[maybe_unused]] Module* curr) { @@ -1052,12 +1137,7 @@ struct Reducer curr = nullptr; reduceFunctionBodies(); - - // Reduction of entire functions at a time is very effective, and we do it - // with exponential growth and backoff, so keep doing it while it works. - // TODO: Figure out how to use delta debugging for this as well. - while (reduceFunctions()) { - } + reduceFunctions(); shrinkElementSegments(); @@ -1567,6 +1647,7 @@ More documentation can be found at if (first) { reducer.loadWorking(); reducer.reduceFunctionBodies(); + reducer.reduceFunctions(); first = false; }