Skip to content

Commit b1873c8

Browse files
authored
Segmenter Shared Container Fix (#6025)
Statements cleaned up by statement guard need to be popped from the specific fusion only, not the entire IrContainer.
1 parent 53d4ac2 commit b1873c8

3 files changed

Lines changed: 89 additions & 58 deletions

File tree

csrc/fusion.cpp

Lines changed: 88 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -184,18 +184,19 @@ struct Fusion::ContainerMutator {
184184
}
185185
}
186186

187-
static int64_t numValsExcludingShortcuts(const Fusion* self) noexcept {
188-
auto* c = self->ir_container();
189-
// Use direct field access. Avoids re-entering valsOwnedBy() which acquires
190-
// shared_lock.
191-
const auto it = c->per_fusion_vals_.find(self);
192-
int64_t count = it != c->per_fusion_vals_.end()
193-
? static_cast<int64_t>(it->second.size())
194-
: 0;
195-
count -= (self->zero_val_ != nullptr) + (self->one_val_ != nullptr) +
196-
(self->true_val_ != nullptr) + (self->false_val_ != nullptr) +
197-
(self->magic_zero_val_ != nullptr);
198-
return count;
187+
// Null out self's shortcut-val pointer cache if v is one of them.
188+
static void nullOutShortcutIfNeeded(Fusion* self, Val* v) {
189+
if (v == self->zero_val_) {
190+
self->zero_val_ = nullptr;
191+
} else if (v == self->one_val_) {
192+
self->one_val_ = nullptr;
193+
} else if (v == self->true_val_) {
194+
self->true_val_ = nullptr;
195+
} else if (v == self->false_val_) {
196+
self->false_val_ = nullptr;
197+
} else if (v == self->magic_zero_val_) {
198+
self->magic_zero_val_ = nullptr;
199+
}
199200
}
200201

201202
static void removeStatementsCreatedAfter(
@@ -204,42 +205,83 @@ struct Fusion::ContainerMutator {
204205
int64_t num_vals_before) {
205206
auto* c = self->ir_container();
206207

207-
// Remove expressions before values because we need to change Val::uses_.
208-
while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
209-
// Pop from global deque back — statements created by this Fusion during
210-
// the guard scope are at the tail (LIFO invariant).
211-
Expr* e = c->exprs_up_.back().get();
212-
NVF_ERROR(
213-
c->per_fusion_exprs_[self].count(e) > 0,
214-
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
215-
for (Val* in : e->inputs()) {
216-
in->removeUse(e);
208+
// Use direct field access — hasMultipleFusions() acquires shared_lock which
209+
// deadlocks since the caller already holds unique_lock on mutex_.
210+
if (c->sharing_fusions_.size() <= 1) {
211+
// Fast path: single Fusion owns this container, so the LIFO invariant
212+
// holds — self's newest statements are always at the global deque tail.
213+
// Remove expressions before values because we need to change Val::uses_.
214+
while (std::ssize(c->per_fusion_exprs_[self]) > num_exprs_before) {
215+
Expr* e = c->exprs_up_.back().get();
216+
NVF_ERROR(
217+
c->per_fusion_exprs_[self].count(e) > 0,
218+
"removeStatementsCreatedAfter: tail expr belongs to another Fusion");
219+
for (Val* out : e->outputs()) {
220+
out->setDefinition(nullptr);
221+
}
222+
for (Val* in : e->inputs()) {
223+
in->removeUse(e);
224+
}
225+
c->per_fusion_exprs_[self].erase(e);
226+
c->exprs_.erase(e);
227+
c->exprs_up_.pop_back();
217228
}
218-
c->per_fusion_exprs_[self].erase(e);
219-
c->exprs_.erase(e);
220-
c->exprs_up_.pop_back();
221-
}
222-
223-
while (numValsExcludingShortcuts(self) > num_vals_before) {
224-
Val* v = c->vals_up_.back().get();
225-
NVF_ERROR(
226-
c->per_fusion_vals_[self].count(v) > 0,
227-
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
228-
// Null out shortcut caches if they point to vals about to be destroyed
229-
if (v == self->zero_val_) {
230-
self->zero_val_ = nullptr;
231-
} else if (v == self->one_val_) {
232-
self->one_val_ = nullptr;
233-
} else if (v == self->true_val_) {
234-
self->true_val_ = nullptr;
235-
} else if (v == self->false_val_) {
236-
self->false_val_ = nullptr;
237-
} else if (v == self->magic_zero_val_) {
238-
self->magic_zero_val_ = nullptr;
229+
while (std::ssize(c->per_fusion_vals_[self]) > num_vals_before) {
230+
Val* v = c->vals_up_.back().get();
231+
NVF_ERROR(
232+
c->per_fusion_vals_[self].count(v) > 0,
233+
"removeStatementsCreatedAfter: tail val belongs to another Fusion");
234+
nullOutShortcutIfNeeded(self, v);
235+
c->per_fusion_vals_[self].erase(v);
236+
c->vals_.erase(v);
237+
c->vals_up_.pop_back();
239238
}
240-
c->per_fusion_vals_[self].erase(v);
241-
c->vals_.erase(v);
242-
c->vals_up_.pop_back();
239+
} else {
240+
// Slow path: shared container — other Fusions' statements may be
241+
// interleaved at the tail of the global deques. Use std::erase_if
242+
// (C++20) to scan forward: skip the first num_before of self's
243+
// statements (old, to keep), then erase the remainder (added during
244+
// the guard scope). Entered whenever the container is shared,
245+
// regardless of success or failure; if no new statements were added
246+
// the scan completes trivially. O(total statements in container).
247+
int64_t exprs_kept = 0;
248+
std::erase_if(c->exprs_up_, [&](const std::unique_ptr<Expr>& e_up) {
249+
Expr* e = e_up.get();
250+
if (c->per_fusion_exprs_[self].count(e) == 0) {
251+
return false; // belongs to another Fusion — keep
252+
}
253+
if (exprs_kept < num_exprs_before) {
254+
++exprs_kept;
255+
return false; // self's old expr — keep
256+
}
257+
// self's new expr — remove (clean up uses and index maps first)
258+
for (Val* out : e->outputs()) {
259+
out->setDefinition(nullptr);
260+
}
261+
for (Val* in : e->inputs()) {
262+
in->removeUse(e);
263+
}
264+
c->per_fusion_exprs_[self].erase(e);
265+
c->exprs_.erase(e);
266+
return true;
267+
});
268+
269+
int64_t vals_kept = 0;
270+
std::erase_if(c->vals_up_, [&](const std::unique_ptr<Val>& v_up) {
271+
Val* v = v_up.get();
272+
if (c->per_fusion_vals_[self].count(v) == 0) {
273+
return false; // belongs to another Fusion — keep
274+
}
275+
if (vals_kept < num_vals_before) {
276+
++vals_kept;
277+
return false; // self's old val — keep
278+
}
279+
// self's new val — remove (null shortcut cache pointer if applicable)
280+
nullOutShortcutIfNeeded(self, v);
281+
c->per_fusion_vals_[self].erase(v);
282+
c->vals_.erase(v);
283+
return true;
284+
});
243285
}
244286
}
245287
};
@@ -622,10 +664,6 @@ void Fusion::removeStatementsCreatedAfter(
622664
this, num_exprs_before, num_vals_before);
623665
}
624666

625-
int64_t Fusion::numValsExcludingShortcuts() const noexcept {
626-
return ContainerMutator::numValsExcludingShortcuts(this);
627-
}
628-
629667
void Fusion::addInput(Val* input) {
630668
assertInContainer(input, "Cannot register input ");
631669

csrc/fusion.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -556,13 +556,6 @@ class NVF_API Fusion : public PolymorphicBase {
556556
return std::ssize(ir_container()->valsOwnedBy(this));
557557
}
558558

559-
//! Return per-Fusion val count excluding shortcut vals (zero_val_, etc.).
560-
//! Shortcut vals are registered in both per_fusion_vals_ and vals_up_, but
561-
//! since they're singletons that should persist across StatementGuard scopes,
562-
//! this count excludes them so the LIFO pop-back in
563-
//! removeStatementsCreatedAfter correctly skips over them.
564-
int64_t numValsExcludingShortcuts() const noexcept;
565-
566559
// Shortcut values (frequently used constants)
567560
Val* zeroVal();
568561
Val* oneVal();

csrc/statement_guard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ StatementGuard::StatementGuard(Fusion* fusion)
2020
return fusion;
2121
}()),
2222
prev_num_exprs_(fusion_->numExprs()),
23-
prev_num_vals_(fusion_->numValsExcludingShortcuts()) {}
23+
prev_num_vals_(fusion_->numVals()) {}
2424

2525
StatementGuard::~StatementGuard() {
2626
fusion_->removeStatementsCreatedAfter(prev_num_exprs_, prev_num_vals_);

0 commit comments

Comments
 (0)