Skip to content

Commit bed10f3

Browse files
committed
feat(profiling): add NVTX profiling support and elementwise kernel optimizations
- Add USE_NVTX cmake option for NVIDIA Tools Extension profiling - Add nvtx.h helper for scoped profiling ranges - Add NVTX ranges to key operations: attention, RMSNorm, embedding, cross-entropy - Optimize elementwise kernels with no-broadcast fast paths - Remove unnecessary Fill(0) calls where cuBLAS beta=0 overwrites output - Fix dtype mismatch in AccumulateGrad for bf16/fp32 under autocast
1 parent bdec219 commit bed10f3

12 files changed

Lines changed: 301 additions & 118 deletions

File tree

CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ option(USE_CUDA "Support NVIDIA CUDA" OFF)
44
option(PROFILE_MODE "ENABLE PROFILE MODE" OFF)
55
option(USE_OMP "Use OpenMP as backend for Eigen" ON)
66
option(USE_NCCL "Build project for distributed running" ON)
7+
option(USE_NVTX "Enable NVTX profiling" OFF)
78

89
project(infini_train VERSION 0.5.0 LANGUAGES CXX)
910

@@ -41,6 +42,10 @@ if(PROFILE_MODE)
4142
add_compile_definitions(PROFILE_MODE=1)
4243
endif()
4344

45+
if(USE_NVTX)
46+
add_compile_definitions(USE_NVTX)
47+
endif()
48+
4449
# ------------------------------------------------------------------------------
4550
# Sources
4651
# ------------------------------------------------------------------------------

example/common/nvtx.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#pragma once
2+
3+
#ifdef USE_NVTX
4+
#include <nvtx3/nvToolsExt.h>
5+
6+
class NvtxRange {
7+
public:
8+
explicit NvtxRange(const char *name) { nvtxRangePushA(name); }
9+
~NvtxRange() { nvtxRangePop(); }
10+
};
11+
12+
#define NVTX_RANGE(name) NvtxRange nvtx_range_##__LINE__(name)
13+
14+
#else
15+
16+
class NvtxRange {
17+
public:
18+
explicit NvtxRange(const char *) {}
19+
};
20+
21+
#define NVTX_RANGE(name) NvtxRange nvtx_range_##__LINE__(name)
22+
23+
#endif

