1313
1414#include " glog/logging.h"
1515
16+ #include " example/common/nvtx.h"
1617#include " example/common/utils.h"
1718#include " infini_train/include/device.h"
1819#include " infini_train/include/nn/functional.h"
@@ -123,6 +124,7 @@ std::shared_ptr<Tensor> PrecomputeFreqsCis(int64_t dim, int64_t end, float theta
123124} // namespace
124125
125126std::vector<std::shared_ptr<Tensor>> SwiGLU::Forward (const std::vector<std::shared_ptr<Tensor>> &x) {
127+ NVTX_RANGE (" SwiGLU" );
126128 return {x[0 ] * nn::function::Sigmoid (x[0 ])};
127129}
128130
@@ -133,9 +135,19 @@ RMSNorm::RMSNorm(int64_t dim, float eps, infini_train::Device device) : Cloneabl
133135}
134136
135137std::vector<std::shared_ptr<Tensor>> RMSNorm::Forward (const std::vector<std::shared_ptr<Tensor>> &x) {
136- // broadcasted Mul([4, 64, 2048] * [4, 64, 1])
137- auto norm = x[0 ] * nn::function::Rsqrt (nn::function::Mean (nn::function::Pow (x[0 ], 2 ), -1 , true ) + eps_);
138- return {norm * parameters_[kParamWeightName ]};
138+ NVTX_RANGE (" RMSNorm" );
139+ auto input = x[0 ];
140+ auto input_dtype = input->Dtype ();
141+ // Compute entirely in fp32 (like PyTorch), then cast back once — avoids per-op autocast CastKernel overhead
142+ if (input_dtype != DataType::kFLOAT32 ) {
143+ input = std::make_shared<Tensor>(input->To (DataType::kFLOAT32 ));
144+ }
145+ auto norm = input * nn::function::Rsqrt (nn::function::Mean (nn::function::Pow (input, 2 ), -1 , true ) + eps_);
146+ auto result = norm * parameters_[kParamWeightName ]; // both fp32, no cast
147+ if (result->Dtype () != input_dtype) {
148+ result = std::make_shared<Tensor>(result->To (input_dtype));
149+ }
150+ return {result};
139151}
140152
141153CausalSelfAttention::CausalSelfAttention (const LLaMA3Config &config)
@@ -202,41 +214,54 @@ std::vector<std::shared_ptr<Tensor>> CausalSelfAttention::Forward(const std::vec
202214 // -> RoPE on q, k
203215 // q: (B, T, H_local, D)
204216 // k: (B, T, KV_local, D)
205- std::tie (q, k) = ApplyRotaryEmbedding (q, k, freqs_cis);
206-
207- // TODO(zbl): use kv cache during inference
208- // if (use_kv_) { ... }
209-
210- // align n_head in GQA
211- // (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
212- k = RepeatKV (k, n_rep_);
213- v = RepeatKV (v, n_rep_);
214-
215- // (B, T, H_local, D) -> (B, H_local, T, D)
216- q = q->Transpose (1 , 2 );
217- k = k->Transpose (1 , 2 );
218- v = v->Transpose (1 , 2 );
219-
220- // TODO(zbl): support flash attention later
221- // if (flash_) { ... }
222-
223- // manual implementation of attention
224- // this materializes the large (T,T) matrix for all the queries and keys
225-
226- // q: (B, H_local, T, D)
227- // k: (B, H_local, T, D) -> (B, H_local, D, T)
228- // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
229- auto att = q->Matmul (k->Transpose (-2 , -1 )) * (1.0 / std::sqrt (static_cast <float >(D)));
230- if (mask) {
231- // mask: (1, 1, T, T)
232- att = att->MaskedFill (mask, std::numeric_limits<float >::lowest ());
217+ std::shared_ptr<Tensor> y;
218+ {
219+ NVTX_RANGE (" AttentionForward" );
220+ {
221+ NVTX_RANGE (" RoPE" );
222+ std::tie (q, k) = ApplyRotaryEmbedding (q, k, freqs_cis);
223+ }
224+
225+ // TODO(zbl): use kv cache during inference
226+ // if (use_kv_) { ... }
227+
228+ // align n_head in GQA
229+ // (B, T, KV_local, D) -> (B, T, H_local, D) via RepeatKV
230+ {
231+ NVTX_RANGE (" RepeatKV" );
232+ k = RepeatKV (k, n_rep_);
233+ v = RepeatKV (v, n_rep_);
234+ }
235+
236+ // (B, T, H_local, D) -> (B, H_local, T, D)
237+ q = q->Transpose (1 , 2 );
238+ k = k->Transpose (1 , 2 );
239+ v = v->Transpose (1 , 2 );
240+
241+ // TODO(zbl): support flash attention later
242+ // if (flash_) { ... }
243+
244+ // manual implementation of attention
245+ // this materializes the large (T,T) matrix for all the queries and keys
246+
247+ // q: (B, H_local, T, D)
248+ // k: (B, H_local, T, D) -> (B, H_local, D, T)
249+ // q @ k.T: (B, H_local, T, T) -> mul 1.0 / sqrt(D) -> (B, H_local, T, T)
250+ auto att = q->Matmul (k->Transpose (-2 , -1 )) * (1.0 / std::sqrt (static_cast <float >(D)));
251+ if (mask) {
252+ // mask: (1, 1, T, T)
253+ att = att->MaskedFill (mask, std::numeric_limits<float >::lowest ());
254+ }
255+ // (B, H_local, T, T)
256+ {
257+ NVTX_RANGE (" Softmax" );
258+ att = nn::function::Softmax (att, -1 );
259+ }
260+ // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
261+ y = att->Matmul (v);
262+ // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
263+ y = y->Transpose (1 , 2 )->Contiguous ()->View ({B, T, C_local});
233264 }
234- // (B, H_local, T, T)
235- att = nn::function::Softmax (att, -1 );
236- // att: (B, H_local, T, T) @ v: (B, H_local, T, D) -> y: (B, H_local, T, D)
237- auto y = att->Matmul (v);
238- // (B, H_local, T, D) -> Transpose(1, 2) -> (B, T, H_local, D) -> (B, T, C_local)
239- y = y->Transpose (1 , 2 )->Contiguous ()->View ({B, T, C_local});
240265 // output projection
241266 // (B, T, C_local) -> RowParallelLinear(C, C) -> (B, T, C)
242267 y = (*modules_[kCProjLayerName ])({y})[0 ];
@@ -329,6 +354,7 @@ LLaMA3FirstStage::LLaMA3FirstStage(const LLaMA3Config &config) : CloneableModule
329354}
330355
331356std::vector<std::shared_ptr<Tensor>> LLaMA3FirstStage::Forward (const std::vector<std::shared_ptr<Tensor>> &x) {
357+ NVTX_RANGE (" Embedding" );
332358 return (*modules_[LLaMA3FirstStage::kWTELayerName ])(x);
333359}
334360
@@ -360,8 +386,12 @@ std::vector<std::shared_ptr<Tensor>> LLaMA3Chunk::Forward(const std::vector<std:
360386 auto freqs_view = buffers_[kFreqsCisName ]->Slice (0 , start_pos, start_pos + t, 1 );
361387
362388 // TODO(lzm): add dtype support for nn::function::Ones later
363- std::shared_ptr<Tensor> ones = std::make_shared<Tensor>(nn::function::Ones ({t, t})->To (x1->GetDevice ()));
364- std::shared_ptr<Tensor> mask = nn::function::Triu (ones, 1 )->View ({1 , 1 , t, t});
389+ std::shared_ptr<Tensor> mask;
390+ {
391+ NVTX_RANGE (" BuildMask" );
392+ std::shared_ptr<Tensor> ones = std::make_shared<Tensor>(nn::function::Ones ({t, t})->To (x1->GetDevice ()));
393+ mask = nn::function::Triu (ones, 1 )->View ({1 , 1 , t, t});
394+ }
365395
366396 std::shared_ptr<Tensor> start_pos_ptr = nullptr ;
367397
@@ -386,6 +416,7 @@ LLaMA3LastStage::LLaMA3LastStage(const LLaMA3Config &config) : CloneableModule(k
386416}
387417
388418std::vector<std::shared_ptr<Tensor>> LLaMA3LastStage::Forward (const std::vector<std::shared_ptr<Tensor>> &x) {
419+ NVTX_RANGE (" LMHead" );
389420 // (bs, seq_len, n_embd) -> RMSNorm -> (bs, seq_len, n_embd)
390421 auto x1 = (*modules_[kLnFLayerName ])(x);
391422
0 commit comments