From 946322378e470b3110f5651f81c2edaa77b6f2af Mon Sep 17 00:00:00 2001 From: mikepapadim Date: Sun, 7 Jun 2026 22:50:47 +0100 Subject: [PATCH] Add Gemma 4 model support with tokenizer, configuration, and inference layers --- .../gpullama3/inference/InferenceCore.java | 184 ++++++++ .../inference/state/Gemma4State.java | 195 ++++++++ .../standard/Gemma4StandardWeights.java | 134 ++++++ .../weights/tornado/Gemma4TornadoWeights.java | 108 +++++ .../beehive/gpullama3/model/ModelType.java | 8 + .../gpullama3/model/format/ChatFormat.java | 2 + .../model/format/Gemma4ChatFormat.java | 61 +++ .../gpullama3/model/gemma4/Gemma4.java | 96 ++++ .../model/gemma4/Gemma4Configuration.java | 116 +++++ .../model/loader/AbstractModelLoader.java | 3 + .../model/loader/Gemma4ModelLoader.java | 299 ++++++++++++ .../gpullama3/model/loader/ModelLoader.java | 88 ++++ .../org/beehive/gpullama3/tensor/GGUF.java | 26 +- .../tensor/standard/BF16FloatTensor.java | 57 +++ .../gpullama3/tokenizer/Gemma4Tokenizer.java | 146 ++++++ .../gpullama3/tokenizer/Vocabulary.java | 6 + .../tornadovm/kernels/Gemma4Kernels.java | 432 ++++++++++++++++++ .../QuantizationPlannerFactory.java | 3 + .../model/fp16/Gemma4FP16LayerPlanner.java | 29 ++ .../layers/type/fp16/Gemma4FP16FFNLayers.java | 375 +++++++++++++++ .../type/fp16/Gemma4LogitsFP16Layer.java | 114 +++++ 21 files changed, 2477 insertions(+), 5 deletions(-) create mode 100644 src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java create mode 100644 src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java create mode 100644 src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java create mode 100644 src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java create mode 100644 src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java create mode 100644 src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java create mode 100644 src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java create mode 100644 src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java create mode 100644 src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java index 9beade35..d54c581d 100644 --- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java +++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java @@ -2,8 +2,11 @@ import org.beehive.gpullama3.auxiliary.Parallel; import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.inference.state.Gemma4State; import org.beehive.gpullama3.inference.state.Phi3State; import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights; +import org.beehive.gpullama3.model.loader.ModelLoader; import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights; import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights; import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights; @@ -11,6 +14,7 @@ import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.gemma4.Gemma4Configuration; import org.beehive.gpullama3.model.granite.GraniteConfiguration; import org.beehive.gpullama3.model.devstral.DevstralConfiguration; import org.beehive.gpullama3.model.phi3.Phi3Configuration; @@ -534,6 +538,186 @@ public static FloatTensor forwardJavaQwen3(Model model, State state, int token, return state.logits; } + /** RMS-normalizes without applying a learned scale (Gemma4 normalizes V with a plain, weight-less RMSNorm). */ + private static void rmsnormNoWeight(FloatTensor out, FloatTensor x, int offset, int size, float rmsNormEps) { + float ss = x.reduce(offset, size, 0f, (acc, xi) -> acc + xi * xi); + ss /= size; + ss += rmsNormEps; + ss = (float) (1.0 / Math.sqrt(ss)); + final float finalss = ss; + out.mapWithIndexInPlace(offset, size, (value, index) -> finalss * x.getFloat(index)); + } + + /** Tanh-approximation GELU, matching ggml's {@code ggml_gelu_f32} (used by Gemma4's GeGLU FFN and PLE gate). */ + private static float gelu(float x) { + return 0.5f * x * (1.0f + (float) Math.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))); + } + + /** NeoX-style RoPE: rotates pairs (i, i + headDim/2) within each head using precomputed cos/sin tables. */ + private static void ropeRotateNeox(FloatTensor vec, int nHeads, int headDim, int position, FloatTensor freqCisReal, FloatTensor freqCisImag) { + int nComplHead = headDim / 2; + for (int h = 0; h < nHeads; h++) { + int base = h * headDim; + for (int ic = 0; ic < nComplHead; ic++) { + float fcr = freqCisReal.getFloat(position * nComplHead + ic); + float fci = freqCisImag.getFloat(position * nComplHead + ic); + float v0 = vec.getFloat(base + ic); + float v1 = vec.getFloat(base + ic + nComplHead); + vec.setFloat(base + ic, v0 * fcr - v1 * fci); + vec.setFloat(base + ic + nComplHead, v0 * fci + v1 * fcr); + } + } + } + + public static FloatTensor forwardJavaGemma4(Model model, State state, int token, int position) { + final Gemma4Configuration config = (Gemma4Configuration) model.configuration(); + final Gemma4StandardWeights weights = (Gemma4StandardWeights) model.weights(); + final Gemma4State gs = (Gemma4State) state; + + final int dim = config.dim(); + final int nHead = config.numberOfHeads(); + final int nHeadKv = config.numberOfKeyValueHeads(); + final int kvMul = config.kvMul(); + final int nLayers = config.numberOfLayers(); + final int nEmbdPerLayer = config.embeddingLengthPerLayer(); + final int perLayerTotal = nLayers * nEmbdPerLayer; + final float attentionScale = 1.0f; // Gemma4 attention uses scaling = 1.0 (no 1/sqrt(headDim)) + + // 1. token embedding, scaled by sqrt(dim) + weights.tokenEmbeddingTable.copyTo(token * dim, state.x, 0, dim); + final float embedScale = (float) Math.sqrt(dim); + state.x.mapInPlace(v -> v * embedScale); + + // 2. per-layer embeddings (PLE): inp_per_layer[l] = (rmsnorm(proj(x) / sqrt(dim)) + tokEmbd[l]*sqrt(nEmbdPerLayer)) / sqrt(2) + // per_layer_token_embd is ~2.35B elements (too large for the int-indexed FloatTensor API), so it + // is addressed one embedding row at a time directly from its raw tensor entry. + ModelLoader.copyEmbeddingRow(weights.perLayerTokenEmbd, token, perLayerTotal, gs.perLayerInputs, 0); + final float perLayerTokEmbedScale = (float) Math.sqrt(nEmbdPerLayer); + gs.perLayerInputs.mapInPlace(v -> v * perLayerTokEmbedScale); + + weights.perLayerModelProj.matmul(state.x, gs.perLayerProjScratch, perLayerTotal, dim); + final float perLayerProjScale = (float) (1.0 / Math.sqrt(dim)); + gs.perLayerProjScratch.mapInPlace(v -> v * perLayerProjScale); + for (int l = 0; l < nLayers; l++) { + rmsnorm(gs.perLayerProjScratch, gs.perLayerProjScratch, weights.perLayerProjNorm, l * nEmbdPerLayer, nEmbdPerLayer, config.rmsNormEps()); + } + final float perLayerInputScale = (float) (1.0 / Math.sqrt(2.0)); + for (int i = 0; i < perLayerTotal; i++) { + float v = (gs.perLayerProjScratch.getFloat(i) + gs.perLayerInputs.getFloat(i)) * perLayerInputScale; + gs.perLayerInputs.setFloat(i, v); + } + + // 3. transformer layers + for (int l = 0; l < nLayers; l++) { + final int curLayer = l; + final int headDim = config.headDim(l); + final boolean isSwa = config.isSwa(l); + final int qDim = nHead * headDim; + final int kvDim = nHeadKv * headDim; + + FloatTensor freqCisReal = isSwa ? weights.freqCisRealSwa : weights.freqCisRealFull; + FloatTensor freqCisImag = isSwa ? weights.freqCisImagSwa : weights.freqCisImagFull; + + // attn_norm + rmsnorm(state.xb, state.x, weights.attnNorm[l], 0, dim, config.rmsNormEps()); + + // Q projection, per-head Q-norm, RoPE + weights.wq[l].matmul(state.xb, state.q, qDim, dim); + for (int h = 0; h < nHead; h++) { + rmsnorm(state.q, state.q, weights.attnQNorm[l], h * headDim, headDim, config.rmsNormEps()); + } + ropeRotateNeox(state.q, nHead, headDim, position, freqCisReal, freqCisImag); + + // K/V: either compute and cache them here, or reuse an earlier layer's cache ("shared KV layers") + final int kvSrcLayer; + if (config.hasOwnKv(l)) { + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); + for (int h = 0; h < nHeadKv; h++) { + rmsnorm(state.k, state.k, weights.attnKNorm[l], h * headDim, headDim, config.rmsNormEps()); + rmsnormNoWeight(state.v, state.v, h * headDim, headDim, config.rmsNormEps()); + } + ropeRotateNeox(state.k, nHeadKv, headDim, position, freqCisReal, freqCisImag); + + state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim); + state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim); + kvSrcLayer = l; + } else { + kvSrcLayer = config.kvReuseLayer(l); + } + + // self-attention (causal; sliding-window layers additionally restrict to a local window) + final int windowStart = isSwa ? Math.max(0, position - config.slidingWindowSize() + 1) : 0; + Parallel.parallelFor(0, nHead, h -> { + int qOffset = h * headDim; + int attOffset = h * config.contextLength(); + int kvHeadOffset = (h / kvMul) * headDim; + + for (int t = windowStart; t <= position; t++) { + int kvOffset = t * kvDim + kvHeadOffset; + float score = state.q.dot(qOffset, state.keyCache[kvSrcLayer], kvOffset, headDim); + score *= attentionScale; + state.att.setFloat(attOffset + t, score); + } + + state.att.softmaxInPlace(attOffset + windowStart, position - windowStart + 1); + + int xbOffset = h * headDim; + state.xb.fillInPlace(xbOffset, headDim, 0f); + for (int t = windowStart; t <= position; t++) { + int kvOffset = t * kvDim + kvHeadOffset; + float a = state.att.getFloat(attOffset + t); + state.xb.saxpyInPlace(xbOffset, state.valueCache[kvSrcLayer], kvOffset, headDim, a); + } + }); + + // wo projection, post-attention norm, residual -> attn_out (kept in state.x) + weights.wo[curLayer].matmul(state.xb, state.xb2, dim, qDim); + rmsnorm(state.xb2, state.xb2, weights.attnPostNorm[curLayer], 0, dim, config.rmsNormEps()); + state.x.addInPlace(state.xb2); // state.x now holds attn_out = inpL + post_attn_norm(attn(...)) + + // FFN (GeGLU: down(gelu(gate(x)) * up(x))), post-FFN norm, residual -> cur (kept in state.x) + rmsnorm(state.xb, state.x, weights.ffnNorm[curLayer], 0, dim, config.rmsNormEps()); + weights.ffnGate[curLayer].matmul(state.xb, state.hb, config.feedForwardLength(curLayer), dim); + weights.ffnUp[curLayer].matmul(state.xb, state.hb2, config.feedForwardLength(curLayer), dim); + state.hb.mapInPlace(InferenceCore::gelu); + state.hb.multiplyInPlace(state.hb2); + weights.ffnDown[curLayer].matmul(state.hb, state.xb2, dim, config.feedForwardLength(curLayer)); + rmsnorm(state.xb2, state.xb2, weights.ffnPostNorm[curLayer], 0, dim, config.rmsNormEps()); + state.x.addInPlace(state.xb2); // state.x now holds cur = attn_out + post_ffn_norm(ffn(...)) + + // per-layer embedding (PLE): cur += per_layer_post_norm(proj(gelu(inp_gate(cur)) * inp_per_layer[l])) + weights.perLayerInpGate[curLayer].matmul(state.x, gs.perLayerGate, nEmbdPerLayer, dim); + gs.perLayerGate.mapInPlace(InferenceCore::gelu); + int peOffset = curLayer * nEmbdPerLayer; + for (int j = 0; j < nEmbdPerLayer; j++) { + gs.perLayerGate.setFloat(j, gs.perLayerGate.getFloat(j) * gs.perLayerInputs.getFloat(peOffset + j)); + } + weights.perLayerProj[curLayer].matmul(gs.perLayerGate, gs.perLayerOut, dim, nEmbdPerLayer); + rmsnorm(gs.perLayerOut, gs.perLayerOut, weights.perLayerPostNorm[curLayer], 0, dim, config.rmsNormEps()); + state.x.addInPlace(gs.perLayerOut); + + // optional learned per-layer output scale + FloatTensor outScale = weights.layerOutputScale[curLayer]; + if (outScale != null) { + final float scale = outScale.getFloat(0); + state.x.mapInPlace(v -> v * scale); + } + } + + // final norm, classifier, and logit soft-capping: logits = softcap * tanh(logits / softcap) + rmsnorm(state.x, state.x, weights.outputNorm, 0, dim, config.rmsNormEps()); + weights.outputWeight.matmul(state.x, state.logits, config.vocabularySize(), dim); + + final float softcap = config.finalLogitSoftcapping(); + if (softcap != 0.0f) { + final float invSoftcap = 1.0f / softcap; + state.logits.mapInPlace(v -> (float) Math.tanh(v * invSoftcap) * softcap); + } + + return state.logits; + } + public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int token, int position) { Phi3Configuration config = (Phi3Configuration) model.configuration(); Phi3StandardWeights weights = (Phi3StandardWeights) model.weights(); diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java b/src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java new file mode 100644 index 00000000..01f96ea0 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java @@ -0,0 +1,195 @@ +package org.beehive.gpullama3.inference.state; + +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.standard.FloatTensor; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.gemma4.Gemma4Configuration; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * Inference state for Gemma 4 models. + * + *

In addition to the common buffers, Gemma 4 needs scratch space for its per-layer + * embedding (PLE) mechanism. Buffers that vary in size across layers (Q/K/V, attention + * output, FFN hidden state) are sized to the maximum across all layers. The KV cache is + * allocated per "physical" layer: layers that reuse an earlier layer's KV cache (Gemma4's + * "shared KV layers" feature) simply alias that layer's cache arrays.

+ * + *

