@@ -275,7 +275,8 @@ static struct ggml_tensor * qw3lm_build_attn(
275275 int kv_pos,
276276 int kv_len,
277277 int n_tokens,
278- bool use_flash_attn = true ) {
278+ bool use_flash_attn = true ,
279+ bool clamp_fp16 = false ) {
279280
280281 int D = c.head_dim ;
281282 int Nh = c.n_heads ;
@@ -328,6 +329,12 @@ static struct ggml_tensor * qw3lm_build_attn(
328329 k = ggml_cont (ctx, k);
329330 v = ggml_cont (ctx, v);
330331
332+ // Clamp V before F16 cast: sub-Ampere tensor cores accumulate in FP16,
333+ // V projection can overflow to inf which corrupts all subsequent attention
334+ if (clamp_fp16) {
335+ v = ggml_clamp (ctx, v, -65504 .0f , 65504 .0f );
336+ }
337+
331338 // Write K,V to cache at kv_pos
332339 // Cache layout: [D, max_seq, Nkv] f16
333340 size_t nb1 = (size_t )D * ggml_type_size (GGML_TYPE_F16);
@@ -410,10 +417,13 @@ static void qw3lm_forward(Qwen3LM * m, const int * token_ids, int n_tokens,
410417 struct ggml_tensor * attn = qw3lm_build_attn (
411418 ctx, gf, c, ly, norm, positions, mask,
412419 m->kv_k [kv_set][l], m->kv_v [kv_set][l],
413- kv_pos, kv_len, n_tokens, m->use_flash_attn );
420+ kv_pos, kv_len, n_tokens, m->use_flash_attn , m-> clamp_fp16 );
414421
415422 // Residual
416423 hidden = ggml_add (ctx, hidden, attn);
424+ if (m->clamp_fp16 ) {
425+ hidden = ggml_clamp (ctx, hidden, -65504 .0f , 65504 .0f );
426+ }
417427
418428 // Post-attention norm + MLP
419429 norm = qwen3_rms_norm (ctx, hidden, ly->post_attn_layernorm , c.rms_norm_eps );
@@ -577,6 +587,11 @@ static void qw3lm_forward_batch(Qwen3LM * m, const int * token_ids,
577587 k = ggml_cont (ctx, k);
578588 v = ggml_cont (ctx, v);
579589
590+ // Clamp V before F16 cast (sub-Ampere FP16 accumulation overflow)
591+ if (m->clamp_fp16 ) {
592+ v = ggml_clamp (ctx, v, -65504 .0f , 65504 .0f );
593+ }
594+
580595 // Batched attention with 4D KV cache
581596 float scale = 1 .0f / sqrtf ((float )D);
582597
@@ -633,6 +648,9 @@ static void qw3lm_forward_batch(Qwen3LM * m, const int * token_ids,
633648 // Batched O proj
634649 struct ggml_tensor * attn_out = qwen3_linear (ctx, ly->o_proj , attn_cat);
635650 hidden = ggml_add (ctx, hidden, attn_out);
651+ if (m->clamp_fp16 ) {
652+ hidden = ggml_clamp (ctx, hidden, -65504 .0f , 65504 .0f );
653+ }
636654
637655 // Batched FFN
638656 norm = qwen3_rms_norm (ctx, hidden, ly->post_attn_layernorm , c.rms_norm_eps );
@@ -706,8 +724,7 @@ static void qw3lm_free(Qwen3LM * m) {
706724 if (m->sched ) ggml_backend_sched_free (m->sched );
707725 if (m->kv_buf ) ggml_backend_buffer_free (m->kv_buf );
708726 if (m->kv_ctx ) ggml_free (m->kv_ctx );
709- if (m->backend && m->backend != m->cpu_backend ) ggml_backend_free (m->backend );
710- if (m->cpu_backend ) ggml_backend_free (m->cpu_backend );
727+ backend_release (m->backend , m->cpu_backend );
711728 wctx_free (&m->wctx );
712729 *m = {};
713730}
0 commit comments