Skip to content

Commit 03560af

Browse files
authored
Merge branch 'ServeurpersoCom:master' into master
2 parents 85b1e29 + 973a88e commit 03560af

6 files changed

Lines changed: 50 additions & 12 deletions

File tree

src/backend.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,22 @@ struct BackendPair {
2222
int gpu_cc; // CUDA compute capability (e.g. 720 for sm_72), 0 if not CUDA
2323
};
2424

25+
// Cached backend state (shared across all modules in the same binary)
26+
static BackendPair g_backend_cache = {};
27+
static int g_backend_refs = 0;
28+
2529
// Initialize backends: load all available (CUDA, Metal, Vulkan...),
2630
// pick the best one, keep CPU as fallback.
2731
// label: log prefix, e.g. "DiT", "VAE", "LM"
32+
// Subsequent calls reuse the same backend (single VMM pool).
2833
static BackendPair backend_init(const char * label) {
34+
if (g_backend_refs > 0) {
35+
g_backend_refs++;
36+
fprintf(stderr, "[Load] %s backend: %s (shared)\n",
37+
label, ggml_backend_name(g_backend_cache.backend));
38+
return g_backend_cache;
39+
}
40+
2941
ggml_backend_load_all();
3042
BackendPair bp = {};
3143
bp.backend = ggml_backend_init_best();
@@ -54,9 +66,22 @@ static BackendPair backend_init(const char * label) {
5466
}
5567
#endif
5668

69+
g_backend_cache = bp;
70+
g_backend_refs = 1;
5771
return bp;
5872
}
5973