example/gpt2/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ void Train(const nn::parallel::Rank &rank) {
384384
loss = loss / grad_accum_steps;
385385

386386
// disable autocast for the current step (backward is not under autocast)
387-
autocast_guard.Disable();
387+
// autocast_guard.Disable();
388388

389389
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";
390390

example/llama3/main.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ void Train(const nn::parallel::Rank &rank) {
353353
loss = loss / grad_accum_steps;
354354

355355
// disable autocast for the current step (backward is not under autocast)
356-
autocast_guard.Disable();
356+
// autocast_guard.Disable();
357357

358358
LOG(INFO) << "Rank " << rank.GlobalRank() << ": finish loss forward";
359359

example/llama3/net.cc

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
#include "glog/logging.h"
1515

16+
#include "example/common/nvtx.h"
1617
#include "example/common/utils.h"
1718
#include "infini_train/include/device.h"
1819
#include "infini_train/include/nn/functional.h"
@@ -123,6 +124,7 @@ std::shared_ptr<Tensor> PrecomputeFreqsCis(int64_t dim, int64_t end, float theta
123124
} // namespace
124125

125126
std::vector<std::shared_ptr<Tensor>> SwiGLU::Forward(const std::vector<std::shared_ptr<Tensor>> &x) {
127+
NVTX_RANGE("SwiGLU");
126128
return {x[0] * nn::function::Sigmoid(x[0])};
127129
}
128130

@@ -133,9 +135,19 @@ RMSNorm::RMSNorm(int64_t dim, float eps, infini_train::Device device) : Cloneabl
133135
}
134136

135137
std::vector<std::shared_ptr<Tensor>> RMSNorm::Forward(const std::vector<std::shared_ptr<Tensor>> &x) {
136-
// broadcasted Mul([4, 64, 2048] * [4, 64, 1])
137-
auto norm = x[0] * nn::function::Rsqrt(nn::function::Mean(nn::function::Pow(x[0], 2), -1, true) + eps_);
138-
return {norm * parameters_[kParamWeightName]};
138+
NVTX_RANGE("RMSNorm");
139+
auto input = x[0];
140+
auto input_dtype = input->Dtype();
141+
// Compute entirely in fp32 (like PyTorch), then cast back once — avoids per-op autocast CastKernel overhead
142+
if (input_dtype != DataType::kFLOAT32) {
143+
input = std::make_shared<Tensor>(input->To(DataType::kFLOAT32));
144+
}
145+
auto norm = input * nn::function::Rsqrt(nn::function::Mean(nn::function::Pow(input, 2), -1, true) + eps_);
146+
auto result = norm * parameters_[kParamWeightName]; // both fp32, no cast
147+
if (result->Dtype() != input_dtype) {
148+
result = std::make_shared<Tensor>(result->To(input_dtype));
149+
}
150+
return {result};
139151
}
140152

141153
CausalSelfAttention::CausalSelfAttention(const LLaMA3Config &config)
@@ -202,41 +214,54 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
202214
// -> RoPE on q, k
203215
// q: (B, T, H_local, D)
204216
// k: (B, T, KV_local, D)
205-
std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis);
206-
207-
// TODO(zbl): use kv cache during inference
208-
// if (use_kv_) { ... }
209-
210-
// align n_head in GQA
211-
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
212-
k = RepeatKV(k, n_rep_);
213-
v = RepeatKV(v, n_rep_);
214-
215-
// (B, T, H_local, D) -> (B, H_local, T, D)
216-
q = q->Transpose(1, 2);
217-
k = k->Transpose(1, 2);
218-
v = v->Transpose(1, 2);
219-
220-
// TODO(zbl): support flash attention later
221-
// if (flash_) { ... }
222-
223-
// manual implementation of attention
224-
// this materializes the large (T,T) matrix for all the queries and keys
225-
226-
// q: (B, H_local, T, D)
227-
// k: (B, H_local, T, D) -> (B, H_local, D, T)
228-
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
229-
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
230-
if (mask) {
231-
// mask: (1, 1, T, T)
232-
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
217+
std::shared_ptr<Tensor> y;
218+
{
219+
NVTX_RANGE("AttentionForward");
220+
{
221+
NVTX_RANGE("RoPE");
222+
std::tie(q, k) = ApplyRotaryEmbedding(q, k, freqs_cis);
223+
}
224+
225+
// TODO(zbl): use kv cache during inference
226+
// if (use_kv_) { ... }
227+
228+
// align n_head in GQA
229+
// (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
230+
{
231+
NVTX_RANGE("RepeatKV");
232+
k = RepeatKV(k, n_rep_);
233+
v = RepeatKV(v, n_rep_);
234+
}
235+
236+
// (B, T, H_local, D) -> (B, H_local, T, D)
237+
q = q->Transpose(1, 2);
238+
k = k->Transpose(1, 2);
239+
v = v->Transpose(1, 2);
240+
241+
// TODO(zbl): support flash attention later
242+
// if (flash_) { ... }
243+
244+
// manual implementation of attention
245+
// this materializes the large (T,T) matrix for all the queries and keys
246+
247+
// q: (B, H_local, T, D)
248+
// k: (B, H_local, T, D) -> (B, H_local, D, T)
249+
// q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
250+
auto att = q->Matmul(k->Transpose(-2, -1)) * (1.0 / std::sqrt(static_cast<float>(D)));
251+
if (mask) {
252+
// mask: (1, 1, T, T)
253+
att = att->MaskedFill(mask, std::numeric_limits<float>::lowest());
254+
}
255+
// (B, H_local, T, T)
256+
{
257+
NVTX_RANGE("Softmax");
258+
att = nn::function::Softmax(att, -1);
259+
}
260+
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
261+
y = att->Matmul(v);
262+
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
263+
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
233264
}
234-
// (B, H_local, T, T)
235-
att = nn::function::Softmax(att, -1);
236-
// att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
237-
auto y = att->Matmul(v);
238-
// (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
239-
y = y->Transpose(1, 2)->Contiguous()->View({B, T, C_local});
240265
// output projection
241266
// (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
242267
y = (*modules_[kCProjLayerName])({y})[0];
@@ -329,6 +354,7 @@ LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : CloneableModule
329354
}
330355

