Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions infini_train/src/nn/parallel/ddp/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,19 @@ void Reducer::PrepareForBackward() {
// If ZeroGrad(set_to_none=True), grad is nullptr at this point
// If ZeroGrad(set_to_none=False), grad is set to view of bucket.contents (or modified by user)
// Either way, we reset grad to view of bucket.contents
// Since bucket.contents might not be zeroed, we need to overwrite it on next grad accumulation
// Since bucket.contents might not be zeroed, we might need to overwrite it on next grad accumulation
if (!grad || (grad.get() != view.get())) {
if (grad) {
LOG(WARNING) << "gradient_as_bucket_view is enabled, but param " << param
<< " has a non-view grad tensor. Automatically overwriting it with bucket view.";
// Buckets might be rebuilt between micro-batches when grad accumulation is on.
// In this case, the old grad may point to a diffrent previous bucket view, the new bucket view
// should continue accumulating from the same value left in old bucket view on next backward.
view->CopyFrom(grad);
} else {
// In this case, ZeroGrad(set_to_none=true) leaves grad null.
// Bucket view may contain stale data, so we must overwrite it on next backward.
param->MarkGradOverwriteOnNextAccum();
}
param->set_grad(view);
param->MarkGradOverwriteOnNextAccum();
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions scripts/test_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"CKPT_ROOT_DIR": "/data1/ckpt",
"COMPARE_LOG_DIR": "",
"RUN_CTEST": "true",
"RUN_PROFILE_TEST": "false"
"RUN_PROFILE_TEST": "true"
},
"basic_compile_commands": [
{
Expand Down Expand Up @@ -673,4 +673,4 @@
]
}
]
}
}
Loading