@@ -184,18 +184,19 @@ struct Fusion::ContainerMutator {
184184 }
185185 }
186186
187- static int64_t numValsExcludingShortcuts (const Fusion* self) noexcept {
188- auto * c = self->ir_container ();
189- // Use direct field access. Avoids re-entering valsOwnedBy() which acquires
190- // shared_lock.
191- const auto it = c->per_fusion_vals_ .find (self);
192- int64_t count = it != c->per_fusion_vals_ .end ()
193- ? static_cast <int64_t >(it->second .size ())
194- : 0 ;
195- count -= (self->zero_val_ != nullptr ) + (self->one_val_ != nullptr ) +
196- (self->true_val_ != nullptr ) + (self->false_val_ != nullptr ) +
197- (self->magic_zero_val_ != nullptr );
198- return count;
187+ // Null out self's shortcut-val pointer cache if v is one of them.
188+ static void nullOutShortcutIfNeeded (Fusion* self, Val* v) {
189+ if (v == self->zero_val_ ) {
190+ self->zero_val_ = nullptr ;
191+ } else if (v == self->one_val_ ) {
192+ self->one_val_ = nullptr ;
193+ } else if (v == self->true_val_ ) {
194+ self->true_val_ = nullptr ;
195+ } else if (v == self->false_val_ ) {
196+ self->false_val_ = nullptr ;
197+ } else if (v == self->magic_zero_val_ ) {
198+ self->magic_zero_val_ = nullptr ;
199+ }
199200 }
200201
201202 static void removeStatementsCreatedAfter (
@@ -204,42 +205,83 @@ struct Fusion::ContainerMutator {
204205 int64_t num_vals_before) {
205206 auto * c = self->ir_container ();
206207
207- // Remove expressions before values because we need to change Val::uses_.
208- while (std::ssize (c->per_fusion_exprs_ [self]) > num_exprs_before) {
209- // Pop from global deque back — statements created by this Fusion during
210- // the guard scope are at the tail (LIFO invariant).
211- Expr* e = c->exprs_up_ .back ().get ();
212- NVF_ERROR (
213- c->per_fusion_exprs_ [self].count (e) > 0 ,
214- " removeStatementsCreatedAfter: tail expr belongs to another Fusion" );
215- for (Val* in : e->inputs ()) {
216- in->removeUse (e);
208+ // Use direct field access — hasMultipleFusions() acquires shared_lock which
209+ // deadlocks since the caller already holds unique_lock on mutex_.
210+ if (c->sharing_fusions_ .size () <= 1 ) {
211+ // Fast path: single Fusion owns this container, so the LIFO invariant
212+ // holds — self's newest statements are always at the global deque tail.
213+ // Remove expressions before values because we need to change Val::uses_.
214+ while (std::ssize (c->per_fusion_exprs_ [self]) > num_exprs_before) {
215+ Expr* e = c->exprs_up_ .back ().get ();
216+ NVF_ERROR (
217+ c->per_fusion_exprs_ [self].count (e) > 0 ,
218+ " removeStatementsCreatedAfter: tail expr belongs to another Fusion" );
219+ for (Val* out : e->outputs ()) {
220+ out->setDefinition (nullptr );
221+ }
222+ for (Val* in : e->inputs ()) {
223+ in->removeUse (e);
224+ }
225+ c->per_fusion_exprs_ [self].erase (e);
226+ c->exprs_ .erase (e);
227+ c->exprs_up_ .pop_back ();
217228 }
218- c->per_fusion_exprs_ [self].erase (e);
219- c->exprs_ .erase (e);
220- c->exprs_up_ .pop_back ();
221- }
222-
223- while (numValsExcludingShortcuts (self) > num_vals_before) {
224- Val* v = c->vals_up_ .back ().get ();
225- NVF_ERROR (
226- c->per_fusion_vals_ [self].count (v) > 0 ,
227- " removeStatementsCreatedAfter: tail val belongs to another Fusion" );
228- // Null out shortcut caches if they point to vals about to be destroyed
229- if (v == self->zero_val_ ) {
230- self->zero_val_ = nullptr ;
231- } else if (v == self->one_val_ ) {
232- self->one_val_ = nullptr ;
233- } else if (v == self->true_val_ ) {
234- self->true_val_ = nullptr ;
235- } else if (v == self->false_val_ ) {
236- self->false_val_ = nullptr ;
237- } else if (v == self->magic_zero_val_ ) {
238- self->magic_zero_val_ = nullptr ;
229+ while (std::ssize (c->per_fusion_vals_ [self]) > num_vals_before) {
230+ Val* v = c->vals_up_ .back ().get ();
231+ NVF_ERROR (
232+ c->per_fusion_vals_ [self].count (v) > 0 ,
233+ " removeStatementsCreatedAfter: tail val belongs to another Fusion" );
234+ nullOutShortcutIfNeeded (self, v);
235+ c->per_fusion_vals_ [self].erase (v);
236+ c->vals_ .erase (v);
237+ c->vals_up_ .pop_back ();
239238 }
240- c->per_fusion_vals_ [self].erase (v);
241- c->vals_ .erase (v);
242- c->vals_up_ .pop_back ();
239+ } else {
240+ // Slow path: shared container — other Fusions' statements may be
241+ // interleaved at the tail of the global deques. Use std::erase_if
242+ // (C++20) to scan forward: skip the first num_before of self's
243+ // statements (old, to keep), then erase the remainder (added during
244+ // the guard scope). Entered whenever the container is shared,
245+ // regardless of success or failure; if no new statements were added
246+ // the scan completes trivially. O(total statements in container).
247+ int64_t exprs_kept = 0 ;
248+ std::erase_if (c->exprs_up_ , [&](const std::unique_ptr<Expr>& e_up) {
249+ Expr* e = e_up.get ();
250+ if (c->per_fusion_exprs_ [self].count (e) == 0 ) {
251+ return false ; // belongs to another Fusion — keep
252+ }
253+ if (exprs_kept < num_exprs_before) {
254+ ++exprs_kept;
255+ return false ; // self's old expr — keep
256+ }
257+ // self's new expr — remove (clean up uses and index maps first)
258+ for (Val* out : e->outputs ()) {
259+ out->setDefinition (nullptr );
260+ }
261+ for (Val* in : e->inputs ()) {
262+ in->removeUse (e);
263+ }
264+ c->per_fusion_exprs_ [self].erase (e);
265+ c->exprs_ .erase (e);
266+ return true ;
267+ });
268+
269+ int64_t vals_kept = 0 ;
270+ std::erase_if (c->vals_up_ , [&](const std::unique_ptr<Val>& v_up) {
271+ Val* v = v_up.get ();
272+ if (c->per_fusion_vals_ [self].count (v) == 0 ) {
273+ return false ; // belongs to another Fusion — keep
274+ }
275+ if (vals_kept < num_vals_before) {
276+ ++vals_kept;
277+ return false ; // self's old val — keep
278+ }
279+ // self's new val — remove (null shortcut cache pointer if applicable)
280+ nullOutShortcutIfNeeded (self, v);
281+ c->per_fusion_vals_ [self].erase (v);
282+ c->vals_ .erase (v);
283+ return true ;
284+ });
243285 }
244286 }
245287};
@@ -622,10 +664,6 @@ void Fusion::removeStatementsCreatedAfter(
622664 this , num_exprs_before, num_vals_before);
623665}
624666
625- int64_t Fusion::numValsExcludingShortcuts () const noexcept {
626- return ContainerMutator::numValsExcludingShortcuts (this );
627- }
628-
629667void Fusion::addInput (Val* input) {
630668 assertInContainer (input, " Cannot register input " );
631669
0 commit comments