The TornadoVM (GPU) wrapper buffers mirror the same scheme: {@link #wrapKeyCache}/ + * {@link #wrapValueCache} are laid out back-to-back only for layers that own a KV cache + * (see {@link #cacheLayerBaseOffset}), and the per-layer-embedding scratch buffers are + * exposed as flat {@link FloatArray}s for transfer to the GPU.

+ */ +public final class Gemma4State extends State { + + /** Per-layer projected input embeddings (PLE), laid out as [layer][embeddingLengthPerLayer]. */ + public final FloatTensor perLayerInputs; + /** Scratch buffer for the per-layer model projection output, same layout as {@link #perLayerInputs}. */ + public final FloatTensor perLayerProjScratch; + /** Scratch buffer for a single layer's gated per-layer-embedding contribution. */ + public final FloatTensor perLayerGate; + /** Scratch buffer for a single layer's projected per-layer-embedding output (dim-sized). */ + public final FloatTensor perLayerOut; + + /** + * For each layer {@code l}, the base element offset of its KV-cache slot inside + * {@link #wrapKeyCache}/{@link #wrapValueCache} (GPU path) and {@link #keyCache}/{@link #valueCache} + * (CPU path, where it doubles as the "physical" layer index used for cache aliasing). Layers that + * reuse an earlier layer's cache share that layer's offset, so attention kernels can address the + * (possibly shared) cache uniformly via {@code cacheLayerBaseOffset[l]} without branching on reuse. + */ + public final int[] cacheLayerBaseOffset; + + // GPU (TornadoVM) per-layer-embedding scratch buffers; mirror perLayerInputs/perLayerProjScratch/perLayerGate/perLayerOut. + public final FloatArray wrapPerLayerInputs; + public final FloatArray wrapPerLayerProjScratch; + public final FloatArray wrapPerLayerGate; + public final FloatArray wrapPerLayerOut; + /** Holds the current token's per-layer-token-embedding row (gathered on the host each step, then transferred to the GPU). */ + public final FloatArray wrapPerLayerTokenEmbedRow; + + // Extra RMSNorm reduction scratch buffers (GPU path): Gemma4's "sandwich norm" pattern needs five + // independent reductions per layer (attn-norm uses the inherited `temp`, FFN-norm `tempFFN`); each + // of the others gets its own buffer so consecutive reduce/apply pairs never alias. + public final FloatArray tempPostAttn; + public final FloatArray tempPostFfn; + public final FloatArray tempPostPle; + + public Gemma4State(Configuration config, int batchsize) { + super(config, batchsize); + + Gemma4Configuration gemma4config = (Gemma4Configuration) config; + int perLayerTotal = gemma4config.numberOfLayers() * gemma4config.embeddingLengthPerLayer(); + this.perLayerInputs = ArrayFloatTensor.allocate(perLayerTotal); + this.perLayerProjScratch = ArrayFloatTensor.allocate(perLayerTotal); + this.perLayerGate = ArrayFloatTensor.allocate(gemma4config.embeddingLengthPerLayer()); + this.perLayerOut = ArrayFloatTensor.allocate(gemma4config.dim()); + + this.cacheLayerBaseOffset = computeCacheLayerBaseOffsets(gemma4config); + + this.wrapPerLayerInputs = new FloatArray(perLayerTotal); + this.wrapPerLayerProjScratch = new FloatArray(perLayerTotal); + this.wrapPerLayerGate = new FloatArray(gemma4config.embeddingLengthPerLayer()); + this.wrapPerLayerOut = new FloatArray(gemma4config.dim()); + this.wrapPerLayerTokenEmbedRow = new FloatArray(perLayerTotal); + + int tempSize = 1 + ((gemma4config.dim() + localSize - 1) / localSize); + this.tempPostAttn = new FloatArray(tempSize); + this.tempPostFfn = new FloatArray(tempSize); + this.tempPostPle = new FloatArray(tempSize); + } + + /** + * Computes, for each layer, the base element offset of its KV-cache slot in a flat buffer that + * back-to-back concatenates only the caches of layers that own one ({@link Gemma4Configuration#hasOwnKv}). + * Reusing layers inherit their source layer's offset (and -- by construction -- its head dimension, + * since {@link Gemma4Configuration#kvReuseLayer} only ever points to a layer with the same {@code isSwa}-ness). + */ + private static int[] computeCacheLayerBaseOffsets(Gemma4Configuration config) { + int nHeadKv = config.numberOfKeyValueHeads(); + int[] offsets = new int[config.numberOfLayers()]; + int running = 0; + for (int l = 0; l < config.numberOfLayers(); l++) { + int reuse = config.kvReuseLayer(l); + if (reuse < 0) { + offsets[l] = running; + running += config.contextLength() * (nHeadKv * config.headDim(l)); + } else { + offsets[l] = offsets[reuse]; + } + } + return offsets; + } + + /** Total number of elements needed for the (deduplicated) flat KV cache buffer. */ + private static int totalCacheElements(Gemma4Configuration config, int[] cacheLayerBaseOffset) { + int nHeadKv = config.numberOfKeyValueHeads(); + int total = 0; + for (int l = 0; l < config.numberOfLayers(); l++) { + if (config.hasOwnKv(l)) { + total = Math.max(total, cacheLayerBaseOffset[l] + config.contextLength() * (nHeadKv * config.headDim(l))); + } + } + return total; + } + + @Override + protected StateFields createStateFields(Configuration configuration) { + StateFields fields = new StateFields(); + + Gemma4Configuration config = (Gemma4Configuration) configuration; + + int dim = config.dim(); + int nHead = config.numberOfHeads(); + int nHeadKv = config.numberOfKeyValueHeads(); + int maxHeadDim = config.maxHeadDim(); + int maxFFN = config.maxFeedForwardLength(); + + int qSize = nHead * maxHeadDim; + int kvSize = nHeadKv * maxHeadDim; + + fields.x = ArrayFloatTensor.allocate(dim); + fields.xb = ArrayFloatTensor.allocate(Math.max(dim, qSize)); + fields.xb2 = ArrayFloatTensor.allocate(dim); + fields.hb = ArrayFloatTensor.allocate(maxFFN); + fields.hb2 = ArrayFloatTensor.allocate(maxFFN); + fields.q = ArrayFloatTensor.allocate(qSize); + fields.k = ArrayFloatTensor.allocate(kvSize); + fields.v = ArrayFloatTensor.allocate(kvSize); + fields.att = ArrayFloatTensor.allocate(nHead, config.contextLength()); + fields.logits = ArrayFloatTensor.allocate(config.vocabularySize()); + + // KV cache: layers that own their KV get a fresh cache; layers that reuse an earlier + // layer's KV (Gemma4's "shared KV layers") alias that layer's arrays directly. + FloatTensor[] keyCache = new FloatTensor[config.numberOfLayers()]; + FloatTensor[] valueCache = new FloatTensor[config.numberOfLayers()]; + for (int l = 0; l < config.numberOfLayers(); l++) { + int reuse = config.kvReuseLayer(l); + if (reuse < 0) { + int layerKvDim = config.headDim(l) * nHeadKv; + keyCache[l] = ArrayFloatTensor.allocate(config.contextLength(), layerKvDim); + valueCache[l] = ArrayFloatTensor.allocate(config.contextLength(), layerKvDim); + } else { + keyCache[l] = keyCache[reuse]; + valueCache[l] = valueCache[reuse]; + } + } + fields.keyCache = keyCache; + fields.valueCache = valueCache; + + switch (config.quantization()) { + case "FP16" -> fields.createActivationFP16(dim); + case "Q8_0" -> fields.createActivationQ8_0(dim); + default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization()); + } + + fields.wrapX = new FloatArray(dim); + fields.wrapXb = new FloatArray(Math.max(dim, qSize)); + fields.wrapXbFP16 = new HalfFloatArray(Math.max(dim, qSize)); + fields.wrapXb2 = new FloatArray(dim); + fields.wrapHb = new FloatArray(maxFFN); + fields.wrapHb2 = new FloatArray(maxFFN); + fields.wrapLogits = new FloatArray(config.vocabularySize()); + fields.wrapQ = new FloatArray(qSize); + fields.wrapK = new FloatArray(kvSize); + fields.wrapV = new FloatArray(kvSize); + + // Flat GPU KV cache: back-to-back slots only for layers that own a cache (see cacheLayerBaseOffset). + int[] gpuCacheLayerBaseOffset = computeCacheLayerBaseOffsets(config); + int totalCacheElements = Math.max(1, totalCacheElements(config, gpuCacheLayerBaseOffset)); + fields.wrapKeyCache = new FloatArray(totalCacheElements); + fields.wrapValueCache = new FloatArray(totalCacheElements); + fields.wrapValueCache.init(0.f); + fields.wrapKeyCache.init(0.f); + fields.wrapAtt = new FloatArray(nHead * config.contextLength()); + fields.positionHolder = new IntArray(1); + + fields.temp = new FloatArray(1 + ((dim + localSize - 1) / localSize)); + fields.tempFFN = new FloatArray(1 + ((dim + localSize - 1) / localSize)); + fields.tempLogits = new FloatArray(1 + ((dim + localSize - 1) / localSize)); + + return fields; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java new file mode 100644 index 00000000..599ed0c1 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java @@ -0,0 +1,134 @@ +package org.beehive.gpullama3.inference.weights.standard; + +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.standard.FloatTensor; + +/** + * Weights for the Gemma 4 architecture in the standard (CPU) format. + * + *

Gemma 4's layer structure differs substantially from the "Llama-like" models that + * {@link StandardWeights} models, so this class implements {@link Weights} directly rather + * than extending it: every layer carries its own Q/K-norm, a "sandwich" of pre- and + * post-normalization around both attention and FFN, a per-layer-embedding (PLE) gate/projection/ + * norm, and an optional learned output scale. There are also two separate RoPE frequency tables + * (sliding-window vs. full/global attention layers use different bases, head dimensions, and — + * for full-attention layers — a per-dimension frequency scaling baked in from {@code rope_freqs}).

+ */ +public class Gemma4StandardWeights implements Weights { + + public final FloatTensor tokenEmbeddingTable; + public final FloatTensor outputWeight; + public final FloatTensor outputNorm; + + // per-layer attention + public final FloatTensor[] attnNorm; + public final FloatTensor[] wq; + public final FloatTensor[] wk; + public final FloatTensor[] wv; + public final FloatTensor[] wo; + public final FloatTensor[] attnQNorm; + public final FloatTensor[] attnKNorm; + public final FloatTensor[] attnPostNorm; // a.k.a. post_attention_norm + + // per-layer FFN + public final FloatTensor[] ffnNorm; + public final FloatTensor[] ffnGate; + public final FloatTensor[] ffnUp; + public final FloatTensor[] ffnDown; + public final FloatTensor[] ffnPostNorm; // a.k.a. post_ffw_norm + + // per-layer embedding (PLE) + public final FloatTensor[] perLayerInpGate; + public final FloatTensor[] perLayerProj; + public final FloatTensor[] perLayerPostNorm; // a.k.a. post_norm + public final FloatTensor[] layerOutputScale; // optional, may contain nulls + + // shared per-layer-embedding tensors + + /** + * The per-layer token embedding table ({@code [embeddingLengthPerLayer * numberOfLayers, vocabularySize]}, + * ~2.35 billion elements for Gemma-4-E2B). It is kept as a raw {@link GGMLTensorEntry} rather than a + * {@link FloatTensor} -- whose int-indexed API would overflow for a tensor this large -- and addressed + * one embedding row at a time via {@link org.beehive.gpullama3.model.loader.ModelLoader#copyEmbeddingRow}. + */ + public final GGMLTensorEntry perLayerTokenEmbd; + public final FloatTensor perLayerModelProj; + public final FloatTensor perLayerProjNorm; + + // RoPE tables: sliding-window (local) layers and full (global) attention layers use different + // bases/dimensions; full-attention layers additionally bake in the `rope_freqs` per-dimension scaling. + public final FloatTensor freqCisRealSwa; + public final FloatTensor freqCisImagSwa; + public final FloatTensor freqCisRealFull; + public final FloatTensor freqCisImagFull; + + private final GGMLType weightType; + + // @formatter:off + public Gemma4StandardWeights( + FloatTensor tokenEmbeddingTable, + FloatTensor outputWeight, + FloatTensor outputNorm, + FloatTensor[] attnNorm, + FloatTensor[] wq, + FloatTensor[] wk, + FloatTensor[] wv, + FloatTensor[] wo, + FloatTensor[] attnQNorm, + FloatTensor[] attnKNorm, + FloatTensor[] attnPostNorm, + FloatTensor[] ffnNorm, + FloatTensor[] ffnGate, + FloatTensor[] ffnUp, + FloatTensor[] ffnDown, + FloatTensor[] ffnPostNorm, + FloatTensor[] perLayerInpGate, + FloatTensor[] perLayerProj, + FloatTensor[] perLayerPostNorm, + FloatTensor[] layerOutputScale, + GGMLTensorEntry perLayerTokenEmbd, + FloatTensor perLayerModelProj, + FloatTensor perLayerProjNorm, + FloatTensor freqCisRealSwa, + FloatTensor freqCisImagSwa, + FloatTensor freqCisRealFull, + FloatTensor freqCisImagFull, + GGMLType weightType) { + this.tokenEmbeddingTable = tokenEmbeddingTable; + this.outputWeight = outputWeight; + this.outputNorm = outputNorm; + this.attnNorm = attnNorm; + this.wq = wq; + this.wk = wk; + this.wv = wv; + this.wo = wo; + this.attnQNorm = attnQNorm; + this.attnKNorm = attnKNorm; + this.attnPostNorm = attnPostNorm; + this.ffnNorm = ffnNorm; + this.ffnGate = ffnGate; + this.ffnUp = ffnUp; + this.ffnDown = ffnDown; + this.ffnPostNorm = ffnPostNorm; + this.perLayerInpGate = perLayerInpGate; + this.perLayerProj = perLayerProj; + this.perLayerPostNorm = perLayerPostNorm; + this.layerOutputScale = layerOutputScale; + this.perLayerTokenEmbd = perLayerTokenEmbd; + this.perLayerModelProj = perLayerModelProj; + this.perLayerProjNorm = perLayerProjNorm; + this.freqCisRealSwa = freqCisRealSwa; + this.freqCisImagSwa = freqCisImagSwa; + this.freqCisRealFull = freqCisRealFull; + this.freqCisImagFull = freqCisImagFull; + this.weightType = weightType; + } + // @formatter:on + + @Override + public GGMLType getWeightType() { + return weightType; + } +} diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java new file mode 100644 index 00000000..9da6eb4c --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java @@ -0,0 +1,108 @@ +package org.beehive.gpullama3.inference.weights.tornado; + +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; + +/** + * TornadoVM (GPU) weights for the Gemma 4 architecture. + * + *

Extends {@link TornadoWeights} (rather than implementing {@link org.beehive.gpullama3.inference.weights.Weights} + * directly, as its CPU counterpart {@link org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights} + * does) because the shared {@code AbstractLogitsLayer}/{@code LogitsFP16Layer} GPU infrastructure requires + * a {@link TornadoWeights}. The base class's "Llama-like" fields are reused for the closest equivalents + * (e.g. {@code rms_att_weightLayered} → {@code attnNorm}, {@code w1Layered}/{@code w3Layered} → + * {@code ffnGate}/{@code ffnUp}); every other Gemma 4-specific tensor (sandwich norms, Q/K-norm, + * per-layer-embedding (PLE) gate/proj/norm, dual RoPE tables, optional layer-output scale) is added here.

+ * + *

Note: {@code per_layer_token_embd} is intentionally not present here -- at ~2.35 billion + * elements it is far too large to keep resident on the GPU. Its per-token row is instead gathered on + * the host (see {@link org.beehive.gpullama3.model.loader.ModelLoader#copyEmbeddingRow}) and streamed + * to the GPU each step via {@link org.beehive.gpullama3.inference.state.Gemma4State#wrapPerLayerTokenEmbedRow}.

+ */ +public class Gemma4TornadoWeights extends TornadoWeights { + + // Gemma4-specific per-layer attention tensors (sandwich norm + Q/K-norm) + public final TornadoTensor[] attnQNorm; + public final TornadoTensor[] attnKNorm; + public final TornadoTensor[] attnPostNorm; + + // Gemma4-specific per-layer FFN tensor (sandwich norm) + public final TornadoTensor[] ffnPostNorm; + + // per-layer embedding (PLE) + public final TornadoTensor[] perLayerInpGate; + public final TornadoTensor[] perLayerProj; + public final TornadoTensor[] perLayerPostNorm; + public final TornadoTensor[] layerOutputScale; // optional, may contain nulls + + // shared per-layer-embedding tensors + + /** + * The per-layer token embedding table ({@code [embeddingLengthPerLayer * numberOfLayers, vocabularySize]}, + * ~2.35 billion elements for Gemma-4-E2B). Far too large to keep resident on the GPU, so it is kept + * as a raw {@link GGMLTensorEntry} and addressed one row at a time on the host -- via + * {@link org.beehive.gpullama3.model.loader.ModelLoader#copyEmbeddingRow} -- with the resulting + * row streamed to the GPU each step (see {@link org.beehive.gpullama3.inference.state.Gemma4State#wrapPerLayerTokenEmbedRow}). + */ + public final GGMLTensorEntry perLayerTokenEmbd; + public final TornadoTensor perLayerModelProj; + public final TornadoTensor perLayerProjNorm; + + // RoPE tables: sliding-window (local) layers and full (global) attention layers use different bases/dimensions. + public final TornadoTensor freqCisRealSwa; + public final TornadoTensor freqCisImagSwa; + public final TornadoTensor freqCisRealFull; + public final TornadoTensor freqCisImagFull; + + // @formatter:off + public Gemma4TornadoWeights( + TornadoTensor tokenEmbeddingTable, + TornadoTensor[] attnNorm, + TornadoTensor[] wq, + TornadoTensor[] wk, + TornadoTensor[] wv, + TornadoTensor[] wo, + TornadoTensor[] attnQNorm, + TornadoTensor[] attnKNorm, + TornadoTensor[] attnPostNorm, + TornadoTensor[] ffnNorm, + TornadoTensor[] ffnGate, + TornadoTensor[] ffnUp, + TornadoTensor[] ffnDown, + TornadoTensor[] ffnPostNorm, + TornadoTensor[] perLayerInpGate, + TornadoTensor[] perLayerProj, + TornadoTensor[] perLayerPostNorm, + TornadoTensor[] layerOutputScale, + GGMLTensorEntry perLayerTokenEmbd, + TornadoTensor perLayerModelProj, + TornadoTensor perLayerProjNorm, + TornadoTensor outputNorm, + TornadoTensor freqCisRealSwa, + TornadoTensor freqCisImagSwa, + TornadoTensor freqCisRealFull, + TornadoTensor freqCisImagFull, + TornadoTensor outputWeight, + GGMLType weightType) { + super(tokenEmbeddingTable, attnNorm, wq, wk, wv, wo, + ffnNorm, ffnGate, ffnDown, ffnUp, outputNorm, + freqCisRealFull, freqCisImagFull, outputWeight, weightType); + this.attnQNorm = attnQNorm; + this.attnKNorm = attnKNorm; + this.attnPostNorm = attnPostNorm; + this.ffnPostNorm = ffnPostNorm; + this.perLayerInpGate = perLayerInpGate; + this.perLayerProj = perLayerProj; + this.perLayerPostNorm = perLayerPostNorm; + this.layerOutputScale = layerOutputScale; + this.perLayerTokenEmbd = perLayerTokenEmbd; + this.perLayerModelProj = perLayerModelProj; + this.perLayerProjNorm = perLayerProjNorm; + this.freqCisRealSwa = freqCisRealSwa; + this.freqCisImagSwa = freqCisImagSwa; + this.freqCisRealFull = freqCisRealFull; + this.freqCisImagFull = freqCisImagFull; + } + // @formatter:on +} diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java index 0659da7d..fd85f4e5 100644 --- a/src/main/java/org/beehive/gpullama3/model/ModelType.java +++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.model; import org.beehive.gpullama3.model.loader.DevstralModelLoader; +import org.beehive.gpullama3.model.loader.Gemma4ModelLoader; import org.beehive.gpullama3.model.loader.GraniteLoader; import org.beehive.gpullama3.tensor.GGUF; import org.beehive.gpullama3.model.loader.LlamaModelLoader; @@ -80,6 +81,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo } }, + GEMMA_4 { + @Override + public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + return new Gemma4ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel(); + } + }, + UNKNOWN { @Override public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java index 827ad625..218d9945 100644 --- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java +++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.model.format; import org.beehive.gpullama3.tokenizer.DevstralTokenizer; +import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer; import org.beehive.gpullama3.tokenizer.GraniteTokenizer; import org.beehive.gpullama3.tokenizer.LlamaTokenizer; import org.beehive.gpullama3.tokenizer.MistralTokenizer; @@ -15,6 +16,7 @@ public interface ChatFormat { static ChatFormat create(Object tokenizer, ChatTokens chatTokens) { return switch (tokenizer) { case DevstralTokenizer devstralTokenizer -> new DevstralChatFormat(devstralTokenizer); + case Gemma4Tokenizer gemma4Tokenizer -> new Gemma4ChatFormat(gemma4Tokenizer); case GraniteTokenizer graniteTokenizer -> new GraniteChatFormat(graniteTokenizer); case LlamaTokenizer llamaTokenizer -> new LlamaChatFormat(llamaTokenizer); case MistralTokenizer mistralTokenizer -> new MistralChatFormat(mistralTokenizer); diff --git a/src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java new file mode 100644 index 00000000..e8be80dc --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java @@ -0,0 +1,61 @@ +package org.beehive.gpullama3.model.format; + +import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Chat format for Gemma 4 models. + *

+ * Gemma 4 uses a {@code <|turn>{role}\n ... } turn structure (the assistant role is + * spelled "model" in the template), starts conversations with {@code }, and stops + * generation on {@code } (the model's configured EOS token). + */ +public class Gemma4ChatFormat implements ChatFormat { + + protected final Gemma4Tokenizer tokenizer; + protected final int beginOfText; + protected final int startTurn; + protected final int endTurn; + + public Gemma4ChatFormat(Gemma4Tokenizer tokenizer) { + this.tokenizer = tokenizer; + Map specialTokens = tokenizer.getSpecialTokens(); + this.beginOfText = specialTokens.getOrDefault("", -1); + this.startTurn = specialTokens.getOrDefault("<|turn>", -1); + this.endTurn = specialTokens.getOrDefault("", -1); + } + + @Override + public List encodeHeader(Message message) { + List tokens = new ArrayList<>(); + tokens.add(startTurn); + // The chat template spells the assistant role "model". + String role = Role.ASSISTANT.equals(message.role()) ? "model" : message.role().name(); + tokens.addAll(tokenizer.encodeAsList(role)); + tokens.addAll(tokenizer.encodeAsList("\n")); + return tokens; + } + + @Override + public List encodeMessage(Message message) { + List tokens = encodeHeader(message); + tokens.addAll(tokenizer.encodeAsList(message.content().strip())); + tokens.add(endTurn); + tokens.addAll(tokenizer.encodeAsList("\n")); + return tokens; + } + + @Override + public int getBeginOfText() { + return beginOfText; + } + + @Override + public Set getStopTokens() { + return Set.of(endTurn); + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java new file mode 100644 index 00000000..24fd42e7 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java @@ -0,0 +1,96 @@ +package org.beehive.gpullama3.model.gemma4; + +import org.beehive.gpullama3.inference.InferenceCore; +import org.beehive.gpullama3.inference.InferenceEngine; +import org.beehive.gpullama3.inference.sampler.Sampler; +import org.beehive.gpullama3.inference.state.Gemma4State; +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights; +import org.beehive.gpullama3.model.AbstractModel; +import org.beehive.gpullama3.model.ModelType; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.loader.ModelLoader; +import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; + +import java.util.List; +import java.util.Set; +import java.util.function.IntConsumer; + +public class Gemma4 extends AbstractModel { + + Gemma4Configuration configuration; + + public Gemma4(Gemma4Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) { + super(tokenizer, weights, chatFormat, null); + this.configuration = configuration; + } + + @Override + public Gemma4Configuration configuration() { + return configuration; + } + + @Override + public ModelType getModelType() { + return ModelType.GEMMA_4; + } + + @Override + public Gemma4Tokenizer tokenizer() { + return (Gemma4Tokenizer) tokenizer; + } + + @Override + public State createNewState() { + State state = new Gemma4State(configuration(), -1); + state.latestToken = chatFormat.getBeginOfText(); + return state; + } + + @Override + public State createNewState(int batchsize) { + State state = new Gemma4State(configuration(), batchsize); + state.latestToken = chatFormat.getBeginOfText(); + return state; + } + + @Override + public void forward(State state, int token, int position) { + if (plan == null) { + InferenceCore.forwardJavaGemma4(this, state, token, position); + } else { + gatherPerLayerTokenEmbeddingRow((Gemma4State) state, token); + InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan()); + } + } + + /** + * Gathers the current token's row out of {@code per_layer_token_embd} (~2.35 billion elements -- + * far too large to keep resident on the GPU, see {@link Gemma4TornadoWeights#perLayerTokenEmbd}) + * directly into {@link Gemma4State#wrapPerLayerTokenEmbedRow}, pre-scaled by {@code sqrt(embeddingLengthPerLayer)} + * (mirroring step 2 of {@link InferenceCore#forwardJavaGemma4}), ready for transfer to the GPU as + * part of layer 0's per-layer-embedding setup. + */ + private void gatherPerLayerTokenEmbeddingRow(Gemma4State state, int token) { + Gemma4TornadoWeights gemma4Weights = (Gemma4TornadoWeights) weights; + int nEmbdPerLayer = configuration.embeddingLengthPerLayer(); + int perLayerTotal = configuration.numberOfLayers() * nEmbdPerLayer; + float scale = (float) Math.sqrt(nEmbdPerLayer); + ModelLoader.copyEmbeddingRowToFloatArray(gemma4Weights.perLayerTokenEmbd, token, perLayerTotal, state.wrapPerLayerTokenEmbedRow, scale); + } + + @Override + public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated) { + return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated); + } + + @Override + public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo, + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { + return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan); + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java new file mode 100644 index 00000000..dd06049a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java @@ -0,0 +1,116 @@ +package org.beehive.gpullama3.model.gemma4; + +import org.beehive.gpullama3.model.Configuration; + +/** + * Configuration for the Gemma 4 architecture (e.g. Gemma-4-E2B-It). + * + *

Gemma 4 alternates sliding-window and full (global) attention layers, each with their + * own head dimensions, RoPE base/scaling, and a subset of layers reusing the KV cache produced + * by an earlier layer ("shared KV layers"). It also augments every layer with a per-layer + * embedding (PLE) mechanism and applies a final logit soft-cap.

+ */ +// @formatter:off +public record Gemma4Configuration(String quantization, + int dim, + int numberOfLayers, + int numberOfHeads, + int numberOfKeyValueHeads, + int headDimSwa, + int headDimFull, + int[] feedForwardLength, + boolean[] slidingWindowPattern, + int slidingWindowSize, + int sharedKvLayers, + int embeddingLengthPerLayer, + int vocabularySize, + int contextLengthModel, + int contextLength, + float rmsNormEps, + float ropeTheta, + float ropeThetaSwa, + float finalLogitSoftcapping) implements Configuration { + + @Override + public String quantization() { + return quantization; + } + + @Override + public int hiddenDim() { + throw new UnsupportedOperationException("Gemma4 has per-layer feed-forward dimensions; use feedForwardLength(layer)."); + } + + @Override + public int numberOfHeadsKey() { + throw new UnsupportedOperationException("Gemma4 has per-layer head dimensions; use headDim(layer)."); + } + + @Override + public int headSize() { + throw new UnsupportedOperationException("Gemma4 has per-layer head dimensions; use headDim(layer)."); + } + + @Override + public int kvDim() { + throw new UnsupportedOperationException("Gemma4 has per-layer head dimensions; use headDim(layer) * numberOfKeyValueHeads()."); + } + + @Override + public int kvMul() { + return numberOfHeads / numberOfKeyValueHeads; + } + + @Override + public int contextLengthModel() { + return contextLengthModel; + } + + /** Returns the feed-forward (FFN hidden) dimension for the given layer. */ + public int feedForwardLength(int layer) { + return feedForwardLength[layer]; + } + + /** Whether the given layer uses sliding-window (local) attention as opposed to full (global) attention. */ + public boolean isSwa(int layer) { + return slidingWindowPattern[layer]; + } + + /** Returns the attention head dimension for the given layer (depends on whether it is a sliding-window or full layer). */ + public int headDim(int layer) { + return isSwa(layer) ? headDimSwa : headDimFull; + } + + /** The maximum head dimension across all layers; used to size shared scratch buffers. */ + public int maxHeadDim() { + return Math.max(headDimSwa, headDimFull); + } + + /** The maximum feed-forward dimension across all layers; used to size shared scratch buffers. */ + public int maxFeedForwardLength() { + int max = 0; + for (int ff : feedForwardLength) { + max = Math.max(max, ff); + } + return max; + } + + /** Number of (initial) layers that own and populate their own KV cache; later layers reuse one of these. */ + public int nLayerKvFromStart() { + return numberOfLayers - sharedKvLayers; + } + + /** Whether the given layer computes and stores its own K/V (as opposed to reusing an earlier layer's KV cache). */ + public boolean hasOwnKv(int layer) { + return layer < nLayerKvFromStart(); + } + + /** Returns the index of the layer whose KV cache this layer reuses, or -1 if this layer owns its KV cache. */ + public int kvReuseLayer(int layer) { + if (hasOwnKv(layer)) { + return -1; + } + return nLayerKvFromStart() - (isSwa(layer) ? 2 : 1); + } +} +// @formatter:on diff --git a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java index 9bbefcad..e8d18891 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java @@ -41,6 +41,7 @@ protected String getModelQuantization(Map metadata) { int modelQuantizationAsInt = (int) metadata.get("general.file_type"); return switch (modelQuantizationAsInt) { case 1 -> "FP16"; + case 32 -> "FP16"; // MOSTLY_BF16 (treated like FP16 for activation buffers) case 7 -> "Q8_0"; case 14, 15 -> "Q8_0"; // Q4_K_S, Q4_K_M (K-quants use Q8_0 activations) case 16, 17 -> "Q8_0"; // Q5_K_S, Q5_K_M @@ -56,6 +57,7 @@ protected String getModelQuantization(Map metadata) { protected static GGMLType effectiveGpuWeightType(GGMLType ggmlType) { return switch (ggmlType) { case F16, F32, Q8_0 -> ggmlType; + case BF16 -> GGMLType.F16; // widened to FP16 at load time; see ModelLoader#loadTornadoTensor case Q4_K, Q5_K, Q6_K -> GGMLType.Q8_0; default -> ggmlType; }; @@ -65,6 +67,7 @@ private static String fileTypeName(int fileType) { return switch (fileType) { case 0 -> "F32"; case 1 -> "F16"; + case 32 -> "BF16"; case 7 -> "Q8_0"; case 14 -> "Q4_K_S"; case 15 -> "Q4_K_M"; diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java new file mode 100644 index 00000000..63af9b7a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java @@ -0,0 +1,299 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.auxiliary.Pair; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights; +import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.gemma4.Gemma4; +import org.beehive.gpullama3.model.gemma4.Gemma4Configuration; +import org.beehive.gpullama3.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.tensor.GGMLType; +import org.beehive.gpullama3.tensor.GGUF; +import org.beehive.gpullama3.tensor.GGUF.GGUFTensorInfo; +import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor; +import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor; +import org.beehive.gpullama3.tensor.tornado.TornadoTensor; +import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer; +import org.beehive.gpullama3.tokenizer.Tokenizer; +import org.beehive.gpullama3.tokenizer.Vocabulary; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.util.Map; +import java.util.function.IntFunction; + +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTensors; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTornadoTensors; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensor; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTornadoTensor; + +/** + * Loader for Gemma 4 models (e.g. Gemma-4-E2B-It). + * + *

Gemma 4 needs two distinct precomputed RoPE tables (sliding-window vs. full/global attention + * layers use different bases and head dimensions, and full-attention layers additionally apply a + * per-dimension frequency scaling stored in the {@code rope_freqs} tensor), so RoPE frequencies are + * computed directly here -- where the tensor entries are available -- rather than through the + * generic {@link #precomputeRopeFrequencies} hook.

+ */ +public class Gemma4ModelLoader extends AbstractModelLoader { + + public Gemma4ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, useTornadovm); + } + + @Override + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadGemma4Vocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new Gemma4Tokenizer(metadata, vocabulary); + } + + // @formatter:off + @Override + protected Gemma4Configuration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("gemma4.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + int numberOfLayers = (int) metadata.get("gemma4.block_count"); + + return new Gemma4Configuration( + getModelQuantization(metadata), + (int) metadata.get("gemma4.embedding_length"), + numberOfLayers, + (int) metadata.get("gemma4.attention.head_count"), + (int) metadata.get("gemma4.attention.head_count_kv"), + (int) metadata.get("gemma4.attention.key_length_swa"), + (int) metadata.get("gemma4.attention.key_length"), + (int[]) metadata.get("gemma4.feed_forward_length"), + (boolean[]) metadata.get("gemma4.attention.sliding_window_pattern"), + (int) metadata.get("gemma4.attention.sliding_window"), + (int) metadata.get("gemma4.attention.shared_kv_layers"), + (int) metadata.get("gemma4.embedding_length_per_layer_input"), + vocabulary.size(), + modelContextLength, + finalContextLength, + (float) metadata.get("gemma4.attention.layer_norm_rms_epsilon"), + (float) metadata.get("gemma4.rope.freq_base"), + (float) metadata.get("gemma4.rope.freq_base_swa"), + (float) metadata.get("gemma4.final_logit_softcapping") + ); + } + // @formatter:on + + /** Gemma4 needs two RoPE tables computed with tensor data (rope_freqs); see {@link #ropeTables}. */ + @Override + protected Pair precomputeRopeFrequencies(Gemma4Configuration config) { + return null; + } + + @Override + protected Gemma4 createModel(Gemma4Configuration config, Tokenizer tokenizer, Weights weights) { + return new Gemma4(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } + + // @formatter:off + @Override + protected Weights createStandardWeights(Map tensorEntries, Gemma4Configuration config, Pair ropeFreqs, + GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + final int nl = config.numberOfLayers(); + RopeTables ropeTables = computeRopeTables(tensorEntries, config); + + return new Gemma4StandardWeights( + loadTensor(tokenEmbeddings), + tensorEntries.containsKey("output.weight") ? loadTensor(tensorEntries.get("output.weight")) : loadTensor(tokenEmbeddings), + loadTensor(tensorEntries.get("output_norm.weight")), + + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".post_attention_norm.weight")), + + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".post_ffw_norm.weight")), + + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".inp_gate.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".proj.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".post_norm.weight")), + loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".layer_output_scale.weight")), + + tensorEntries.get("per_layer_token_embd.weight"), + loadTensor(tensorEntries.get("per_layer_model_proj.weight")), + loadTensor(tensorEntries.get("per_layer_proj_norm.weight")), + + new ArrayFloatTensor(ropeTables.realSwa), + new ArrayFloatTensor(ropeTables.imagSwa), + new ArrayFloatTensor(ropeTables.realFull), + new ArrayFloatTensor(ropeTables.imagFull), + + null + ); + } + // @formatter:on + + // @formatter:off + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, Gemma4Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + final int nl = config.numberOfLayers(); + GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType()); + RopeTables ropeTables = computeRopeTables(tensorEntries, config); + + return new Gemma4TornadoWeights( + loadTornadoTensor(tokenEmbeddings), + + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".post_attention_norm.weight")), + + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".post_ffw_norm.weight")), + + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".inp_gate.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".proj.weight")), + loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".post_norm.weight")), + loadArrayOfTornadoTensorsNullable(nl, i -> tensorEntries.get("blk." + i + ".layer_output_scale.weight")), + + stripTornadoArrayHeader(tensorEntries.get("per_layer_token_embd.weight")), + loadTornadoTensor(tensorEntries.get("per_layer_model_proj.weight")), + loadTornadoTensor(tensorEntries.get("per_layer_proj_norm.weight")), + loadTornadoTensor(tensorEntries.get("output_norm.weight")), + + new FP32TornadoTensor(FloatArray.fromArray(ropeTables.realSwa)), + new FP32TornadoTensor(FloatArray.fromArray(ropeTables.imagSwa)), + new FP32TornadoTensor(FloatArray.fromArray(ropeTables.realFull)), + new FP32TornadoTensor(FloatArray.fromArray(ropeTables.imagFull)), + + tensorEntries.containsKey("output.weight") ? loadTornadoTensor(tensorEntries.get("output.weight")) : loadTornadoTensor(tokenEmbeddings), + ggmlType + ); + } + // @formatter:on + + /** + * Tensor entries produced by {@link GGUF#loadTensorsTornado} prefix every {@code memorySegment()} with a + * 16-byte {@link TornadoNativeArray#ARRAY_HEADER} (so the data can be wrapped as a TornadoVM native array + * without copying) -- but {@code per_layer_token_embd} is kept as a raw entry and addressed with + * byte-offset arithmetic that assumes the segment starts at the tensor's actual data (see + * {@link ModelLoader#copyEmbeddingRowToFloatArray}, mirroring the CPU path's {@link ModelLoader#copyEmbeddingRow} + * over a {@link GGUF#loadTensorsStandard}-produced entry, which has no such header). Slice past the + * header here so both code paths see the same layout. + */ + private static GGMLTensorEntry stripTornadoArrayHeader(GGMLTensorEntry entry) { + long headerBytes = TornadoNativeArray.ARRAY_HEADER; + return new GGMLTensorEntry(entry.mappedFile(), entry.name(), entry.ggmlType(), entry.shape(), entry.memorySegment().asSlice(headerBytes)); + } + + /** Like {@link ModelLoader#loadArrayOfTornadoTensors}, but tolerates missing entries (Gemma4's optional per-layer output scale). */ + private static TornadoTensor[] loadArrayOfTornadoTensorsNullable(int size, IntFunction getTensorEntry) { + TornadoTensor[] array = new TornadoTensor[size]; + for (int i = 0; i < size; i++) { + GGMLTensorEntry entry = getTensorEntry.apply(i); + array[i] = (entry == null) ? null : loadTornadoTensor(entry); + } + return array; + } + + private record RopeTables(float[] realSwa, float[] imagSwa, float[] realFull, float[] imagFull) { + } + + /** + * Computes the two RoPE frequency tables Gemma4 needs. + *

+ * Sliding-window layers use {@code rope_theta_swa} with {@code headDimSwa} and no extra scaling. + * Full/global-attention layers use {@code rope_theta} with {@code headDimFull}, additionally + * dividing each rotation angle by the corresponding entry of the (single, shared) {@code + * rope_freqs} tensor -- this is how the GGUF encodes "partial RoPE" (entries are 1.0 for the + * active low-frequency dimensions and effectively infinite for the inactive ones, which zeroes + * out their rotation). + */ + private RopeTables computeRopeTables(Map tensorEntries, Gemma4Configuration config) { + Pair swa = precomputeFreqsCisWithFactors(config.contextLengthModel(), config.headDimSwa(), config.ropeThetaSwa(), null); + + // rope_freqs.weight is intentionally excluded from tensorEntries by GGUF.loadTensorsStandard/ + // loadTensorsTornado (it isn't needed by most architectures), so read it directly here. + float[] freqFactors = readFloat32TensorDirect("rope_freqs.weight"); + Pair full = precomputeFreqsCisWithFactors(config.contextLengthModel(), config.headDimFull(), config.ropeTheta(), freqFactors); + + return new RopeTables(swa.first(), swa.second(), full.first(), full.second()); + } + + /** Reads a small F32 tensor's raw data directly from the GGUF file, bypassing the {@code tensorEntries} map. */ + private float[] readFloat32TensorDirect(String tensorName) { + GGUFTensorInfo info = gguf.getTensorInfos().get(tensorName); + if (info == null) { + return null; + } + if (info.ggmlType() != GGMLType.F32) { + throw new UnsupportedOperationException("Expected F32 tensor for " + tensorName + ", got " + info.ggmlType()); + } + + int numberOfElements = 1; + for (int dimension : info.dimensions()) { + numberOfElements *= dimension; + } + + long byteOffset = gguf.getTensorDataOffset() + info.offset(); + ByteBuffer buffer = ByteBuffer.allocate(numberOfElements * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + try { + while (buffer.hasRemaining()) { + if (gguf.getFileChannel().read(buffer, byteOffset + buffer.position()) < 0) { + throw new EOFException("Unexpected end of file while reading " + tensorName); + } + } + } catch (IOException e) { + throw new ModelLoadException("Failed to read " + tensorName + " from GGUF file", e); + } + buffer.flip(); + + float[] result = new float[numberOfElements]; + buffer.asFloatBuffer().get(result); + return result; + } + + /** Like {@link org.beehive.gpullama3.inference.operation.RoPE#precomputeFreqsCis}, but allows dividing each pair's frequency by a per-dimension scaling factor (NeoX-style RoPE). */ + private static Pair precomputeFreqsCisWithFactors(int contextLength, int headSize, double theta, float[] freqFactors) { + assert headSize % 2 == 0; + float[] cr = new float[contextLength * (headSize / 2)]; + float[] ci = new float[contextLength * (headSize / 2)]; + int n = 0; + for (int pos = 0; pos < contextLength; ++pos) { + for (int i = 0; i < headSize; i += 2) { + int pairIndex = i / 2; + float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize)); + if (freqFactors != null) { + freq = freq / freqFactors[pairIndex]; + } + float val = pos * freq; + cr[n] = (float) Math.cos(val); + ci[n] = (float) Math.sin(val); + n++; + } + } + assert contextLength * (headSize / 2) == n; + return new Pair<>(cr, ci); + } +} diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index 353aea91..b2889785 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -52,6 +52,8 @@ private static ModelType detectModelType(Map metadata) { String lowerName = name.toLowerCase(); if (lowerName.contains("granite")) { return ModelType.GRANITE; + } else if (lowerName.contains("gemma-4") || lowerName.contains("gemma 4")) { + return ModelType.GEMMA_4; } else if (lowerName.contains("devstral")) { return ModelType.DEVSTRAL_2; } else if (lowerName.contains("mistral")) { @@ -73,6 +75,9 @@ private static ModelType detectModelType(Map metadata) { if (metadata.containsKey("granite.block_count")) { return ModelType.GRANITE; } + if ("gemma4".equals(metadata.get("general.architecture")) || metadata.containsKey("gemma4.block_count")) { + return ModelType.GEMMA_4; + } return ModelType.UNKNOWN; } @@ -127,6 +132,7 @@ public static FloatTensor loadTensor(GGMLTensorEntry entry) { case Q5_K -> new Q5_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case Q6_K -> new Q6_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); + case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()); default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); }; } @@ -153,6 +159,7 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) { return switch (ggmlType) { case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); + case BF16 -> convertBF16ToFP16TornadoTensor(entry); case Q8_0 -> Q8_0TornadoTensor.fromTornadoMemorySegment(entry.memorySegment()); case Q4_K, Q5_K, Q6_K -> dequantizeToQ8_0TornadoTensor(entry); case Q4_0 -> throw new UnsupportedOperationException("Q4_0 format not supported for TornadoVM yet"); @@ -217,6 +224,31 @@ private static Q8_0TornadoTensor dequantizeToQ8_0TornadoTensor(GGMLTensorEntry e return Q8_0TornadoTensor.fromTornadoMemorySegment(nativeSegment); } + /** + * Converts a BF16 tensor to an FP16 {@link FP16TornadoTensor} for TornadoVM/GPU execution. + * TornadoVM has no native BF16 kernel support, so weights are widened to FP32 (a lossless, + * simple bit-shift for BF16) and narrowed to IEEE FP16 at load time -- the same representation + * the existing FP16 GPU kernels already expect (see {@link #loadTornadoTensor}). + */ + private static FP16TornadoTensor convertBF16ToFP16TornadoTensor(GGMLTensorEntry entry) { + long headerBytes = TornadoNativeArray.ARRAY_HEADER; + GGMLTensorEntry dataEntry = new GGMLTensorEntry( + entry.mappedFile(), entry.name(), entry.ggmlType(), entry.shape(), + entry.memorySegment().asSlice(headerBytes)); + FloatTensor source = loadTensor(dataEntry); + int numElements = source.size(); + + MemorySegment nativeSegment = Arena.ofAuto().allocate(headerBytes + (long) numElements * Short.BYTES, 4); + for (long i = 0; i < headerBytes; i++) { + nativeSegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0); + } + for (int i = 0; i < numElements; i++) { + short f16Bits = Float.floatToFloat16(source.getFloat(i)); + nativeSegment.set(ValueLayout.JAVA_SHORT_UNALIGNED, headerBytes + (long) i * Short.BYTES, f16Bits); + } + return FP16TornadoTensor.fromTornadoMemorySegment(nativeSegment); + } + /** * Dispatcher method for loading a TornadoVM tensor array based on type. * Used in GPU-path. @@ -229,6 +261,62 @@ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunctionSome tensors (e.g. Gemma4's {@code per_layer_token_embd}, with ~2.35 billion elements) exceed + * {@link Integer#MAX_VALUE} elements/bytes, which would overflow the int-based + * {@link FloatTensor#numberOfElements} / {@link GGMLType#byteSizeFor} used to wrap a tensor entry in + * a {@link FloatTensor}. Such tensors are kept as raw {@link GGMLTensorEntry}s and addressed here with + * {@code long} byte offsets instead -- since only single-row (embedding lookup) access is needed.

+ */ + public static void copyEmbeddingRow(GGMLTensorEntry entry, long rowIndex, int rowSize, FloatTensor dest, int destOffset) { + GGMLType type = entry.ggmlType(); + if (type.getBlockSize() != 1) { + throw new UnsupportedOperationException("copyEmbeddingRow only supports unblocked (per-element) types, got " + type); + } + MemorySegment segment = entry.memorySegment(); + long elementBytes = type.getTypeSize(); + long rowByteOffset = rowIndex * rowSize * elementBytes; + for (int i = 0; i < rowSize; i++) { + long byteOffset = rowByteOffset + (long) i * elementBytes; + float value = switch (type) { + case F32 -> segment.get(ValueLayout.JAVA_FLOAT_UNALIGNED, byteOffset); + case F16 -> Float.float16ToFloat(segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset)); + case BF16 -> Float.intBitsToFloat(((int) segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset)) << 16); + default -> throw new UnsupportedOperationException("copyEmbeddingRow: unsupported type " + type); + }; + dest.setFloat(destOffset + i, value); + } + } + + /** + * Like {@link #copyEmbeddingRow(GGMLTensorEntry, long, int, FloatTensor, int)}, but writes into a + * TornadoVM {@link FloatArray} (optionally scaling each element) -- used by the GPU path to gather + * a per-token embedding row directly into a buffer ready for transfer to the device. + */ + public static void copyEmbeddingRowToFloatArray(GGMLTensorEntry entry, long rowIndex, int rowSize, FloatArray dest, float scale) { + GGMLType type = entry.ggmlType(); + if (type.getBlockSize() != 1) { + throw new UnsupportedOperationException("copyEmbeddingRowToFloatArray only supports unblocked (per-element) types, got " + type); + } + MemorySegment segment = entry.memorySegment(); + long elementBytes = type.getTypeSize(); + long rowByteOffset = rowIndex * rowSize * elementBytes; + for (int i = 0; i < rowSize; i++) { + long byteOffset = rowByteOffset + (long) i * elementBytes; + float value = switch (type) { + case F32 -> segment.get(ValueLayout.JAVA_FLOAT_UNALIGNED, byteOffset); + case F16 -> Float.float16ToFloat(segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset)); + case BF16 -> Float.intBitsToFloat(((int) segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset)) << 16); + default -> throw new UnsupportedOperationException("copyEmbeddingRowToFloatArray: unsupported type " + type); + }; + dest.set(i, value * scale); + } + } + // Helper methods public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) { diff --git a/src/main/java/org/beehive/gpullama3/tensor/GGUF.java b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java index 9cdc5b7d..9da00b76 100644 --- a/src/main/java/org/beehive/gpullama3/tensor/GGUF.java +++ b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java @@ -1,6 +1,5 @@ package org.beehive.gpullama3.tensor; -import org.beehive.gpullama3.tensor.standard.FloatTensor; import org.beehive.gpullama3.auxiliary.Pair; import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray; @@ -122,8 +121,12 @@ public static Map loadTensorsStandard(FileChannel fileC continue; } - int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); - int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); + // Long arithmetic: some tensors (e.g. Gemma4's per-layer token embedding table) exceed + // Integer.MAX_VALUE elements/bytes, which would overflow the int-based + // FloatTensor.numberOfElements/GGMLType.byteSizeFor. Such tensors are read directly via + // long-offset MemorySegment access rather than wrapped in a FloatTensor, so only the + // slice size (which MemorySegment#asSlice accepts as a long) matters here. + long sizeInBytes = tensorByteSize(ti); // per-tensor slice offset; ti.offset() is relative to tensor-data start long offset = ti.offset(); @@ -167,8 +170,9 @@ public static Map loadTensorsTornado(FileChannel fileCh continue; } - int numberOfElements = FloatTensor.numberOfElements(ti.dimensions()); - int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements)); + // see loadTensorsStandard for why this uses long arithmetic instead of + // FloatTensor.numberOfElements/GGMLType.byteSizeFor + long sizeInBytes = tensorByteSize(ti); // absolute tensor offset - relative to start of the file long mappingOffset = tensorDataOffset + ti.offset(); @@ -193,6 +197,18 @@ public static Map loadTensorsTornado(FileChannel fileCh return tensorEntries; } + /** Computes a tensor's byte size with {@code long} arithmetic, avoiding int-overflow for very large tensors. */ + private static long tensorByteSize(GGUFTensorInfo ti) { + long numberOfElements = 1L; + for (int dimension : ti.dimensions()) { + numberOfElements *= dimension; + } + long typeSize = ti.ggmlType().getTypeSize(); + long blockSize = ti.ggmlType().getBlockSize(); + assert (numberOfElements * typeSize) % blockSize == 0; + return numberOfElements * typeSize / blockSize; + } + public Map getTensorInfos() { return tensorInfos; } diff --git a/src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java new file mode 100644 index 00000000..1d370c06 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java @@ -0,0 +1,57 @@ +package org.beehive.gpullama3.tensor.standard; + +import org.beehive.gpullama3.tensor.GGMLType; +import jdk.incubator.vector.FloatVector; +import jdk.incubator.vector.VectorSpecies; + +import java.lang.foreign.MemorySegment; + +/** + * {@link FloatTensor} backed by raw BF16 (bfloat16) data. + * + *

BF16 stores the upper 16 bits of an IEEE-754 binary32 value (same sign/exponent layout, + * truncated mantissa), so widening to float32 is a plain left-shift by 16 bits -- no exponent + * rebiasing is needed, unlike IEEE binary16 (F16).

+ */ +public final class BF16FloatTensor extends FloatTensor { + + final int size; + final MemorySegment memorySegment; + + public BF16FloatTensor(int size, MemorySegment memorySegment) { + this.size = size; + this.memorySegment = memorySegment; + } + + @Override + public int size() { + return size; + } + + @Override + public void setFloat(int index, float value) { + throw new UnsupportedOperationException("setFloat"); + } + + @Override + protected FloatVector getFloatVector(VectorSpecies species, int index) { + throw new UnsupportedOperationException("getFloatVector"); + } + + @Override + public GGMLType type() { + return GGMLType.BF16; + } + + @Override + public MemorySegment asMemorySegment() { + return null; + } + + @Override + public float getFloat(int index) { + assert 0 <= index && index < size; + short bits = readShort(memorySegment, index * (long) GGMLType.BFLOAT16_BYTES); + return Float.intBitsToFloat(((int) bits) << 16); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java new file mode 100644 index 00000000..1fe9da5a --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java @@ -0,0 +1,146 @@ +package org.beehive.gpullama3.tokenizer; + +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * SentencePiece-style BPE tokenizer with byte fallback, used by Gemma 4 models. + *

+ * Spaces are represented with the SentencePiece marker {@code ▁}, and any codepoint missing from + * the vocabulary falls back to its individual UTF-8 bytes encoded as {@code <0xXX>} tokens. Pairs + * are greedily merged according to the highest {@code tokenizer.ggml.scores} value, mirroring + * {@link MistralTokenizer}. + */ +public class Gemma4Tokenizer implements Tokenizer { + + private final Vocabulary vocabulary; + private final Map specialTokens; + private final int[] tokenType; + private final int byte0; + + public Gemma4Tokenizer(Map metadata, Vocabulary vocabulary) { + int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type"); + + // Special tokens are anything that isn't a regular sub-word (NORMAL, type 1) or a raw byte-fallback token (BYTE, type 6). + Map specialTokens = IntStream.range(0, vocabulary.size()) + .filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6) + .boxed() + .collect(Collectors.toMap(vocabulary::get, t -> t, (first, second) -> first)); + + this.vocabulary = vocabulary; + this.specialTokens = new HashMap<>(specialTokens); + this.tokenType = tokenTypes; + this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow(); + } + + @Override + public String regexPattern() { + return null; + } + + @Override + public Map getSpecialTokens() { + return specialTokens; + } + + @Override + public boolean isSpecialToken(int tokenIndex) { + return getTokenType(tokenIndex) != 1; + } + + @Override + public boolean shouldDisplayToken(int token) { + int type = getTokenType(token); + return type == 1 || type == 6; + } + + public int getTokenType(int tokenIndex) { + return tokenType[tokenIndex]; + } + + private List encodeImpl(String text) { + List tokens = new ArrayList<>(); + + // first encode every individual codepoint in the input string + for (int i = 0, cpi; i < text.length(); i += Character.charCount(cpi)) { + cpi = text.codePointAt(i); + + String singleCodepoint = Character.toString(cpi); + int id = vocabulary.getIndex(singleCodepoint).orElse(-1); + + if (id != -1) { + tokens.add(id); + } else { + // byte fallback: encode each UTF-8 byte as a <0xXX> token (offset by the index of <0x00>) + for (byte b : singleCodepoint.getBytes(StandardCharsets.UTF_8)) { + tokens.add(Byte.toUnsignedInt(b) + byte0); + } + } + } + + // greedily merge the highest-scoring adjacent pair until no more merges apply + while (true) { + float bestScore = -1e10f; + int bestId = -1; + int bestIdx = -1; + + for (int i = 0; i < tokens.size() - 1; ++i) { + String merged = vocabulary.get(tokens.get(i)) + vocabulary.get(tokens.get(i + 1)); + int id = vocabulary.getIndex(merged).orElse(-1); + if (id != -1 && vocabulary.getScore(id) > bestScore) { + bestScore = vocabulary.getScore(id); + bestId = id; + bestIdx = i; + } + } + + if (bestIdx == -1) { + break; + } + + tokens.set(bestIdx, bestId); + tokens.remove(bestIdx + 1); + } + + return tokens; + } + + @Override + public List encode(String text, Set allowedSpecial) { + return encodeImpl(text.replace(' ', '▁')); + } + + @Override + public List encodeAsList(String text) { + return encode(text, Collections.emptySet()); + } + + @Override + public String decode(List tokens) { + StringBuilder sb = new StringBuilder(); + for (int token : tokens) { + String tokenString = vocabulary.get(token); + if (isSpecialToken(token)) { + // byte-fallback tokens decode back to their raw byte/codepoint + String prefix = "<0x"; + String suffix = ">"; + if (tokenString.length() == 6 && tokenString.startsWith(prefix) && tokenString.endsWith(suffix)) { + String code = tokenString.substring(prefix.length(), tokenString.length() - suffix.length()); + int cp = Integer.parseInt(code, 16); + tokenString = Character.toString(cp); + } + } else { + tokenString = tokenString.replace('▁', ' '); + } + sb.append(tokenString); + } + return sb.toString(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java index b29f3576..be10d4b0 100644 --- a/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java +++ b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java @@ -36,6 +36,12 @@ public static Vocabulary loadQwen3Vocabulary(Map metadata) { return new Vocabulary(tokens, scores); } + public static Vocabulary loadGemma4Vocabulary(Map metadata) { + String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); + float[] scores = (float[]) metadata.get("tokenizer.ggml.scores"); + return new Vocabulary(tokens, scores); + } + public static Vocabulary loadDevstralVocabulary(Map metadata) { String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens"); return new Vocabulary(tokens, null); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java new file mode 100644 index 00000000..47a12ec3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java @@ -0,0 +1,432 @@ +package org.beehive.gpullama3.tornadovm.kernels; + +import uk.ac.manchester.tornado.api.KernelContext; +import uk.ac.manchester.tornado.api.annotations.Parallel; +import uk.ac.manchester.tornado.api.math.TornadoMath; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; +import uk.ac.manchester.tornado.api.types.arrays.IntArray; + +/** + * Custom GPU kernels for the Gemma 4 architecture. + * + *

Gemma 4's computation graph differs substantially from the "Llama-like" models the rest of the + * {@code tornadovm.kernels} package targets: every layer carries its own Q/K-norm and a "sandwich" of + * pre/post normalization around both attention and FFN, attention alternates between sliding-window + * (local) and full (global) variants -- with different head dimensions and RoPE tables -- some layers + * reuse an earlier layer's KV cache, the FFN uses a GeGLU activation, and every layer additionally + * mixes in a per-layer embedding (PLE). None of the existing fused kernels match this shape, so this + * class provides purpose-built (but otherwise unfused/modular) replacements; see + * {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4} for the reference computation + * each of these mirrors.

+ */ +// @formatter:off +public class Gemma4Kernels { + + /** Materializes {@code out = weight * (rmsScale[0] * x)} -- i.e. RMSNorm with a learned scale, written to a separate buffer. */ + public static void applyRmsNorm(KernelContext context, FloatArray out, FloatArray x, FloatArray weight, FloatArray rmsScale, int size) { + int gid = context.globalIdx; + if (gid < size) { + float scale = rmsScale.get(0); + out.set(gid, weight.get(gid) * (scale * x.get(gid))); + } + } + + /** {@code x[i] *= scale} (used for embedding scaling). */ + public static void scaleInPlace(KernelContext context, FloatArray x, float scale, int size) { + int gid = context.globalIdx; + if (gid < size) { + x.set(gid, x.get(gid) * scale); + } + } + + /** {@code x[i] *= scaleTensor[0]} -- like {@link #scaleInPlace}, but the (learned, per-layer) scale is read from a 1-element tensor at kernel time. */ + public static void scaleInPlaceFromTensor(KernelContext context, FloatArray x, FloatArray scaleTensor, int size) { + int gid = context.globalIdx; + if (gid < size) { + x.set(gid, x.get(gid) * scaleTensor.get(0)); + } + } + + /** {@code out[i] = (a[i] + b[i]) * scale} (used to merge the per-layer projection with the per-layer token embedding). */ + public static void addAndScale(KernelContext context, FloatArray out, FloatArray a, FloatArray b, float scale, int size) { + int gid = context.globalIdx; + if (gid < size) { + out.set(gid, (a.get(gid) + b.get(gid)) * scale); + } + } + + /** + * Sandwich-norm + residual: {@code x[i] += weight[i] * (rmsScale[0] * delta[i])}. + * Used for post-attention-norm, post-FFN-norm, and the per-layer-embedding post-norm, each of + * which normalizes a freshly computed branch output and adds it back onto the running residual. + */ + public static void rmsNormApplyWithResidual(KernelContext context, FloatArray x, FloatArray delta, FloatArray weight, FloatArray rmsScale, int size) { + int gid = context.globalIdx; + if (gid < size) { + float scale = rmsScale.get(0); + float normalized = weight.get(gid) * (scale * delta.get(gid)); + x.set(gid, x.get(gid) + normalized); + } + } + + /** + * Per-head RMSNorm with a learned scale (Q-norm / K-norm): each workgroup normalizes one head + * of {@code vec} in place, mirroring {@code rmsnorm(vec, vec, weight, h*headDim, headDim, eps)} + * applied independently for every head {@code h}. + */ + public static void rmsNormPerHead(KernelContext context, FloatArray vec, FloatArray weight, int nHeads, int headDim, int localMemSize, float rmsNormEps) { + int headIdx = context.groupIdx; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + if (headIdx >= nHeads) { + return; + } + int base = headIdx * headDim; + + float[] localSum = context.allocateFloatLocalArray(localMemSize); + float partial = 0f; + for (int i = localId; i < headDim; i += localSize) { + float v = vec.get(base + i); + partial += v * v; + } + localSum[localId] = partial; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float ss = localSum[0] / headDim + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < headDim; i += localSize) { + float normalized = ss * vec.get(base + i); + vec.set(base + i, weight.get(i) * normalized); + } + } + + /** Like {@link #rmsNormPerHead}, but without a learned scale (Gemma4 normalizes V with a plain, weight-less RMSNorm). */ + public static void rmsNormPerHeadNoWeight(KernelContext context, FloatArray vec, int nHeads, int headDim, int localMemSize, float rmsNormEps) { + int headIdx = context.groupIdx; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + if (headIdx >= nHeads) { + return; + } + int base = headIdx * headDim; + + float[] localSum = context.allocateFloatLocalArray(localMemSize); + float partial = 0f; + for (int i = localId; i < headDim; i += localSize) { + float v = vec.get(base + i); + partial += v * v; + } + localSum[localId] = partial; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float ss = localSum[0] / headDim + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < headDim; i += localSize) { + vec.set(base + i, ss * vec.get(base + i)); + } + } + + /** + * NeoX-style RoPE rotation (split-half pairs, using precomputed cos/sin tables) for Q only -- + * used by layers that reuse an earlier layer's KV cache (so K is never computed/rotated here). + * Launched on a 2D grid of (nHeads, headDim/2). + */ + public static void ropeNeoxRotateQOnly(KernelContext context, IntArray positionHolder, FloatArray q, FloatArray freqCisReal, FloatArray freqCisImag, int headDim) { + int h = context.globalIdx; + int ic = context.globalIdy; + int half = headDim / 2; + int pos = positionHolder.get(0); + + float fcr = freqCisReal.get(pos * half + ic); + float fci = freqCisImag.get(pos * half + ic); + + int base = h * headDim; + float v0 = q.get(base + ic); + float v1 = q.get(base + ic + half); + q.set(base + ic, v0 * fcr - v1 * fci); + q.set(base + ic + half, v0 * fci + v1 * fcr); + } + + /** + * NeoX-style RoPE rotation for Q and K, fused with the KV-cache write (K rotated then cached, + * V copied as-is) -- used by layers that own their KV cache. Launched on a 2D grid of + * (nHeads, headDim/2); K/V handling is gated on {@code h < nHeadKv} (mirrors + * {@code Qwen3Kernels.ropeRotationWithCacheCopy}'s {@code rotn} pattern for GQA). + * + *

{@code cacheBaseOffset} is the (possibly shared, see {@link org.beehive.gpullama3.inference.state.Gemma4State#cacheLayerBaseOffset}) + * base element offset of this layer's slot in the flat {@code keyCache}/{@code valueCache} buffers.

+ */ + public static void ropeNeoxRotateAndCacheCopy( + KernelContext context, + IntArray positionHolder, + FloatArray q, + FloatArray k, + FloatArray v, + FloatArray keyCache, + FloatArray valueCache, + FloatArray freqCisReal, + FloatArray freqCisImag, + int nHeadKv, + int headDim, + int kvDim, + int cacheBaseOffset) { + + int h = context.globalIdx; + int ic = context.globalIdy; + int half = headDim / 2; + int pos = positionHolder.get(0); + + float fcr = freqCisReal.get(pos * half + ic); + float fci = freqCisImag.get(pos * half + ic); + + // Rotate Q (all heads) + int qBase = h * headDim; + float v0q = q.get(qBase + ic); + float v1q = q.get(qBase + ic + half); + q.set(qBase + ic, v0q * fcr - v1q * fci); + q.set(qBase + ic + half, v0q * fci + v1q * fcr); + + // Rotate K and write rotated-K / raw-V into the cache (KV heads only) + if (h < nHeadKv) { + int kBase = h * headDim; + float v0k = k.get(kBase + ic); + float v1k = k.get(kBase + ic + half); + float rotatedK0 = v0k * fcr - v1k * fci; + float rotatedK1 = v0k * fci + v1k * fcr; + k.set(kBase + ic, rotatedK0); + k.set(kBase + ic + half, rotatedK1); + + int cacheOffset = cacheBaseOffset + pos * kvDim + h * headDim; + keyCache.set(cacheOffset + ic, rotatedK0); + keyCache.set(cacheOffset + ic + half, rotatedK1); + valueCache.set(cacheOffset + ic, v.get(kBase + ic)); + valueCache.set(cacheOffset + ic + half, v.get(kBase + ic + half)); + } + } + + /** + * Causal self-attention restricted to a (possibly sliding) window: scores/softmax/weighted-sum + * over {@code t} in {@code [windowStart, pos]}, where {@code windowStart = max(0, pos - windowSize + 1)}. + * Full-attention layers pass {@code windowSize >= contextLength} so that {@code windowStart} is + * always {@code 0} (plain causal attention) -- see {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}. + * Gemma4 uses an attention scale of {@code 1.0} (no {@code 1/sqrt(headDim)}). + * + *

{@code cacheBaseOffset} addresses the (possibly shared) KV-cache slot for this layer, see + * {@link #ropeNeoxRotateAndCacheCopy}.

+ */ + public static void attentionWithSlidingWindow( + FloatArray q, + FloatArray keyCache, + FloatArray valueCache, + FloatArray xb, + FloatArray wrapAtt, + int nHeads, + int headDim, + int kvDim, + int kvMul, + IntArray positionHolder, + int cacheBaseOffset, + int windowSize, + int contextLength) { + + int pos = positionHolder.get(0); + int windowStart = Math.max(0, pos - windowSize + 1); + + for (@Parallel int h = 0; h < nHeads; h++) { + gemma4ProcessHead(q, keyCache, valueCache, xb, wrapAtt, h, headDim, kvDim, kvMul, cacheBaseOffset, pos, windowStart, contextLength); + } + } + + private static void gemma4ProcessHead( + FloatArray q, + FloatArray keyCache, + FloatArray valueCache, + FloatArray xb, + FloatArray wrapAtt, + int h, + int headDim, + int kvDim, + int kvMul, + int cacheBaseOffset, + int pos, + int windowStart, + int contextLength) { + + // wrapAtt is sized (nHeads * contextLength); index by absolute time t with a per-head stride of contextLength. + int hOff = h * contextLength; + int kvHeadIdx = h / kvMul; + int qOffset = h * headDim; + + // STEP 1: scores for t in [windowStart, pos] + for (int t = windowStart; t <= pos; t++) { + int keyOffset = cacheBaseOffset + t * kvDim + kvHeadIdx * headDim; + float score = 0.0f; + for (int i = 0; i < headDim; i++) { + score += q.get(qOffset + i) * keyCache.get(keyOffset + i); + } + // Gemma4 attention scaling = 1.0 (no 1/sqrt(headDim)) + wrapAtt.set(hOff + t, score); + } + + // STEP 2: softmax over [windowStart, pos] + float maxScore = wrapAtt.get(hOff + windowStart); + for (int t = windowStart + 1; t <= pos; t++) { + float val = wrapAtt.get(hOff + t); + if (val > maxScore) { + maxScore = val; + } + } + float sum = 0.0f; + for (int t = windowStart; t <= pos; t++) { + int idx = hOff + t; + float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore); + wrapAtt.set(idx, expScore); + sum += expScore; + } + float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos - windowStart + 1)); + for (int t = windowStart; t <= pos; t++) { + int idx = hOff + t; + wrapAtt.set(idx, wrapAtt.get(idx) * normFactor); + } + + // STEP 3: weighted sum of values + for (int i = 0; i < headDim; i++) { + float weightedSum = 0.0f; + for (int t = windowStart; t <= pos; t++) { + int valueOffset = cacheBaseOffset + t * kvDim + kvHeadIdx * headDim; + weightedSum += wrapAtt.get(hOff + t) * valueCache.get(valueOffset + i); + } + xb.set(h * headDim + i, weightedSum); + } + } + + /** + * Fused GeGLU FFN gate/up projection: {@code hb[row] = gelu(W1[row] . xNorm) * (W3[row] . xNorm)}. + * Mirrors {@code TransformerComputeKernelsLayered.fusedRmsNormFFNGateUp} but (a) takes an + * already-normalized input -- Gemma4 materializes the normalized branch separately via + * {@link #applyRmsNorm} since the same normalized {@code xb} also feeds the attention QKV + * projections -- and (b) uses GELU rather than SiLU (see {@link TransformerComputeKernelsLayered#geluActivation}). + */ + public static void fusedGateUpGeGLU( + KernelContext context, + FloatArray xNorm, + FloatArray hb, + HalfFloatArray w1, + HalfFloatArray w3, + int dim, + int hiddenDim, + int localWorkGroupSize) { + + int rowId = context.groupIdx; + int localId = context.localIdx; + if (rowId >= hiddenDim) { + return; + } + + float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize); + int rowOffset = rowId * dim; + + // === W1 (gate) === + float sum1 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + sum1 += w1.get(rowOffset + j).getFloat32() * xNorm.get(j); + } + localSum[localId] = sum1; + context.localBarrier(); + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result1 = localSum[0]; + context.localBarrier(); + + // === W3 (up) === + float sum3 = 0.0f; + for (int j = localId; j < dim; j += localWorkGroupSize) { + sum3 += w3.get(rowOffset + j).getFloat32() * xNorm.get(j); + } + localSum[localId] = sum3; + context.localBarrier(); + for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float result3 = localSum[0]; + + if (localId == 0) { + hb.set(rowId, TransformerComputeKernelsLayered.geluActivation(result1) * result3); + } + } + + /** {@code gate[i] = gelu(gate[i]) * perLayerInputs[peOffset + i]} -- the PLE gating step. */ + public static void pleGateGeluMul(KernelContext context, FloatArray gate, FloatArray perLayerInputs, int peOffset, int size) { + int gid = context.globalIdx; + if (gid < size) { + float gated = TransformerComputeKernelsLayered.geluActivation(gate.get(gid)); + gate.set(gid, gated * perLayerInputs.get(peOffset + gid)); + } + } + + /** + * Per-segment scale + RMSNorm with a single shared learned scale, used for the per-layer + * projection's normalization: {@code perLayerProjScratch} is laid out as {@code [numLayers][segmentSize]}, + * and {@code weight} (size {@code segmentSize}) is reused identically for every segment. One + * workgroup processes one segment (segment index = {@code groupIdx}), mirroring + * {@code rmsnorm(scratch, scratch, perLayerProjNorm, l*segmentSize, segmentSize, eps)} for every {@code l}. + */ + public static void pleProjScaleAndNormalize(KernelContext context, FloatArray x, FloatArray weight, int segmentSize, int localMemSize, float preScale, float rmsNormEps) { + int segIdx = context.groupIdx; + int localId = context.localIdx; + int localSize = context.localGroupSizeX; + int base = segIdx * segmentSize; + + float[] localSum = context.allocateFloatLocalArray(localMemSize); + float partial = 0f; + for (int i = localId; i < segmentSize; i += localSize) { + float v = x.get(base + i) * preScale; + x.set(base + i, v); + partial += v * v; + } + localSum[localId] = partial; + context.localBarrier(); + for (int stride = localSize / 2; stride > 0; stride >>= 1) { + if (localId < stride) { + localSum[localId] += localSum[localId + stride]; + } + context.localBarrier(); + } + float ss = localSum[0] / segmentSize + rmsNormEps; + ss = 1.0f / TornadoMath.sqrt(ss); + context.localBarrier(); + for (int i = localId; i < segmentSize; i += localSize) { + float normalized = ss * x.get(base + i); + x.set(base + i, weight.get(i) * normalized); + } + } + + /** Final logit soft-capping: {@code logits[i] = softcap * tanh(logits[i] / softcap)}. */ + public static void applyLogitSoftcap(KernelContext context, FloatArray logits, float softcap, int size) { + int gid = context.globalIdx; + if (gid < size) { + float v = logits.get(gid); + logits.set(gid, TornadoMath.tanh(v / softcap) * softcap); + } + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java index 42d2dc0c..64ba2815 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java @@ -1,6 +1,7 @@ package org.beehive.gpullama3.tornadovm.layerplanner; import org.beehive.gpullama3.inference.state.DevstralState; +import org.beehive.gpullama3.inference.state.Gemma4State; import org.beehive.gpullama3.inference.state.GraniteState; import org.beehive.gpullama3.tensor.GGMLType; import org.beehive.gpullama3.inference.state.LlamaState; @@ -10,6 +11,7 @@ import org.beehive.gpullama3.inference.state.State; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.DevstralFP16LayerPlanner; +import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Gemma4FP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner; import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner; @@ -63,6 +65,7 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) { case DEVSTRAL_2 -> new DevstralFP16LayerPlanner((DevstralState) state, model); case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model); + case GEMMA_4 -> new Gemma4FP16LayerPlanner((Gemma4State) state, model); case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model); case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model); case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model); diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java new file mode 100644 index 00000000..9b476cd4 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java @@ -0,0 +1,29 @@ +package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16; + +import org.beehive.gpullama3.inference.state.Gemma4State; +import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.model.gemma4.Gemma4Configuration; +import org.beehive.gpullama3.tornadovm.layers.Activation; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Gemma4LogitsFP16Layer; +import org.beehive.gpullama3.tornadovm.layers.type.fp16.Gemma4FP16FFNLayers; + +/** + * Gemma4FP16LayerPlanner: Gemma 4 model with FP16 weights. + * + * Follows the same pattern as Qwen3FP16LayerPlanner: wires together the (model-agnostic) Activation + * layer, Gemma4-specific FFN layers, and a Gemma4-specific logits layer (which adds the final + * logit soft-cap), then assembles the inference plan. + * + * Inherits from FP16LayerPlanner + */ +public class Gemma4FP16LayerPlanner extends FP16LayerPlanner { + + public Gemma4FP16LayerPlanner(Gemma4State state, Model model) { + super(state, model); + this.activationLayer = new Activation("activationUpdate", state, weights, config); + this.ffnLayers = new Gemma4FP16FFNLayers("gemma4FFN", state, weights, config, schedulerType); + this.logitsLayer = new Gemma4LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType); + createTornadoInferencePlan(); + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java new file mode 100644 index 00000000..2b97c241 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java @@ -0,0 +1,375 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.Gemma4State; +import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights; +import org.beehive.gpullama3.model.gemma4.Gemma4Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Gemma4Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.WorkerGrid; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Gemma4FP16FFNLayers: FP16 transformer-layer task graphs for the Gemma 4 architecture. + * + *

Gemma 4's layers differ enough from the "Llama-like" models that nothing here is fused the way + * {@code Qwen3FP16FFNLayers} is -- each layer carries its own Q/K-norm and a "sandwich" of pre/post + * normalization around both attention and FFN, attention head dimensions and RoPE tables differ + * between sliding-window and full-attention layers (and are baked into each layer's task graph as + * compile-time constants -- see {@link Gemma4Configuration#headDim}), some layers reuse an earlier + * layer's KV cache instead of computing their own, the FFN uses GeGLU, and every layer mixes in a + * per-layer embedding (PLE) contribution. See {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4} + * for the reference computation each task mirrors.

+ * + *

Layer 0's task graph additionally carries one-time-per-token setup that the reference + * implementation performs before the layer loop: scaling the token embedding by {@code sqrt(dim)}, + * and computing the per-layer-embedding inputs ({@code perLayerInputs}) from the per-layer model + * projection and the (host-gathered) per-layer token embedding row -- see {@link #appendPLESetupTasks}.

+ */ +public class Gemma4FP16FFNLayers extends AbstractFFNLayers { + + /** Local memory size for per-head Q/K/V-norm reductions; must evenly divide both head dimensions (256, 512). */ + private static final int HEAD_NORM_LOCAL_SIZE = 64; + + private final Gemma4State gemma4State; + private final int nHead; + private final int nHeadKv; + private final int kvMul; + private final int dim; + private final int nEmbdPerLayer; + private final int perLayerTotal; + private final float embedScale; + private final float perLayerTokEmbedScale; + private final float perLayerProjScale; + private final float perLayerInputScale; + + public Gemma4FP16FFNLayers(String taskGraphName, Gemma4State state, Gemma4TornadoWeights weights, Gemma4Configuration config, SchedulerType schedulerType) { + super(taskGraphName, state, weights, config, schedulerType); + this.gemma4State = state; + this.nHead = config.numberOfHeads(); + this.nHeadKv = config.numberOfKeyValueHeads(); + this.kvMul = config.kvMul(); + this.dim = config.dim(); + this.nEmbdPerLayer = config.embeddingLengthPerLayer(); + this.perLayerTotal = config.numberOfLayers() * nEmbdPerLayer; + this.embedScale = (float) Math.sqrt(dim); + this.perLayerTokEmbedScale = (float) Math.sqrt(nEmbdPerLayer); + this.perLayerProjScale = (float) (1.0 / Math.sqrt(dim)); + this.perLayerInputScale = (float) (1.0 / Math.sqrt(2.0)); + setupFFNLayers(); + } + + // ═══════════════════════════════════════════════════════════════════════════════════ + // TASK GRAPH + // ═══════════════════════════════════════════════════════════════════════════════════ + + @Override + protected TaskGraph createFFNLayerTaskGraph(int layerIndex) { + var taskGraphName = "layer_" + layerIndex; + final int headDim = config.headDim(layerIndex); + final boolean isSwa = config.isSwa(layerIndex); + final boolean hasOwnKv = config.hasOwnKv(layerIndex); + final int qDim = nHead * headDim; + final int kvDim = nHeadKv * headDim; + final int ffnLen = config.feedForwardLength(layerIndex); + final int cacheBaseOffset = gemma4State.cacheLayerBaseOffset[layerIndex]; + final int windowSize = isSwa ? config.slidingWindowSize() : config.contextLength(); + final var freqCisReal = (isSwa ? weights.freqCisRealSwa : weights.freqCisRealFull).asFloatArray(); + final var freqCisImag = (isSwa ? weights.freqCisImagSwa : weights.freqCisImagFull).asFloatArray(); + final int peOffset = layerIndex * nEmbdPerLayer; + + var unifiedLayer = new TaskGraph(taskGraphName); + unifiedLayer.consumeFromDevice(gemma4State.wrapX); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.rms_att_weightLayered[layerIndex].asFloatArray(), + weights.wqLayered[layerIndex].asHalfFloatArray(), + weights.wkLayered[layerIndex].asHalfFloatArray(), + weights.wvLayered[layerIndex].asHalfFloatArray(), + weights.woLayered[layerIndex].asHalfFloatArray(), + weights.attnQNorm[layerIndex].asFloatArray(), + weights.attnKNorm[layerIndex].asFloatArray(), + weights.attnPostNorm[layerIndex].asFloatArray(), + weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), + weights.w1Layered[layerIndex].asHalfFloatArray(), + weights.w3Layered[layerIndex].asHalfFloatArray(), + weights.w2Layered[layerIndex].asHalfFloatArray(), + weights.ffnPostNorm[layerIndex].asFloatArray(), + weights.perLayerInpGate[layerIndex].asHalfFloatArray(), + weights.perLayerProj[layerIndex].asHalfFloatArray(), + weights.perLayerPostNorm[layerIndex].asFloatArray()); + if (weights.layerOutputScale[layerIndex] != null) { + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.layerOutputScale[layerIndex].asFloatArray()); + } + unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex); + + if (layerIndex == 0) { + appendPLESetupTasks(unifiedLayer); + } + + // ═══════════════════════════════════ ATTENTION ═══════════════════════════════════ + unifiedLayer.task("attn_norm_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, gemma4State.temp, gemma4State.wrapX, dim, config.rmsNormEps(), gemma4State.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("attn_norm_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, gemma4State.temp, dim, config.rmsNormEps()); + } + unifiedLayer.task("attn_norm_apply", + Gemma4Kernels::applyRmsNorm, + context, gemma4State.wrapXb, gemma4State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), gemma4State.temp, dim); + + unifiedLayer.task("q_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapXb, gemma4State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), dim, qDim, LOCAL_WORK_GROUP_SIZE_ALLOC); + unifiedLayer.task("q_norm", + Gemma4Kernels::rmsNormPerHead, + context, gemma4State.wrapQ, weights.attnQNorm[layerIndex].asFloatArray(), nHead, headDim, HEAD_NORM_LOCAL_SIZE, config.rmsNormEps()); + + if (hasOwnKv) { + unifiedLayer.task("k_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapXb, gemma4State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), dim, kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC); + unifiedLayer.task("k_norm", + Gemma4Kernels::rmsNormPerHead, + context, gemma4State.wrapK, weights.attnKNorm[layerIndex].asFloatArray(), nHeadKv, headDim, HEAD_NORM_LOCAL_SIZE, config.rmsNormEps()); + unifiedLayer.task("v_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapXb, gemma4State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), dim, kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC); + unifiedLayer.task("v_norm", + Gemma4Kernels::rmsNormPerHeadNoWeight, + context, gemma4State.wrapV, nHeadKv, headDim, HEAD_NORM_LOCAL_SIZE, config.rmsNormEps()); + unifiedLayer.task("rope_and_cache", + Gemma4Kernels::ropeNeoxRotateAndCacheCopy, + context, gemma4State.positionHolder, gemma4State.wrapQ, gemma4State.wrapK, gemma4State.wrapV, + gemma4State.wrapKeyCache, gemma4State.wrapValueCache, freqCisReal, freqCisImag, + nHeadKv, headDim, kvDim, cacheBaseOffset); + } else { + unifiedLayer.task("rope_q_only", + Gemma4Kernels::ropeNeoxRotateQOnly, + context, gemma4State.positionHolder, gemma4State.wrapQ, freqCisReal, freqCisImag, headDim); + } + + unifiedLayer.task("attention", + Gemma4Kernels::attentionWithSlidingWindow, + gemma4State.wrapQ, gemma4State.wrapKeyCache, gemma4State.wrapValueCache, gemma4State.wrapXb, gemma4State.wrapAtt, + nHead, headDim, kvDim, kvMul, gemma4State.positionHolder, cacheBaseOffset, windowSize, config.contextLength()); + + unifiedLayer.task("wo_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapXb, gemma4State.wrapXb2, weights.woLayered[layerIndex].asHalfFloatArray(), qDim, dim, LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("post_attn_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, gemma4State.tempPostAttn, gemma4State.wrapXb2, dim, config.rmsNormEps(), gemma4State.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("post_attn_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, gemma4State.tempPostAttn, dim, config.rmsNormEps()); + } + unifiedLayer.task("post_attn_apply", + Gemma4Kernels::rmsNormApplyWithResidual, + context, gemma4State.wrapX, gemma4State.wrapXb2, weights.attnPostNorm[layerIndex].asFloatArray(), gemma4State.tempPostAttn, dim); + + // ═══════════════════════════════════════ FFN ═════════════════════════════════════ + unifiedLayer.task("ffn_norm_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, gemma4State.tempFFN, gemma4State.wrapX, dim, config.rmsNormEps(), gemma4State.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ffn_norm_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, gemma4State.tempFFN, dim, config.rmsNormEps()); + } + unifiedLayer.task("ffn_norm_apply", + Gemma4Kernels::applyRmsNorm, + context, gemma4State.wrapXb, gemma4State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), gemma4State.tempFFN, dim); + + unifiedLayer.task("ffn_gate_up", + Gemma4Kernels::fusedGateUpGeGLU, + context, gemma4State.wrapXb, gemma4State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(), + dim, ffnLen, LOCAL_WORK_GROUP_SIZE_ALLOC); + unifiedLayer.task("ffn_down_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapHb, gemma4State.wrapXb2, weights.w2Layered[layerIndex].asHalfFloatArray(), ffnLen, dim, LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("post_ffn_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, gemma4State.tempPostFfn, gemma4State.wrapXb2, dim, config.rmsNormEps(), gemma4State.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("post_ffn_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, gemma4State.tempPostFfn, dim, config.rmsNormEps()); + } + unifiedLayer.task("post_ffn_apply", + Gemma4Kernels::rmsNormApplyWithResidual, + context, gemma4State.wrapX, gemma4State.wrapXb2, weights.ffnPostNorm[layerIndex].asFloatArray(), gemma4State.tempPostFfn, dim); + + // ═══════════════════════════ PER-LAYER EMBEDDING (PLE) ═══════════════════════════ + unifiedLayer.task("ple_gate_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapX, gemma4State.wrapPerLayerGate, weights.perLayerInpGate[layerIndex].asHalfFloatArray(), dim, nEmbdPerLayer, LOCAL_WORK_GROUP_SIZE_ALLOC); + unifiedLayer.task("ple_gate_gelu_mul", + Gemma4Kernels::pleGateGeluMul, + context, gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerInputs, peOffset, nEmbdPerLayer); + unifiedLayer.task("ple_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerOut, weights.perLayerProj[layerIndex].asHalfFloatArray(), nEmbdPerLayer, dim, LOCAL_WORK_GROUP_SIZE_ALLOC); + + unifiedLayer.task("ple_post_reduce", + TransformerComputeKernelsLayered::reductionOneBlockWithLayer, + context, gemma4State.tempPostPle, gemma4State.wrapPerLayerOut, dim, config.rmsNormEps(), gemma4State.localSize); + if (shouldUseFinalNormalization()) { + unifiedLayer.task("ple_post_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, gemma4State.tempPostPle, dim, config.rmsNormEps()); + } + unifiedLayer.task("ple_post_apply", + Gemma4Kernels::rmsNormApplyWithResidual, + context, gemma4State.wrapX, gemma4State.wrapPerLayerOut, weights.perLayerPostNorm[layerIndex].asFloatArray(), gemma4State.tempPostPle, dim); + + if (weights.layerOutputScale[layerIndex] != null) { + unifiedLayer.task("layer_output_scale", + Gemma4Kernels::scaleInPlaceFromTensor, + context, gemma4State.wrapX, weights.layerOutputScale[layerIndex].asFloatArray(), dim); + } + + unifiedLayer.persistOnDevice(gemma4State.wrapX); + return unifiedLayer; + } + + /** + * One-time-per-token setup tasks, prepended to layer 0's graph: scales the token embedding by + * {@code sqrt(dim)} (Gemma4 scales embeddings on input -- the generic {@link org.beehive.gpullama3.tornadovm.layers.Activation} + * task graph that produced {@code wrapX} doesn't know about this), then computes the per-layer + * embedding inputs from the per-layer model projection and the (host-gathered) per-token + * per-layer-token-embedding row. Mirrors steps 1-2 of {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}. + */ + private void appendPLESetupTasks(TaskGraph unifiedLayer) { + unifiedLayer.task("scale_embedding", + Gemma4Kernels::scaleInPlace, + context, gemma4State.wrapX, embedScale, dim); + + unifiedLayer.task("ple_model_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, gemma4State.wrapX, gemma4State.wrapPerLayerProjScratch, weights.perLayerModelProj.asHalfFloatArray(), dim, perLayerTotal, LOCAL_WORK_GROUP_SIZE_ALLOC); + unifiedLayer.task("ple_proj_scale_norm", + Gemma4Kernels::pleProjScaleAndNormalize, + context, gemma4State.wrapPerLayerProjScratch, weights.perLayerProjNorm.asFloatArray(), nEmbdPerLayer, HEAD_NORM_LOCAL_SIZE, perLayerProjScale, config.rmsNormEps()); + unifiedLayer.task("ple_merge", + Gemma4Kernels::addAndScale, + context, gemma4State.wrapPerLayerInputs, gemma4State.wrapPerLayerProjScratch, gemma4State.wrapPerLayerTokenEmbedRow, perLayerInputScale, perLayerTotal); + } + + /** + * Configure data transfers for first and subsequent layers. + */ + protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) { + if (layerIndex == 0) { + unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION, + gemma4State.positionHolder, gemma4State.wrapPerLayerTokenEmbedRow, + gemma4State.temp, gemma4State.tempFFN, gemma4State.tempPostAttn, gemma4State.tempPostFfn, gemma4State.tempPostPle); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + weights.perLayerModelProj.asHalfFloatArray(), weights.perLayerProjNorm.asFloatArray(), + weights.freqCisRealSwa.asFloatArray(), weights.freqCisImagSwa.asFloatArray(), + weights.freqCisRealFull.asFloatArray(), weights.freqCisImagFull.asFloatArray()); + unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, gemma4State.wrapXb, gemma4State.wrapXb2, + gemma4State.wrapQ, gemma4State.wrapK, gemma4State.wrapV, + gemma4State.wrapKeyCache, gemma4State.wrapValueCache, + gemma4State.wrapAtt, gemma4State.wrapHb, + gemma4State.wrapPerLayerInputs, gemma4State.wrapPerLayerProjScratch, + gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerOut); + } else { + unifiedLayer.consumeFromDevice(context, gemma4State.wrapXb, gemma4State.wrapXb2, + gemma4State.wrapQ, gemma4State.wrapK, gemma4State.wrapV, + gemma4State.wrapKeyCache, gemma4State.wrapValueCache, + gemma4State.wrapAtt, gemma4State.wrapHb, + gemma4State.wrapPerLayerInputs, gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerOut, + gemma4State.positionHolder); + } + return unifiedLayer; + } + + // ═══════════════════════════════════════════════════════════════════════════════════ + // GRID SCHEDULER + // ═══════════════════════════════════════════════════════════════════════════════════ + + @Override + public GridScheduler updateGridScheduler(GridScheduler gridScheduler) { + WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(dim, gemma4State.localSize); + WorkerGrid dimElementWiseWorker = WorkerGridFactory.genericWorker(dim, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid woProjWorker = WorkerGridFactory.genericWorker(dim * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid pleGateProjWorker = WorkerGridFactory.genericWorker(nEmbdPerLayer * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid pleGateGeluWorker = WorkerGridFactory.genericWorker(nEmbdPerLayer, LOCAL_WORK_GROUP_SIZE_ALLOC); + + // === Layer-0 PLE setup === + gridScheduler.addWorkerGrid("layer_0.scale_embedding", dimElementWiseWorker); + gridScheduler.addWorkerGrid("layer_0.ple_model_proj", WorkerGridFactory.genericWorker(perLayerTotal * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC)); + gridScheduler.addWorkerGrid("layer_0.ple_proj_scale_norm", WorkerGridFactory.genericWorker(config.numberOfLayers() * HEAD_NORM_LOCAL_SIZE, HEAD_NORM_LOCAL_SIZE)); + gridScheduler.addWorkerGrid("layer_0.ple_merge", WorkerGridFactory.genericWorker(perLayerTotal, LOCAL_WORK_GROUP_SIZE_ALLOC)); + + for (int i = 0; i < config.numberOfLayers(); i++) { + String prefix = "layer_" + i + "."; + int headDim = config.headDim(i); + boolean hasOwnKv = config.hasOwnKv(i); + int qDim = nHead * headDim; + int kvDim = nHeadKv * headDim; + int ffnLen = config.feedForwardLength(i); + + WorkerGrid headNormWorker = WorkerGridFactory.genericWorker(nHead * HEAD_NORM_LOCAL_SIZE, HEAD_NORM_LOCAL_SIZE); + WorkerGrid kvHeadNormWorker = WorkerGridFactory.genericWorker(nHeadKv * HEAD_NORM_LOCAL_SIZE, HEAD_NORM_LOCAL_SIZE); + WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(nHead, headDim); + WorkerGrid attentionWorker = WorkerGridFactory.createAttentionWorker(nHead, headDim); + WorkerGrid qProjWorker = WorkerGridFactory.genericWorker(qDim * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid kvProjWorker = WorkerGridFactory.genericWorker(kvDim * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC); + WorkerGrid ffnGateUpWorker = WorkerGridFactory.genericWorker(ffnLen * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC); + + gridScheduler.addWorkerGrid(prefix + "attn_norm_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "attn_norm_apply", dimElementWiseWorker); + gridScheduler.addWorkerGrid(prefix + "q_proj", qProjWorker); + gridScheduler.addWorkerGrid(prefix + "q_norm", headNormWorker); + if (hasOwnKv) { + gridScheduler.addWorkerGrid(prefix + "k_proj", kvProjWorker); + gridScheduler.addWorkerGrid(prefix + "k_norm", kvHeadNormWorker); + gridScheduler.addWorkerGrid(prefix + "v_proj", kvProjWorker); + gridScheduler.addWorkerGrid(prefix + "v_norm", kvHeadNormWorker); + gridScheduler.addWorkerGrid(prefix + "rope_and_cache", ropeWorker); + } else { + gridScheduler.addWorkerGrid(prefix + "rope_q_only", ropeWorker); + } + gridScheduler.addWorkerGrid(prefix + "attention", attentionWorker); + gridScheduler.addWorkerGrid(prefix + "wo_proj", woProjWorker); + gridScheduler.addWorkerGrid(prefix + "post_attn_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "post_attn_apply", dimElementWiseWorker); + + gridScheduler.addWorkerGrid(prefix + "ffn_norm_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "ffn_norm_apply", dimElementWiseWorker); + gridScheduler.addWorkerGrid(prefix + "ffn_gate_up", ffnGateUpWorker); + gridScheduler.addWorkerGrid(prefix + "ffn_down_proj", woProjWorker); + gridScheduler.addWorkerGrid(prefix + "post_ffn_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "post_ffn_apply", dimElementWiseWorker); + + gridScheduler.addWorkerGrid(prefix + "ple_gate_proj", pleGateProjWorker); + gridScheduler.addWorkerGrid(prefix + "ple_gate_gelu_mul", pleGateGeluWorker); + gridScheduler.addWorkerGrid(prefix + "ple_proj", woProjWorker); + gridScheduler.addWorkerGrid(prefix + "ple_post_reduce", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "ple_post_apply", dimElementWiseWorker); + + if (shouldUseFinalNormalization()) { + gridScheduler.addWorkerGrid(prefix + "attn_norm_finalize", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "post_attn_finalize", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "ffn_norm_finalize", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "post_ffn_finalize", rmsNormWorker); + gridScheduler.addWorkerGrid(prefix + "ple_post_finalize", rmsNormWorker); + } + if (weights.layerOutputScale[i] != null) { + gridScheduler.addWorkerGrid(prefix + "layer_output_scale", dimElementWiseWorker); + } + } + return gridScheduler; + } +} diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java new file mode 100644 index 00000000..baa76009 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java @@ -0,0 +1,114 @@ +package org.beehive.gpullama3.tornadovm.layers.type.fp16; + +import org.beehive.gpullama3.inference.state.State; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.gemma4.Gemma4Configuration; +import org.beehive.gpullama3.tornadovm.kernels.Gemma4Kernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels; +import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered; +import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory; +import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType; +import uk.ac.manchester.tornado.api.GridScheduler; +import uk.ac.manchester.tornado.api.TaskGraph; +import uk.ac.manchester.tornado.api.enums.DataTransferMode; + +/** + * Gemma4-specific FP16 logits layer. + * + * Identical to {@link LogitsFP16Layer} except for one addition: Gemma4 applies a final + * logit soft-cap, {@code logits = softcap * tanh(logits / softcap)}, after the vocabulary + * projection (see {@code gemma4.final_logit_softcapping} and + * {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}). + */ +public class Gemma4LogitsFP16Layer extends LogitsFP16Layer { + + private static final String SOFTCAP_TASK = "logit_softcap"; + + public Gemma4LogitsFP16Layer(String name, State state, Weights weights, Configuration config, + String lastTaskGraphID, SchedulerType schedulerType) { + super(name, state, weights, config, lastTaskGraphID, schedulerType); + } + + private float softcap() { + return ((Gemma4Configuration) config).finalLogitSoftcapping(); + } + + // @formatter:off + @Override + protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) { + var logits = new TaskGraph("logits"); + // === Data Setup === + logits.consumeFromDevice(lastTaskGraphID, state.wrapX); + logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits); + logits.transferToDevice(DataTransferMode.FIRST_EXECUTION, + context, + state.wrapLogits, + state.wrapXbFP16, + weights.wclsByteArray.asHalfFloatArray(), + weights.rms_final_weight_as_floatArray.asFloatArray()); + + // === Final RMS Normalization === + logits.task("rms_reduce", + TransformerComputeKernels::reductionOneBlockWithLayer, + context, + state.tempLogits, + state.wrapX, + config.dim(), + config.rmsNormEps(), + state.localSize); + + if (schedulerType == SchedulerType.NON_NVIDIA) { + logits.task("rms_finalize", + TransformerComputeKernelsLayered::reductionFinalNormalization, + context, + state.tempLogits, + config.dim(), + config.rmsNormEps()); + } + + logits.task("rms_apply_fp16", + TransformerComputeKernels::mapContextWithQuantizeLogits, + context, + state.wrapXbFP16, + state.wrapX, + weights.rms_final_weight_as_floatArray.asFloatArray(), + state.tempLogits); + + // === Vocabulary Projection === + logits.task("vocab_proj", + TransformerComputeKernelsLayered::matrixVectorGeneric, + context, + state.wrapXbFP16, + state.wrapLogits, + weights.wclsByteArray.asHalfFloatArray(), + config.dim(), + config.vocabularySize(), + LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS); + + // === Final logit soft-capping (Gemma4-specific) === + if (softcap() != 0.0f) { + logits.task(SOFTCAP_TASK, + Gemma4Kernels::applyLogitSoftcap, + context, + state.wrapLogits, + softcap(), + config.vocabularySize()); + } + + // === Transfer Results to Host === + logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits); + return logits; + } + // @formatter:on + + @Override + public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) { + var scheduler = super.updateGridScheduler(tornadoForwardScheduler); + if (softcap() != 0.0f) { + scheduler.addWorkerGrid("logits." + SOFTCAP_TASK, WorkerGridFactory.genericWorker(config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC)); + } + return scheduler; + } +}