Skip to content

Commit 8112846

Browse files
committed
Fix dangling special val pointers after StatementGuard rollback
Special vals (trueVal, falseVal, oneVal, etc.) can be lazily created inside a StatementGuard scope (e.g. by simplifyExpr called from haveDifferentShardings). When the guard rolls back, it pops vals_up_ back to the snapshot, destroying those vals while the Fusion cache pointers still reference them. Subsequent calls return dangling pointers causing UB — this manifested as LoopShardedSplitReshapeIds incorrectly classifying a reshape as resharding on CI. Fusion::removeStatementsCreatedAfter now nulls out any special val cache pointers that are about to be destroyed, so they get re-created on next access.
1 parent 5dbaa1a commit 8112846

3 files changed

Lines changed: 80 additions & 4 deletions

File tree

csrc/fusion.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,34 @@ void Fusion::removeVal(Val* val) {
401401
invalidateTvsAndUses();
402402
}
403403

404+
void Fusion::removeStatementsCreatedAfter(
405+
int64_t num_exprs_before,
406+
int64_t num_vals_before) {
407+
// Before removing vals from vals_up_, null out any special value caches that
408+
// point to vals that are about to be destroyed. This prevents dangling
409+
// pointers when special vals are lazily created inside a StatementGuard scope.
410+
int64_t current_num_vals = ir_container()->numVals();
411+
if (current_num_vals > num_vals_before) {
412+
auto& vals_up = ir_container()->vals_up_;
413+
for (int64_t i = num_vals_before; i < current_num_vals; ++i) {
414+
Val* v = vals_up[i].get();
415+
if (v == zero_val_) {
416+
zero_val_ = nullptr;
417+
} else if (v == one_val_) {
418+
one_val_ = nullptr;
419+
} else if (v == true_val_) {
420+
true_val_ = nullptr;
421+
} else if (v == false_val_) {
422+
false_val_ = nullptr;
423+
} else if (v == magic_zero_val_) {
424+
magic_zero_val_ = nullptr;
425+
}
426+
}
427+
}
428+
ir_container()->removeStatementsCreatedAfter(
429+
num_exprs_before, num_vals_before);
430+
}
431+
404432
void Fusion::addInput(Val* input) {
405433
assertInContainer(input, "Cannot register input ");
406434

csrc/fusion.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -574,10 +574,7 @@ class NVF_API Fusion : public PolymorphicBase {
574574
// Statement removal
575575
void removeStatementsCreatedAfter(
576576
int64_t num_exprs_before,
577-
int64_t num_vals_before) {
578-
ir_container()->removeStatementsCreatedAfter(
579-
num_exprs_before, num_vals_before);
580-
}
577+
int64_t num_vals_before);
581578

582579
protected:
583580
friend SegmentCandidateFinder;

tests/cpp/test_statement_guard.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,55 @@ TEST_F(StatementGuardTest, ExecuteAfterGuard) {
5151
executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__);
5252
}
5353

54+
// Regression test: special vals lazily created inside a StatementGuard scope
55+
// must not become dangling pointers after the guard rolls back.
56+
TEST_F(StatementGuardTest, LazySpecialValsNotDangling) {
57+
auto fusion = std::make_unique<Fusion>();
58+
FusionGuard fg(fusion.get());
59+
60+
TensorView* in = makeContigTensor(1);
61+
fusion->addInput(in);
62+
TensorView* out = set(in);
63+
fusion->addOutput(out);
64+
65+
// Force lazy creation of trueVal/falseVal inside a StatementGuard scope.
66+
// This reproduces the bug where haveDifferentShardings calls simplifyExpr
67+
// inside a StatementGuard, which can lazily create special vals that then
68+
// become dangling pointers when the guard rolls back.
69+
{
70+
StatementGuard sg(fusion.get());
71+
// Directly trigger lazy creation of trueVal and falseVal
72+
fusion->trueVal();
73+
fusion->falseVal();
74+
fusion->oneVal();
75+
}
76+
77+
// After the guard, the special vals should still be valid (re-created if the
78+
// originals were destroyed by the guard's rollback).
79+
Val* z = fusion->zeroVal();
80+
Val* o = fusion->oneVal();
81+
Val* t = fusion->trueVal();
82+
Val* f = fusion->falseVal();
83+
EXPECT_NE(z, nullptr);
84+
EXPECT_NE(o, nullptr);
85+
EXPECT_NE(t, nullptr);
86+
EXPECT_NE(f, nullptr);
87+
EXPECT_TRUE(z->isZeroInt());
88+
EXPECT_TRUE(o->isOneInt());
89+
EXPECT_TRUE(t->isTrue());
90+
EXPECT_TRUE(f->isFalse());
91+
92+
// The fusion should still be executable
93+
FusionExecutorCache executor_cache(std::move(fusion));
94+
at::Tensor in_tensor = at::randn({8}, at::device(at::kCUDA));
95+
auto out_tensors = executor_cache.runFusionWithInputs({in_tensor});
96+
ASSERT_EQ(out_tensors.size(), 1);
97+
testValidate(
98+
executor_cache.fusion(),
99+
{out_tensors[0].as<at::Tensor>()},
100+
{in_tensor},
101+
__LINE__,
102+
__FILE__);
103+
}
104+
54105
} // namespace nvfuser

0 commit comments

Comments
 (0)