vae: drop BF16 activation casts, decode in F32 throughout#7
Merged
lmangani merged 3 commits intoaudiohacking:masterfrom Mar 2, 2026
Merged
vae: drop BF16 activation casts, decode in F32 throughout#7lmangani merged 3 commits intoaudiohacking:masterfrom
lmangani merged 3 commits intoaudiohacking:masterfrom
Conversation
Snake and col2im_1d kernels compute in F32 internally, so the BF16 casts were round-trip bandwidth waste: 3 dispatches and 16 bytes/elem vs 1 dispatch and 8 bytes/elem. Removes 82 graph nodes per tile (417 -> 335). Weights stay BF16 in GGUF, mul_mat dequantizes on-the-fly. M2 Pro 16GB, 86.8s audio, Q8_0, chunk=1024 overlap=16: 38.89s -> 26.82s (-31%)
The generic kernel_im2col dispatches (IC, 1, OW) threadgroups with K threads each. For 1D convolutions with small kernels (k=1 or k=7), this wastes 78-97% of SIMD lanes (7 or 1 active threads per 32-wide SIMD group). Add a dedicated kernel_im2col_1d with flat dispatch identical to snake and col2im_1d: (total/256, 1, 1) threadgroups with 256 threads. The existing im2col dispatch branches on is_2D at runtime; the 2D path and kernel are unchanged. VAE decode benchmark (M2 Pro 16GB, 86.8s audio @ 48kHz stereo): chunk=256 overlap=64 old im2col: 71.2s 17 tiles chunk=1024 overlap=16 old im2col: 38.9s 3 tiles chunk=256 overlap=64 im2col_1d: 31.8s 17 tiles chunk=1024 overlap=16 im2col_1d: 18.3s 3 tiles
Sub-Ampere GPUs (cc < 800) use FP16 tensor core accumulation in GGML's
mul_mat (max 65504). Deep transformer layers can overflow to inf, then
rms_norm computes inf/inf = NaN, silently corrupting the pipeline: LM
produces 0 audio codes, condition encoder feeds NaN to DiT, silent WAV.
The fix detects GPU compute capability at init and conditionally clamps
hidden states to [-65504, 65504] before rms_norm on affected hardware.
On Ampere+ (FP32 accumulation), no clamp op is added, zero overhead.
Tested on Jetson Xavier NX (sm_72) with Q8_0 models.
src/backend.h Add gpu_cc to BackendPair. Query cc via forward-
declared cudaDeviceGetAttribute (no cuda_runtime.h).
src/cond-enc.h Clamp lyric/timbre encoder output before rms_norm
when cc < 800. Prevents NaN in DiT cross-attention.
src/qwen3-lm.h Clamp hidden state after each MLP residual in
prefill and decode loops (36 layers).
CMakeLists.txt Link CUDA::cudart for cudaDeviceGetAttribute.
src/bpe.h Fix -Wconversion warning (int to char cast).
Fixing this in GGML's mul_mat or rms_norm would require touching core
operators across all architectures for a niche hardware edge case.
Closes #4
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Snake and col2im_1d kernels compute in F32 internally, so the BF16 casts were round-trip bandwidth waste: 3 dispatches and 16 bytes/elem vs 1 dispatch and 8 bytes/elem.
Removes 82 graph nodes per tile (417 -> 335).
Weights stay BF16 in GGUF, mul_mat dequantizes on-the-fly.
M2 Pro 16GB, 86.8s audio, Q8_0, chunk=1024 overlap=16:
38.89s -> 26.82s (-31%)