Skip to content

Commit 88b2e60

Browse files
committed
Per-Fusion name counters fix duplicate TV names after copy
Move val/expr name counters from IrContainer to Fusion so each Fusion independently tracks name assignment. This fixes CI failures where Fusion::copy left the dest counter at N (number of cloned vals) instead of max(name)+1 when source names were non-sequential, causing newly created TVs to collide with existing names. The fix adds val_type_name_map_ and expr_name_counter_ to Fusion, and updates registerVal/registerExpr to use the Fusion-level counters. Fusion::copy syncs counters from source to dest after cloning. Fusion::swap exchanges counters. Fusion::clear resets them.
1 parent fa6bb78 commit 88b2e60

2 files changed

Lines changed: 35 additions & 2 deletions

File tree

csrc/fusion.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
157157
std::swap(a.magic_zero_val_, b.magic_zero_val_);
158158
std::swap(a.axioms_, b.axioms_);
159159
std::swap(a.metadata_, b.metadata_);
160+
std::swap(a.val_type_name_map_, b.val_type_name_map_);
161+
std::swap(a.expr_name_counter_, b.expr_name_counter_);
160162

161163
// Update Statement::ir_container_ pointers: a's old statements now belong
162164
// to b, and b's old statements now belong to a
@@ -208,6 +210,16 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
208210
ir_cloner.clone(val);
209211
}
210212

213+
// Sync per-Fusion name counters from source to dest.
214+
// During cloning, registerVal increments the dest Fusion's counter for each
215+
// val, then IrBuilder::clone overrides the name with setName(src->name()).
216+
// If source names are non-sequential (e.g., {0..10, 22..27} from segmenter
217+
// creating intermediate TVs), the dest counter ends up at N (number of vals)
218+
// instead of max(name)+1. Copying the source's counter state ensures new
219+
// vals created post-copy won't collide with existing names.
220+
to->val_type_name_map_ = from->val_type_name_map_;
221+
to->expr_name_counter_ = from->expr_name_counter_;
222+
211223
// Wire up definitions and uses on cloned vals
212224
for (auto val : from->vals()) {
213225
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
@@ -366,6 +378,9 @@ void Fusion::clear() noexcept {
366378
axioms_.reset();
367379
metadata_.clear();
368380

381+
val_type_name_map_.clear();
382+
expr_name_counter_ = 0;
383+
369384
invalidateTvsAndUses();
370385

371386
is_during_update_uses_ = false;
@@ -981,7 +996,7 @@ void Fusion::registerVal(Val* val) {
981996
c->vals_up_.emplace_back(val);
982997
c->vals_.insert(val);
983998
c->per_fusion_vals_[this].insert(val);
984-
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
999+
val->setName(IrContainerPasskey(), getValName(val->vtype()));
9851000
}
9861001

9871002
void Fusion::registerExpr(Expr* expr) {
@@ -998,7 +1013,7 @@ void Fusion::registerExpr(Expr* expr) {
9981013
c->exprs_up_.emplace_back(expr);
9991014
c->exprs_.insert(expr);
10001015
c->per_fusion_exprs_[this].insert(expr);
1001-
expr->setName(IrContainerPasskey(), c->getExprName());
1016+
expr->setName(IrContainerPasskey(), getExprName());
10021017

10031018
for (Val* input : expr->inputs()) {
10041019
assertInContainer(input, "Input to expr is invalid, ");

csrc/fusion.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,6 +647,24 @@ class NVF_API Fusion : public PolymorphicBase {
647647
std::unique_ptr<std::vector<Val*>> axioms_;
648648

649649
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
650+
651+
// Per-Fusion name counters. Each Fusion independently tracks name assignment
652+
// so that cloned Fusions get matching names (T0→T0) regardless of whether
653+
// they share an IrContainer. This is required by downstream consumers that
654+
// use tv->name() as a map key (alias_memory, GreedyParams, etc.).
655+
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
656+
StmtNameType expr_name_counter_ = 0;
657+
658+
StmtNameType getValName(ValType vtype) {
659+
if (val_type_name_map_.find(vtype) == val_type_name_map_.end()) {
660+
val_type_name_map_[vtype] = 0;
661+
}
662+
return val_type_name_map_[vtype]++;
663+
}
664+
665+
StmtNameType getExprName() {
666+
return expr_name_counter_++;
667+
}
650668
};
651669

652670
// Template implementations for Fusion::manage<T>() that use IrCloner

0 commit comments

Comments
 (0)