Skip to content

Commit f8ff364

Browse files
committed
shared_ptr<IrContainer> transition and Fusion tracking infrastructure
Change Fusion::ir_container_ from unique_ptr to shared_ptr to enable future container sharing between Fusions. Add Fusion tracking API to IrContainer (addFusion/removeFusion/transferFusion/sharingCount). Remove IrContainer::parent_ since the 1:1 relationship no longer holds. Disable parallel compilation during the shared_ptr transition.
1 parent 54cd0fe commit f8ff364

5 files changed

Lines changed: 62 additions & 28 deletions

File tree

csrc/fusion.cpp

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
116116
// update the parent backpointers in those containers to point to their new
117117
// owners
118118
if (a.ir_container_) {
119-
// Also update all Statement ir_container_ pointers to point to new owner
120-
a.ir_container()->parent_ = &a;
121119
for (auto val : a.vals()) {
122120
val->ir_container_ = &a;
123121
}
@@ -126,8 +124,6 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
126124
}
127125
}
128126
if (b.ir_container_) {
129-
// Also update all Statement ir_container_ pointers to point to new owner
130-
b.ir_container()->parent_ = &b;
131127
for (auto val : b.vals()) {
132128
val->ir_container_ = &b;
133129
}
@@ -161,7 +157,8 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
161157
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
162158
to->clear();
163159

164-
auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());
160+
auto ir_cloner =
161+
IrContainer::copy(from->ir_container(), to->ir_container(), to);
165162

166163
// Remap cached special val pointers through the cloner
167164
if (from->zero_val_) {
@@ -254,8 +251,8 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
254251
}
255252

256253
// Default constructor
257-
Fusion::Fusion() : ir_container_(std::make_unique<IrContainer>()) {
258-
ir_container_->parent_ = this;
254+
Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
255+
ir_container_->addFusion(this);
259256
}
260257

