Skip to content

Commit 4644c19

Browse files
authored
Initialize threadResource only once per thread (#2199)
Only initialize once per thread
1 parent a5fd5a2 commit 4644c19

1 file changed

Lines changed: 107 additions & 39 deletions

File tree

mlir/tools/rocmlir-tuning-driver/rocmlir-tuning-driver.cpp

Lines changed: 107 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,55 @@ struct CompilationResult {
292292
SmallVector<uint32_t> gridSizes;
293293
};
294294

295+
// Thread-local resources to avoid per-config initialization overhead.
296+
// Each worker thread gets its own context, PassManagers, and parsed module.
297+
// Note: MLIR's MLIRContext cannot be safely shared across parallel pass
298+
// executions - it asserts when the registry is modified during multi-threaded
299+
// execution. Therefore, each thread needs its own context.
300+
struct ThreadResources {
301+
std::unique_ptr<MLIRContext> ctx;
302+
std::unique_ptr<PassManager> applicabilityPM;
303+
std::unique_ptr<PassManager> compilationPM;
304+
OwningOpRef<ModuleOp> sourceModule;
305+
306+
ThreadResources() = default;
307+
ThreadResources(ThreadResources &&) = default;
308+
ThreadResources &operator=(ThreadResources &&) = default;
309+
310+
// Non-copyable
311+
ThreadResources(const ThreadResources &) = delete;
312+
ThreadResources &operator=(const ThreadResources &) = delete;
313+
314+
// Initialize all resources for this thread
315+
bool initialize(const std::string &sourceModuleStr,
316+
const rock::KernelOptions &applicabilityOpts,
317+
const rock::KernelOptions &compilationKernOpts,
318+
const rock::BackendOptions &backendOpts) {
319+
DialectRegistry registry;
320+
registerRocMLIRDialects(registry);
321+
ctx = std::make_unique<MLIRContext>(registry);
322+
ctx->getDiagEngine().registerHandler([](Diagnostic &) {});
323+
324+
// Pre-build pipelines once per thread
325+
applicabilityPM = std::make_unique<PassManager>(
326+
ctx.get(), PassManager::getAnyOpAnchorName(),
327+
PassManager::Nesting::Implicit);
328+
compilationPM = std::make_unique<PassManager>(
329+
ctx.get(), PassManager::getAnyOpAnchorName(),
330+
PassManager::Nesting::Implicit);
331+
332+
rock::buildKernelPipeline(*applicabilityPM, applicabilityOpts);
333+
rock::buildKernelPipeline(*compilationPM, compilationKernOpts);
334+
rock::buildBackendPipeline(*compilationPM, backendOpts);
335+
336+
// Parse source module once per thread
337+
sourceModule = parseSourceString<ModuleOp>(sourceModuleStr, ctx.get());
338+
return sourceModule && *sourceModule;
339+
}
340+
341+
bool isValid() const { return sourceModule && *sourceModule; }
342+
};
343+
295344
static LogicalResult
296345
measureSmallKernel(unsigned iterations, hipStream_t stream,
297346
const std::vector<hipFunction_t> &functions,
@@ -740,51 +789,64 @@ static LogicalResult runTuningLoop(ModuleOp source) {
740789
// Don't create more threads than configs to compile
741790
numThreads = std::min(numThreads, static_cast<unsigned>(configs.size()));
742791

743-
// Serialize source module once (shared by all threads for cloning)
792+
// Serialize source module once (shared by all threads for parsing)
744793
std::string sourceModuleStr;
745-
llvm::raw_string_ostream sourceOs(sourceModuleStr);
746-
source->print(sourceOs);
747-
sourceOs.flush();
794+
{
795+
llvm::raw_string_ostream sourceOs(sourceModuleStr);
796+
source->print(sourceOs);
797+
}
798+
799+
// PHASE 2: Pre-initialize thread resources (contexts, PassManagers, parsed
800+
// modules). This avoids the expensive per-config overhead of creating
801+
// contexts, parsing modules, and building pipelines.
802+
// Note: MLIR's MLIRContext cannot be safely shared across parallel pass
803+
// executions, so each thread needs its own context.
804+
std::vector<ThreadResources> threadResources(numThreads);
805+
std::atomic<bool> initFailed{false};
806+
807+
{
808+
std::vector<std::thread> initThreads;
809+
initThreads.reserve(numThreads);
810+
for (unsigned i = 0; i < numThreads; ++i) {
811+
initThreads.emplace_back([&, i]() {
812+
if (!threadResources[i].initialize(sourceModuleStr, applicabilityOpts,
813+
compilationKernOpts,
814+
backendOpts)) {
815+
initFailed.store(true, std::memory_order_relaxed);
816+
}
817+
});
818+
}
819+
for (auto &t : initThreads) {
820+
t.join();
821+
}
822+
}
748823

749-
// PHASE 2: Parallel compilation phase
824+
if (initFailed.load(std::memory_order_relaxed)) {
825+
llvm::errs() << "Failed to initialize thread resources\n";
826+
return failure();
827+
}
828+
829+
// PHASE 3: Parallel compilation phase using pre-initialized resources
750830
std::vector<CompilationResult> compilationResults(configs.size());
751831
std::mutex outputMutex; // For thread-safe console output
752832
std::atomic<bool> compilationFailed{
753833
false}; // Flag to signal early termination
754834

755-
auto compileConfig = [&](size_t idx) -> CompilationResult {
835+
// Compile a single config using pre-initialized thread resources
836+
auto compileConfig = [&](size_t idx,
837+
ThreadResources &res) -> CompilationResult {
756838
CompilationResult result;
757839
result.perfConfig = configs[idx];
758-
// Each thread needs its own context and pass managers for thread-safety
759-
DialectRegistry threadRegistry;
760-
registerRocMLIRDialects(threadRegistry);
761-
MLIRContext threadCtx(threadRegistry);
762-
threadCtx.getDiagEngine().registerHandler([](Diagnostic &diag) {});
763-
764-
// Parse the serialized module in this thread's context
765-
OwningOpRef<ModuleOp> threadSource =
766-
parseSourceString<ModuleOp>(sourceModuleStr, &threadCtx);
767-
if (!threadSource)
768-
return result;
769-
770-
// Set up pipelines for this thread
771-
PassManager threadApplicability(&threadCtx,
772-
PassManager::getAnyOpAnchorName(),
773-
PassManager::Nesting::Implicit);
774-
PassManager threadCompilation(&threadCtx,
775-
PassManager::getAnyOpAnchorName(),
776-
PassManager::Nesting::Implicit);
777840

778-
rock::buildKernelPipeline(threadApplicability, applicabilityOpts);
779-
rock::buildKernelPipeline(threadCompilation, compilationKernOpts);
780-
rock::buildBackendPipeline(threadCompilation, backendOpts);
841+
if (!res.isValid())
842+
return result;
781843

782844
StringAttr perfConfigAttr =
783-
StringAttr::get(&threadCtx, result.perfConfig);
845+
StringAttr::get(res.ctx.get(), result.perfConfig);
784846

785847
// Helper to copy IR with perf config set
786-
auto copyIRThread = [&](ModuleOp src,
787-
StringAttr attr) -> OwningOpRef<ModuleOp> {
848+
auto copyIR = [&](ModuleOp src,
849+
StringAttr attr) -> OwningOpRef<ModuleOp> {
788850
OwningOpRef<ModuleOp> copy = cast<ModuleOp>(src->clone());
789851
copy->walk([&attr](rock::RockGemmWrapperInterface op) {
790852
op->setAttr("perf_config", attr);
@@ -795,16 +857,16 @@ static LogicalResult runTuningLoop(ModuleOp source) {
795857
return copy;
796858
};
797859

798-
if (doesModuleHaveFusions(threadSource.get()) &&
799-
!rock::isModuleFusible(threadSource.get(), result.perfConfig)) {
860+
if (doesModuleHaveFusions(res.sourceModule.get()) &&
861+
!rock::isModuleFusible(res.sourceModule.get(), result.perfConfig)) {
800862
result.status = CompilationStatus::NotApplicable;
801863
return result;
802864
}
803865

804-
// Applicability check
866+
// Applicability check - clone the pre-parsed module
805867
OwningOpRef<ModuleOp> sourceCopy =
806-
copyIRThread(threadSource.get(), perfConfigAttr);
807-
if (failed(threadApplicability.run(sourceCopy.get()))) {
868+
copyIR(res.sourceModule.get(), perfConfigAttr);
869+
if (failed(res.applicabilityPM->run(sourceCopy.get()))) {
808870
result.status = CompilationStatus::NotApplicable;
809871
return result;
810872
}
@@ -823,8 +885,8 @@ static LogicalResult runTuningLoop(ModuleOp source) {
823885
tunedFunc->getAttrOfType<IntegerAttr>("grid_size").getInt());
824886
}
825887

826-
// Compilation
827-
if (failed(threadCompilation.run(sourceCopy.get()))) {
888+
// Compilation - use pre-built pipeline
889+
if (failed(res.compilationPM->run(sourceCopy.get()))) {
828890
std::lock_guard<std::mutex> lock(outputMutex);
829891
llvm::errs() << "Backend pipeline failed for config: "
830892
<< result.perfConfig << "\n";
@@ -860,8 +922,14 @@ static LogicalResult runTuningLoop(ModuleOp source) {
860922
// load balancing by allowing fast threads to pick up more work.
861923
{
862924
std::atomic<size_t> nextIdx{0};
925+
std::atomic<unsigned> nextThreadId{0};
863926

864927
auto worker = [&]() {
928+
// Each worker gets assigned a unique thread ID for its resources
929+
unsigned myThreadId =
930+
nextThreadId.fetch_add(1, std::memory_order_relaxed);
931+
ThreadResources &myRes = threadResources[myThreadId];
932+
865933
while (true) {
866934
if (compilationFailed.load(std::memory_order_relaxed))
867935
break;
@@ -870,7 +938,7 @@ static LogicalResult runTuningLoop(ModuleOp source) {
870938
if (idx >= configs.size())
871939
break;
872940

873-
compilationResults[idx] = compileConfig(idx);
941+
compilationResults[idx] = compileConfig(idx, myRes);
874942
}
875943
};
876944

0 commit comments

Comments
 (0)