Skip to content

Commit f273b16

Browse files
committed
Per-Fusion Axioms and Metadata
Moved `axioms_` and `metadata_` from `IrContainer` to the `Fusion` class. This completes the deprecation of `parent_` usage for val-creating methods, which was necessary because `parent_` implies a 1-1 relationship (container → Fusion), but Phase 2 has 1-many (shared containers). Methods that used `parent_` to create vals were moved to Fusion: - `metadataOf(Val*)` - Now uses `v->container()` to get owning Fusion - `axioms()` - Now creates axiom vals owned by `this` Fusion - `assumePositive/assumeNonNegative` - Now adds to `this` Fusion's axioms - Added `axioms_` and `metadata_` private members - Changed method declarations from forwarding to actual implementations - Added includes for `ir/builder.h` and `ir/internal_nodes.h` - Implemented `metadataOf()`, `axioms()`, `assumePositive()`, `assumeNonNegative()` methods - Updated `clear()` to reset `axioms_` and `metadata_` - Removed `metadataOf()`, `axioms()`, `assumePositive()`, `assumeNonNegative()` declarations - Removed `lazyInitAxioms()` declaration - Removed `axioms_` and `metadata_` members - Removed implementations of above methods - Updated `IrContainer::swap` to remove axioms_/metadata_ swapping - Updated `IrContainer::copy` to remove axioms_/metadata_ handling - Updated `IrContainer::clear` to remove axioms_/metadata_ clearing Each Fusion now has its own axioms and metadata cache. This ensures: 1. No ownership conflicts when multiple Fusions share an IrContainer 2. Correct behavior when one Fusion is destroyed (doesn't affect others) 3. Lazy creation pattern preserved (create on first access) This is a prerequisite for the copy/move semantics changes which will swap/transfer these per-Fusion members.
1 parent d6c9b7c commit f273b16

4 files changed

Lines changed: 80 additions & 90 deletions

File tree

csrc/fusion.cpp

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
#include <host_ir/container.h>
2121
#include <instrumentation.h>
2222
#include <ir/all_nodes.h>
23+
#include <ir/builder.h>
2324
#include <ir/cloner.h>
25+
#include <ir/internal_nodes.h>
2426
#include <ir/printer.h>
2527
#include <ir/utils.h>
2628
#include <iter_visitor.h>
@@ -280,6 +282,10 @@ void Fusion::clear() noexcept {
280282
false_val_ = nullptr;
281283
magic_zero_val_ = nullptr;
282284

285+
// Reset per-Fusion axioms and metadata (Phase 2)
286+
axioms_.reset();
287+
metadata_.clear();
288+
283289
invalidateTvsAndUses();
284290

285291
is_during_update_uses_ = false;
@@ -768,6 +774,54 @@ Val* Fusion::oneVal(DataType dtype) {
768774
}
769775
}
770776

