diff --git a/infini_train/src/nn/parallel/ddp/reducer.cc b/infini_train/src/nn/parallel/ddp/reducer.cc index 1bdd29e1..a13a3937 100644 --- a/infini_train/src/nn/parallel/ddp/reducer.cc +++ b/infini_train/src/nn/parallel/ddp/reducer.cc @@ -285,14 +285,19 @@ void Reducer::PrepareForBackward() { // If ZeroGrad(set_to_none=True), grad is nullptr at this point // If ZeroGrad(set_to_none=False), grad is set to view of bucket.contents (or modified by user) // Either way, we reset grad to view of bucket.contents - // Since bucket.contents might not be zeroed, we need to overwrite it on next grad accumulation + // Since bucket.contents might not be zeroed, we might need to overwrite it on next grad accumulation if (!grad || (grad.get() != view.get())) { if (grad) { - LOG(WARNING) << "gradient_as_bucket_view is enabled, but param " << param - << " has a non-view grad tensor. Automatically overwriting it with bucket view."; + // Buckets might be rebuilt between micro-batches when grad accumulation is on. + // In this case, the old grad may point to a diffrent previous bucket view, the new bucket view + // should continue accumulating from the same value left in old bucket view on next backward. + view->CopyFrom(grad); + } else { + // In this case, ZeroGrad(set_to_none=true) leaves grad null. + // Bucket view may contain stale data, so we must overwrite it on next backward. + param->MarkGradOverwriteOnNextAccum(); } param->set_grad(view); - param->MarkGradOverwriteOnNextAccum(); } } } diff --git a/scripts/test_config.json b/scripts/test_config.json index aa7de8d1..70563dbd 100644 --- a/scripts/test_config.json +++ b/scripts/test_config.json @@ -10,7 +10,7 @@ "CKPT_ROOT_DIR": "/data1/ckpt", "COMPARE_LOG_DIR": "", "RUN_CTEST": "true", - "RUN_PROFILE_TEST": "false" + "RUN_PROFILE_TEST": "true" }, "basic_compile_commands": [ { @@ -673,4 +673,4 @@ ] } ] -} \ No newline at end of file +}