Skip to content

Commit 15bff84

Browse files
authored
ggml webgpu: initial flashattention implementation (ggml-org#18610)
* FlashAttention (#13) * Add inplace softmax * Move rms_norm to split row approach * Update debug for supports_op * clean up debug statements * neg f16xf32xip builds and runs, havent actually ran a model that uses neg kernel yet though * neg passes backend test * unary operators pass ggml tests * rms_norm double declaration bug atoned * abides by editor-config * removed vestigial files * fixed autoconfig * All operators (inlcluding xielu) working * removed unnecesarry checking if node->src[1] exists for unary operators * responded and dealt with PR comments * implemented REPL_Template support and removed bug in unary operators kernel * formatted embed wgsl and ggml-webgpu.cpp * Faster tensors (#8) Add fast matrix and matrix/vector multiplication. * Use map for shader replacements instead of pair of strings * Wasm (#9) * webgpu : fix build on emscripten * more debugging stuff * test-backend-ops: force single thread on wasm * fix single-thread case for init_tensor_uniform * use jspi * add pthread * test: remember to set n_thread for cpu backend * Add buffer label and enable dawn-specific toggles to turn off some checks * Intermediate state * Fast working f16/f32 vec4 * Working float fast mul mat * Clean up naming of mul_mat to match logical model, start work on q mul_mat * Setup for subgroup matrix mat mul * Basic working subgroup matrix * Working subgroup matrix tiling * Handle weirder sg matrix sizes (but still % sg matrix size) * Working start to gemv * working f16 accumulation with shared memory staging * Print out available subgroup matrix configurations * Vectorize dst stores for sg matrix shader * Gemv working scalar * Minor set_rows optimization (#4) * updated optimization, fixed errors * non vectorized version now dispatches one thread per element * Simplify * Change logic for set_rows pipelines --------- Co-authored-by: Neha Abbas <nehaabbas@macbookpro.lan> Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local> Co-authored-by: Reese Levine <reeselevine1@gmail.com> * Comment on dawn toggles * Working subgroup matrix code for (semi)generic sizes * Remove some comments * Cleanup code * Update dawn version and move to portable subgroup size * Try to fix new dawn release * Update subgroup size comment * Only check for subgroup matrix configs if they are supported * Add toggles for subgroup matrix/f16 support on nvidia+vulkan * Make row/col naming consistent * Refactor shared memory loading * Move sg matrix stores to correct file * Working q4_0 * Formatting * Work with emscripten builds * Fix test-backend-ops emscripten for f16/quantized types * Use emscripten memory64 to support get_memory * Add build flags and try ci --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> * Remove extra whitespace * Move wasm single-thread logic out of test-backend-ops for cpu backend * Disable multiple threads for emscripten single-thread builds in ggml_graph_plan * Refactored pipelines and workgroup calculations (#10) * refactored pipelines * refactored workgroup calculation * removed commented out block of prior maps * Clean up ceiling division pattern --------- Co-authored-by: Neha Abbas <nehaabbas@eduroam-169-233-141-223.ucsc.edu> Co-authored-by: Reese Levine <reeselevine1@gmail.com> * Start work on flash attention * Shader structure set up (many bugs still) * debugging * Working first test * Working with head grouping, head sizes to 128, logit softcap, mask/sinks enabled, f32 * Generalize softmax to work with multiple subgroups, f16 accumulation, mask shared memory tiling * Start work on integrating pre-wgsl * Separate structs/initial shader compilation library into separate files * Work on compilation choices for flashattention * Work on subgroup matrix/tile size portability * subgroup size agnostic online softmax * Cleanups, quantization types * more cleanup * fix wasm build * Refactor flashattention to increase parallelism, use direct loads for KV in somce cases * Checkpoint * formatting * Update to account for default kv cache padding * formatting shader * Add workflow for ggml-ci webgpu * Try passing absolute path to dawn in ggml-ci * Avoid error on device destruction, add todos for proper cleanup * Fix unused warning * Forgot one parameter unused * Move some flashattn computation to f32 for correctness
1 parent 2524c26 commit 15bff84

6 files changed

Lines changed: 1838 additions & 47 deletions

File tree

.github/workflows/build.yml

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,13 +152,13 @@ jobs:
152152
DAWN_VERSION="v2.0.0"
153153
DAWN_OWNER="reeselevine"
154154
DAWN_REPO="dawn"
155-
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.zip"
156-
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
155+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
156+
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
157157
curl -L -o artifact.zip \
158-
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
158+
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
159159
mkdir dawn
160160
unzip artifact.zip
161-
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release.tar.gz -C dawn --strip-components=1
161+
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
162162
163163
- name: Build
164164
id: cmake_build
@@ -532,13 +532,13 @@ jobs:
532532
DAWN_VERSION="v2.0.0"
533533
DAWN_OWNER="reeselevine"
534534
DAWN_REPO="dawn"
535-
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.zip"
536-
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
535+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release"
536+
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
537537
curl -L -o artifact.zip \
538-
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}"
538+
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
539539
mkdir dawn
540540
unzip artifact.zip
541-
tar -xvf Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-ubuntu-latest-Release.tar.gz -C dawn --strip-components=1
541+
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
542542
543543
- name: Build
544544
id: cmake_build
@@ -1704,6 +1704,34 @@ jobs:
17041704
run: |
17051705
GG_BUILD_METAL=1 bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
17061706
1707+
ggml-ci-mac-webgpu:
1708+
runs-on: [self-hosted, macOS, ARM64]
1709+
1710+
steps:
1711+
- name: Clone
1712+
id: checkout
1713+
uses: actions/checkout@v4
1714+
1715+
- name: Dawn Dependency
1716+
id: dawn-depends
1717+
run: |
1718+
DAWN_VERSION="v2.0.0"
1719+
DAWN_OWNER="reeselevine"
1720+
DAWN_REPO="dawn"
1721+
DAWN_ASSET_NAME="Dawn-5e9a4865b1635796ccc77dd30057f2b4002a1355-macos-latest-Release"
1722+
echo "Fetching release asset from https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
1723+
curl -L -o artifact.zip \
1724+
"https://github.com/${DAWN_OWNER}/${DAWN_REPO}/releases/download/${DAWN_VERSION}/${DAWN_ASSET_NAME}.zip"
1725+
mkdir dawn
1726+
unzip artifact.zip
1727+
tar -xvf ${DAWN_ASSET_NAME}.tar.gz -C dawn --strip-components=1
1728+
1729+
- name: Test
1730+
id: ggml-ci
1731+
run: |
1732+
GG_BUILD_WEBGPU=1 GG_BUILD_WEBGPU_DAWN_PREFIX="$GITHUB_WORKSPACE/dawn" \
1733+
bash ./ci/run.sh ~/results/llama.cpp ~/mnt/llama.cpp
1734+
17071735
ggml-ci-mac-vulkan:
17081736
runs-on: [self-hosted, macOS, ARM64]
17091737

ci/run.sh

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,20 @@ if [ ! -z ${GG_BUILD_VULKAN} ]; then
105105
fi
106106

107107
if [ ! -z ${GG_BUILD_WEBGPU} ]; then
108-
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1"
108+
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_WEBGPU=1 -DGGML_METAL=OFF -DGGML_BLAS=OFF"
109+
110+
if [ ! -z "${GG_BUILD_WEBGPU_DAWN_PREFIX}" ]; then
111+
if [ -z "${CMAKE_PREFIX_PATH}" ]; then
112+
export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}"
113+
else
114+
export CMAKE_PREFIX_PATH="${GG_BUILD_WEBGPU_DAWN_PREFIX}:${CMAKE_PREFIX_PATH}"
115+
fi
116+
fi
117+
118+
# For some systems, Dawn_DIR needs to be set explicitly, e.g., the lib64 path
119+
if [ ! -z "${GG_BUILD_WEBGPU_DAWN_DIR}" ]; then
120+
CMAKE_EXTRA="${CMAKE_EXTRA} -DDawn_DIR=${GG_BUILD_WEBGPU_DAWN_DIR}"
121+
fi
109122
fi
110123