777+
// =========================================================================
778+
// Per-Fusion Metadata and Axioms (Phase 2)
779+
// These are per-Fusion to avoid ownership issues with shared containers.
780+
// =========================================================================
781+
782+
Val* Fusion::metadataOf(Val* v) {
783+
if (metadata_.count(v) == 0) {
784+
// Create metadata val owned by the same Fusion as v
785+
Fusion* owner = v->container();
786+
auto metadata_val =
787+
IrBuilder::createInContainer<Val>(owner, metaDataTypeOf(v));
788+
auto metadata_expr =
789+
IrBuilder::createInContainer<GetMetaData>(owner, metadata_val, v);
790+
metadata_[v] = std::make_pair(metadata_val, metadata_expr);
791+
}
792+
return metadata_.at(v).first;
793+
}
794+
795+
const std::vector<Val*>& Fusion::axioms() {
796+
if (!axioms_) {
797+
axioms_ = std::make_unique<std::vector<Val*>>();
798+
axioms_->reserve(kParallelTypeThreads.size() * 3);
799+
auto zero = zeroVal();
800+
for (auto p : kParallelTypeThreads) {
801+
auto pidx = NamedScalar::getParallelIndex(p);
802+
auto pdim = NamedScalar::getParallelDim(p);
803+
axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero));
804+
axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero));
805+
axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim));
806+
}
807+
}
808+
return *axioms_;
809+
}
810+
811+
void Fusion::assumePositive(Val* val) {
812+
NVF_ERROR(inContainer(val));
813+
// Lazy init axioms, then add the assumption
814+
axioms();
815+
axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal()));
816+
}
817+
818+
void Fusion::assumeNonNegative(Val* val) {
819+
NVF_ERROR(inContainer(val));
820+
// Lazy init axioms, then add the assumption
821+
axioms();
822+
axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal()));
823+
}
824+
771825
void Fusion::registerVal(Val* val) {
772826
if (inContainer(val)) {
773827
return;

csrc/fusion.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -566,22 +566,15 @@ class NVF_API Fusion : public PolymorphicBase {
566566
Val* zeroVal(DataType dtype);
567567
Val* oneVal(DataType dtype);
568568

569-
Val* metadataOf(Val* val) {
570-
return ir_container()->metadataOf(val);
571-
}
569+
// Phase 2: Per-Fusion metadata and axioms
570+
// These are now per-Fusion to avoid ownership issues with shared containers.
571+
Val* metadataOf(Val* val);
572572

573573
// Axioms (CUDA programming assumptions)
574-
const std::vector<Val*>& axioms() {
575-
return ir_container()->axioms();
576-
}
574+
const std::vector<Val*>& axioms();
577575

578-
void assumePositive(Val* val) {
579-
ir_container()->assumePositive(val);
580-
}
581-
582-
void assumeNonNegative(Val* val) {
583-
ir_container()->assumeNonNegative(val);
584-
}
576+
void assumePositive(Val* val);
577+
void assumeNonNegative(Val* val);
585578

586579
// Statement removal
587580
void removeStatementsCreatedAfter(
@@ -661,6 +654,14 @@ class NVF_API Fusion : public PolymorphicBase {
661654
Val* true_val_ = nullptr;
662655
Val* false_val_ = nullptr;
663656
NamedScalar* magic_zero_val_ = nullptr;
657+
658+
// Phase 2: Per-Fusion axioms (CUDA programming assumptions)
659+
// These are per-Fusion to avoid ownership issues with shared containers.
660+
std::unique_ptr<std::vector<Val*>> axioms_;
661+
662+
// Phase 2: Per-Fusion metadata cache
663+
// Maps Val* to (metadata_val, metadata_expr) pairs
664+
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
664665
};
665666

666667
// Template implementations for Fusion::manage<T>() that use IrCloner

csrc/ir/container.cpp

Lines changed: 7 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
8181
std::swap(a.val_type_name_map_, b.val_type_name_map_);
8282
std::swap(a.expr_name_counter_, b.expr_name_counter_);
8383

84-
std::swap(a.metadata_, b.metadata_);
85-
8684
std::swap(a.parent_, b.parent_);
8785

88-
// Note: Special values (zero_val_, one_val_, etc.) are now per-Fusion,
89-
// not per-IrContainer. They are swapped as part of the Fusion-level swap.
90-
std::swap(a.axioms_, b.axioms_);
86+
// Note: Special values, axioms, and metadata are now per-Fusion,
87+
// not per-IrContainer. They are handled by Fusion::swap.
9188
}
9289

9390
IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
9491
to->clear();
92+
9593
IrCloner ir_cloner(to->parent());
9694

9795
// Copy values in deterministic order
@@ -113,14 +111,7 @@ IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
113111
to->val_type_name_map_ = from->val_type_name_map_;
114112
to->expr_name_counter_ = from->expr_name_counter_;
115113

116-
if (from->axioms_ != nullptr) {
117-
to->axioms_ = std::make_unique<std::vector<Val*>>();
118-
for (auto pred : *from->axioms_) {
119-
to->axioms_->push_back(ir_cloner.clone(pred));
120-
}
121-
}
122-
123-
to->metadata_ = ir_cloner.clone(from->metadata_);
114+
// Note: axioms and metadata are now per-Fusion, handled by Fusion::copy
124115

125116
return ir_cloner;
126117
}
@@ -201,9 +192,7 @@ void IrContainer::clear() noexcept {
201192
vals_up_.clear();
202193
exprs_.clear();
203194
exprs_up_.clear();
204-
axioms_.reset();
205195
val_type_name_map_.clear();
206-
metadata_.clear();
207196
expr_name_counter_ = 0;
208197
}
209198

@@ -239,51 +228,11 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
239228
return true;
240229
}
241230

242-
// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal)
243-
// are now per-Fusion. Use Fusion::zeroVal() etc. instead.
231+
// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal),
232+
// metadata, and axioms are now per-Fusion. Use Fusion::zeroVal(),
233+
// Fusion::metadataOf(), Fusion::axioms(), etc. instead.
244234
// This avoids ownership conflicts when multiple Fusions share an IrContainer.
245235

246-
Val* IrContainer::metadataOf(Val* v) {
247-
if (metadata_.count(v) == 0) {
248-
auto metadata_val =
249-
IrBuilder::createInContainer<Val>(this->parent(), metaDataTypeOf(v));
250-
auto metadata_expr = IrBuilder::createInContainer<GetMetaData>(
251-
this->parent(), metadata_val, v);
252-
metadata_[v] = std::make_pair(metadata_val, metadata_expr);
253-
}
254-
return metadata_.at(v).first;
255-
}
256-
257-
void IrContainer::lazyInitAxioms() {
258-
if (!axioms_) {
259-
axioms_ = std::make_unique<std::vector<Val*>>();
260-
axioms_->reserve(kParallelTypeThreads.size() * 3);
261-
// Use parent()->zeroVal() since special values are now per-Fusion
262-
auto zero = parent()->zeroVal();
263-
for (auto p : kParallelTypeThreads) {
264-
auto pidx = NamedScalar::getParallelIndex(p);
265-
auto pdim = NamedScalar::getParallelDim(p);
266-
axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero));
267-
axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero));
268-
axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim));
269-
}
270-
}
271-
}
272-
273-
void IrContainer::assumePositive(Val* val) {
274-
NVF_ERROR(val->container() == this->parent());
275-
lazyInitAxioms();
276-
// Use parent()->zeroVal() since special values are now per-Fusion
277-
axioms_->emplace_back(IrBuilder::gtExpr(val, parent()->zeroVal()));
278-
}
279-
280-
void IrContainer::assumeNonNegative(Val* val) {
281-
NVF_ERROR(val->container() == this->parent());
282-
lazyInitAxioms();
283-
// Use parent()->zeroVal() since special values are now per-Fusion
284-
axioms_->emplace_back(IrBuilder::geExpr(val, parent()->zeroVal()));
285-
}
286-
287236
void IrContainer::removeStatementsCreatedAfter(
288237
int64_t prev_num_exprs,
289238
int64_t prev_num_vals) {

csrc/ir/container.h

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,11 @@ class IrContainer {
8888
return include_shortcuts ? std::ssize(vals_) : std::ssize(vals_up_);
8989
}
9090

91-
// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal)
92-
// are now per-Fusion. Use Fusion::zeroVal() etc. instead.
91+
// Note: Shortcut values (zeroVal, oneVal, trueVal, falseVal, magicZeroVal),
92+
// metadata, and axioms are now per-Fusion. Use Fusion::zeroVal(),
93+
// Fusion::metadataOf(), Fusion::axioms(), etc. instead.
9394
// This avoids ownership conflicts when multiple Fusions share an IrContainer.
9495

95-
Val* metadataOf(Val*);
96-
97-
// Axioms about CUDA programming, for example: threadIdx.x < blockDim.x
98-
const std::vector<Val*>& axioms() {
99-
lazyInitAxioms();
100-
return *axioms_;
101-
}
102-
103-
void assumePositive(Val* val);
104-
void assumeNonNegative(Val* val);
105-
10696
protected:
10797
static IrCloner copy(const IrContainer* from, IrContainer* to);
10898

@@ -136,8 +126,6 @@ class IrContainer {
136126

137127
void clear() noexcept;
138128

139-
void lazyInitAxioms();
140-
141129
friend class StatementGuard;
142130

143131
// A simple garbage collection mechanism to remove all Exprs and Vals that
@@ -173,10 +161,8 @@ class IrContainer {
173161
// Note: Special values (zero_val_, one_val_, true_val_, false_val_,
174162
// magic_zero_val_) are now per-Fusion, stored in Fusion class.
175163
// This avoids ownership conflicts when multiple Fusions share an IrContainer.
176-
// See Fusion::zeroVal(), etc. for the per-Fusion implementation.
177-
178-
std::unique_ptr<std::vector<Val*>> axioms_;
179-
std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
164+
// See Fusion::zeroVal(), Fusion::axioms(), Fusion::metadataOf(), etc.
165+
// for the per-Fusion implementations.
180166

181167
public:
182168
Fusion* parent() const {

0 commit comments

Comments
 (0)