@@ -147,6 +147,9 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
147147 std::swap (a.true_val_ , b.true_val_ );
148148 std::swap (a.false_val_ , b.false_val_ );
149149 std::swap (a.magic_zero_val_ , b.magic_zero_val_ );
150+
151+ std::swap (a.axioms_ , b.axioms_ );
152+ std::swap (a.metadata_ , b.metadata_ );
150153}
151154
152155std::unique_ptr<SegmentedFusion> Fusion::segment (
@@ -208,6 +211,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
208211
209212 to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_ ;
210213
214+ if (from->axioms_ != nullptr ) {
215+ to->axioms_ = std::make_unique<std::vector<Val*>>();
216+ to->axioms_ ->reserve (from->axioms_ ->size ());
217+ for (auto pred : *from->axioms_ ) {
218+ to->axioms_ ->push_back (ir_cloner.clone (pred));
219+ }
220+ }
221+
222+ for (auto & [key, val_expr] : from->metadata_ ) {
223+ to->metadata_ [ir_cloner.clone (key)] = std::make_pair (
224+ ir_cloner.clone (val_expr.first ), ir_cloner.clone (val_expr.second ));
225+ }
226+
211227 if (from->all_tvs_ptr_ != nullptr ) {
212228 to->all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>();
213229 to->all_tvs_ptr_ ->reserve (from->all_tvs_ptr_ ->size ());
@@ -274,13 +290,14 @@ void Fusion::clear() noexcept {
274290 managed_data_.clear ();
275291 managed_named_data_.clear ();
276292
277- // Reset per-Fusion special values (they'll be recreated lazily if needed)
278- // The actual Val objects were removed by removeStatementsOwnedBy above.
279- zero_val_ = nullptr ;
280- one_val_ = nullptr ;
281- true_val_ = nullptr ;
282- false_val_ = nullptr ;
283- magic_zero_val_ = nullptr ;
293+ // Reset per-Fusion special values (they'll be recreated lazily if needed).
294+ // These unique_ptrs own the Val objects; ir_container()->clear() above only
295+ // removed them from vals_ (they were already absent from vals_up_).
296+ zero_val_.reset ();
297+ one_val_.reset ();
298+ true_val_.reset ();
299+ false_val_.reset ();
300+ magic_zero_val_.reset ();
284301
285302 axioms_.reset ();
286303 metadata_.clear ();
@@ -318,6 +335,13 @@ void Fusion::removeExpr(Expr* expr) {
318335void Fusion::removeVal (Val* val) {
319336 assertInContainer (val, " Cannot remove val " );
320337
338+ // Don't remove cached special vals — they are lazily created singletons
339+ if (val == zero_val_.get () || val == one_val_.get () ||
340+ val == true_val_.get () || val == false_val_.get () ||
341+ val == magic_zero_val_.get ()) {
342+ return ;
343+ }
344+
321345 NVF_CHECK (
322346 !val->isFusionInput (),
323347 " Cannot remove val as it is an input of the fusion." );
@@ -712,38 +736,55 @@ void Fusion::printTransforms() {
712736
713737Val* Fusion::zeroVal () {
714738 if (!zero_val_) {
715- zero_val_ = IrBuilder::createInContainer<Val>(this , 0L , DataType::Index);
739+ auto val = IrBuilder::createInContainer<Val>(this , 0L , DataType::Index);
740+ NVF_ERROR (ir_container ()->vals_up_ .back ().get () == val);
741+ zero_val_ = std::unique_ptr<Val>(ir_container ()->vals_up_ .back ().release ());
742+ ir_container ()->vals_up_ .pop_back ();
716743 }
717- return zero_val_;
744+ return zero_val_. get () ;
718745}
719746
720747Val* Fusion::oneVal () {
721748 if (!one_val_) {
722- one_val_ = IrBuilder::createInContainer<Val>(this , 1L , DataType::Index);
749+ auto val = IrBuilder::createInContainer<Val>(this , 1L , DataType::Index);
750+ NVF_ERROR (ir_container ()->vals_up_ .back ().get () == val);
751+ one_val_ = std::unique_ptr<Val>(ir_container ()->vals_up_ .back ().release ());
752+ ir_container ()->vals_up_ .pop_back ();
723753 }
724- return one_val_;
754+ return one_val_. get () ;
725755}
726756
727757Val* Fusion::falseVal () {
728758 if (!false_val_) {
729- false_val_ = IrBuilder::createInContainer<Val>(this , false , DataType::Bool);
759+ auto val = IrBuilder::createInContainer<Val>(this , false , DataType::Bool);
760+ NVF_ERROR (ir_container ()->vals_up_ .back ().get () == val);
761+ false_val_ =
762+ std::unique_ptr<Val>(ir_container ()->vals_up_ .back ().release ());
763+ ir_container ()->vals_up_ .pop_back ();
730764 }
731- return false_val_;
765+ return false_val_. get () ;
732766}
733767
734768Val* Fusion::trueVal () {
735769 if (!true_val_) {
736- true_val_ = IrBuilder::createInContainer<Val>(this , true , DataType::Bool);
770+ auto val = IrBuilder::createInContainer<Val>(this , true , DataType::Bool);
771+ NVF_ERROR (ir_container ()->vals_up_ .back ().get () == val);
772+ true_val_ = std::unique_ptr<Val>(ir_container ()->vals_up_ .back ().release ());
773+ ir_container ()->vals_up_ .pop_back ();
737774 }
738- return true_val_;
775+ return true_val_. get () ;
739776}
740777
741778NamedScalar* Fusion::magicZeroVal () {
742779 if (!magic_zero_val_) {
743- magic_zero_val_ = IrBuilder::createInContainer<NamedScalar>(
780+ auto val = IrBuilder::createInContainer<NamedScalar>(
744781 this , kMagicZeroName , DataType::Index);
782+ NVF_ERROR (ir_container ()->vals_up_ .back ().get () == val);
783+ magic_zero_val_ = std::unique_ptr<NamedScalar>(
784+ ir_container ()->vals_up_ .back ().release ()->as <NamedScalar>());
785+ ir_container ()->vals_up_ .pop_back ();
745786 }
746- return magic_zero_val_;
787+ return magic_zero_val_. get () ;
747788}
748789
749790Val* Fusion::zeroVal (DataType dtype) {
@@ -770,12 +811,10 @@ Val* Fusion::oneVal(DataType dtype) {
770811
771812Val* Fusion::metadataOf (Val* v) {
772813 if (metadata_.count (v) == 0 ) {
773- // Create metadata val owned by the same Fusion as v
774- Fusion* owner = v->container ();
775814 auto metadata_val =
776- IrBuilder::createInContainer<Val>(owner , metaDataTypeOf (v));
815+ IrBuilder::createInContainer<Val>(this , metaDataTypeOf (v));
777816 auto metadata_expr =
778- IrBuilder::createInContainer<GetMetaData>(owner , metadata_val, v);
817+ IrBuilder::createInContainer<GetMetaData>(this , metadata_val, v);
779818 metadata_[v] = std::make_pair (metadata_val, metadata_expr);
780819 }
781820 return metadata_.at (v).first ;
0 commit comments