@@ -292,6 +292,55 @@ struct CompilationResult {
292292 SmallVector<uint32_t > gridSizes;
293293};
294294
295+ // Thread-local resources to avoid per-config initialization overhead.
296+ // Each worker thread gets its own context, PassManagers, and parsed module.
297+ // Note: MLIR's MLIRContext cannot be safely shared across parallel pass
298+ // executions - it asserts when the registry is modified during multi-threaded
299+ // execution. Therefore, each thread needs its own context.
300+ struct ThreadResources {
301+ std::unique_ptr<MLIRContext> ctx;
302+ std::unique_ptr<PassManager> applicabilityPM;
303+ std::unique_ptr<PassManager> compilationPM;
304+ OwningOpRef<ModuleOp> sourceModule;
305+
306+ ThreadResources () = default ;
307+ ThreadResources (ThreadResources &&) = default ;
308+ ThreadResources &operator =(ThreadResources &&) = default ;
309+
310+ // Non-copyable
311+ ThreadResources (const ThreadResources &) = delete ;
312+ ThreadResources &operator =(const ThreadResources &) = delete ;
313+
314+ // Initialize all resources for this thread
315+ bool initialize (const std::string &sourceModuleStr,
316+ const rock::KernelOptions &applicabilityOpts,
317+ const rock::KernelOptions &compilationKernOpts,
318+ const rock::BackendOptions &backendOpts) {
319+ DialectRegistry registry;
320+ registerRocMLIRDialects (registry);
321+ ctx = std::make_unique<MLIRContext>(registry);
322+ ctx->getDiagEngine ().registerHandler ([](Diagnostic &) {});
323+
324+ // Pre-build pipelines once per thread
325+ applicabilityPM = std::make_unique<PassManager>(
326+ ctx.get (), PassManager::getAnyOpAnchorName (),
327+ PassManager::Nesting::Implicit);
328+ compilationPM = std::make_unique<PassManager>(
329+ ctx.get (), PassManager::getAnyOpAnchorName (),
330+ PassManager::Nesting::Implicit);
331+
332+ rock::buildKernelPipeline (*applicabilityPM, applicabilityOpts);
333+ rock::buildKernelPipeline (*compilationPM, compilationKernOpts);
334+ rock::buildBackendPipeline (*compilationPM, backendOpts);
335+
336+ // Parse source module once per thread
337+ sourceModule = parseSourceString<ModuleOp>(sourceModuleStr, ctx.get ());
338+ return sourceModule && *sourceModule;
339+ }
340+
341+ bool isValid () const { return sourceModule && *sourceModule; }
342+ };
343+
295344static LogicalResult
296345measureSmallKernel (unsigned iterations, hipStream_t stream,
297346 const std::vector<hipFunction_t> &functions,
@@ -740,51 +789,64 @@ static LogicalResult runTuningLoop(ModuleOp source) {
740789 // Don't create more threads than configs to compile
741790 numThreads = std::min (numThreads, static_cast <unsigned >(configs.size ()));
742791
743- // Serialize source module once (shared by all threads for cloning )
792+ // Serialize source module once (shared by all threads for parsing )
744793 std::string sourceModuleStr;
745- llvm::raw_string_ostream sourceOs (sourceModuleStr);
746- source->print (sourceOs);
747- sourceOs.flush ();
794+ {
795+ llvm::raw_string_ostream sourceOs (sourceModuleStr);
796+ source->print (sourceOs);
797+ }
798+
799+ // PHASE 2: Pre-initialize thread resources (contexts, PassManagers, parsed
800+ // modules). This avoids the expensive per-config overhead of creating
801+ // contexts, parsing modules, and building pipelines.
802+ // Note: MLIR's MLIRContext cannot be safely shared across parallel pass
803+ // executions, so each thread needs its own context.
804+ std::vector<ThreadResources> threadResources (numThreads);
805+ std::atomic<bool > initFailed{false };
806+
807+ {
808+ std::vector<std::thread> initThreads;
809+ initThreads.reserve (numThreads);
810+ for (unsigned i = 0 ; i < numThreads; ++i) {
811+ initThreads.emplace_back ([&, i]() {
812+ if (!threadResources[i].initialize (sourceModuleStr, applicabilityOpts,
813+ compilationKernOpts,
814+ backendOpts)) {
815+ initFailed.store (true , std::memory_order_relaxed);
816+ }
817+ });
818+ }
819+ for (auto &t : initThreads) {
820+ t.join ();
821+ }
822+ }
748823
749- // PHASE 2: Parallel compilation phase
824+ if (initFailed.load (std::memory_order_relaxed)) {
825+ llvm::errs () << " Failed to initialize thread resources\n " ;
826+ return failure ();
827+ }
828+
829+ // PHASE 3: Parallel compilation phase using pre-initialized resources
750830 std::vector<CompilationResult> compilationResults (configs.size ());
751831 std::mutex outputMutex; // For thread-safe console output
752832 std::atomic<bool > compilationFailed{
753833 false }; // Flag to signal early termination
754834
755- auto compileConfig = [&](size_t idx) -> CompilationResult {
835+ // Compile a single config using pre-initialized thread resources
836+ auto compileConfig = [&](size_t idx,
837+ ThreadResources &res) -> CompilationResult {
756838 CompilationResult result;
757839 result.perfConfig = configs[idx];
758- // Each thread needs its own context and pass managers for thread-safety
759- DialectRegistry threadRegistry;
760- registerRocMLIRDialects (threadRegistry);
761- MLIRContext threadCtx (threadRegistry);
762- threadCtx.getDiagEngine ().registerHandler ([](Diagnostic &diag) {});
763-
764- // Parse the serialized module in this thread's context
765- OwningOpRef<ModuleOp> threadSource =
766- parseSourceString<ModuleOp>(sourceModuleStr, &threadCtx);
767- if (!threadSource)
768- return result;
769-
770- // Set up pipelines for this thread
771- PassManager threadApplicability (&threadCtx,
772- PassManager::getAnyOpAnchorName (),
773- PassManager::Nesting::Implicit);
774- PassManager threadCompilation (&threadCtx,
775- PassManager::getAnyOpAnchorName (),
776- PassManager::Nesting::Implicit);
777840
778- rock::buildKernelPipeline (threadApplicability, applicabilityOpts);
779- rock::buildKernelPipeline (threadCompilation, compilationKernOpts);
780- rock::buildBackendPipeline (threadCompilation, backendOpts);
841+ if (!res.isValid ())
842+ return result;
781843
782844 StringAttr perfConfigAttr =
783- StringAttr::get (&threadCtx , result.perfConfig );
845+ StringAttr::get (res. ctx . get () , result.perfConfig );
784846
785847 // Helper to copy IR with perf config set
786- auto copyIRThread = [&](ModuleOp src,
787- StringAttr attr) -> OwningOpRef<ModuleOp> {
848+ auto copyIR = [&](ModuleOp src,
849+ StringAttr attr) -> OwningOpRef<ModuleOp> {
788850 OwningOpRef<ModuleOp> copy = cast<ModuleOp>(src->clone ());
789851 copy->walk ([&attr](rock::RockGemmWrapperInterface op) {
790852 op->setAttr (" perf_config" , attr);
@@ -795,16 +857,16 @@ static LogicalResult runTuningLoop(ModuleOp source) {
795857 return copy;
796858 };
797859
798- if (doesModuleHaveFusions (threadSource .get ()) &&
799- !rock::isModuleFusible (threadSource .get (), result.perfConfig )) {
860+ if (doesModuleHaveFusions (res. sourceModule .get ()) &&
861+ !rock::isModuleFusible (res. sourceModule .get (), result.perfConfig )) {
800862 result.status = CompilationStatus::NotApplicable;
801863 return result;
802864 }
803865
804- // Applicability check
866+ // Applicability check - clone the pre-parsed module
805867 OwningOpRef<ModuleOp> sourceCopy =
806- copyIRThread (threadSource .get (), perfConfigAttr);
807- if (failed (threadApplicability. run (sourceCopy.get ()))) {
868+ copyIR (res. sourceModule .get (), perfConfigAttr);
869+ if (failed (res. applicabilityPM -> run (sourceCopy.get ()))) {
808870 result.status = CompilationStatus::NotApplicable;
809871 return result;
810872 }
@@ -823,8 +885,8 @@ static LogicalResult runTuningLoop(ModuleOp source) {
823885 tunedFunc->getAttrOfType <IntegerAttr>(" grid_size" ).getInt ());
824886 }
825887
826- // Compilation
827- if (failed (threadCompilation. run (sourceCopy.get ()))) {
888+ // Compilation - use pre-built pipeline
889+ if (failed (res. compilationPM -> run (sourceCopy.get ()))) {
828890 std::lock_guard<std::mutex> lock (outputMutex);
829891 llvm::errs () << " Backend pipeline failed for config: "
830892 << result.perfConfig << " \n " ;
@@ -860,8 +922,14 @@ static LogicalResult runTuningLoop(ModuleOp source) {
860922 // load balancing by allowing fast threads to pick up more work.
861923 {
862924 std::atomic<size_t > nextIdx{0 };
925+ std::atomic<unsigned > nextThreadId{0 };
863926
864927 auto worker = [&]() {
928+ // Each worker gets assigned a unique thread ID for its resources
929+ unsigned myThreadId =
930+ nextThreadId.fetch_add (1 , std::memory_order_relaxed);
931+ ThreadResources &myRes = threadResources[myThreadId];
932+
865933 while (true ) {
866934 if (compilationFailed.load (std::memory_order_relaxed))
867935 break ;
@@ -870,7 +938,7 @@ static LogicalResult runTuningLoop(ModuleOp source) {
870938 if (idx >= configs.size ())
871939 break ;
872940
873- compilationResults[idx] = compileConfig (idx);
941+ compilationResults[idx] = compileConfig (idx, myRes );
874942 }
875943 };
876944
0 commit comments