Skip to content

Commit 35b7405

Browse files
committed
Per-Fusion name counters fix duplicate TV names after copy
Move val/expr name counters from IrContainer to Fusion so each Fusion independently tracks name assignment. This fixes CI failures where Fusion::copy left the dest counter at N (number of cloned vals) instead of max(name)+1 when source names were non-sequential, causing newly created TVs to collide with existing names. The fix adds val_type_name_map_ and expr_name_counter_ to Fusion, and updates registerVal/registerExpr to use the Fusion-level counters. Fusion::copy syncs counters from source to dest after cloning. Fusion::swap exchanges counters. Fusion::clear resets them.
1 parent a215b34 commit 35b7405

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

csrc/fusion.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
156156
std::swap(a.magic_zero_val_, b.magic_zero_val_);
157157
std::swap(a.axioms_, b.axioms_);
158158
std::swap(a.metadata_, b.metadata_);
159+
std::swap(a.val_type_name_map_, b.val_type_name_map_);
160+
std::swap(a.expr_name_counter_, b.expr_name_counter_);
159161

160162
// Update Statement::ir_container_ pointers: a's old statements now belong
161163
// to b, and b's old statements now belong to a
@@ -207,6 +209,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
207209
ir_cloner.clone(val);
208210
}
209211

212+
// Sync per-Fusion name counters from source to dest.
213+
// During cloning, registerVal increments the dest Fusion's counter for each
214+
// val, then IrBuilder::clone overrides the name with setName(src->name()).
215+
// If source names are non-sequential (e.g., {0..10, 22..27} from segmenter
216+
// creating intermediate TVs), the dest counter ends up at N (number of vals)
217+
// instead of max(name)+1. Copying the source's counter state ensures new
218+
// vals created post-copy won't collide with existing names.
219+
to->val_type_name_map_ = from->val_type_name_map_;
220+
to->expr_name_counter_ = from->expr_name_counter_;
221+
210222
// Wire up definitions and uses on cloned vals
211223
for (auto val : from->vals()) {
212224
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
@@ -365,6 +377,9 @@ void Fusion::clear() noexcept {
365377
axioms_.reset();
366378
metadata_.clear();
367379

380+
val_type_name_map_.clear();
381+
expr_name_counter_ = 0;
382+
368383
invalidateTvsAndUses();
369384

370385
is_during_update_uses_ = false;
@@ -980,7 +995,7 @@ void Fusion::registerVal(Val* val) {
980995
c->vals_up_.emplace_back(val);
981996
c->vals_.insert(val);
982997
c->per_fusion_vals_[this].insert(val);
983-
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
998+
val->setName(IrContainerPasskey(), getValName(val->vtype()));
984999
}
9851000

9861001
void Fusion::registerExpr(Expr* expr) {
@@ -997,7 +1012,7 @@ void Fusion::registerExpr(Expr* expr) {
9971012
c->exprs_up_.emplace_back(expr);
9981013
c->exprs_.insert(expr);
9991014
c->per_fusion_exprs_[this].insert(expr);
1000-
expr->setName(IrContainerPasskey(), c->getExprName());
1015+
expr->setName(IrContainerPasskey(), getExprName());
10011016

10021017
for (Val* input : expr->inputs()) {
10031018
assertInContainer(input, "Input to expr is invalid, ");

csrc/fusion.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,24 @@ class NVF_API Fusion : public PolymorphicBase {
647647
std::unique_ptr<std::vector<Val*>> axioms_;
648648

649649
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
650+
651+
// Per-Fusion name counters. Each Fusion independently tracks name assignment
652+
// so that cloned Fusions get matching names (T0→T0) regardless of whether
653+
// they share an IrContainer. This is required by downstream consumers that
654+
// use tv->name() as a map key (alias_memory, GreedyParams, etc.).
655+
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
656+
StmtNameType expr_name_counter_ = 0;
657+
658+
StmtNameType getValName(ValType vtype) {
659+
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
660+
val_type_name_map_[vtype] = 0;
661+
}
662+
return val_type_name_map_[vtype]++;
663+
}
664+
665+
StmtNameType getExprName() {
666+
return expr_name_counter_++;
667+
}
650668
};
651669

652670
// Template implementations for Fusion::manage<T>() that use IrCloner

0 commit comments

Comments
 (0)