261258
// Copy constructor
@@ -287,6 +284,9 @@ Fusion& Fusion::operator=(Fusion&& other) noexcept {
287284

288285
Fusion::~Fusion() {
289286
clear();
287+
if (ir_container_) {
288+
ir_container_->removeFusion(this);
289+
}
290290
}
291291

292292
void Fusion::clear() noexcept {
@@ -350,9 +350,7 @@ void Fusion::removeExpr(Expr* expr) {
350350
auto expr_in_deque = std::find_if(
351351
c->exprs_up_.begin(),
352352
c->exprs_up_.end(),
353-
[expr](std::unique_ptr<Expr>& expr_up) {
354-
return expr_up.get() == expr;
355-
});
353+
[expr](std::unique_ptr<Expr>& expr_up) { return expr_up.get() == expr; });
356354
NVF_ERROR(
357355
expr_in_deque != c->exprs_up_.end(),
358356
"Wanted to remove an expression but its unique ptr is missing.");

csrc/fusion.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,6 @@ class NVF_API Fusion : public PolymorphicBase {
148148
typedef std::unordered_map<int, std::vector<int64_t>> PermutationMap;
149149

150150
protected:
151-
// Direct access to underlying container
152151
IrContainer* ir_container() {
153152
NVF_ERROR(
154153
ir_container_.get() != nullptr,
@@ -163,6 +162,10 @@ class NVF_API Fusion : public PolymorphicBase {
163162
return ir_container_.get();
164163
}
165164

165+
std::shared_ptr<IrContainer> ir_container_ptr() const {
166+
return ir_container_;
167+
}
168+
166169
public:
167170
// Registration (public API with passkey)
168171
virtual void registerStmt(IrBuilderPasskey, Statement* stmt) {
@@ -635,7 +638,7 @@ class NVF_API Fusion : public PolymorphicBase {
635638
std::unique_ptr<std::vector<TensorView*>> all_tvs_ptr_ = nullptr;
636639

637640
inline static const std::string exact_mappings_key = "exact_mappings";
638-
std::unique_ptr<IrContainer> ir_container_;
641+
std::shared_ptr<IrContainer> ir_container_;
639642

640643
Val* zero_val_ = nullptr;
641644
Val* one_val_ = nullptr;

csrc/ir/container.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,15 @@ void IrContainer::swap(IrContainer& a, IrContainer& b) noexcept {
8080

8181
std::swap(a.val_type_name_map_, b.val_type_name_map_);
8282
std::swap(a.expr_name_counter_, b.expr_name_counter_);
83-
84-
std::swap(a.parent_, b.parent_);
8583
}
8684

87-
IrCloner IrContainer::copy(const IrContainer* from, IrContainer* to) {
85+
IrCloner IrContainer::copy(
86+
const IrContainer* from,
87+
IrContainer* to,
88+
Fusion* dest_fusion) {
8889
to->clear();
8990

90-
IrCloner ir_cloner(to->parent());
91+
IrCloner ir_cloner(dest_fusion);
9192

9293
// Copy values in deterministic order
9394
for (auto val : from->deterministic_vals()) {
@@ -138,7 +139,7 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
138139
}
139140

140141
NVF_ERROR(
141-
const_stmt->container() == this->parent(),
142+
sharing_fusions_.count(const_stmt->container()) > 0,
142143
"Container claims to own stmt, but stmt disagrees.");
143144

144145
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
@@ -157,4 +158,29 @@ bool IrContainer::inContainer(const Statement* const_stmt) const {
157158
return true;
158159
}
159160

161+
void IrContainer::addFusion(Fusion* fusion) {
162+
sharing_fusions_.insert(fusion);
163+
}
164+
165+
void IrContainer::removeFusion(Fusion* fusion) {
166+
sharing_fusions_.erase(fusion);
167+
}
168+
169+
void IrContainer::transferFusion(Fusion* from, Fusion* to) {
170+
sharing_fusions_.erase(from);
171+
sharing_fusions_.insert(to);
172+
}
173+
174+
size_t IrContainer::sharingCount() const {
175+
return sharing_fusions_.size();
176+
}
177+
178+
bool IrContainer::hasMultipleFusions() const {
179+
return sharing_fusions_.size() > 1;
180+
}
181+
182+
const std::unordered_set<Fusion*>& IrContainer::sharingFusions() const {
183+
return sharing_fusions_;
184+
}
185+
160186
} // namespace nvfuser

csrc/ir/container.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ class IrContainer {
8686
}
8787

8888
protected:
89-
static IrCloner copy(const IrContainer* from, IrContainer* to);
89+
static IrCloner copy(
90+
const IrContainer* from,
91+
IrContainer* to,
92+
Fusion* dest_fusion);
9093

9194
static void swap(IrContainer& a, IrContainer& b) noexcept;
9295

@@ -127,16 +130,15 @@ class IrContainer {
127130
StmtNameType expr_name_counter_ = 0;
128131

129132
public:
130-
Fusion* parent() const {
131-
NVF_ERROR(
132-
parent_ != nullptr, "Call to IrContainer::parent() holds nullptr.")
133-
return parent_;
134-
}
133+
void addFusion(Fusion* fusion);
134+
void removeFusion(Fusion* fusion);
135+
void transferFusion(Fusion* from, Fusion* to);
136+
size_t sharingCount() const;
137+
bool hasMultipleFusions() const;
138+
const std::unordered_set<Fusion*>& sharingFusions() const;
135139

136140
private:
137-
// Parent Fusion that owns this container (for pure composition pattern)
138-
// Used by Statement::fusion() to navigate back to owning Fusion
139-
Fusion* parent_ = nullptr;
141+
std::unordered_set<Fusion*> sharing_fusions_;
140142
};
141143

142144
} // namespace nvfuser

csrc/runtime/fusion_kernel_runtime.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@
2828

2929
namespace nvfuser {
3030

31+
// TODO: Remove when std::shared_mutex is added to IrContainer.
32+
constexpr bool kPhase2DisableParallelCompile = true;
33+
3134
namespace {
3235
// Replace CUDA tensor with Meta tensor because storing tensors can cause
3336
// out-of-memory issues. Other arguments are returned as-is.
@@ -454,7 +457,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
454457
try {
455458
for (const auto& [group_to_run, group_runtime_inputs] :
456459
zip(runtime_workspace_.group_run_order, all_runtime_inputs)) {
457-
if (num_groups == 1 || isOptionDisabled(DisableOption::ParallelCompile)) {
460+
if (num_groups == 1 || kPhase2DisableParallelCompile ||
461+
isOptionDisabled(DisableOption::ParallelCompile)) {
458462
compileKernel(group_runtime_inputs, group_to_run);
459463
} else {
460464
// launch compileKernel thread here
@@ -488,7 +492,8 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) {
488492
throw;
489493
}
490494

491-
if (num_groups != 1 && !isOptionDisabled(DisableOption::ParallelCompile)) {
495+
if (num_groups != 1 && !kPhase2DisableParallelCompile &&
496+
!isOptionDisabled(DisableOption::ParallelCompile)) {
492497
// Wait until all segments finish compiling
493498
getThreadPool()->waitWorkComplete();
494499
NVF_ERROR(

0 commit comments

Comments
 (0)