331356
std::vector<std::shared_ptr<Tensor>> LLaMA3FirstStage::Forward(const std::vector<std::shared_ptr<Tensor>> &x) {
357+
NVTX_RANGE("Embedding");
332358
return (*modules_[LLaMA3FirstStage::kWTELayerName])(x);
333359
}
334360

@@ -360,8 +386,12 @@ std::vector<std::shared_ptr<Tensor>> LLaMA3Chunk::Forward(const std::vector<std:
360386
auto freqs_view = buffers_[kFreqsCisName]->Slice(0, start_pos, start_pos + t, 1);
361387

362388
// TODO(lzm): add dtype support for nn::function::Ones later
363-
std::shared_ptr<Tensor> ones = std::make_shared<Tensor>(nn::function::Ones({t, t})->To(x1->GetDevice()));
364-
std::shared_ptr<Tensor> mask = nn::function::Triu(ones, 1)->View({1, 1, t, t});
389+
std::shared_ptr<Tensor> mask;
390+
{
391+
NVTX_RANGE("BuildMask");
392+
std::shared_ptr<Tensor> ones = std::make_shared<Tensor>(nn::function::Ones({t, t})->To(x1->GetDevice()));
393+
mask = nn::function::Triu(ones, 1)->View({1, 1, t, t});
394+
}
365395

366396
std::shared_ptr<Tensor> start_pos_ptr = nullptr;
367397

@@ -386,6 +416,7 @@ LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : CloneableModule(k
386416
}
387417

388418
std::vector<std::shared_ptr<Tensor>> LLaMA3LastStage::Forward(const std::vector<std::shared_ptr<Tensor>> &x) {
419+
NVTX_RANGE("LMHead");
389420
// (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd)
390421
auto x1 = (*modules_[kLnFLayerName])(x);
391422

infini_train/src/autograd/accumulate.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ AccumulateGrad::Backward(const std::vector<std::shared_ptr<Tensor>> &grad_output
2626
core::DeviceGuard guard(device);
2727

2828
if (grad_output) {
29+
// Cast grad to match parameter dtype (e.g. bf16 grad -> fp32 param under autocast)
30+
if (grad_output->Dtype() != tensor_->Dtype()) {
31+
grad_output = std::make_shared<Tensor>(grad_output->To(tensor_->Dtype()));
32+
}
33+
2934
if (grad) {
3035
if (tensor_->ConsumeGradOverwriteFlag()) {
3136
// If the tensor is marked to overrite its current grad on next grad update

infini_train/src/autograd/elementwise.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,11 @@ std::vector<std::shared_ptr<Tensor>> Add::Backward(const std::vector<std::shared
390390
CHECK_EQ(grad_outputs.size(), 1);
391391
const auto &grad_output = grad_outputs[0];
392392

393+
// Fast path: no broadcast — grad_a and grad_b are both just grad_output
394+
if (a_dims_ == b_dims_) {
395+
return {grad_output, grad_output};
396+
}
397+
393398
auto device = grad_output->GetDevice().type();
394399
auto [grad_a, grad_b] = Dispatcher::Instance().Call<std::pair<std::shared_ptr<Tensor>, std::shared_ptr<Tensor>>>(
395400
{device, "AddBackward"}, grad_output, a_dims_, b_dims_);

infini_train/src/autograd/function.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
#include "infini_train/include/autograd/function.h"
22

3+
#ifdef USE_NVTX
4+
#include <nvtx3/nvToolsExt.h>
5+
#endif
6+
37
#include "glog/logging.h"
48

59
#include "infini_train/include/autograd/accumulate.h"
@@ -115,9 +119,15 @@ void Function::BackwardPartial(const std::shared_ptr<Tensor> &grad_output, int g
115119

116120
std::vector<std::shared_ptr<Tensor>> grad_inputs;
117121
{
122+
#ifdef USE_NVTX
123+
nvtxRangePushA(type().c_str());
124+
#endif
118125
autograd::NoGradGuard no_grad;
119126
// no_grad in autograd.Function.Backward()
120127
grad_inputs = Backward(grad_outputs_);
128+
#ifdef USE_NVTX
129+
nvtxRangePop();
130+
#endif
121131
}
122132

123133
// Call backward post-hooks

0 commit comments

Comments
 (0)