74+
// Release a backend reference. Frees GPU + CPU backends when refcount hits 0.
75+
static void backend_release(ggml_backend_t backend, ggml_backend_t cpu_backend) {
76+
if (g_backend_refs <= 0) return;
77+
g_backend_refs--;
78+
if (g_backend_refs == 0) {
79+
if (backend && backend != cpu_backend) ggml_backend_free(backend);
80+
if (cpu_backend) ggml_backend_free(cpu_backend);
81+
g_backend_cache = {};
82+
}
83+
}
84+
6085
// Create a scheduler from a backend pair.
6186
// max_nodes: graph size hint (4096 for small models, 8192 for large)
6287
static ggml_backend_sched_t backend_sched_new(BackendPair bp, int max_nodes) {

src/cond-enc.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -356,8 +356,7 @@ static void cond_ggml_forward(CondGGML * m,
356356
// Free
357357
static void cond_ggml_free(CondGGML * m) {
358358
if (m->sched) ggml_backend_sched_free(m->sched);
359-
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
360-
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
359+
backend_release(m->backend, m->cpu_backend);
361360
wctx_free(&m->wctx);
362361
*m = {};
363362
}

src/dit.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,7 @@ bool dit_ggml_load_lora(DiTGGML * m, const char * lora_path, float scale);
410410

411411
static void dit_ggml_free(DiTGGML * m) {
412412
if (m->sched) ggml_backend_sched_free(m->sched);
413-
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
414-
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
413+
backend_release(m->backend, m->cpu_backend);
415414
wctx_free(&m->wctx);
416415
if (m->lora_wctx.ctx) wctx_free(&m->lora_wctx);
417416
*m = {};

src/qwen3-enc.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,8 +467,7 @@ static void qwen3_embed_lookup(Qwen3GGML * m, const int * token_ids, int S, floa
467467
// Free
468468
static void qwen3_free(Qwen3GGML * m) {
469469
if (m->sched) ggml_backend_sched_free(m->sched);
470-
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
471-
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
470+
backend_release(m->backend, m->cpu_backend);
472471
wctx_free(&m->wctx);
473472
*m = {};
474473
}

src/qwen3-lm.h

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,8 @@ static struct ggml_tensor * qw3lm_build_attn(
275275
int kv_pos,
276276
int kv_len,
277277
int n_tokens,
278-
bool use_flash_attn = true) {
278+
bool use_flash_attn = true,
279+
bool clamp_fp16 = false) {
279280

280281
int D = c.head_dim;
281282
int Nh = c.n_heads;
@@ -328,6 +329,12 @@ static struct ggml_tensor * qw3lm_build_attn(
328329
k = ggml_cont(ctx, k);
329330
v = ggml_cont(ctx, v);
330331

332+
// Clamp V before F16 cast: sub-Ampere tensor cores accumulate in FP16,
333+
// V projection can overflow to inf which corrupts all subsequent attention
334+
if (clamp_fp16) {
335+
v = ggml_clamp(ctx, v, -65504.0f, 65504.0f);
336+
}
337+
331338
// Write K,V to cache at kv_pos
332339
// Cache layout: [D, max_seq, Nkv] f16
333340
size_t nb1 = (size_t)D * ggml_type_size(GGML_TYPE_F16);
@@ -410,10 +417,13 @@ static void qw3lm_forward(Qwen3LM * m, const int * token_ids, int n_tokens,
410417
struct ggml_tensor * attn = qw3lm_build_attn(
411418
ctx, gf, c, ly, norm, positions, mask,
412419
m->kv_k[kv_set][l], m->kv_v[kv_set][l],
413-
kv_pos, kv_len, n_tokens, m->use_flash_attn);
420+
kv_pos, kv_len, n_tokens, m->use_flash_attn, m->clamp_fp16);
414421

415422
// Residual
416423
hidden = ggml_add(ctx, hidden, attn);
424+
if (m->clamp_fp16) {
425+
hidden = ggml_clamp(ctx, hidden, -65504.0f, 65504.0f);
426+
}
417427

418428
// Post-attention norm + MLP
419429
norm = qwen3_rms_norm(ctx, hidden, ly->post_attn_layernorm, c.rms_norm_eps);
@@ -577,6 +587,11 @@ static void qw3lm_forward_batch(Qwen3LM * m, const int * token_ids,
577587
k = ggml_cont(ctx, k);
578588
v = ggml_cont(ctx, v);
579589

590+
// Clamp V before F16 cast (sub-Ampere FP16 accumulation overflow)
591+
if (m->clamp_fp16) {
592+
v = ggml_clamp(ctx, v, -65504.0f, 65504.0f);
593+
}
594+
580595
// Batched attention with 4D KV cache
581596
float scale = 1.0f / sqrtf((float)D);
582597

@@ -633,6 +648,9 @@ static void qw3lm_forward_batch(Qwen3LM * m, const int * token_ids,
633648
// Batched O proj
634649
struct ggml_tensor * attn_out = qwen3_linear(ctx, ly->o_proj, attn_cat);
635650
hidden = ggml_add(ctx, hidden, attn_out);
651+
if (m->clamp_fp16) {
652+
hidden = ggml_clamp(ctx, hidden, -65504.0f, 65504.0f);
653+
}
636654

637655
// Batched FFN
638656
norm = qwen3_rms_norm(ctx, hidden, ly->post_attn_layernorm, c.rms_norm_eps);
@@ -706,8 +724,7 @@ static void qw3lm_free(Qwen3LM * m) {
706724
if (m->sched) ggml_backend_sched_free(m->sched);
707725
if (m->kv_buf) ggml_backend_buffer_free(m->kv_buf);
708726
if (m->kv_ctx) ggml_free(m->kv_ctx);
709-
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
710-
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
727+
backend_release(m->backend, m->cpu_backend);
711728
wctx_free(&m->wctx);
712729
*m = {};
713730
}

src/vae.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,7 @@ static void vae_ggml_free(VAEGGML * m) {
552552
if (m->sched) ggml_backend_sched_free(m->sched);
553553
if (m->buf) ggml_backend_buffer_free(m->buf);
554554
if (m->weight_ctx) ggml_free(m->weight_ctx);
555-
if (m->backend && m->backend != m->cpu_backend) ggml_backend_free(m->backend);
556-
if (m->cpu_backend) ggml_backend_free(m->cpu_backend);
555+
backend_release(m->backend, m->cpu_backend);
557556
*m = {};
558557
}
559558

0 commit comments

Comments
 (0)