111124
if [ ! -z ${GG_BUILD_MUSA} ]; then
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#ifndef GGML_WEBGPU_SHADER_LIB_HPP
2+
#define GGML_WEBGPU_SHADER_LIB_HPP
3+
4+
#include "ggml.h"
5+
#include "pre_wgsl.hpp"
6+
7+
#include <string>
8+
#include <vector>
9+
10+
#define GGML_WEBGPU_F16_SIZE_BYTES 2
11+
#define GGML_WEBGPU_F32_SIZE_BYTES 4
12+
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES 8u
13+
#define GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE 128u
14+
// Matches GGML_PAD(..., 256) in src/llama-context.cpp for KV cache sizing.
15+
#define GGML_WEBGPU_KV_SEQ_PAD 256u
16+
17+
struct ggml_webgpu_flash_attn_shader_lib_context {
18+
ggml_type kv_type;
19+
uint32_t head_dim_qk;
20+
uint32_t head_dim_v;
21+
bool kv_direct;
22+
bool has_mask;
23+
bool has_sinks;
24+
bool uses_logit_softcap;
25+
uint32_t sg_mat_m;
26+
uint32_t sg_mat_n;
27+
uint32_t sg_mat_k;
28+
size_t wg_mem_limit_bytes;
29+
uint32_t max_subgroup_size;
30+
};
31+
32+
struct ggml_webgpu_flash_attn_shader_decisions {
33+
uint32_t q_tile = 0;
34+
uint32_t kv_tile = 0;
35+
uint32_t wg_size = 0;
36+
};
37+
38+
struct ggml_webgpu_processed_shader {
39+
std::string wgsl;
40+
std::string variant;
41+
ggml_webgpu_flash_attn_shader_decisions decisions;
42+
};
43+
44+
// This is exposed because it's necessary in supports_op
45+
inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile,
46+
uint32_t kv_tile,
47+
uint32_t head_dim_qk,
48+
uint32_t head_dim_v,
49+
bool has_mask,
50+
bool kv_direct) {
51+
const uint32_t max_head_dim = std::max(head_dim_qk, head_dim_v);
52+
size_t f16_elems = 0;
53+
size_t f32_elems = 0;
54+
f16_elems += q_tile * head_dim_qk; // q_shmem
55+
if (!kv_direct) {
56+
f16_elems += kv_tile * max_head_dim; // kv_shmem
57+
}
58+
f16_elems += q_tile * head_dim_v; // o_shmem
59+
if (has_mask) {
60+
f16_elems += q_tile * kv_tile; // mask_shmem
61+
}
62+
f16_elems += q_tile * kv_tile; // inter_shmem
63+
f32_elems += q_tile; // row_max_shmem
64+
f32_elems += q_tile; // exp_sum_shmem
65+
return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES;
66+
}
67+
68+
static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) {
69+
const size_t limit_bytes = context.wg_mem_limit_bytes;
70+
const size_t q_tile = context.sg_mat_m;
71+
const size_t base_q_bytes = (context.head_dim_qk + context.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES +
72+
2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES;
73+
size_t bytes_per_kv = 0;
74+
if (!context.kv_direct) {
75+
bytes_per_kv += std::max(context.head_dim_qk, context.head_dim_v);
76+
}
77+
if (context.has_mask) {
78+
bytes_per_kv += q_tile;
79+
}
80+
bytes_per_kv += q_tile;
81+
bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES;
82+
const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv;
83+
return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n;
84+
}
85+
86+
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
87+
pre_wgsl::Preprocessor & preprocessor,
88+
const char * shader_src,
89+
const ggml_webgpu_flash_attn_shader_lib_context & context) {
90+
std::vector<std::string> defines;
91+
std::string variant = "flash_attn";
92+
93+
switch (context.kv_type) {
94+
case GGML_TYPE_F32:
95+
defines.push_back("KV_F32");
96+
break;
97+
case GGML_TYPE_F16:
98+
defines.push_back("KV_F16");
99+
break;
100+
case GGML_TYPE_Q4_0:
101+
defines.push_back("KV_Q4_0");
102+
break;
103+
case GGML_TYPE_Q8_0:
104+
defines.push_back("KV_Q8_0");
105+
break;
106+
default:
107+
GGML_ABORT("Unsupported KV type for flash attention shader");
108+
}
109+
variant += std::string("_") + ggml_type_name(context.kv_type);
110+
111+
if (context.has_mask) {
112+
defines.push_back("MASK");
113+
variant += "_mask";
114+
}
115+
if (context.has_sinks) {
116+
defines.push_back("SINKS");
117+
variant += "_sinks";
118+
}
119+
if (context.uses_logit_softcap) {
120+
defines.push_back("LOGIT_SOFTCAP");
121+
variant += "_lgsc";
122+
}
123+
124+
if (context.kv_direct) {
125+
defines.push_back("KV_DIRECT");
126+
variant += "_kvdirect";
127+
}
128+
129+
defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.head_dim_qk));
130+
variant += std::string("_hsqk") + std::to_string(context.head_dim_qk);
131+
132+
defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.head_dim_v));
133+
variant += std::string("_hsv") + std::to_string(context.head_dim_v);
134+
135+
// For now these are not part of the variant name
136+
defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m));
137+
defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n));
138+
defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k));
139+
140+
// Add chosen Q/KV tile sizes
141+
uint32_t q_tile = context.sg_mat_m;
142+
uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context),
143+
context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES);
144+
if (context.kv_direct) {
145+
GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD);
146+
// Avoids having to use bounds-checks and decreasing performance for direct KV loads
147+
while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) {
148+
kv_tile -= context.sg_mat_n;
149+
}
150+
}
151+
152+
defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile));
153+
defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile));
154+
155+
// workgroup size
156+
uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE);
157+
158+
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
159+
160+
ggml_webgpu_processed_shader result;
161+
result.wgsl = preprocessor.preprocess(shader_src, defines);
162+
result.variant = variant;
163+
result.decisions.q_tile = q_tile;
164+
result.decisions.kv_tile = kv_tile;
165+
result.decisions.wg_size = wg_size;
166+
return result;
167+
}
168+
169+
#endif // GGML_WEBGPU_SHADER_LIB_HPP

0 commit comments

Comments
 (0)