Skip to content

Commit f982c6e

Browse files
committed
Fix review issues in per-Fusion vals/axioms/metadata migration
- Add missing swap of axioms_ and metadata_ in Fusion::swap to prevent dangling pointers after move/assignment - Add missing cloning of axioms_ and metadata_ in Fusion::copy to preserve custom assumptions and metadata cache across copies - Guard Fusion::removeVal against removing cached special vals - Use std::unique_ptr for special vals and steal from vals_up_ to preserve the original invariant (shortcuts in vals_ but not vals_up_) - Fix metadataOf to use 'this' instead of v->container()
1 parent d14b90b commit f982c6e

2 files changed

Lines changed: 68 additions & 27 deletions

File tree

csrc/fusion.cpp

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,9 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
147147
std::swap(a.true_val_, b.true_val_);
148148
std::swap(a.false_val_, b.false_val_);
149149
std::swap(a.magic_zero_val_, b.magic_zero_val_);
150+
151+
std::swap(a.axioms_, b.axioms_);
152+
std::swap(a.metadata_, b.metadata_);
150153
}
151154

152155
std::unique_ptr<SegmentedFusion> Fusion::segment(
@@ -208,6 +211,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
208211

209212
to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_;
210213

214+
if (from->axioms_ != nullptr) {
215+
to->axioms_ = std::make_unique<std::vector<Val*>>();
216+
to->axioms_->reserve(from->axioms_->size());
217+
for (auto pred : *from->axioms_) {
218+
to->axioms_->push_back(ir_cloner.clone(pred));
219+
}
220+
}
221+
222+
for (auto& [key, val_expr] : from->metadata_) {
223+
to->metadata_[ir_cloner.clone(key)] = std::make_pair(
224+
ir_cloner.clone(val_expr.first), ir_cloner.clone(val_expr.second));
225+
}
226+
211227
if (from->all_tvs_ptr_ != nullptr) {
212228
to->all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>();
213229
to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size());
@@ -274,13 +290,14 @@ void Fusion::clear() noexcept {
274290
managed_data_.clear();
275291
managed_named_data_.clear();
276292

277-
// Reset per-Fusion special values (they'll be recreated lazily if needed)
278-
// The actual Val objects were removed by removeStatementsOwnedBy above.
279-
zero_val_ = nullptr;
280-
one_val_ = nullptr;
281-
true_val_ = nullptr;
282-
false_val_ = nullptr;
283-
magic_zero_val_ = nullptr;
293+
// Reset per-Fusion special values (they'll be recreated lazily if needed).
294+
// These unique_ptrs own the Val objects; ir_container()->clear() above only
295+
// removed them from vals_ (they were already absent from vals_up_).
296+
zero_val_.reset();
297+
one_val_.reset();
298+
true_val_.reset();
299+
false_val_.reset();
300+
magic_zero_val_.reset();
284301

285302
axioms_.reset();
286303
metadata_.clear();
@@ -318,6 +335,13 @@ void Fusion::removeExpr(Expr* expr) {
318335
void Fusion::removeVal(Val* val) {
319336
assertInContainer(val, "Cannot remove val ");
320337

338+
// Don't remove cached special vals — they are lazily created singletons
339+
if (val == zero_val_.get() || val == one_val_.get() ||
340+
val == true_val_.get() || val == false_val_.get() ||
341+
val == magic_zero_val_.get()) {
342+
return;
343+
}
344+
321345
NVF_CHECK(
322346
!val->isFusionInput(),
323347
"Cannot remove val as it is an input of the fusion.");
@@ -712,38 +736,55 @@ void Fusion::printTransforms() {
712736

713737
Val* Fusion::zeroVal() {
714738
if (!zero_val_) {
715-
zero_val_ = IrBuilder::createInContainer<Val>(this, 0L, DataType::Index);
739+
auto val = IrBuilder::createInContainer<Val>(this, 0L, DataType::Index);
740+
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
741+
zero_val_ = std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
742+
ir_container()->vals_up_.pop_back();
716743
}
717-
return zero_val_;
744+
return zero_val_.get();
718745
}
719746

720747
Val* Fusion::oneVal() {
721748
if (!one_val_) {
722-
one_val_ = IrBuilder::createInContainer<Val>(this, 1L, DataType::Index);
749+
auto val = IrBuilder::createInContainer<Val>(this, 1L, DataType::Index);
750+
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
751+
one_val_ = std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
752+
ir_container()->vals_up_.pop_back();
723753
}
724-
return one_val_;
754+
return one_val_.get();
725755
}
726756

727757
Val* Fusion::falseVal() {
728758
if (!false_val_) {
729-
false_val_ = IrBuilder::createInContainer<Val>(this, false, DataType::Bool);
759+
auto val = IrBuilder::createInContainer<Val>(this, false, DataType::Bool);
760+
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
761+
false_val_ =
762+
std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
763+
ir_container()->vals_up_.pop_back();
730764
}
731-
return false_val_;
765+
return false_val_.get();
732766
}
733767

734768
Val* Fusion::trueVal() {
735769
if (!true_val_) {
736-
true_val_ = IrBuilder::createInContainer<Val>(this, true, DataType::Bool);
770+
auto val = IrBuilder::createInContainer<Val>(this, true, DataType::Bool);
771+
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
772+
true_val_ = std::unique_ptr<Val>(ir_container()->vals_up_.back().release());
773+
ir_container()->vals_up_.pop_back();
737774
}
738-
return true_val_;
775+
return true_val_.get();
739776
}
740777

741778
NamedScalar* Fusion::magicZeroVal() {
742779
if (!magic_zero_val_) {
743-
magic_zero_val_ = IrBuilder::createInContainer<NamedScalar>(
780+
auto val = IrBuilder::createInContainer<NamedScalar>(
744781
this, kMagicZeroName, DataType::Index);
782+
NVF_ERROR(ir_container()->vals_up_.back().get() == val);
783+
magic_zero_val_ = std::unique_ptr<NamedScalar>(
784+
ir_container()->vals_up_.back().release()->as<NamedScalar>());
785+
ir_container()->vals_up_.pop_back();
745786
}
746-
return magic_zero_val_;
787+
return magic_zero_val_.get();
747788
}
748789

749790
Val* Fusion::zeroVal(DataType dtype) {
@@ -770,12 +811,10 @@ Val* Fusion::oneVal(DataType dtype) {
770811

771812
Val* Fusion::metadataOf(Val* v) {
772813
if (metadata_.count(v) == 0) {
773-
// Create metadata val owned by the same Fusion as v
774-
Fusion* owner = v->container();
775814
auto metadata_val =
776-
IrBuilder::createInContainer<Val>(owner, metaDataTypeOf(v));
815+
IrBuilder::createInContainer<Val>(this, metaDataTypeOf(v));
777816
auto metadata_expr =
778-
IrBuilder::createInContainer<GetMetaData>(owner, metadata_val, v);
817+
IrBuilder::createInContainer<GetMetaData>(this, metadata_val, v);
779818
metadata_[v] = std::make_pair(metadata_val, metadata_expr);
780819
}
781820
return metadata_.at(v).first;

csrc/fusion.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,9 @@ class NVF_API Fusion : public PolymorphicBase {
550550
return ir_container()->numExprs();
551551
}
552552

553-
int64_t numVals(bool include_shortcuts) const noexcept {
553+
// When include_shortcuts is true, count cached special vals (zeroVal, etc.)
554+
// which live outside vals_up_ but inside vals_.
555+
int64_t numVals(bool include_shortcuts = true) const noexcept {
554556
return ir_container()->numVals(include_shortcuts);
555557
}
556558

@@ -640,11 +642,11 @@ class NVF_API Fusion : public PolymorphicBase {
640642
inline static const std::string exact_mappings_key = "exact_mappings";
641643
std::unique_ptr<IrContainer> ir_container_;
642644

643-
Val* zero_val_ = nullptr;
644-
Val* one_val_ = nullptr;
645-
Val* true_val_ = nullptr;
646-
Val* false_val_ = nullptr;
647-
NamedScalar* magic_zero_val_ = nullptr;
645+
std::unique_ptr<Val> zero_val_;
646+
std::unique_ptr<Val> one_val_;
647+
std::unique_ptr<Val> true_val_;
648+
std::unique_ptr<Val> false_val_;
649+
std::unique_ptr<NamedScalar> magic_zero_val_;
648650

649651
std::unique_ptr<std::vector<Val*>> axioms_;
650652

0 commit comments

Comments
 (0)