Skip to content

Commit 56cf217

Browse files
committed
Per-Fusion name counters for shared container name correspondence
Replace global IrContainer name counters with per-Fusion counters so cloned Fusions produce matching statement names (T0=T0, T1=T1) instead of incrementing names (T0=T10). This fixes cross-fusion name lookups in GreedyParams and normalization_utils which use tv->name() as map keys. Changes: - Add per_fusion_val_name_map_ and per_fusion_expr_name_counter_ to IrContainer - Update getValName/getExprName to use per-Fusion counter with global fallback - Update registerVal/registerExpr to pass owning Fusion to name generators - Handle counter lifecycle in swap, copy, clear, destroy, transferOwnership - Use deterministic_vals() in Fusion::copy for stable clone ordering - Add 8 new tests for name correspondence (71/71 Phase 2 tests pass)
1 parent edff8af commit 56cf217

5 files changed

Lines changed: 334 additions & 16 deletions

File tree

csrc/fusion.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
218218
IrCloner ir_cloner(to);
219219

220220
// Phase 2: Clone only 'from's owned vals (not all vals in shared container)
221-
// Use ownedVals() to get only vals belonging to 'from'
222-
for (auto val : from->ownedVals()) {
221+
// CRITICAL: Use deterministic_vals() to get vals in insertion order.
222+
// Using ownedVals() (unordered_set) causes non-deterministic clone order,
223+
// which assigns different name() values to cloned vals between runs.
224+
// This breaks code that uses tv->name() as map keys (e.g., GreedyParams).
225+
for (auto val : from->deterministic_vals()) {
223226
ir_cloner.clone(val);
224227
}
225228

226229
// Update definition_ and uses_ on cloned vals
227-
for (auto val : from->ownedVals()) {
230+
for (auto val : from->deterministic_vals()) {
228231
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
229232
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
230233
}

csrc/fusion.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,13 @@ T* IrBuilder::clone(const T* src, IrCloner* ir_cloner) {
744744

745745
dest_container->registerStmt(IrBuilderPasskey(dest_container), dest_stmt);
746746

747-
if (src_container != dest_container) {
747+
// Phase 2 Task 10: For same-container cloning (shared IrContainer),
748+
// per-Fusion name counters produce matching names naturally (both start
749+
// at 0), so the name override below is NOT needed and is skipped.
750+
// For cross-container cloning (different IrContainers), we still need
751+
// to force the source name since the destination's global counter may
752+
// have diverged.
753+
if (src_container->ir_container() != dest_container->ir_container()) {
748754
dest_stmt->setName(IrBuilderPasskey(dest_container), src_stmt->name());
749755
}
750756

csrc/ir/container.cpp

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,24 @@ void IrContainer::transferStatementOwnership(Fusion* from, Fusion* to) {
268268
to_exprs.insert(exprs_it->second.begin(), exprs_it->second.end());
269269
per_fusion_exprs_.erase(exprs_it);
270270
}
271+
272+
// Transfer per-Fusion name counters (Phase 2 Task 10)
273+
auto val_names_it = per_fusion_val_name_map_.find(from);
274+
if (val_names_it != per_fusion_val_name_map_.end()) {
275+
// Merge counter maps: take max of each ValType counter
276+
auto& to_map = per_fusion_val_name_map_[to];
277+
for (auto& [vtype, counter] : val_names_it->second) {
278+
to_map[vtype] = std::max(to_map[vtype], counter);
279+
}
280+
per_fusion_val_name_map_.erase(val_names_it);
281+
}
282+
283+
auto expr_names_it = per_fusion_expr_name_counter_.find(from);
284+
if (expr_names_it != per_fusion_expr_name_counter_.end()) {
285+
auto& to_counter = per_fusion_expr_name_counter_[to];
286+
to_counter = std::max(to_counter, expr_names_it->second);
287+
per_fusion_expr_name_counter_.erase(expr_names_it);
288+
}
271289
}
272290

273291
void IrContainer::removeStatementsOwnedBy(Fusion* fusion) {
@@ -303,6 +321,10 @@ void IrContainer::removeStatementsOwnedByUnlocked(Fusion* fusion) {
303321
// Clean up per-Fusion tracking (Phase 2 Task 4)
304322
per_fusion_vals_.erase(fusion);
305323
per_fusion_exprs_.erase(fusion);
324+
325+
// Clean up per-Fusion name counters (Phase 2 Task 10)
326+
per_fusion_val_name_map_.erase(fusion);
327+
per_fusion_expr_name_counter_.erase(fusion);
306328
}
307329

308330
void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
@@ -334,6 +356,10 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
334356
std::swap(a.sharing_fusions_, b.sharing_fusions_);
335357
std::swap(a.per_fusion_vals_, b.per_fusion_vals_);
336358
std::swap(a.per_fusion_exprs_, b.per_fusion_exprs_);
359+
360+
// Swap per-Fusion name counters (Phase 2 Task 10)
361+
std::swap(a.per_fusion_val_name_map_, b.per_fusion_val_name_map_);
362+
std::swap(a.per_fusion_expr_name_counter_, b.per_fusion_expr_name_counter_);
337363
}
338364

339365
IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
@@ -354,6 +380,8 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
354380
to->expr_name_counter_ = 0;
355381
to->per_fusion_vals_.clear();
356382
to->per_fusion_exprs_.clear();
383+
to->per_fusion_val_name_map_.clear();
384+
to->per_fusion_expr_name_counter_.clear();
357385

358386
// NOTE: In Phase 2, we can't use to->parent() here because parent_ might
359387
// not be set correctly for shared containers. Fusion::copy handles this.
@@ -468,12 +496,16 @@ void IrContainer::registerVal(Val* val) {
468496
// Otherwise handle registration locally
469497
vals_up_.emplace_back(val);
470498
vals_.insert(val);
471-
val->setName(IrContainerPasskey(), getValName(val->vtype()));
499+
500+
// Phase 2 Task 10: Use per-Fusion counter if val has an owning Fusion.
501+
// This ensures cloned Fusions get matching names (T0=T0, T1=T1)
502+
// instead of incrementing global names (T0=T10, T1=T11).
503+
Fusion* owning_fusion = val->container();
504+
val->setName(IrContainerPasskey(), getValName(owning_fusion, val->vtype()));
472505

473506
// Track per-Fusion ownership (Phase 2 Task 4)
474-
// val->container() returns the owning Fusion
475-
if (val->container() != nullptr) {
476-
per_fusion_vals_[val->container()].insert(val);
507+
if (owning_fusion != nullptr) {
508+
per_fusion_vals_[owning_fusion].insert(val);
477509
}
478510
}
479511

@@ -486,12 +518,14 @@ void IrContainer::registerExpr(Expr* expr) {
486518
// Otherwise handle registration locally
487519
exprs_up_.emplace_back(expr);
488520
exprs_.insert(expr);
489-
expr->setName(IrContainerPasskey(), getExprName());
521+
522+
// Phase 2 Task 10: Use per-Fusion counter if expr has an owning Fusion.
523+
Fusion* owning_fusion = expr->container();
524+
expr->setName(IrContainerPasskey(), getExprName(owning_fusion));
490525

491526
// Track per-Fusion ownership (Phase 2 Task 4)
492-
// expr->container() returns the owning Fusion
493-
if (expr->container() != nullptr) {
494-
per_fusion_exprs_[expr->container()].insert(expr);
527+
if (owning_fusion != nullptr) {
528+
per_fusion_exprs_[owning_fusion].insert(expr);
495529
}
496530
}
497531

@@ -507,6 +541,10 @@ void IrContainer::clear() noexcept {
507541
// Clear per-Fusion tracking (Phase 2 Task 4)
508542
per_fusion_vals_.clear();
509543
per_fusion_exprs_.clear();
544+
545+
// Clear per-Fusion name counters (Phase 2 Task 10)
546+
per_fusion_val_name_map_.clear();
547+
per_fusion_expr_name_counter_.clear();
510548
}
511549

512550
bool IrContainer::inContainer(const Statement* const_stmt) const {

csrc/ir/container.h

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,14 +192,43 @@ class IrContainer {
192192
//! Register expr with this container.
193193
NVF_API void registerExpr(Expr* expr);
194194

195-
StmtNameType getValName(ValType vtype) {
195+
//! Get next val name, using per-Fusion counter if fusion is non-null,
196+
//! falling back to global counter otherwise.
197+
//! Per-Fusion counters ensure cloned Fusions produce matching names.
198+
StmtNameType getValName(Fusion* fusion, ValType vtype) {
199+
if (fusion != nullptr) {
200+
auto& name_map = per_fusion_val_name_map_[fusion];
201+
if (name_map.find(vtype) == name_map.end()) {
202+
name_map[vtype] = 0;
203+
}
204+
// Also advance global counter to keep it >= all per-Fusion counters
205+
// This prevents conflicts if global counter is used later
206+
auto& global = val_type_name_map_[vtype];
207+
auto per_fusion_name = name_map[vtype]++;
208+
if (global <= per_fusion_name) {
209+
global = per_fusion_name + 1;
210+
}
211+
return per_fusion_name;
212+
}
213+
// Global fallback for non-Fusion contexts
196214
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
197215
val_type_name_map_[vtype] = 0;
198216
}
199217
return val_type_name_map_[vtype]++;
200218
}
201219

202-
StmtNameType getExprName() {
220+
//! Get next expr name, using per-Fusion counter if fusion is non-null,
221+
//! falling back to global counter otherwise.
222+
StmtNameType getExprName(Fusion* fusion) {
223+
if (fusion != nullptr) {
224+
auto& counter = per_fusion_expr_name_counter_[fusion];
225+
auto per_fusion_name = counter++;
226+
// Also advance global counter
227+
if (expr_name_counter_ <= per_fusion_name) {
228+
expr_name_counter_ = per_fusion_name + 1;
229+
}
230+
return per_fusion_name;
231+
}
203232
return expr_name_counter_++;
204233
}
205234

@@ -237,12 +266,21 @@ class IrContainer {
237266
// something like check if an Expr is in this container
238267
std::unordered_set<Expr*> exprs_;
239268

240-
// Values names counters
269+
// Values names counters (global fallback for non-Fusion contexts)
241270
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
242271

243-
// Expression names counter
272+
// Expression names counter (global fallback for non-Fusion contexts)
244273
StmtNameType expr_name_counter_ = 0;
245274

275+
// Per-Fusion name counters (Phase 2 Task 10)
276+
// Each Fusion gets its own counter starting at 0, so cloned Fusions
277+
// produce matching names (T0=T0, T1=T1) instead of incrementing names.
278+
// This is critical for GreedyParams and normalization_utils which use
279+
// tv->name() as map keys across cloned Fusions.
280+
std::unordered_map<Fusion*, std::unordered_map<ValType, StmtNameType>>
281+
per_fusion_val_name_map_;
282+
std::unordered_map<Fusion*, StmtNameType> per_fusion_expr_name_counter_;
283+
246284
// Note: Special values (zero_val_, one_val_, true_val_, false_val_,
247285
// magic_zero_val_) are now per-Fusion, stored in Fusion class.
248286
// This avoids ownership conflicts when multiple Fusions share an IrContainer.

0 commit comments

Comments
 (0)