diff --git a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
index 9beade35..d54c581d 100644
--- a/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
+++ b/src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
@@ -2,8 +2,11 @@
import org.beehive.gpullama3.auxiliary.Parallel;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
+import org.beehive.gpullama3.inference.state.Gemma4State;
import org.beehive.gpullama3.inference.state.Phi3State;
import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights;
+import org.beehive.gpullama3.model.loader.ModelLoader;
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
@@ -11,6 +14,7 @@
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
@@ -534,6 +538,186 @@ public static FloatTensor forwardJavaQwen3(Model model, State state, int token,
return state.logits;
}
+ /** RMS-normalizes without applying a learned scale (Gemma4 normalizes V with a plain, weight-less RMSNorm). */
+ private static void rmsnormNoWeight(FloatTensor out, FloatTensor x, int offset, int size, float rmsNormEps) {
+ float ss = x.reduce(offset, size, 0f, (acc, xi) -> acc + xi * xi);
+ ss /= size;
+ ss += rmsNormEps;
+ ss = (float) (1.0 / Math.sqrt(ss));
+ final float finalss = ss;
+ out.mapWithIndexInPlace(offset, size, (value, index) -> finalss * x.getFloat(index));
+ }
+
+ /** Tanh-approximation GELU, matching ggml's {@code ggml_gelu_f32} (used by Gemma4's GeGLU FFN and PLE gate). */
+ private static float gelu(float x) {
+ return 0.5f * x * (1.0f + (float) Math.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)));
+ }
+
+ /** NeoX-style RoPE: rotates pairs (i, i + headDim/2) within each head using precomputed cos/sin tables. */
+ private static void ropeRotateNeox(FloatTensor vec, int nHeads, int headDim, int position, FloatTensor freqCisReal, FloatTensor freqCisImag) {
+ int nComplHead = headDim / 2;
+ for (int h = 0; h < nHeads; h++) {
+ int base = h * headDim;
+ for (int ic = 0; ic < nComplHead; ic++) {
+ float fcr = freqCisReal.getFloat(position * nComplHead + ic);
+ float fci = freqCisImag.getFloat(position * nComplHead + ic);
+ float v0 = vec.getFloat(base + ic);
+ float v1 = vec.getFloat(base + ic + nComplHead);
+ vec.setFloat(base + ic, v0 * fcr - v1 * fci);
+ vec.setFloat(base + ic + nComplHead, v0 * fci + v1 * fcr);
+ }
+ }
+ }
+
+ public static FloatTensor forwardJavaGemma4(Model model, State state, int token, int position) {
+ final Gemma4Configuration config = (Gemma4Configuration) model.configuration();
+ final Gemma4StandardWeights weights = (Gemma4StandardWeights) model.weights();
+ final Gemma4State gs = (Gemma4State) state;
+
+ final int dim = config.dim();
+ final int nHead = config.numberOfHeads();
+ final int nHeadKv = config.numberOfKeyValueHeads();
+ final int kvMul = config.kvMul();
+ final int nLayers = config.numberOfLayers();
+ final int nEmbdPerLayer = config.embeddingLengthPerLayer();
+ final int perLayerTotal = nLayers * nEmbdPerLayer;
+ final float attentionScale = 1.0f; // Gemma4 attention uses scaling = 1.0 (no 1/sqrt(headDim))
+
+ // 1. token embedding, scaled by sqrt(dim)
+ weights.tokenEmbeddingTable.copyTo(token * dim, state.x, 0, dim);
+ final float embedScale = (float) Math.sqrt(dim);
+ state.x.mapInPlace(v -> v * embedScale);
+
+ // 2. per-layer embeddings (PLE): inp_per_layer[l] = (rmsnorm(proj(x) / sqrt(dim)) + tokEmbd[l]*sqrt(nEmbdPerLayer)) / sqrt(2)
+ // per_layer_token_embd is ~2.35B elements (too large for the int-indexed FloatTensor API), so it
+ // is addressed one embedding row at a time directly from its raw tensor entry.
+ ModelLoader.copyEmbeddingRow(weights.perLayerTokenEmbd, token, perLayerTotal, gs.perLayerInputs, 0);
+ final float perLayerTokEmbedScale = (float) Math.sqrt(nEmbdPerLayer);
+ gs.perLayerInputs.mapInPlace(v -> v * perLayerTokEmbedScale);
+
+ weights.perLayerModelProj.matmul(state.x, gs.perLayerProjScratch, perLayerTotal, dim);
+ final float perLayerProjScale = (float) (1.0 / Math.sqrt(dim));
+ gs.perLayerProjScratch.mapInPlace(v -> v * perLayerProjScale);
+ for (int l = 0; l < nLayers; l++) {
+ rmsnorm(gs.perLayerProjScratch, gs.perLayerProjScratch, weights.perLayerProjNorm, l * nEmbdPerLayer, nEmbdPerLayer, config.rmsNormEps());
+ }
+ final float perLayerInputScale = (float) (1.0 / Math.sqrt(2.0));
+ for (int i = 0; i < perLayerTotal; i++) {
+ float v = (gs.perLayerProjScratch.getFloat(i) + gs.perLayerInputs.getFloat(i)) * perLayerInputScale;
+ gs.perLayerInputs.setFloat(i, v);
+ }
+
+ // 3. transformer layers
+ for (int l = 0; l < nLayers; l++) {
+ final int curLayer = l;
+ final int headDim = config.headDim(l);
+ final boolean isSwa = config.isSwa(l);
+ final int qDim = nHead * headDim;
+ final int kvDim = nHeadKv * headDim;
+
+ FloatTensor freqCisReal = isSwa ? weights.freqCisRealSwa : weights.freqCisRealFull;
+ FloatTensor freqCisImag = isSwa ? weights.freqCisImagSwa : weights.freqCisImagFull;
+
+ // attn_norm
+ rmsnorm(state.xb, state.x, weights.attnNorm[l], 0, dim, config.rmsNormEps());
+
+ // Q projection, per-head Q-norm, RoPE
+ weights.wq[l].matmul(state.xb, state.q, qDim, dim);
+ for (int h = 0; h < nHead; h++) {
+ rmsnorm(state.q, state.q, weights.attnQNorm[l], h * headDim, headDim, config.rmsNormEps());
+ }
+ ropeRotateNeox(state.q, nHead, headDim, position, freqCisReal, freqCisImag);
+
+ // K/V: either compute and cache them here, or reuse an earlier layer's cache ("shared KV layers")
+ final int kvSrcLayer;
+ if (config.hasOwnKv(l)) {
+ weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
+ weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
+ for (int h = 0; h < nHeadKv; h++) {
+ rmsnorm(state.k, state.k, weights.attnKNorm[l], h * headDim, headDim, config.rmsNormEps());
+ rmsnormNoWeight(state.v, state.v, h * headDim, headDim, config.rmsNormEps());
+ }
+ ropeRotateNeox(state.k, nHeadKv, headDim, position, freqCisReal, freqCisImag);
+
+ state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
+ state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
+ kvSrcLayer = l;
+ } else {
+ kvSrcLayer = config.kvReuseLayer(l);
+ }
+
+ // self-attention (causal; sliding-window layers additionally restrict to a local window)
+ final int windowStart = isSwa ? Math.max(0, position - config.slidingWindowSize() + 1) : 0;
+ Parallel.parallelFor(0, nHead, h -> {
+ int qOffset = h * headDim;
+ int attOffset = h * config.contextLength();
+ int kvHeadOffset = (h / kvMul) * headDim;
+
+ for (int t = windowStart; t <= position; t++) {
+ int kvOffset = t * kvDim + kvHeadOffset;
+ float score = state.q.dot(qOffset, state.keyCache[kvSrcLayer], kvOffset, headDim);
+ score *= attentionScale;
+ state.att.setFloat(attOffset + t, score);
+ }
+
+ state.att.softmaxInPlace(attOffset + windowStart, position - windowStart + 1);
+
+ int xbOffset = h * headDim;
+ state.xb.fillInPlace(xbOffset, headDim, 0f);
+ for (int t = windowStart; t <= position; t++) {
+ int kvOffset = t * kvDim + kvHeadOffset;
+ float a = state.att.getFloat(attOffset + t);
+ state.xb.saxpyInPlace(xbOffset, state.valueCache[kvSrcLayer], kvOffset, headDim, a);
+ }
+ });
+
+ // wo projection, post-attention norm, residual -> attn_out (kept in state.x)
+ weights.wo[curLayer].matmul(state.xb, state.xb2, dim, qDim);
+ rmsnorm(state.xb2, state.xb2, weights.attnPostNorm[curLayer], 0, dim, config.rmsNormEps());
+ state.x.addInPlace(state.xb2); // state.x now holds attn_out = inpL + post_attn_norm(attn(...))
+
+ // FFN (GeGLU: down(gelu(gate(x)) * up(x))), post-FFN norm, residual -> cur (kept in state.x)
+ rmsnorm(state.xb, state.x, weights.ffnNorm[curLayer], 0, dim, config.rmsNormEps());
+ weights.ffnGate[curLayer].matmul(state.xb, state.hb, config.feedForwardLength(curLayer), dim);
+ weights.ffnUp[curLayer].matmul(state.xb, state.hb2, config.feedForwardLength(curLayer), dim);
+ state.hb.mapInPlace(InferenceCore::gelu);
+ state.hb.multiplyInPlace(state.hb2);
+ weights.ffnDown[curLayer].matmul(state.hb, state.xb2, dim, config.feedForwardLength(curLayer));
+ rmsnorm(state.xb2, state.xb2, weights.ffnPostNorm[curLayer], 0, dim, config.rmsNormEps());
+ state.x.addInPlace(state.xb2); // state.x now holds cur = attn_out + post_ffn_norm(ffn(...))
+
+ // per-layer embedding (PLE): cur += per_layer_post_norm(proj(gelu(inp_gate(cur)) * inp_per_layer[l]))
+ weights.perLayerInpGate[curLayer].matmul(state.x, gs.perLayerGate, nEmbdPerLayer, dim);
+ gs.perLayerGate.mapInPlace(InferenceCore::gelu);
+ int peOffset = curLayer * nEmbdPerLayer;
+ for (int j = 0; j < nEmbdPerLayer; j++) {
+ gs.perLayerGate.setFloat(j, gs.perLayerGate.getFloat(j) * gs.perLayerInputs.getFloat(peOffset + j));
+ }
+ weights.perLayerProj[curLayer].matmul(gs.perLayerGate, gs.perLayerOut, dim, nEmbdPerLayer);
+ rmsnorm(gs.perLayerOut, gs.perLayerOut, weights.perLayerPostNorm[curLayer], 0, dim, config.rmsNormEps());
+ state.x.addInPlace(gs.perLayerOut);
+
+ // optional learned per-layer output scale
+ FloatTensor outScale = weights.layerOutputScale[curLayer];
+ if (outScale != null) {
+ final float scale = outScale.getFloat(0);
+ state.x.mapInPlace(v -> v * scale);
+ }
+ }
+
+ // final norm, classifier, and logit soft-capping: logits = softcap * tanh(logits / softcap)
+ rmsnorm(state.x, state.x, weights.outputNorm, 0, dim, config.rmsNormEps());
+ weights.outputWeight.matmul(state.x, state.logits, config.vocabularySize(), dim);
+
+ final float softcap = config.finalLogitSoftcapping();
+ if (softcap != 0.0f) {
+ final float invSoftcap = 1.0f / softcap;
+ state.logits.mapInPlace(v -> (float) Math.tanh(v * invSoftcap) * softcap);
+ }
+
+ return state.logits;
+ }
+
public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int token, int position) {
Phi3Configuration config = (Phi3Configuration) model.configuration();
Phi3StandardWeights weights = (Phi3StandardWeights) model.weights();
diff --git a/src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java b/src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java
new file mode 100644
index 00000000..01f96ea0
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/inference/state/Gemma4State.java
@@ -0,0 +1,195 @@
+package org.beehive.gpullama3.inference.state;
+
+import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
+import org.beehive.gpullama3.tensor.standard.FloatTensor;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * Inference state for Gemma 4 models.
+ *
+ *
In addition to the common buffers, Gemma 4 needs scratch space for its per-layer
+ * embedding (PLE) mechanism. Buffers that vary in size across layers (Q/K/V, attention
+ * output, FFN hidden state) are sized to the maximum across all layers. The KV cache is
+ * allocated per "physical" layer: layers that reuse an earlier layer's KV cache (Gemma4's
+ * "shared KV layers" feature) simply alias that layer's cache arrays.
+ *
+ * The TornadoVM (GPU) wrapper buffers mirror the same scheme: {@link #wrapKeyCache}/
+ * {@link #wrapValueCache} are laid out back-to-back only for layers that own a KV cache
+ * (see {@link #cacheLayerBaseOffset}), and the per-layer-embedding scratch buffers are
+ * exposed as flat {@link FloatArray}s for transfer to the GPU.
+ */
+public final class Gemma4State extends State {
+
+ /** Per-layer projected input embeddings (PLE), laid out as [layer][embeddingLengthPerLayer]. */
+ public final FloatTensor perLayerInputs;
+ /** Scratch buffer for the per-layer model projection output, same layout as {@link #perLayerInputs}. */
+ public final FloatTensor perLayerProjScratch;
+ /** Scratch buffer for a single layer's gated per-layer-embedding contribution. */
+ public final FloatTensor perLayerGate;
+ /** Scratch buffer for a single layer's projected per-layer-embedding output (dim-sized). */
+ public final FloatTensor perLayerOut;
+
+ /**
+ * For each layer {@code l}, the base element offset of its KV-cache slot inside
+ * {@link #wrapKeyCache}/{@link #wrapValueCache} (GPU path) and {@link #keyCache}/{@link #valueCache}
+ * (CPU path, where it doubles as the "physical" layer index used for cache aliasing). Layers that
+ * reuse an earlier layer's cache share that layer's offset, so attention kernels can address the
+ * (possibly shared) cache uniformly via {@code cacheLayerBaseOffset[l]} without branching on reuse.
+ */
+ public final int[] cacheLayerBaseOffset;
+
+ // GPU (TornadoVM) per-layer-embedding scratch buffers; mirror perLayerInputs/perLayerProjScratch/perLayerGate/perLayerOut.
+ public final FloatArray wrapPerLayerInputs;
+ public final FloatArray wrapPerLayerProjScratch;
+ public final FloatArray wrapPerLayerGate;
+ public final FloatArray wrapPerLayerOut;
+ /** Holds the current token's per-layer-token-embedding row (gathered on the host each step, then transferred to the GPU). */
+ public final FloatArray wrapPerLayerTokenEmbedRow;
+
+ // Extra RMSNorm reduction scratch buffers (GPU path): Gemma4's "sandwich norm" pattern needs five
+ // independent reductions per layer (attn-norm uses the inherited `temp`, FFN-norm `tempFFN`); each
+ // of the others gets its own buffer so consecutive reduce/apply pairs never alias.
+ public final FloatArray tempPostAttn;
+ public final FloatArray tempPostFfn;
+ public final FloatArray tempPostPle;
+
+ public Gemma4State(Configuration config, int batchsize) {
+ super(config, batchsize);
+
+ Gemma4Configuration gemma4config = (Gemma4Configuration) config;
+ int perLayerTotal = gemma4config.numberOfLayers() * gemma4config.embeddingLengthPerLayer();
+ this.perLayerInputs = ArrayFloatTensor.allocate(perLayerTotal);
+ this.perLayerProjScratch = ArrayFloatTensor.allocate(perLayerTotal);
+ this.perLayerGate = ArrayFloatTensor.allocate(gemma4config.embeddingLengthPerLayer());
+ this.perLayerOut = ArrayFloatTensor.allocate(gemma4config.dim());
+
+ this.cacheLayerBaseOffset = computeCacheLayerBaseOffsets(gemma4config);
+
+ this.wrapPerLayerInputs = new FloatArray(perLayerTotal);
+ this.wrapPerLayerProjScratch = new FloatArray(perLayerTotal);
+ this.wrapPerLayerGate = new FloatArray(gemma4config.embeddingLengthPerLayer());
+ this.wrapPerLayerOut = new FloatArray(gemma4config.dim());
+ this.wrapPerLayerTokenEmbedRow = new FloatArray(perLayerTotal);
+
+ int tempSize = 1 + ((gemma4config.dim() + localSize - 1) / localSize);
+ this.tempPostAttn = new FloatArray(tempSize);
+ this.tempPostFfn = new FloatArray(tempSize);
+ this.tempPostPle = new FloatArray(tempSize);
+ }
+
+ /**
+ * Computes, for each layer, the base element offset of its KV-cache slot in a flat buffer that
+ * back-to-back concatenates only the caches of layers that own one ({@link Gemma4Configuration#hasOwnKv}).
+ * Reusing layers inherit their source layer's offset (and -- by construction -- its head dimension,
+ * since {@link Gemma4Configuration#kvReuseLayer} only ever points to a layer with the same {@code isSwa}-ness).
+ */
+ private static int[] computeCacheLayerBaseOffsets(Gemma4Configuration config) {
+ int nHeadKv = config.numberOfKeyValueHeads();
+ int[] offsets = new int[config.numberOfLayers()];
+ int running = 0;
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ int reuse = config.kvReuseLayer(l);
+ if (reuse < 0) {
+ offsets[l] = running;
+ running += config.contextLength() * (nHeadKv * config.headDim(l));
+ } else {
+ offsets[l] = offsets[reuse];
+ }
+ }
+ return offsets;
+ }
+
+ /** Total number of elements needed for the (deduplicated) flat KV cache buffer. */
+ private static int totalCacheElements(Gemma4Configuration config, int[] cacheLayerBaseOffset) {
+ int nHeadKv = config.numberOfKeyValueHeads();
+ int total = 0;
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ if (config.hasOwnKv(l)) {
+ total = Math.max(total, cacheLayerBaseOffset[l] + config.contextLength() * (nHeadKv * config.headDim(l)));
+ }
+ }
+ return total;
+ }
+
+ @Override
+ protected StateFields createStateFields(Configuration configuration) {
+ StateFields fields = new StateFields();
+
+ Gemma4Configuration config = (Gemma4Configuration) configuration;
+
+ int dim = config.dim();
+ int nHead = config.numberOfHeads();
+ int nHeadKv = config.numberOfKeyValueHeads();
+ int maxHeadDim = config.maxHeadDim();
+ int maxFFN = config.maxFeedForwardLength();
+
+ int qSize = nHead * maxHeadDim;
+ int kvSize = nHeadKv * maxHeadDim;
+
+ fields.x = ArrayFloatTensor.allocate(dim);
+ fields.xb = ArrayFloatTensor.allocate(Math.max(dim, qSize));
+ fields.xb2 = ArrayFloatTensor.allocate(dim);
+ fields.hb = ArrayFloatTensor.allocate(maxFFN);
+ fields.hb2 = ArrayFloatTensor.allocate(maxFFN);
+ fields.q = ArrayFloatTensor.allocate(qSize);
+ fields.k = ArrayFloatTensor.allocate(kvSize);
+ fields.v = ArrayFloatTensor.allocate(kvSize);
+ fields.att = ArrayFloatTensor.allocate(nHead, config.contextLength());
+ fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
+
+ // KV cache: layers that own their KV get a fresh cache; layers that reuse an earlier
+ // layer's KV (Gemma4's "shared KV layers") alias that layer's arrays directly.
+ FloatTensor[] keyCache = new FloatTensor[config.numberOfLayers()];
+ FloatTensor[] valueCache = new FloatTensor[config.numberOfLayers()];
+ for (int l = 0; l < config.numberOfLayers(); l++) {
+ int reuse = config.kvReuseLayer(l);
+ if (reuse < 0) {
+ int layerKvDim = config.headDim(l) * nHeadKv;
+ keyCache[l] = ArrayFloatTensor.allocate(config.contextLength(), layerKvDim);
+ valueCache[l] = ArrayFloatTensor.allocate(config.contextLength(), layerKvDim);
+ } else {
+ keyCache[l] = keyCache[reuse];
+ valueCache[l] = valueCache[reuse];
+ }
+ }
+ fields.keyCache = keyCache;
+ fields.valueCache = valueCache;
+
+ switch (config.quantization()) {
+ case "FP16" -> fields.createActivationFP16(dim);
+ case "Q8_0" -> fields.createActivationQ8_0(dim);
+ default -> throw new UnsupportedOperationException("Unsupported quantization format: " + config.quantization());
+ }
+
+ fields.wrapX = new FloatArray(dim);
+ fields.wrapXb = new FloatArray(Math.max(dim, qSize));
+ fields.wrapXbFP16 = new HalfFloatArray(Math.max(dim, qSize));
+ fields.wrapXb2 = new FloatArray(dim);
+ fields.wrapHb = new FloatArray(maxFFN);
+ fields.wrapHb2 = new FloatArray(maxFFN);
+ fields.wrapLogits = new FloatArray(config.vocabularySize());
+ fields.wrapQ = new FloatArray(qSize);
+ fields.wrapK = new FloatArray(kvSize);
+ fields.wrapV = new FloatArray(kvSize);
+
+ // Flat GPU KV cache: back-to-back slots only for layers that own a cache (see cacheLayerBaseOffset).
+ int[] gpuCacheLayerBaseOffset = computeCacheLayerBaseOffsets(config);
+ int totalCacheElements = Math.max(1, totalCacheElements(config, gpuCacheLayerBaseOffset));
+ fields.wrapKeyCache = new FloatArray(totalCacheElements);
+ fields.wrapValueCache = new FloatArray(totalCacheElements);
+ fields.wrapValueCache.init(0.f);
+ fields.wrapKeyCache.init(0.f);
+ fields.wrapAtt = new FloatArray(nHead * config.contextLength());
+ fields.positionHolder = new IntArray(1);
+
+ fields.temp = new FloatArray(1 + ((dim + localSize - 1) / localSize));
+ fields.tempFFN = new FloatArray(1 + ((dim + localSize - 1) / localSize));
+ fields.tempLogits = new FloatArray(1 + ((dim + localSize - 1) / localSize));
+
+ return fields;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java
new file mode 100644
index 00000000..599ed0c1
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/standard/Gemma4StandardWeights.java
@@ -0,0 +1,134 @@
+package org.beehive.gpullama3.inference.weights.standard;
+
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.standard.FloatTensor;
+
+/**
+ * Weights for the Gemma 4 architecture in the standard (CPU) format.
+ *
+ * Gemma 4's layer structure differs substantially from the "Llama-like" models that
+ * {@link StandardWeights} models, so this class implements {@link Weights} directly rather
+ * than extending it: every layer carries its own Q/K-norm, a "sandwich" of pre- and
+ * post-normalization around both attention and FFN, a per-layer-embedding (PLE) gate/projection/
+ * norm, and an optional learned output scale. There are also two separate RoPE frequency tables
+ * (sliding-window vs. full/global attention layers use different bases, head dimensions, and —
+ * for full-attention layers — a per-dimension frequency scaling baked in from {@code rope_freqs}).
+ */
+public class Gemma4StandardWeights implements Weights {
+
+ public final FloatTensor tokenEmbeddingTable;
+ public final FloatTensor outputWeight;
+ public final FloatTensor outputNorm;
+
+ // per-layer attention
+ public final FloatTensor[] attnNorm;
+ public final FloatTensor[] wq;
+ public final FloatTensor[] wk;
+ public final FloatTensor[] wv;
+ public final FloatTensor[] wo;
+ public final FloatTensor[] attnQNorm;
+ public final FloatTensor[] attnKNorm;
+ public final FloatTensor[] attnPostNorm; // a.k.a. post_attention_norm
+
+ // per-layer FFN
+ public final FloatTensor[] ffnNorm;
+ public final FloatTensor[] ffnGate;
+ public final FloatTensor[] ffnUp;
+ public final FloatTensor[] ffnDown;
+ public final FloatTensor[] ffnPostNorm; // a.k.a. post_ffw_norm
+
+ // per-layer embedding (PLE)
+ public final FloatTensor[] perLayerInpGate;
+ public final FloatTensor[] perLayerProj;
+ public final FloatTensor[] perLayerPostNorm; // a.k.a. post_norm
+ public final FloatTensor[] layerOutputScale; // optional, may contain nulls
+
+ // shared per-layer-embedding tensors
+
+ /**
+ * The per-layer token embedding table ({@code [embeddingLengthPerLayer * numberOfLayers, vocabularySize]},
+ * ~2.35 billion elements for Gemma-4-E2B). It is kept as a raw {@link GGMLTensorEntry} rather than a
+ * {@link FloatTensor} -- whose int-indexed API would overflow for a tensor this large -- and addressed
+ * one embedding row at a time via {@link org.beehive.gpullama3.model.loader.ModelLoader#copyEmbeddingRow}.
+ */
+ public final GGMLTensorEntry perLayerTokenEmbd;
+ public final FloatTensor perLayerModelProj;
+ public final FloatTensor perLayerProjNorm;
+
+ // RoPE tables: sliding-window (local) layers and full (global) attention layers use different
+ // bases/dimensions; full-attention layers additionally bake in the `rope_freqs` per-dimension scaling.
+ public final FloatTensor freqCisRealSwa;
+ public final FloatTensor freqCisImagSwa;
+ public final FloatTensor freqCisRealFull;
+ public final FloatTensor freqCisImagFull;
+
+ private final GGMLType weightType;
+
+ // @formatter:off
+ public Gemma4StandardWeights(
+ FloatTensor tokenEmbeddingTable,
+ FloatTensor outputWeight,
+ FloatTensor outputNorm,
+ FloatTensor[] attnNorm,
+ FloatTensor[] wq,
+ FloatTensor[] wk,
+ FloatTensor[] wv,
+ FloatTensor[] wo,
+ FloatTensor[] attnQNorm,
+ FloatTensor[] attnKNorm,
+ FloatTensor[] attnPostNorm,
+ FloatTensor[] ffnNorm,
+ FloatTensor[] ffnGate,
+ FloatTensor[] ffnUp,
+ FloatTensor[] ffnDown,
+ FloatTensor[] ffnPostNorm,
+ FloatTensor[] perLayerInpGate,
+ FloatTensor[] perLayerProj,
+ FloatTensor[] perLayerPostNorm,
+ FloatTensor[] layerOutputScale,
+ GGMLTensorEntry perLayerTokenEmbd,
+ FloatTensor perLayerModelProj,
+ FloatTensor perLayerProjNorm,
+ FloatTensor freqCisRealSwa,
+ FloatTensor freqCisImagSwa,
+ FloatTensor freqCisRealFull,
+ FloatTensor freqCisImagFull,
+ GGMLType weightType) {
+ this.tokenEmbeddingTable = tokenEmbeddingTable;
+ this.outputWeight = outputWeight;
+ this.outputNorm = outputNorm;
+ this.attnNorm = attnNorm;
+ this.wq = wq;
+ this.wk = wk;
+ this.wv = wv;
+ this.wo = wo;
+ this.attnQNorm = attnQNorm;
+ this.attnKNorm = attnKNorm;
+ this.attnPostNorm = attnPostNorm;
+ this.ffnNorm = ffnNorm;
+ this.ffnGate = ffnGate;
+ this.ffnUp = ffnUp;
+ this.ffnDown = ffnDown;
+ this.ffnPostNorm = ffnPostNorm;
+ this.perLayerInpGate = perLayerInpGate;
+ this.perLayerProj = perLayerProj;
+ this.perLayerPostNorm = perLayerPostNorm;
+ this.layerOutputScale = layerOutputScale;
+ this.perLayerTokenEmbd = perLayerTokenEmbd;
+ this.perLayerModelProj = perLayerModelProj;
+ this.perLayerProjNorm = perLayerProjNorm;
+ this.freqCisRealSwa = freqCisRealSwa;
+ this.freqCisImagSwa = freqCisImagSwa;
+ this.freqCisRealFull = freqCisRealFull;
+ this.freqCisImagFull = freqCisImagFull;
+ this.weightType = weightType;
+ }
+ // @formatter:on
+
+ @Override
+ public GGMLType getWeightType() {
+ return weightType;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java
new file mode 100644
index 00000000..9da6eb4c
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Gemma4TornadoWeights.java
@@ -0,0 +1,108 @@
+package org.beehive.gpullama3.inference.weights.tornado;
+
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.tornado.TornadoTensor;
+
+/**
+ * TornadoVM (GPU) weights for the Gemma 4 architecture.
+ *
+ * Extends {@link TornadoWeights} (rather than implementing {@link org.beehive.gpullama3.inference.weights.Weights}
+ * directly, as its CPU counterpart {@link org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights}
+ * does) because the shared {@code AbstractLogitsLayer}/{@code LogitsFP16Layer} GPU infrastructure requires
+ * a {@link TornadoWeights}. The base class's "Llama-like" fields are reused for the closest equivalents
+ * (e.g. {@code rms_att_weightLayered} → {@code attnNorm}, {@code w1Layered}/{@code w3Layered} →
+ * {@code ffnGate}/{@code ffnUp}); every other Gemma 4-specific tensor (sandwich norms, Q/K-norm,
+ * per-layer-embedding (PLE) gate/proj/norm, dual RoPE tables, optional layer-output scale) is added here.
+ *
+ * Note: {@code per_layer_token_embd} is intentionally not present here -- at ~2.35 billion
+ * elements it is far too large to keep resident on the GPU. Its per-token row is instead gathered on
+ * the host (see {@link org.beehive.gpullama3.model.loader.ModelLoader#copyEmbeddingRow}) and streamed
+ * to the GPU each step via {@link org.beehive.gpullama3.inference.state.Gemma4State#wrapPerLayerTokenEmbedRow}.
+ */
+public class Gemma4TornadoWeights extends TornadoWeights {
+
+ // Gemma4-specific per-layer attention tensors (sandwich norm + Q/K-norm)
+ public final TornadoTensor[] attnQNorm;
+ public final TornadoTensor[] attnKNorm;
+ public final TornadoTensor[] attnPostNorm;
+
+ // Gemma4-specific per-layer FFN tensor (sandwich norm)
+ public final TornadoTensor[] ffnPostNorm;
+
+ // per-layer embedding (PLE)
+ public final TornadoTensor[] perLayerInpGate;
+ public final TornadoTensor[] perLayerProj;
+ public final TornadoTensor[] perLayerPostNorm;
+ public final TornadoTensor[] layerOutputScale; // optional, may contain nulls
+
+ // shared per-layer-embedding tensors
+
+ /**
+ * The per-layer token embedding table ({@code [embeddingLengthPerLayer * numberOfLayers, vocabularySize]},
+ * ~2.35 billion elements for Gemma-4-E2B). Far too large to keep resident on the GPU, so it is kept
+ * as a raw {@link GGMLTensorEntry} and addressed one row at a time on the host -- via
+ * {@link org.beehive.gpullama3.model.loader.ModelLoader#copyEmbeddingRow} -- with the resulting
+ * row streamed to the GPU each step (see {@link org.beehive.gpullama3.inference.state.Gemma4State#wrapPerLayerTokenEmbedRow}).
+ */
+ public final GGMLTensorEntry perLayerTokenEmbd;
+ public final TornadoTensor perLayerModelProj;
+ public final TornadoTensor perLayerProjNorm;
+
+ // RoPE tables: sliding-window (local) layers and full (global) attention layers use different bases/dimensions.
+ public final TornadoTensor freqCisRealSwa;
+ public final TornadoTensor freqCisImagSwa;
+ public final TornadoTensor freqCisRealFull;
+ public final TornadoTensor freqCisImagFull;
+
+ // @formatter:off
+ public Gemma4TornadoWeights(
+ TornadoTensor tokenEmbeddingTable,
+ TornadoTensor[] attnNorm,
+ TornadoTensor[] wq,
+ TornadoTensor[] wk,
+ TornadoTensor[] wv,
+ TornadoTensor[] wo,
+ TornadoTensor[] attnQNorm,
+ TornadoTensor[] attnKNorm,
+ TornadoTensor[] attnPostNorm,
+ TornadoTensor[] ffnNorm,
+ TornadoTensor[] ffnGate,
+ TornadoTensor[] ffnUp,
+ TornadoTensor[] ffnDown,
+ TornadoTensor[] ffnPostNorm,
+ TornadoTensor[] perLayerInpGate,
+ TornadoTensor[] perLayerProj,
+ TornadoTensor[] perLayerPostNorm,
+ TornadoTensor[] layerOutputScale,
+ GGMLTensorEntry perLayerTokenEmbd,
+ TornadoTensor perLayerModelProj,
+ TornadoTensor perLayerProjNorm,
+ TornadoTensor outputNorm,
+ TornadoTensor freqCisRealSwa,
+ TornadoTensor freqCisImagSwa,
+ TornadoTensor freqCisRealFull,
+ TornadoTensor freqCisImagFull,
+ TornadoTensor outputWeight,
+ GGMLType weightType) {
+ super(tokenEmbeddingTable, attnNorm, wq, wk, wv, wo,
+ ffnNorm, ffnGate, ffnDown, ffnUp, outputNorm,
+ freqCisRealFull, freqCisImagFull, outputWeight, weightType);
+ this.attnQNorm = attnQNorm;
+ this.attnKNorm = attnKNorm;
+ this.attnPostNorm = attnPostNorm;
+ this.ffnPostNorm = ffnPostNorm;
+ this.perLayerInpGate = perLayerInpGate;
+ this.perLayerProj = perLayerProj;
+ this.perLayerPostNorm = perLayerPostNorm;
+ this.layerOutputScale = layerOutputScale;
+ this.perLayerTokenEmbd = perLayerTokenEmbd;
+ this.perLayerModelProj = perLayerModelProj;
+ this.perLayerProjNorm = perLayerProjNorm;
+ this.freqCisRealSwa = freqCisRealSwa;
+ this.freqCisImagSwa = freqCisImagSwa;
+ this.freqCisRealFull = freqCisRealFull;
+ this.freqCisImagFull = freqCisImagFull;
+ }
+ // @formatter:on
+}
diff --git a/src/main/java/org/beehive/gpullama3/model/ModelType.java b/src/main/java/org/beehive/gpullama3/model/ModelType.java
index 0659da7d..fd85f4e5 100644
--- a/src/main/java/org/beehive/gpullama3/model/ModelType.java
+++ b/src/main/java/org/beehive/gpullama3/model/ModelType.java
@@ -1,6 +1,7 @@
package org.beehive.gpullama3.model;
import org.beehive.gpullama3.model.loader.DevstralModelLoader;
+import org.beehive.gpullama3.model.loader.Gemma4ModelLoader;
import org.beehive.gpullama3.model.loader.GraniteLoader;
import org.beehive.gpullama3.tensor.GGUF;
import org.beehive.gpullama3.model.loader.LlamaModelLoader;
@@ -80,6 +81,13 @@ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, bo
}
},
+ GEMMA_4 {
+ @Override
+ public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
+ return new Gemma4ModelLoader(fileChannel, gguf, contextLength, useTornadovm).loadModel();
+ }
+ },
+
UNKNOWN {
@Override
public Model loadModel(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
diff --git a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java
index 827ad625..218d9945 100644
--- a/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java
+++ b/src/main/java/org/beehive/gpullama3/model/format/ChatFormat.java
@@ -1,6 +1,7 @@
package org.beehive.gpullama3.model.format;
import org.beehive.gpullama3.tokenizer.DevstralTokenizer;
+import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer;
import org.beehive.gpullama3.tokenizer.GraniteTokenizer;
import org.beehive.gpullama3.tokenizer.LlamaTokenizer;
import org.beehive.gpullama3.tokenizer.MistralTokenizer;
@@ -15,6 +16,7 @@ public interface ChatFormat {
static ChatFormat create(Object tokenizer, ChatTokens chatTokens) {
return switch (tokenizer) {
case DevstralTokenizer devstralTokenizer -> new DevstralChatFormat(devstralTokenizer);
+ case Gemma4Tokenizer gemma4Tokenizer -> new Gemma4ChatFormat(gemma4Tokenizer);
case GraniteTokenizer graniteTokenizer -> new GraniteChatFormat(graniteTokenizer);
case LlamaTokenizer llamaTokenizer -> new LlamaChatFormat(llamaTokenizer);
case MistralTokenizer mistralTokenizer -> new MistralChatFormat(mistralTokenizer);
diff --git a/src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java b/src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java
new file mode 100644
index 00000000..e8be80dc
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/model/format/Gemma4ChatFormat.java
@@ -0,0 +1,61 @@
+package org.beehive.gpullama3.model.format;
+
+import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Chat format for Gemma 4 models.
+ *
+ * Gemma 4 uses a {@code <|turn>{role}\n ... } turn structure (the assistant role is
+ * spelled "model" in the template), starts conversations with {@code }, and stops
+ * generation on {@code } (the model's configured EOS token).
+ */
+public class Gemma4ChatFormat implements ChatFormat {
+
+ protected final Gemma4Tokenizer tokenizer;
+ protected final int beginOfText;
+ protected final int startTurn;
+ protected final int endTurn;
+
+ public Gemma4ChatFormat(Gemma4Tokenizer tokenizer) {
+ this.tokenizer = tokenizer;
+ Map specialTokens = tokenizer.getSpecialTokens();
+ this.beginOfText = specialTokens.getOrDefault("", -1);
+ this.startTurn = specialTokens.getOrDefault("<|turn>", -1);
+ this.endTurn = specialTokens.getOrDefault("", -1);
+ }
+
+ @Override
+ public List encodeHeader(Message message) {
+ List tokens = new ArrayList<>();
+ tokens.add(startTurn);
+ // The chat template spells the assistant role "model".
+ String role = Role.ASSISTANT.equals(message.role()) ? "model" : message.role().name();
+ tokens.addAll(tokenizer.encodeAsList(role));
+ tokens.addAll(tokenizer.encodeAsList("\n"));
+ return tokens;
+ }
+
+ @Override
+ public List encodeMessage(Message message) {
+ List tokens = encodeHeader(message);
+ tokens.addAll(tokenizer.encodeAsList(message.content().strip()));
+ tokens.add(endTurn);
+ tokens.addAll(tokenizer.encodeAsList("\n"));
+ return tokens;
+ }
+
+ @Override
+ public int getBeginOfText() {
+ return beginOfText;
+ }
+
+ @Override
+ public Set getStopTokens() {
+ return Set.of(endTurn);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java
new file mode 100644
index 00000000..24fd42e7
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4.java
@@ -0,0 +1,96 @@
+package org.beehive.gpullama3.model.gemma4;
+
+import org.beehive.gpullama3.inference.InferenceCore;
+import org.beehive.gpullama3.inference.InferenceEngine;
+import org.beehive.gpullama3.inference.sampler.Sampler;
+import org.beehive.gpullama3.inference.state.Gemma4State;
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights;
+import org.beehive.gpullama3.model.AbstractModel;
+import org.beehive.gpullama3.model.ModelType;
+import org.beehive.gpullama3.model.format.ChatFormat;
+import org.beehive.gpullama3.model.loader.ModelLoader;
+import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+
+import java.util.List;
+import java.util.Set;
+import java.util.function.IntConsumer;
+
+public class Gemma4 extends AbstractModel {
+
+ Gemma4Configuration configuration;
+
+ public Gemma4(Gemma4Configuration configuration, Tokenizer tokenizer, Weights weights, ChatFormat chatFormat) {
+ super(tokenizer, weights, chatFormat, null);
+ this.configuration = configuration;
+ }
+
+ @Override
+ public Gemma4Configuration configuration() {
+ return configuration;
+ }
+
+ @Override
+ public ModelType getModelType() {
+ return ModelType.GEMMA_4;
+ }
+
+ @Override
+ public Gemma4Tokenizer tokenizer() {
+ return (Gemma4Tokenizer) tokenizer;
+ }
+
+ @Override
+ public State createNewState() {
+ State state = new Gemma4State(configuration(), -1);
+ state.latestToken = chatFormat.getBeginOfText();
+ return state;
+ }
+
+ @Override
+ public State createNewState(int batchsize) {
+ State state = new Gemma4State(configuration(), batchsize);
+ state.latestToken = chatFormat.getBeginOfText();
+ return state;
+ }
+
+ @Override
+ public void forward(State state, int token, int position) {
+ if (plan == null) {
+ InferenceCore.forwardJavaGemma4(this, state, token, position);
+ } else {
+ gatherPerLayerTokenEmbeddingRow((Gemma4State) state, token);
+ InferenceCore.forwardTornadoVM(this, state, token, position, tornadoVMPlan());
+ }
+ }
+
+ /**
+ * Gathers the current token's row out of {@code per_layer_token_embd} (~2.35 billion elements --
+ * far too large to keep resident on the GPU, see {@link Gemma4TornadoWeights#perLayerTokenEmbd})
+ * directly into {@link Gemma4State#wrapPerLayerTokenEmbedRow}, pre-scaled by {@code sqrt(embeddingLengthPerLayer)}
+ * (mirroring step 2 of {@link InferenceCore#forwardJavaGemma4}), ready for transfer to the GPU as
+ * part of layer 0's per-layer-embedding setup.
+ */
+ private void gatherPerLayerTokenEmbeddingRow(Gemma4State state, int token) {
+ Gemma4TornadoWeights gemma4Weights = (Gemma4TornadoWeights) weights;
+ int nEmbdPerLayer = configuration.embeddingLengthPerLayer();
+ int perLayerTotal = configuration.numberOfLayers() * nEmbdPerLayer;
+ float scale = (float) Math.sqrt(nEmbdPerLayer);
+ ModelLoader.copyEmbeddingRowToFloatArray(gemma4Weights.perLayerTokenEmbd, token, perLayerTotal, state.wrapPerLayerTokenEmbedRow, scale);
+ }
+
+ @Override
+ public List generateTokens(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
+ IntConsumer onTokenGenerated) {
+ return InferenceEngine.generateTokensQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated);
+ }
+
+ @Override
+ public List generateTokensGPU(State state, int startPosition, List promptTokens, Set stopTokens, int maxTokens, Sampler sampler, boolean echo,
+ IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
+ return InferenceEngine.generateTokensGPUQwen3(this, state, startPosition, promptTokens, stopTokens, maxTokens, sampler, echo, onTokenGenerated, tornadoVMPlan);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java
new file mode 100644
index 00000000..dd06049a
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/model/gemma4/Gemma4Configuration.java
@@ -0,0 +1,116 @@
+package org.beehive.gpullama3.model.gemma4;
+
+import org.beehive.gpullama3.model.Configuration;
+
+/**
+ * Configuration for the Gemma 4 architecture (e.g. Gemma-4-E2B-It).
+ *
+ * Gemma 4 alternates sliding-window and full (global) attention layers, each with their
+ * own head dimensions, RoPE base/scaling, and a subset of layers reusing the KV cache produced
+ * by an earlier layer ("shared KV layers"). It also augments every layer with a per-layer
+ * embedding (PLE) mechanism and applies a final logit soft-cap.
+ */
+// @formatter:off
+public record Gemma4Configuration(String quantization,
+ int dim,
+ int numberOfLayers,
+ int numberOfHeads,
+ int numberOfKeyValueHeads,
+ int headDimSwa,
+ int headDimFull,
+ int[] feedForwardLength,
+ boolean[] slidingWindowPattern,
+ int slidingWindowSize,
+ int sharedKvLayers,
+ int embeddingLengthPerLayer,
+ int vocabularySize,
+ int contextLengthModel,
+ int contextLength,
+ float rmsNormEps,
+ float ropeTheta,
+ float ropeThetaSwa,
+ float finalLogitSoftcapping) implements Configuration {
+
+ @Override
+ public String quantization() {
+ return quantization;
+ }
+
+ @Override
+ public int hiddenDim() {
+ throw new UnsupportedOperationException("Gemma4 has per-layer feed-forward dimensions; use feedForwardLength(layer).");
+ }
+
+ @Override
+ public int numberOfHeadsKey() {
+ throw new UnsupportedOperationException("Gemma4 has per-layer head dimensions; use headDim(layer).");
+ }
+
+ @Override
+ public int headSize() {
+ throw new UnsupportedOperationException("Gemma4 has per-layer head dimensions; use headDim(layer).");
+ }
+
+ @Override
+ public int kvDim() {
+ throw new UnsupportedOperationException("Gemma4 has per-layer head dimensions; use headDim(layer) * numberOfKeyValueHeads().");
+ }
+
+ @Override
+ public int kvMul() {
+ return numberOfHeads / numberOfKeyValueHeads;
+ }
+
+ @Override
+ public int contextLengthModel() {
+ return contextLengthModel;
+ }
+
+ /** Returns the feed-forward (FFN hidden) dimension for the given layer. */
+ public int feedForwardLength(int layer) {
+ return feedForwardLength[layer];
+ }
+
+ /** Whether the given layer uses sliding-window (local) attention as opposed to full (global) attention. */
+ public boolean isSwa(int layer) {
+ return slidingWindowPattern[layer];
+ }
+
+ /** Returns the attention head dimension for the given layer (depends on whether it is a sliding-window or full layer). */
+ public int headDim(int layer) {
+ return isSwa(layer) ? headDimSwa : headDimFull;
+ }
+
+ /** The maximum head dimension across all layers; used to size shared scratch buffers. */
+ public int maxHeadDim() {
+ return Math.max(headDimSwa, headDimFull);
+ }
+
+ /** The maximum feed-forward dimension across all layers; used to size shared scratch buffers. */
+ public int maxFeedForwardLength() {
+ int max = 0;
+ for (int ff : feedForwardLength) {
+ max = Math.max(max, ff);
+ }
+ return max;
+ }
+
+ /** Number of (initial) layers that own and populate their own KV cache; later layers reuse one of these. */
+ public int nLayerKvFromStart() {
+ return numberOfLayers - sharedKvLayers;
+ }
+
+ /** Whether the given layer computes and stores its own K/V (as opposed to reusing an earlier layer's KV cache). */
+ public boolean hasOwnKv(int layer) {
+ return layer < nLayerKvFromStart();
+ }
+
+ /** Returns the index of the layer whose KV cache this layer reuses, or -1 if this layer owns its KV cache. */
+ public int kvReuseLayer(int layer) {
+ if (hasOwnKv(layer)) {
+ return -1;
+ }
+ return nLayerKvFromStart() - (isSwa(layer) ? 2 : 1);
+ }
+}
+// @formatter:on
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java
index 9bbefcad..e8d18891 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java
@@ -41,6 +41,7 @@ protected String getModelQuantization(Map metadata) {
int modelQuantizationAsInt = (int) metadata.get("general.file_type");
return switch (modelQuantizationAsInt) {
case 1 -> "FP16";
+ case 32 -> "FP16"; // MOSTLY_BF16 (treated like FP16 for activation buffers)
case 7 -> "Q8_0";
case 14, 15 -> "Q8_0"; // Q4_K_S, Q4_K_M (K-quants use Q8_0 activations)
case 16, 17 -> "Q8_0"; // Q5_K_S, Q5_K_M
@@ -56,6 +57,7 @@ protected String getModelQuantization(Map metadata) {
protected static GGMLType effectiveGpuWeightType(GGMLType ggmlType) {
return switch (ggmlType) {
case F16, F32, Q8_0 -> ggmlType;
+ case BF16 -> GGMLType.F16; // widened to FP16 at load time; see ModelLoader#loadTornadoTensor
case Q4_K, Q5_K, Q6_K -> GGMLType.Q8_0;
default -> ggmlType;
};
@@ -65,6 +67,7 @@ private static String fileTypeName(int fileType) {
return switch (fileType) {
case 0 -> "F32";
case 1 -> "F16";
+ case 32 -> "BF16";
case 7 -> "Q8_0";
case 14 -> "Q4_K_S";
case 15 -> "Q4_K_M";
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java
new file mode 100644
index 00000000..63af9b7a
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Gemma4ModelLoader.java
@@ -0,0 +1,299 @@
+package org.beehive.gpullama3.model.loader;
+
+import org.beehive.gpullama3.auxiliary.Pair;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights;
+import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights;
+import org.beehive.gpullama3.model.format.ChatFormat;
+import org.beehive.gpullama3.model.gemma4.Gemma4;
+import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
+import org.beehive.gpullama3.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.tensor.GGMLType;
+import org.beehive.gpullama3.tensor.GGUF;
+import org.beehive.gpullama3.tensor.GGUF.GGUFTensorInfo;
+import org.beehive.gpullama3.tensor.standard.ArrayFloatTensor;
+import org.beehive.gpullama3.tensor.tornado.FP32TornadoTensor;
+import org.beehive.gpullama3.tensor.tornado.TornadoTensor;
+import org.beehive.gpullama3.tokenizer.Gemma4Tokenizer;
+import org.beehive.gpullama3.tokenizer.Tokenizer;
+import org.beehive.gpullama3.tokenizer.Vocabulary;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.channels.FileChannel;
+import java.util.Map;
+import java.util.function.IntFunction;
+
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTensors;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayOfTornadoTensors;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensor;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadTornadoTensor;
+
+/**
+ * Loader for Gemma 4 models (e.g. Gemma-4-E2B-It).
+ *
+ * Gemma 4 needs two distinct precomputed RoPE tables (sliding-window vs. full/global attention
+ * layers use different bases and head dimensions, and full-attention layers additionally apply a
+ * per-dimension frequency scaling stored in the {@code rope_freqs} tensor), so RoPE frequencies are
+ * computed directly here -- where the tensor entries are available -- rather than through the
+ * generic {@link #precomputeRopeFrequencies} hook.
+ */
+public class Gemma4ModelLoader extends AbstractModelLoader {
+
+ public Gemma4ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean useTornadovm) {
+ super(fileChannel, gguf, contextLength, useTornadovm);
+ }
+
+ @Override
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return Vocabulary.loadGemma4Vocabulary(metadata);
+ }
+
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ return new Gemma4Tokenizer(metadata, vocabulary);
+ }
+
+ // @formatter:off
+ @Override
+ protected Gemma4Configuration createConfiguration(Map metadata) {
+ int modelContextLength = (int) metadata.get("gemma4.context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+ int numberOfLayers = (int) metadata.get("gemma4.block_count");
+
+ return new Gemma4Configuration(
+ getModelQuantization(metadata),
+ (int) metadata.get("gemma4.embedding_length"),
+ numberOfLayers,
+ (int) metadata.get("gemma4.attention.head_count"),
+ (int) metadata.get("gemma4.attention.head_count_kv"),
+ (int) metadata.get("gemma4.attention.key_length_swa"),
+ (int) metadata.get("gemma4.attention.key_length"),
+ (int[]) metadata.get("gemma4.feed_forward_length"),
+ (boolean[]) metadata.get("gemma4.attention.sliding_window_pattern"),
+ (int) metadata.get("gemma4.attention.sliding_window"),
+ (int) metadata.get("gemma4.attention.shared_kv_layers"),
+ (int) metadata.get("gemma4.embedding_length_per_layer_input"),
+ vocabulary.size(),
+ modelContextLength,
+ finalContextLength,
+ (float) metadata.get("gemma4.attention.layer_norm_rms_epsilon"),
+ (float) metadata.get("gemma4.rope.freq_base"),
+ (float) metadata.get("gemma4.rope.freq_base_swa"),
+ (float) metadata.get("gemma4.final_logit_softcapping")
+ );
+ }
+ // @formatter:on
+
+ /** Gemma4 needs two RoPE tables computed with tensor data (rope_freqs); see {@link #ropeTables}. */
+ @Override
+ protected Pair precomputeRopeFrequencies(Gemma4Configuration config) {
+ return null;
+ }
+
+ @Override
+ protected Gemma4 createModel(Gemma4Configuration config, Tokenizer tokenizer, Weights weights) {
+ return new Gemma4(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
+ }
+
+ // @formatter:off
+ @Override
+ protected Weights createStandardWeights(Map tensorEntries, Gemma4Configuration config, Pair ropeFreqs,
+ GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
+ final int nl = config.numberOfLayers();
+ RopeTables ropeTables = computeRopeTables(tensorEntries, config);
+
+ return new Gemma4StandardWeights(
+ loadTensor(tokenEmbeddings),
+ tensorEntries.containsKey("output.weight") ? loadTensor(tensorEntries.get("output.weight")) : loadTensor(tokenEmbeddings),
+ loadTensor(tensorEntries.get("output_norm.weight")),
+
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".post_attention_norm.weight")),
+
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".post_ffw_norm.weight")),
+
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".inp_gate.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".proj.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".post_norm.weight")),
+ loadArrayOfTensors(nl, i -> tensorEntries.get("blk." + i + ".layer_output_scale.weight")),
+
+ tensorEntries.get("per_layer_token_embd.weight"),
+ loadTensor(tensorEntries.get("per_layer_model_proj.weight")),
+ loadTensor(tensorEntries.get("per_layer_proj_norm.weight")),
+
+ new ArrayFloatTensor(ropeTables.realSwa),
+ new ArrayFloatTensor(ropeTables.imagSwa),
+ new ArrayFloatTensor(ropeTables.realFull),
+ new ArrayFloatTensor(ropeTables.imagFull),
+
+ null
+ );
+ }
+ // @formatter:on
+
+ // @formatter:off
+ @Override
+ protected Weights createTornadoVMWeights(Map tensorEntries, Gemma4Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ final int nl = config.numberOfLayers();
+ GGMLType ggmlType = effectiveGpuWeightType(outputWeight.ggmlType());
+ RopeTables ropeTables = computeRopeTables(tensorEntries, config);
+
+ return new Gemma4TornadoWeights(
+ loadTornadoTensor(tokenEmbeddings),
+
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".post_attention_norm.weight")),
+
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".post_ffw_norm.weight")),
+
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".inp_gate.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".proj.weight")),
+ loadArrayOfTornadoTensors(nl, i -> tensorEntries.get("blk." + i + ".post_norm.weight")),
+ loadArrayOfTornadoTensorsNullable(nl, i -> tensorEntries.get("blk." + i + ".layer_output_scale.weight")),
+
+ stripTornadoArrayHeader(tensorEntries.get("per_layer_token_embd.weight")),
+ loadTornadoTensor(tensorEntries.get("per_layer_model_proj.weight")),
+ loadTornadoTensor(tensorEntries.get("per_layer_proj_norm.weight")),
+ loadTornadoTensor(tensorEntries.get("output_norm.weight")),
+
+ new FP32TornadoTensor(FloatArray.fromArray(ropeTables.realSwa)),
+ new FP32TornadoTensor(FloatArray.fromArray(ropeTables.imagSwa)),
+ new FP32TornadoTensor(FloatArray.fromArray(ropeTables.realFull)),
+ new FP32TornadoTensor(FloatArray.fromArray(ropeTables.imagFull)),
+
+ tensorEntries.containsKey("output.weight") ? loadTornadoTensor(tensorEntries.get("output.weight")) : loadTornadoTensor(tokenEmbeddings),
+ ggmlType
+ );
+ }
+ // @formatter:on
+
+ /**
+ * Tensor entries produced by {@link GGUF#loadTensorsTornado} prefix every {@code memorySegment()} with a
+ * 16-byte {@link TornadoNativeArray#ARRAY_HEADER} (so the data can be wrapped as a TornadoVM native array
+ * without copying) -- but {@code per_layer_token_embd} is kept as a raw entry and addressed with
+ * byte-offset arithmetic that assumes the segment starts at the tensor's actual data (see
+ * {@link ModelLoader#copyEmbeddingRowToFloatArray}, mirroring the CPU path's {@link ModelLoader#copyEmbeddingRow}
+ * over a {@link GGUF#loadTensorsStandard}-produced entry, which has no such header). Slice past the
+ * header here so both code paths see the same layout.
+ */
+ private static GGMLTensorEntry stripTornadoArrayHeader(GGMLTensorEntry entry) {
+ long headerBytes = TornadoNativeArray.ARRAY_HEADER;
+ return new GGMLTensorEntry(entry.mappedFile(), entry.name(), entry.ggmlType(), entry.shape(), entry.memorySegment().asSlice(headerBytes));
+ }
+
+ /** Like {@link ModelLoader#loadArrayOfTornadoTensors}, but tolerates missing entries (Gemma4's optional per-layer output scale). */
+ private static TornadoTensor[] loadArrayOfTornadoTensorsNullable(int size, IntFunction getTensorEntry) {
+ TornadoTensor[] array = new TornadoTensor[size];
+ for (int i = 0; i < size; i++) {
+ GGMLTensorEntry entry = getTensorEntry.apply(i);
+ array[i] = (entry == null) ? null : loadTornadoTensor(entry);
+ }
+ return array;
+ }
+
+ private record RopeTables(float[] realSwa, float[] imagSwa, float[] realFull, float[] imagFull) {
+ }
+
+ /**
+ * Computes the two RoPE frequency tables Gemma4 needs.
+ *
+ * Sliding-window layers use {@code rope_theta_swa} with {@code headDimSwa} and no extra scaling.
+ * Full/global-attention layers use {@code rope_theta} with {@code headDimFull}, additionally
+ * dividing each rotation angle by the corresponding entry of the (single, shared) {@code
+ * rope_freqs} tensor -- this is how the GGUF encodes "partial RoPE" (entries are 1.0 for the
+ * active low-frequency dimensions and effectively infinite for the inactive ones, which zeroes
+ * out their rotation).
+ */
+ private RopeTables computeRopeTables(Map tensorEntries, Gemma4Configuration config) {
+ Pair swa = precomputeFreqsCisWithFactors(config.contextLengthModel(), config.headDimSwa(), config.ropeThetaSwa(), null);
+
+ // rope_freqs.weight is intentionally excluded from tensorEntries by GGUF.loadTensorsStandard/
+ // loadTensorsTornado (it isn't needed by most architectures), so read it directly here.
+ float[] freqFactors = readFloat32TensorDirect("rope_freqs.weight");
+ Pair full = precomputeFreqsCisWithFactors(config.contextLengthModel(), config.headDimFull(), config.ropeTheta(), freqFactors);
+
+ return new RopeTables(swa.first(), swa.second(), full.first(), full.second());
+ }
+
+ /** Reads a small F32 tensor's raw data directly from the GGUF file, bypassing the {@code tensorEntries} map. */
+ private float[] readFloat32TensorDirect(String tensorName) {
+ GGUFTensorInfo info = gguf.getTensorInfos().get(tensorName);
+ if (info == null) {
+ return null;
+ }
+ if (info.ggmlType() != GGMLType.F32) {
+ throw new UnsupportedOperationException("Expected F32 tensor for " + tensorName + ", got " + info.ggmlType());
+ }
+
+ int numberOfElements = 1;
+ for (int dimension : info.dimensions()) {
+ numberOfElements *= dimension;
+ }
+
+ long byteOffset = gguf.getTensorDataOffset() + info.offset();
+ ByteBuffer buffer = ByteBuffer.allocate(numberOfElements * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
+ try {
+ while (buffer.hasRemaining()) {
+ if (gguf.getFileChannel().read(buffer, byteOffset + buffer.position()) < 0) {
+ throw new EOFException("Unexpected end of file while reading " + tensorName);
+ }
+ }
+ } catch (IOException e) {
+ throw new ModelLoadException("Failed to read " + tensorName + " from GGUF file", e);
+ }
+ buffer.flip();
+
+ float[] result = new float[numberOfElements];
+ buffer.asFloatBuffer().get(result);
+ return result;
+ }
+
+ /** Like {@link org.beehive.gpullama3.inference.operation.RoPE#precomputeFreqsCis}, but allows dividing each pair's frequency by a per-dimension scaling factor (NeoX-style RoPE). */
+ private static Pair precomputeFreqsCisWithFactors(int contextLength, int headSize, double theta, float[] freqFactors) {
+ assert headSize % 2 == 0;
+ float[] cr = new float[contextLength * (headSize / 2)];
+ float[] ci = new float[contextLength * (headSize / 2)];
+ int n = 0;
+ for (int pos = 0; pos < contextLength; ++pos) {
+ for (int i = 0; i < headSize; i += 2) {
+ int pairIndex = i / 2;
+ float freq = (float) (1.0 / Math.pow(theta, i / (double) headSize));
+ if (freqFactors != null) {
+ freq = freq / freqFactors[pairIndex];
+ }
+ float val = pos * freq;
+ cr[n] = (float) Math.cos(val);
+ ci[n] = (float) Math.sin(val);
+ n++;
+ }
+ }
+ assert contextLength * (headSize / 2) == n;
+ return new Pair<>(cr, ci);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
index 353aea91..b2889785 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
@@ -52,6 +52,8 @@ private static ModelType detectModelType(Map metadata) {
String lowerName = name.toLowerCase();
if (lowerName.contains("granite")) {
return ModelType.GRANITE;
+ } else if (lowerName.contains("gemma-4") || lowerName.contains("gemma 4")) {
+ return ModelType.GEMMA_4;
} else if (lowerName.contains("devstral")) {
return ModelType.DEVSTRAL_2;
} else if (lowerName.contains("mistral")) {
@@ -73,6 +75,9 @@ private static ModelType detectModelType(Map metadata) {
if (metadata.containsKey("granite.block_count")) {
return ModelType.GRANITE;
}
+ if ("gemma4".equals(metadata.get("general.architecture")) || metadata.containsKey("gemma4.block_count")) {
+ return ModelType.GEMMA_4;
+ }
return ModelType.UNKNOWN;
}
@@ -127,6 +132,7 @@ public static FloatTensor loadTensor(GGMLTensorEntry entry) {
case Q5_K -> new Q5_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case Q6_K -> new Q6_KFloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
case F16 -> new FP16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
+ case BF16 -> new BF16FloatTensor(FloatTensor.numberOfElements(entry.shape()), entry.memorySegment());
default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
};
}
@@ -153,6 +159,7 @@ public static TornadoTensor loadTornadoTensor(GGMLTensorEntry entry) {
return switch (ggmlType) {
case F32 -> FP32TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
case F16 -> FP16TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
+ case BF16 -> convertBF16ToFP16TornadoTensor(entry);
case Q8_0 -> Q8_0TornadoTensor.fromTornadoMemorySegment(entry.memorySegment());
case Q4_K, Q5_K, Q6_K -> dequantizeToQ8_0TornadoTensor(entry);
case Q4_0 -> throw new UnsupportedOperationException("Q4_0 format not supported for TornadoVM yet");
@@ -217,6 +224,31 @@ private static Q8_0TornadoTensor dequantizeToQ8_0TornadoTensor(GGMLTensorEntry e
return Q8_0TornadoTensor.fromTornadoMemorySegment(nativeSegment);
}
+ /**
+ * Converts a BF16 tensor to an FP16 {@link FP16TornadoTensor} for TornadoVM/GPU execution.
+ * TornadoVM has no native BF16 kernel support, so weights are widened to FP32 (a lossless,
+ * simple bit-shift for BF16) and narrowed to IEEE FP16 at load time -- the same representation
+ * the existing FP16 GPU kernels already expect (see {@link #loadTornadoTensor}).
+ */
+ private static FP16TornadoTensor convertBF16ToFP16TornadoTensor(GGMLTensorEntry entry) {
+ long headerBytes = TornadoNativeArray.ARRAY_HEADER;
+ GGMLTensorEntry dataEntry = new GGMLTensorEntry(
+ entry.mappedFile(), entry.name(), entry.ggmlType(), entry.shape(),
+ entry.memorySegment().asSlice(headerBytes));
+ FloatTensor source = loadTensor(dataEntry);
+ int numElements = source.size();
+
+ MemorySegment nativeSegment = Arena.ofAuto().allocate(headerBytes + (long) numElements * Short.BYTES, 4);
+ for (long i = 0; i < headerBytes; i++) {
+ nativeSegment.set(ValueLayout.JAVA_BYTE, i, (byte) 0);
+ }
+ for (int i = 0; i < numElements; i++) {
+ short f16Bits = Float.floatToFloat16(source.getFloat(i));
+ nativeSegment.set(ValueLayout.JAVA_SHORT_UNALIGNED, headerBytes + (long) i * Short.BYTES, f16Bits);
+ }
+ return FP16TornadoTensor.fromTornadoMemorySegment(nativeSegment);
+ }
+
/**
* Dispatcher method for loading a TornadoVM tensor array based on type.
* Used in GPU-path.
@@ -229,6 +261,62 @@ public static TornadoTensor[] loadArrayOfTornadoTensors(int size, IntFunctionSome tensors (e.g. Gemma4's {@code per_layer_token_embd}, with ~2.35 billion elements) exceed
+ * {@link Integer#MAX_VALUE} elements/bytes, which would overflow the int-based
+ * {@link FloatTensor#numberOfElements} / {@link GGMLType#byteSizeFor} used to wrap a tensor entry in
+ * a {@link FloatTensor}. Such tensors are kept as raw {@link GGMLTensorEntry}s and addressed here with
+ * {@code long} byte offsets instead -- since only single-row (embedding lookup) access is needed.
+ */
+ public static void copyEmbeddingRow(GGMLTensorEntry entry, long rowIndex, int rowSize, FloatTensor dest, int destOffset) {
+ GGMLType type = entry.ggmlType();
+ if (type.getBlockSize() != 1) {
+ throw new UnsupportedOperationException("copyEmbeddingRow only supports unblocked (per-element) types, got " + type);
+ }
+ MemorySegment segment = entry.memorySegment();
+ long elementBytes = type.getTypeSize();
+ long rowByteOffset = rowIndex * rowSize * elementBytes;
+ for (int i = 0; i < rowSize; i++) {
+ long byteOffset = rowByteOffset + (long) i * elementBytes;
+ float value = switch (type) {
+ case F32 -> segment.get(ValueLayout.JAVA_FLOAT_UNALIGNED, byteOffset);
+ case F16 -> Float.float16ToFloat(segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset));
+ case BF16 -> Float.intBitsToFloat(((int) segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset)) << 16);
+ default -> throw new UnsupportedOperationException("copyEmbeddingRow: unsupported type " + type);
+ };
+ dest.setFloat(destOffset + i, value);
+ }
+ }
+
+ /**
+ * Like {@link #copyEmbeddingRow(GGMLTensorEntry, long, int, FloatTensor, int)}, but writes into a
+ * TornadoVM {@link FloatArray} (optionally scaling each element) -- used by the GPU path to gather
+ * a per-token embedding row directly into a buffer ready for transfer to the device.
+ */
+ public static void copyEmbeddingRowToFloatArray(GGMLTensorEntry entry, long rowIndex, int rowSize, FloatArray dest, float scale) {
+ GGMLType type = entry.ggmlType();
+ if (type.getBlockSize() != 1) {
+ throw new UnsupportedOperationException("copyEmbeddingRowToFloatArray only supports unblocked (per-element) types, got " + type);
+ }
+ MemorySegment segment = entry.memorySegment();
+ long elementBytes = type.getTypeSize();
+ long rowByteOffset = rowIndex * rowSize * elementBytes;
+ for (int i = 0; i < rowSize; i++) {
+ long byteOffset = rowByteOffset + (long) i * elementBytes;
+ float value = switch (type) {
+ case F32 -> segment.get(ValueLayout.JAVA_FLOAT_UNALIGNED, byteOffset);
+ case F16 -> Float.float16ToFloat(segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset));
+ case BF16 -> Float.intBitsToFloat(((int) segment.get(ValueLayout.JAVA_SHORT_UNALIGNED, byteOffset)) << 16);
+ default -> throw new UnsupportedOperationException("copyEmbeddingRowToFloatArray: unsupported type " + type);
+ };
+ dest.set(i, value * scale);
+ }
+ }
+
// Helper methods
public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) {
diff --git a/src/main/java/org/beehive/gpullama3/tensor/GGUF.java b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java
index 9cdc5b7d..9da00b76 100644
--- a/src/main/java/org/beehive/gpullama3/tensor/GGUF.java
+++ b/src/main/java/org/beehive/gpullama3/tensor/GGUF.java
@@ -1,6 +1,5 @@
package org.beehive.gpullama3.tensor;
-import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.auxiliary.Pair;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
@@ -122,8 +121,12 @@ public static Map loadTensorsStandard(FileChannel fileC
continue;
}
- int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
- int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
+ // Long arithmetic: some tensors (e.g. Gemma4's per-layer token embedding table) exceed
+ // Integer.MAX_VALUE elements/bytes, which would overflow the int-based
+ // FloatTensor.numberOfElements/GGMLType.byteSizeFor. Such tensors are read directly via
+ // long-offset MemorySegment access rather than wrapped in a FloatTensor, so only the
+ // slice size (which MemorySegment#asSlice accepts as a long) matters here.
+ long sizeInBytes = tensorByteSize(ti);
// per-tensor slice offset; ti.offset() is relative to tensor-data start
long offset = ti.offset();
@@ -167,8 +170,9 @@ public static Map loadTensorsTornado(FileChannel fileCh
continue;
}
- int numberOfElements = FloatTensor.numberOfElements(ti.dimensions());
- int sizeInBytes = Math.toIntExact(ti.ggmlType().byteSizeFor(numberOfElements));
+ // see loadTensorsStandard for why this uses long arithmetic instead of
+ // FloatTensor.numberOfElements/GGMLType.byteSizeFor
+ long sizeInBytes = tensorByteSize(ti);
// absolute tensor offset - relative to start of the file
long mappingOffset = tensorDataOffset + ti.offset();
@@ -193,6 +197,18 @@ public static Map loadTensorsTornado(FileChannel fileCh
return tensorEntries;
}
+ /** Computes a tensor's byte size with {@code long} arithmetic, avoiding int-overflow for very large tensors. */
+ private static long tensorByteSize(GGUFTensorInfo ti) {
+ long numberOfElements = 1L;
+ for (int dimension : ti.dimensions()) {
+ numberOfElements *= dimension;
+ }
+ long typeSize = ti.ggmlType().getTypeSize();
+ long blockSize = ti.ggmlType().getBlockSize();
+ assert (numberOfElements * typeSize) % blockSize == 0;
+ return numberOfElements * typeSize / blockSize;
+ }
+
public Map getTensorInfos() {
return tensorInfos;
}
diff --git a/src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java b/src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java
new file mode 100644
index 00000000..1d370c06
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tensor/standard/BF16FloatTensor.java
@@ -0,0 +1,57 @@
+package org.beehive.gpullama3.tensor.standard;
+
+import org.beehive.gpullama3.tensor.GGMLType;
+import jdk.incubator.vector.FloatVector;
+import jdk.incubator.vector.VectorSpecies;
+
+import java.lang.foreign.MemorySegment;
+
+/**
+ * {@link FloatTensor} backed by raw BF16 (bfloat16) data.
+ *
+ * BF16 stores the upper 16 bits of an IEEE-754 binary32 value (same sign/exponent layout,
+ * truncated mantissa), so widening to float32 is a plain left-shift by 16 bits -- no exponent
+ * rebiasing is needed, unlike IEEE binary16 (F16).
+ */
+public final class BF16FloatTensor extends FloatTensor {
+
+ final int size;
+ final MemorySegment memorySegment;
+
+ public BF16FloatTensor(int size, MemorySegment memorySegment) {
+ this.size = size;
+ this.memorySegment = memorySegment;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public void setFloat(int index, float value) {
+ throw new UnsupportedOperationException("setFloat");
+ }
+
+ @Override
+ protected FloatVector getFloatVector(VectorSpecies species, int index) {
+ throw new UnsupportedOperationException("getFloatVector");
+ }
+
+ @Override
+ public GGMLType type() {
+ return GGMLType.BF16;
+ }
+
+ @Override
+ public MemorySegment asMemorySegment() {
+ return null;
+ }
+
+ @Override
+ public float getFloat(int index) {
+ assert 0 <= index && index < size;
+ short bits = readShort(memorySegment, index * (long) GGMLType.BFLOAT16_BYTES);
+ return Float.intBitsToFloat(((int) bits) << 16);
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java b/src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java
new file mode 100644
index 00000000..1fe9da5a
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/Gemma4Tokenizer.java
@@ -0,0 +1,146 @@
+package org.beehive.gpullama3.tokenizer;
+
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * SentencePiece-style BPE tokenizer with byte fallback, used by Gemma 4 models.
+ *
+ * Spaces are represented with the SentencePiece marker {@code ▁}, and any codepoint missing from
+ * the vocabulary falls back to its individual UTF-8 bytes encoded as {@code <0xXX>} tokens. Pairs
+ * are greedily merged according to the highest {@code tokenizer.ggml.scores} value, mirroring
+ * {@link MistralTokenizer}.
+ */
+public class Gemma4Tokenizer implements Tokenizer {
+
+ private final Vocabulary vocabulary;
+ private final Map specialTokens;
+ private final int[] tokenType;
+ private final int byte0;
+
+ public Gemma4Tokenizer(Map metadata, Vocabulary vocabulary) {
+ int[] tokenTypes = (int[]) metadata.get("tokenizer.ggml.token_type");
+
+ // Special tokens are anything that isn't a regular sub-word (NORMAL, type 1) or a raw byte-fallback token (BYTE, type 6).
+ Map specialTokens = IntStream.range(0, vocabulary.size())
+ .filter(t -> tokenTypes[t] != 1 && tokenTypes[t] != 6)
+ .boxed()
+ .collect(Collectors.toMap(vocabulary::get, t -> t, (first, second) -> first));
+
+ this.vocabulary = vocabulary;
+ this.specialTokens = new HashMap<>(specialTokens);
+ this.tokenType = tokenTypes;
+ this.byte0 = vocabulary.getIndex("<0x00>").orElseThrow();
+ }
+
+ @Override
+ public String regexPattern() {
+ return null;
+ }
+
+ @Override
+ public Map getSpecialTokens() {
+ return specialTokens;
+ }
+
+ @Override
+ public boolean isSpecialToken(int tokenIndex) {
+ return getTokenType(tokenIndex) != 1;
+ }
+
+ @Override
+ public boolean shouldDisplayToken(int token) {
+ int type = getTokenType(token);
+ return type == 1 || type == 6;
+ }
+
+ public int getTokenType(int tokenIndex) {
+ return tokenType[tokenIndex];
+ }
+
+ private List encodeImpl(String text) {
+ List tokens = new ArrayList<>();
+
+ // first encode every individual codepoint in the input string
+ for (int i = 0, cpi; i < text.length(); i += Character.charCount(cpi)) {
+ cpi = text.codePointAt(i);
+
+ String singleCodepoint = Character.toString(cpi);
+ int id = vocabulary.getIndex(singleCodepoint).orElse(-1);
+
+ if (id != -1) {
+ tokens.add(id);
+ } else {
+ // byte fallback: encode each UTF-8 byte as a <0xXX> token (offset by the index of <0x00>)
+ for (byte b : singleCodepoint.getBytes(StandardCharsets.UTF_8)) {
+ tokens.add(Byte.toUnsignedInt(b) + byte0);
+ }
+ }
+ }
+
+ // greedily merge the highest-scoring adjacent pair until no more merges apply
+ while (true) {
+ float bestScore = -1e10f;
+ int bestId = -1;
+ int bestIdx = -1;
+
+ for (int i = 0; i < tokens.size() - 1; ++i) {
+ String merged = vocabulary.get(tokens.get(i)) + vocabulary.get(tokens.get(i + 1));
+ int id = vocabulary.getIndex(merged).orElse(-1);
+ if (id != -1 && vocabulary.getScore(id) > bestScore) {
+ bestScore = vocabulary.getScore(id);
+ bestId = id;
+ bestIdx = i;
+ }
+ }
+
+ if (bestIdx == -1) {
+ break;
+ }
+
+ tokens.set(bestIdx, bestId);
+ tokens.remove(bestIdx + 1);
+ }
+
+ return tokens;
+ }
+
+ @Override
+ public List encode(String text, Set allowedSpecial) {
+ return encodeImpl(text.replace(' ', '▁'));
+ }
+
+ @Override
+ public List encodeAsList(String text) {
+ return encode(text, Collections.emptySet());
+ }
+
+ @Override
+ public String decode(List tokens) {
+ StringBuilder sb = new StringBuilder();
+ for (int token : tokens) {
+ String tokenString = vocabulary.get(token);
+ if (isSpecialToken(token)) {
+ // byte-fallback tokens decode back to their raw byte/codepoint
+ String prefix = "<0x";
+ String suffix = ">";
+ if (tokenString.length() == 6 && tokenString.startsWith(prefix) && tokenString.endsWith(suffix)) {
+ String code = tokenString.substring(prefix.length(), tokenString.length() - suffix.length());
+ int cp = Integer.parseInt(code, 16);
+ tokenString = Character.toString(cp);
+ }
+ } else {
+ tokenString = tokenString.replace('▁', ' ');
+ }
+ sb.append(tokenString);
+ }
+ return sb.toString();
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java
index b29f3576..be10d4b0 100644
--- a/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java
+++ b/src/main/java/org/beehive/gpullama3/tokenizer/Vocabulary.java
@@ -36,6 +36,12 @@ public static Vocabulary loadQwen3Vocabulary(Map metadata) {
return new Vocabulary(tokens, scores);
}
+ public static Vocabulary loadGemma4Vocabulary(Map metadata) {
+ String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens");
+ float[] scores = (float[]) metadata.get("tokenizer.ggml.scores");
+ return new Vocabulary(tokens, scores);
+ }
+
public static Vocabulary loadDevstralVocabulary(Map metadata) {
String[] tokens = (String[]) metadata.get("tokenizer.ggml.tokens");
return new Vocabulary(tokens, null);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java
new file mode 100644
index 00000000..47a12ec3
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/kernels/Gemma4Kernels.java
@@ -0,0 +1,432 @@
+package org.beehive.gpullama3.tornadovm.kernels;
+
+import uk.ac.manchester.tornado.api.KernelContext;
+import uk.ac.manchester.tornado.api.annotations.Parallel;
+import uk.ac.manchester.tornado.api.math.TornadoMath;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.IntArray;
+
+/**
+ * Custom GPU kernels for the Gemma 4 architecture.
+ *
+ * Gemma 4's computation graph differs substantially from the "Llama-like" models the rest of the
+ * {@code tornadovm.kernels} package targets: every layer carries its own Q/K-norm and a "sandwich" of
+ * pre/post normalization around both attention and FFN, attention alternates between sliding-window
+ * (local) and full (global) variants -- with different head dimensions and RoPE tables -- some layers
+ * reuse an earlier layer's KV cache, the FFN uses a GeGLU activation, and every layer additionally
+ * mixes in a per-layer embedding (PLE). None of the existing fused kernels match this shape, so this
+ * class provides purpose-built (but otherwise unfused/modular) replacements; see
+ * {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4} for the reference computation
+ * each of these mirrors.
+ */
+// @formatter:off
+public class Gemma4Kernels {
+
+ /** Materializes {@code out = weight * (rmsScale[0] * x)} -- i.e. RMSNorm with a learned scale, written to a separate buffer. */
+ public static void applyRmsNorm(KernelContext context, FloatArray out, FloatArray x, FloatArray weight, FloatArray rmsScale, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ float scale = rmsScale.get(0);
+ out.set(gid, weight.get(gid) * (scale * x.get(gid)));
+ }
+ }
+
+ /** {@code x[i] *= scale} (used for embedding scaling). */
+ public static void scaleInPlace(KernelContext context, FloatArray x, float scale, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ x.set(gid, x.get(gid) * scale);
+ }
+ }
+
+ /** {@code x[i] *= scaleTensor[0]} -- like {@link #scaleInPlace}, but the (learned, per-layer) scale is read from a 1-element tensor at kernel time. */
+ public static void scaleInPlaceFromTensor(KernelContext context, FloatArray x, FloatArray scaleTensor, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ x.set(gid, x.get(gid) * scaleTensor.get(0));
+ }
+ }
+
+ /** {@code out[i] = (a[i] + b[i]) * scale} (used to merge the per-layer projection with the per-layer token embedding). */
+ public static void addAndScale(KernelContext context, FloatArray out, FloatArray a, FloatArray b, float scale, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ out.set(gid, (a.get(gid) + b.get(gid)) * scale);
+ }
+ }
+
+ /**
+ * Sandwich-norm + residual: {@code x[i] += weight[i] * (rmsScale[0] * delta[i])}.
+ * Used for post-attention-norm, post-FFN-norm, and the per-layer-embedding post-norm, each of
+ * which normalizes a freshly computed branch output and adds it back onto the running residual.
+ */
+ public static void rmsNormApplyWithResidual(KernelContext context, FloatArray x, FloatArray delta, FloatArray weight, FloatArray rmsScale, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ float scale = rmsScale.get(0);
+ float normalized = weight.get(gid) * (scale * delta.get(gid));
+ x.set(gid, x.get(gid) + normalized);
+ }
+ }
+
+ /**
+ * Per-head RMSNorm with a learned scale (Q-norm / K-norm): each workgroup normalizes one head
+ * of {@code vec} in place, mirroring {@code rmsnorm(vec, vec, weight, h*headDim, headDim, eps)}
+ * applied independently for every head {@code h}.
+ */
+ public static void rmsNormPerHead(KernelContext context, FloatArray vec, FloatArray weight, int nHeads, int headDim, int localMemSize, float rmsNormEps) {
+ int headIdx = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = context.localGroupSizeX;
+ if (headIdx >= nHeads) {
+ return;
+ }
+ int base = headIdx * headDim;
+
+ float[] localSum = context.allocateFloatLocalArray(localMemSize);
+ float partial = 0f;
+ for (int i = localId; i < headDim; i += localSize) {
+ float v = vec.get(base + i);
+ partial += v * v;
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float ss = localSum[0] / headDim + rmsNormEps;
+ ss = 1.0f / TornadoMath.sqrt(ss);
+ context.localBarrier();
+ for (int i = localId; i < headDim; i += localSize) {
+ float normalized = ss * vec.get(base + i);
+ vec.set(base + i, weight.get(i) * normalized);
+ }
+ }
+
+ /** Like {@link #rmsNormPerHead}, but without a learned scale (Gemma4 normalizes V with a plain, weight-less RMSNorm). */
+ public static void rmsNormPerHeadNoWeight(KernelContext context, FloatArray vec, int nHeads, int headDim, int localMemSize, float rmsNormEps) {
+ int headIdx = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = context.localGroupSizeX;
+ if (headIdx >= nHeads) {
+ return;
+ }
+ int base = headIdx * headDim;
+
+ float[] localSum = context.allocateFloatLocalArray(localMemSize);
+ float partial = 0f;
+ for (int i = localId; i < headDim; i += localSize) {
+ float v = vec.get(base + i);
+ partial += v * v;
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float ss = localSum[0] / headDim + rmsNormEps;
+ ss = 1.0f / TornadoMath.sqrt(ss);
+ context.localBarrier();
+ for (int i = localId; i < headDim; i += localSize) {
+ vec.set(base + i, ss * vec.get(base + i));
+ }
+ }
+
+ /**
+ * NeoX-style RoPE rotation (split-half pairs, using precomputed cos/sin tables) for Q only --
+ * used by layers that reuse an earlier layer's KV cache (so K is never computed/rotated here).
+ * Launched on a 2D grid of (nHeads, headDim/2).
+ */
+ public static void ropeNeoxRotateQOnly(KernelContext context, IntArray positionHolder, FloatArray q, FloatArray freqCisReal, FloatArray freqCisImag, int headDim) {
+ int h = context.globalIdx;
+ int ic = context.globalIdy;
+ int half = headDim / 2;
+ int pos = positionHolder.get(0);
+
+ float fcr = freqCisReal.get(pos * half + ic);
+ float fci = freqCisImag.get(pos * half + ic);
+
+ int base = h * headDim;
+ float v0 = q.get(base + ic);
+ float v1 = q.get(base + ic + half);
+ q.set(base + ic, v0 * fcr - v1 * fci);
+ q.set(base + ic + half, v0 * fci + v1 * fcr);
+ }
+
+ /**
+ * NeoX-style RoPE rotation for Q and K, fused with the KV-cache write (K rotated then cached,
+ * V copied as-is) -- used by layers that own their KV cache. Launched on a 2D grid of
+ * (nHeads, headDim/2); K/V handling is gated on {@code h < nHeadKv} (mirrors
+ * {@code Qwen3Kernels.ropeRotationWithCacheCopy}'s {@code rotn} pattern for GQA).
+ *
+ * {@code cacheBaseOffset} is the (possibly shared, see {@link org.beehive.gpullama3.inference.state.Gemma4State#cacheLayerBaseOffset})
+ * base element offset of this layer's slot in the flat {@code keyCache}/{@code valueCache} buffers.
+ */
+ public static void ropeNeoxRotateAndCacheCopy(
+ KernelContext context,
+ IntArray positionHolder,
+ FloatArray q,
+ FloatArray k,
+ FloatArray v,
+ FloatArray keyCache,
+ FloatArray valueCache,
+ FloatArray freqCisReal,
+ FloatArray freqCisImag,
+ int nHeadKv,
+ int headDim,
+ int kvDim,
+ int cacheBaseOffset) {
+
+ int h = context.globalIdx;
+ int ic = context.globalIdy;
+ int half = headDim / 2;
+ int pos = positionHolder.get(0);
+
+ float fcr = freqCisReal.get(pos * half + ic);
+ float fci = freqCisImag.get(pos * half + ic);
+
+ // Rotate Q (all heads)
+ int qBase = h * headDim;
+ float v0q = q.get(qBase + ic);
+ float v1q = q.get(qBase + ic + half);
+ q.set(qBase + ic, v0q * fcr - v1q * fci);
+ q.set(qBase + ic + half, v0q * fci + v1q * fcr);
+
+ // Rotate K and write rotated-K / raw-V into the cache (KV heads only)
+ if (h < nHeadKv) {
+ int kBase = h * headDim;
+ float v0k = k.get(kBase + ic);
+ float v1k = k.get(kBase + ic + half);
+ float rotatedK0 = v0k * fcr - v1k * fci;
+ float rotatedK1 = v0k * fci + v1k * fcr;
+ k.set(kBase + ic, rotatedK0);
+ k.set(kBase + ic + half, rotatedK1);
+
+ int cacheOffset = cacheBaseOffset + pos * kvDim + h * headDim;
+ keyCache.set(cacheOffset + ic, rotatedK0);
+ keyCache.set(cacheOffset + ic + half, rotatedK1);
+ valueCache.set(cacheOffset + ic, v.get(kBase + ic));
+ valueCache.set(cacheOffset + ic + half, v.get(kBase + ic + half));
+ }
+ }
+
+ /**
+ * Causal self-attention restricted to a (possibly sliding) window: scores/softmax/weighted-sum
+ * over {@code t} in {@code [windowStart, pos]}, where {@code windowStart = max(0, pos - windowSize + 1)}.
+ * Full-attention layers pass {@code windowSize >= contextLength} so that {@code windowStart} is
+ * always {@code 0} (plain causal attention) -- see {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}.
+ * Gemma4 uses an attention scale of {@code 1.0} (no {@code 1/sqrt(headDim)}).
+ *
+ * {@code cacheBaseOffset} addresses the (possibly shared) KV-cache slot for this layer, see
+ * {@link #ropeNeoxRotateAndCacheCopy}.
+ */
+ public static void attentionWithSlidingWindow(
+ FloatArray q,
+ FloatArray keyCache,
+ FloatArray valueCache,
+ FloatArray xb,
+ FloatArray wrapAtt,
+ int nHeads,
+ int headDim,
+ int kvDim,
+ int kvMul,
+ IntArray positionHolder,
+ int cacheBaseOffset,
+ int windowSize,
+ int contextLength) {
+
+ int pos = positionHolder.get(0);
+ int windowStart = Math.max(0, pos - windowSize + 1);
+
+ for (@Parallel int h = 0; h < nHeads; h++) {
+ gemma4ProcessHead(q, keyCache, valueCache, xb, wrapAtt, h, headDim, kvDim, kvMul, cacheBaseOffset, pos, windowStart, contextLength);
+ }
+ }
+
+ private static void gemma4ProcessHead(
+ FloatArray q,
+ FloatArray keyCache,
+ FloatArray valueCache,
+ FloatArray xb,
+ FloatArray wrapAtt,
+ int h,
+ int headDim,
+ int kvDim,
+ int kvMul,
+ int cacheBaseOffset,
+ int pos,
+ int windowStart,
+ int contextLength) {
+
+ // wrapAtt is sized (nHeads * contextLength); index by absolute time t with a per-head stride of contextLength.
+ int hOff = h * contextLength;
+ int kvHeadIdx = h / kvMul;
+ int qOffset = h * headDim;
+
+ // STEP 1: scores for t in [windowStart, pos]
+ for (int t = windowStart; t <= pos; t++) {
+ int keyOffset = cacheBaseOffset + t * kvDim + kvHeadIdx * headDim;
+ float score = 0.0f;
+ for (int i = 0; i < headDim; i++) {
+ score += q.get(qOffset + i) * keyCache.get(keyOffset + i);
+ }
+ // Gemma4 attention scaling = 1.0 (no 1/sqrt(headDim))
+ wrapAtt.set(hOff + t, score);
+ }
+
+ // STEP 2: softmax over [windowStart, pos]
+ float maxScore = wrapAtt.get(hOff + windowStart);
+ for (int t = windowStart + 1; t <= pos; t++) {
+ float val = wrapAtt.get(hOff + t);
+ if (val > maxScore) {
+ maxScore = val;
+ }
+ }
+ float sum = 0.0f;
+ for (int t = windowStart; t <= pos; t++) {
+ int idx = hOff + t;
+ float expScore = TornadoMath.exp(wrapAtt.get(idx) - maxScore);
+ wrapAtt.set(idx, expScore);
+ sum += expScore;
+ }
+ float normFactor = (sum > 0.0f) ? (1.0f / sum) : (1.0f / (pos - windowStart + 1));
+ for (int t = windowStart; t <= pos; t++) {
+ int idx = hOff + t;
+ wrapAtt.set(idx, wrapAtt.get(idx) * normFactor);
+ }
+
+ // STEP 3: weighted sum of values
+ for (int i = 0; i < headDim; i++) {
+ float weightedSum = 0.0f;
+ for (int t = windowStart; t <= pos; t++) {
+ int valueOffset = cacheBaseOffset + t * kvDim + kvHeadIdx * headDim;
+ weightedSum += wrapAtt.get(hOff + t) * valueCache.get(valueOffset + i);
+ }
+ xb.set(h * headDim + i, weightedSum);
+ }
+ }
+
+ /**
+ * Fused GeGLU FFN gate/up projection: {@code hb[row] = gelu(W1[row] . xNorm) * (W3[row] . xNorm)}.
+ * Mirrors {@code TransformerComputeKernelsLayered.fusedRmsNormFFNGateUp} but (a) takes an
+ * already-normalized input -- Gemma4 materializes the normalized branch separately via
+ * {@link #applyRmsNorm} since the same normalized {@code xb} also feeds the attention QKV
+ * projections -- and (b) uses GELU rather than SiLU (see {@link TransformerComputeKernelsLayered#geluActivation}).
+ */
+ public static void fusedGateUpGeGLU(
+ KernelContext context,
+ FloatArray xNorm,
+ FloatArray hb,
+ HalfFloatArray w1,
+ HalfFloatArray w3,
+ int dim,
+ int hiddenDim,
+ int localWorkGroupSize) {
+
+ int rowId = context.groupIdx;
+ int localId = context.localIdx;
+ if (rowId >= hiddenDim) {
+ return;
+ }
+
+ float[] localSum = context.allocateFloatLocalArray(localWorkGroupSize);
+ int rowOffset = rowId * dim;
+
+ // === W1 (gate) ===
+ float sum1 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ sum1 += w1.get(rowOffset + j).getFloat32() * xNorm.get(j);
+ }
+ localSum[localId] = sum1;
+ context.localBarrier();
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float result1 = localSum[0];
+ context.localBarrier();
+
+ // === W3 (up) ===
+ float sum3 = 0.0f;
+ for (int j = localId; j < dim; j += localWorkGroupSize) {
+ sum3 += w3.get(rowOffset + j).getFloat32() * xNorm.get(j);
+ }
+ localSum[localId] = sum3;
+ context.localBarrier();
+ for (int stride = localWorkGroupSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float result3 = localSum[0];
+
+ if (localId == 0) {
+ hb.set(rowId, TransformerComputeKernelsLayered.geluActivation(result1) * result3);
+ }
+ }
+
+ /** {@code gate[i] = gelu(gate[i]) * perLayerInputs[peOffset + i]} -- the PLE gating step. */
+ public static void pleGateGeluMul(KernelContext context, FloatArray gate, FloatArray perLayerInputs, int peOffset, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ float gated = TransformerComputeKernelsLayered.geluActivation(gate.get(gid));
+ gate.set(gid, gated * perLayerInputs.get(peOffset + gid));
+ }
+ }
+
+ /**
+ * Per-segment scale + RMSNorm with a single shared learned scale, used for the per-layer
+ * projection's normalization: {@code perLayerProjScratch} is laid out as {@code [numLayers][segmentSize]},
+ * and {@code weight} (size {@code segmentSize}) is reused identically for every segment. One
+ * workgroup processes one segment (segment index = {@code groupIdx}), mirroring
+ * {@code rmsnorm(scratch, scratch, perLayerProjNorm, l*segmentSize, segmentSize, eps)} for every {@code l}.
+ */
+ public static void pleProjScaleAndNormalize(KernelContext context, FloatArray x, FloatArray weight, int segmentSize, int localMemSize, float preScale, float rmsNormEps) {
+ int segIdx = context.groupIdx;
+ int localId = context.localIdx;
+ int localSize = context.localGroupSizeX;
+ int base = segIdx * segmentSize;
+
+ float[] localSum = context.allocateFloatLocalArray(localMemSize);
+ float partial = 0f;
+ for (int i = localId; i < segmentSize; i += localSize) {
+ float v = x.get(base + i) * preScale;
+ x.set(base + i, v);
+ partial += v * v;
+ }
+ localSum[localId] = partial;
+ context.localBarrier();
+ for (int stride = localSize / 2; stride > 0; stride >>= 1) {
+ if (localId < stride) {
+ localSum[localId] += localSum[localId + stride];
+ }
+ context.localBarrier();
+ }
+ float ss = localSum[0] / segmentSize + rmsNormEps;
+ ss = 1.0f / TornadoMath.sqrt(ss);
+ context.localBarrier();
+ for (int i = localId; i < segmentSize; i += localSize) {
+ float normalized = ss * x.get(base + i);
+ x.set(base + i, weight.get(i) * normalized);
+ }
+ }
+
+ /** Final logit soft-capping: {@code logits[i] = softcap * tanh(logits[i] / softcap)}. */
+ public static void applyLogitSoftcap(KernelContext context, FloatArray logits, float softcap, int size) {
+ int gid = context.globalIdx;
+ if (gid < size) {
+ float v = logits.get(gid);
+ logits.set(gid, TornadoMath.tanh(v / softcap) * softcap);
+ }
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java
index 42d2dc0c..64ba2815 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/QuantizationPlannerFactory.java
@@ -1,6 +1,7 @@
package org.beehive.gpullama3.tornadovm.layerplanner;
import org.beehive.gpullama3.inference.state.DevstralState;
+import org.beehive.gpullama3.inference.state.Gemma4State;
import org.beehive.gpullama3.inference.state.GraniteState;
import org.beehive.gpullama3.tensor.GGMLType;
import org.beehive.gpullama3.inference.state.LlamaState;
@@ -10,6 +11,7 @@
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.DevstralFP16LayerPlanner;
+import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.Gemma4FP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.GraniteFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.LlamaFP16LayerPlanner;
import org.beehive.gpullama3.tornadovm.layerplanner.model.fp16.MistralFP16LayerPlanner;
@@ -63,6 +65,7 @@ private static GenericLayerPlanner createFP16Planner(State state, Model model) {
case DEVSTRAL_2 -> new DevstralFP16LayerPlanner((DevstralState) state, model);
case QWEN_2 -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
case QWEN_3 -> new Qwen3FP16LayerPlanner((Qwen3State) state, model);
+ case GEMMA_4 -> new Gemma4FP16LayerPlanner((Gemma4State) state, model);
case PHI_3 -> new Phi3FP16LayerPlanner((Phi3State) state, model);
case GRANITE -> new GraniteFP16LayerPlanner((GraniteState) state, model);
case DEEPSEEK_R1_DISTILL_QWEN -> new Qwen2FP16LayerPlanner((Qwen2State) state, model);
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java
new file mode 100644
index 00000000..9b476cd4
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layerplanner/model/fp16/Gemma4FP16LayerPlanner.java
@@ -0,0 +1,29 @@
+package org.beehive.gpullama3.tornadovm.layerplanner.model.fp16;
+
+import org.beehive.gpullama3.inference.state.Gemma4State;
+import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
+import org.beehive.gpullama3.tornadovm.layers.Activation;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.Gemma4LogitsFP16Layer;
+import org.beehive.gpullama3.tornadovm.layers.type.fp16.Gemma4FP16FFNLayers;
+
+/**
+ * Gemma4FP16LayerPlanner: Gemma 4 model with FP16 weights.
+ *
+ * Follows the same pattern as Qwen3FP16LayerPlanner: wires together the (model-agnostic) Activation
+ * layer, Gemma4-specific FFN layers, and a Gemma4-specific logits layer (which adds the final
+ * logit soft-cap), then assembles the inference plan.
+ *
+ * Inherits from FP16LayerPlanner
+ */
+public class Gemma4FP16LayerPlanner extends FP16LayerPlanner {
+
+ public Gemma4FP16LayerPlanner(Gemma4State state, Model model) {
+ super(state, model);
+ this.activationLayer = new Activation("activationUpdate", state, weights, config);
+ this.ffnLayers = new Gemma4FP16FFNLayers("gemma4FFN", state, weights, config, schedulerType);
+ this.logitsLayer = new Gemma4LogitsFP16Layer("logits", state, weights, config, ffnLayers.getLastFFNLayerTaskGraphID(), schedulerType);
+ createTornadoInferencePlan();
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java
new file mode 100644
index 00000000..2b97c241
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4FP16FFNLayers.java
@@ -0,0 +1,375 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16;
+
+import org.beehive.gpullama3.inference.state.Gemma4State;
+import org.beehive.gpullama3.inference.weights.tornado.Gemma4TornadoWeights;
+import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
+import org.beehive.gpullama3.tornadovm.kernels.Gemma4Kernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import org.beehive.gpullama3.tornadovm.layers.AbstractFFNLayers;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.WorkerGrid;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+/**
+ * Gemma4FP16FFNLayers: FP16 transformer-layer task graphs for the Gemma 4 architecture.
+ *
+ * Gemma 4's layers differ enough from the "Llama-like" models that nothing here is fused the way
+ * {@code Qwen3FP16FFNLayers} is -- each layer carries its own Q/K-norm and a "sandwich" of pre/post
+ * normalization around both attention and FFN, attention head dimensions and RoPE tables differ
+ * between sliding-window and full-attention layers (and are baked into each layer's task graph as
+ * compile-time constants -- see {@link Gemma4Configuration#headDim}), some layers reuse an earlier
+ * layer's KV cache instead of computing their own, the FFN uses GeGLU, and every layer mixes in a
+ * per-layer embedding (PLE) contribution. See {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}
+ * for the reference computation each task mirrors.
+ *
+ * Layer 0's task graph additionally carries one-time-per-token setup that the reference
+ * implementation performs before the layer loop: scaling the token embedding by {@code sqrt(dim)},
+ * and computing the per-layer-embedding inputs ({@code perLayerInputs}) from the per-layer model
+ * projection and the (host-gathered) per-layer token embedding row -- see {@link #appendPLESetupTasks}.
+ */
+public class Gemma4FP16FFNLayers extends AbstractFFNLayers {
+
+ /** Local memory size for per-head Q/K/V-norm reductions; must evenly divide both head dimensions (256, 512). */
+ private static final int HEAD_NORM_LOCAL_SIZE = 64;
+
+ private final Gemma4State gemma4State;
+ private final int nHead;
+ private final int nHeadKv;
+ private final int kvMul;
+ private final int dim;
+ private final int nEmbdPerLayer;
+ private final int perLayerTotal;
+ private final float embedScale;
+ private final float perLayerTokEmbedScale;
+ private final float perLayerProjScale;
+ private final float perLayerInputScale;
+
+ public Gemma4FP16FFNLayers(String taskGraphName, Gemma4State state, Gemma4TornadoWeights weights, Gemma4Configuration config, SchedulerType schedulerType) {
+ super(taskGraphName, state, weights, config, schedulerType);
+ this.gemma4State = state;
+ this.nHead = config.numberOfHeads();
+ this.nHeadKv = config.numberOfKeyValueHeads();
+ this.kvMul = config.kvMul();
+ this.dim = config.dim();
+ this.nEmbdPerLayer = config.embeddingLengthPerLayer();
+ this.perLayerTotal = config.numberOfLayers() * nEmbdPerLayer;
+ this.embedScale = (float) Math.sqrt(dim);
+ this.perLayerTokEmbedScale = (float) Math.sqrt(nEmbdPerLayer);
+ this.perLayerProjScale = (float) (1.0 / Math.sqrt(dim));
+ this.perLayerInputScale = (float) (1.0 / Math.sqrt(2.0));
+ setupFFNLayers();
+ }
+
+ // ═══════════════════════════════════════════════════════════════════════════════════
+ // TASK GRAPH
+ // ═══════════════════════════════════════════════════════════════════════════════════
+
+ @Override
+ protected TaskGraph createFFNLayerTaskGraph(int layerIndex) {
+ var taskGraphName = "layer_" + layerIndex;
+ final int headDim = config.headDim(layerIndex);
+ final boolean isSwa = config.isSwa(layerIndex);
+ final boolean hasOwnKv = config.hasOwnKv(layerIndex);
+ final int qDim = nHead * headDim;
+ final int kvDim = nHeadKv * headDim;
+ final int ffnLen = config.feedForwardLength(layerIndex);
+ final int cacheBaseOffset = gemma4State.cacheLayerBaseOffset[layerIndex];
+ final int windowSize = isSwa ? config.slidingWindowSize() : config.contextLength();
+ final var freqCisReal = (isSwa ? weights.freqCisRealSwa : weights.freqCisRealFull).asFloatArray();
+ final var freqCisImag = (isSwa ? weights.freqCisImagSwa : weights.freqCisImagFull).asFloatArray();
+ final int peOffset = layerIndex * nEmbdPerLayer;
+
+ var unifiedLayer = new TaskGraph(taskGraphName);
+ unifiedLayer.consumeFromDevice(gemma4State.wrapX);
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ weights.rms_att_weightLayered[layerIndex].asFloatArray(),
+ weights.wqLayered[layerIndex].asHalfFloatArray(),
+ weights.wkLayered[layerIndex].asHalfFloatArray(),
+ weights.wvLayered[layerIndex].asHalfFloatArray(),
+ weights.woLayered[layerIndex].asHalfFloatArray(),
+ weights.attnQNorm[layerIndex].asFloatArray(),
+ weights.attnKNorm[layerIndex].asFloatArray(),
+ weights.attnPostNorm[layerIndex].asFloatArray(),
+ weights.rms_ffn_weightLayered[layerIndex].asFloatArray(),
+ weights.w1Layered[layerIndex].asHalfFloatArray(),
+ weights.w3Layered[layerIndex].asHalfFloatArray(),
+ weights.w2Layered[layerIndex].asHalfFloatArray(),
+ weights.ffnPostNorm[layerIndex].asFloatArray(),
+ weights.perLayerInpGate[layerIndex].asHalfFloatArray(),
+ weights.perLayerProj[layerIndex].asHalfFloatArray(),
+ weights.perLayerPostNorm[layerIndex].asFloatArray());
+ if (weights.layerOutputScale[layerIndex] != null) {
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION, weights.layerOutputScale[layerIndex].asFloatArray());
+ }
+ unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
+
+ if (layerIndex == 0) {
+ appendPLESetupTasks(unifiedLayer);
+ }
+
+ // ═══════════════════════════════════ ATTENTION ═══════════════════════════════════
+ unifiedLayer.task("attn_norm_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, gemma4State.temp, gemma4State.wrapX, dim, config.rmsNormEps(), gemma4State.localSize);
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("attn_norm_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, gemma4State.temp, dim, config.rmsNormEps());
+ }
+ unifiedLayer.task("attn_norm_apply",
+ Gemma4Kernels::applyRmsNorm,
+ context, gemma4State.wrapXb, gemma4State.wrapX, weights.rms_att_weightLayered[layerIndex].asFloatArray(), gemma4State.temp, dim);
+
+ unifiedLayer.task("q_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapXb, gemma4State.wrapQ, weights.wqLayered[layerIndex].asHalfFloatArray(), dim, qDim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ unifiedLayer.task("q_norm",
+ Gemma4Kernels::rmsNormPerHead,
+ context, gemma4State.wrapQ, weights.attnQNorm[layerIndex].asFloatArray(), nHead, headDim, HEAD_NORM_LOCAL_SIZE, config.rmsNormEps());
+
+ if (hasOwnKv) {
+ unifiedLayer.task("k_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapXb, gemma4State.wrapK, weights.wkLayered[layerIndex].asHalfFloatArray(), dim, kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ unifiedLayer.task("k_norm",
+ Gemma4Kernels::rmsNormPerHead,
+ context, gemma4State.wrapK, weights.attnKNorm[layerIndex].asFloatArray(), nHeadKv, headDim, HEAD_NORM_LOCAL_SIZE, config.rmsNormEps());
+ unifiedLayer.task("v_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapXb, gemma4State.wrapV, weights.wvLayered[layerIndex].asHalfFloatArray(), dim, kvDim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ unifiedLayer.task("v_norm",
+ Gemma4Kernels::rmsNormPerHeadNoWeight,
+ context, gemma4State.wrapV, nHeadKv, headDim, HEAD_NORM_LOCAL_SIZE, config.rmsNormEps());
+ unifiedLayer.task("rope_and_cache",
+ Gemma4Kernels::ropeNeoxRotateAndCacheCopy,
+ context, gemma4State.positionHolder, gemma4State.wrapQ, gemma4State.wrapK, gemma4State.wrapV,
+ gemma4State.wrapKeyCache, gemma4State.wrapValueCache, freqCisReal, freqCisImag,
+ nHeadKv, headDim, kvDim, cacheBaseOffset);
+ } else {
+ unifiedLayer.task("rope_q_only",
+ Gemma4Kernels::ropeNeoxRotateQOnly,
+ context, gemma4State.positionHolder, gemma4State.wrapQ, freqCisReal, freqCisImag, headDim);
+ }
+
+ unifiedLayer.task("attention",
+ Gemma4Kernels::attentionWithSlidingWindow,
+ gemma4State.wrapQ, gemma4State.wrapKeyCache, gemma4State.wrapValueCache, gemma4State.wrapXb, gemma4State.wrapAtt,
+ nHead, headDim, kvDim, kvMul, gemma4State.positionHolder, cacheBaseOffset, windowSize, config.contextLength());
+
+ unifiedLayer.task("wo_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapXb, gemma4State.wrapXb2, weights.woLayered[layerIndex].asHalfFloatArray(), qDim, dim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ unifiedLayer.task("post_attn_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, gemma4State.tempPostAttn, gemma4State.wrapXb2, dim, config.rmsNormEps(), gemma4State.localSize);
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("post_attn_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, gemma4State.tempPostAttn, dim, config.rmsNormEps());
+ }
+ unifiedLayer.task("post_attn_apply",
+ Gemma4Kernels::rmsNormApplyWithResidual,
+ context, gemma4State.wrapX, gemma4State.wrapXb2, weights.attnPostNorm[layerIndex].asFloatArray(), gemma4State.tempPostAttn, dim);
+
+ // ═══════════════════════════════════════ FFN ═════════════════════════════════════
+ unifiedLayer.task("ffn_norm_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, gemma4State.tempFFN, gemma4State.wrapX, dim, config.rmsNormEps(), gemma4State.localSize);
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ffn_norm_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, gemma4State.tempFFN, dim, config.rmsNormEps());
+ }
+ unifiedLayer.task("ffn_norm_apply",
+ Gemma4Kernels::applyRmsNorm,
+ context, gemma4State.wrapXb, gemma4State.wrapX, weights.rms_ffn_weightLayered[layerIndex].asFloatArray(), gemma4State.tempFFN, dim);
+
+ unifiedLayer.task("ffn_gate_up",
+ Gemma4Kernels::fusedGateUpGeGLU,
+ context, gemma4State.wrapXb, gemma4State.wrapHb, weights.w1Layered[layerIndex].asHalfFloatArray(), weights.w3Layered[layerIndex].asHalfFloatArray(),
+ dim, ffnLen, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ unifiedLayer.task("ffn_down_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapHb, gemma4State.wrapXb2, weights.w2Layered[layerIndex].asHalfFloatArray(), ffnLen, dim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ unifiedLayer.task("post_ffn_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, gemma4State.tempPostFfn, gemma4State.wrapXb2, dim, config.rmsNormEps(), gemma4State.localSize);
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("post_ffn_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, gemma4State.tempPostFfn, dim, config.rmsNormEps());
+ }
+ unifiedLayer.task("post_ffn_apply",
+ Gemma4Kernels::rmsNormApplyWithResidual,
+ context, gemma4State.wrapX, gemma4State.wrapXb2, weights.ffnPostNorm[layerIndex].asFloatArray(), gemma4State.tempPostFfn, dim);
+
+ // ═══════════════════════════ PER-LAYER EMBEDDING (PLE) ═══════════════════════════
+ unifiedLayer.task("ple_gate_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapX, gemma4State.wrapPerLayerGate, weights.perLayerInpGate[layerIndex].asHalfFloatArray(), dim, nEmbdPerLayer, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ unifiedLayer.task("ple_gate_gelu_mul",
+ Gemma4Kernels::pleGateGeluMul,
+ context, gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerInputs, peOffset, nEmbdPerLayer);
+ unifiedLayer.task("ple_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerOut, weights.perLayerProj[layerIndex].asHalfFloatArray(), nEmbdPerLayer, dim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ unifiedLayer.task("ple_post_reduce",
+ TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
+ context, gemma4State.tempPostPle, gemma4State.wrapPerLayerOut, dim, config.rmsNormEps(), gemma4State.localSize);
+ if (shouldUseFinalNormalization()) {
+ unifiedLayer.task("ple_post_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context, gemma4State.tempPostPle, dim, config.rmsNormEps());
+ }
+ unifiedLayer.task("ple_post_apply",
+ Gemma4Kernels::rmsNormApplyWithResidual,
+ context, gemma4State.wrapX, gemma4State.wrapPerLayerOut, weights.perLayerPostNorm[layerIndex].asFloatArray(), gemma4State.tempPostPle, dim);
+
+ if (weights.layerOutputScale[layerIndex] != null) {
+ unifiedLayer.task("layer_output_scale",
+ Gemma4Kernels::scaleInPlaceFromTensor,
+ context, gemma4State.wrapX, weights.layerOutputScale[layerIndex].asFloatArray(), dim);
+ }
+
+ unifiedLayer.persistOnDevice(gemma4State.wrapX);
+ return unifiedLayer;
+ }
+
+ /**
+ * One-time-per-token setup tasks, prepended to layer 0's graph: scales the token embedding by
+ * {@code sqrt(dim)} (Gemma4 scales embeddings on input -- the generic {@link org.beehive.gpullama3.tornadovm.layers.Activation}
+ * task graph that produced {@code wrapX} doesn't know about this), then computes the per-layer
+ * embedding inputs from the per-layer model projection and the (host-gathered) per-token
+ * per-layer-token-embedding row. Mirrors steps 1-2 of {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}.
+ */
+ private void appendPLESetupTasks(TaskGraph unifiedLayer) {
+ unifiedLayer.task("scale_embedding",
+ Gemma4Kernels::scaleInPlace,
+ context, gemma4State.wrapX, embedScale, dim);
+
+ unifiedLayer.task("ple_model_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context, gemma4State.wrapX, gemma4State.wrapPerLayerProjScratch, weights.perLayerModelProj.asHalfFloatArray(), dim, perLayerTotal, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ unifiedLayer.task("ple_proj_scale_norm",
+ Gemma4Kernels::pleProjScaleAndNormalize,
+ context, gemma4State.wrapPerLayerProjScratch, weights.perLayerProjNorm.asFloatArray(), nEmbdPerLayer, HEAD_NORM_LOCAL_SIZE, perLayerProjScale, config.rmsNormEps());
+ unifiedLayer.task("ple_merge",
+ Gemma4Kernels::addAndScale,
+ context, gemma4State.wrapPerLayerInputs, gemma4State.wrapPerLayerProjScratch, gemma4State.wrapPerLayerTokenEmbedRow, perLayerInputScale, perLayerTotal);
+ }
+
+ /**
+ * Configure data transfers for first and subsequent layers.
+ */
+ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int layerIndex) {
+ if (layerIndex == 0) {
+ unifiedLayer.transferToDevice(DataTransferMode.EVERY_EXECUTION,
+ gemma4State.positionHolder, gemma4State.wrapPerLayerTokenEmbedRow,
+ gemma4State.temp, gemma4State.tempFFN, gemma4State.tempPostAttn, gemma4State.tempPostFfn, gemma4State.tempPostPle);
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ weights.perLayerModelProj.asHalfFloatArray(), weights.perLayerProjNorm.asFloatArray(),
+ weights.freqCisRealSwa.asFloatArray(), weights.freqCisImagSwa.asFloatArray(),
+ weights.freqCisRealFull.asFloatArray(), weights.freqCisImagFull.asFloatArray());
+ unifiedLayer.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context, gemma4State.wrapXb, gemma4State.wrapXb2,
+ gemma4State.wrapQ, gemma4State.wrapK, gemma4State.wrapV,
+ gemma4State.wrapKeyCache, gemma4State.wrapValueCache,
+ gemma4State.wrapAtt, gemma4State.wrapHb,
+ gemma4State.wrapPerLayerInputs, gemma4State.wrapPerLayerProjScratch,
+ gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerOut);
+ } else {
+ unifiedLayer.consumeFromDevice(context, gemma4State.wrapXb, gemma4State.wrapXb2,
+ gemma4State.wrapQ, gemma4State.wrapK, gemma4State.wrapV,
+ gemma4State.wrapKeyCache, gemma4State.wrapValueCache,
+ gemma4State.wrapAtt, gemma4State.wrapHb,
+ gemma4State.wrapPerLayerInputs, gemma4State.wrapPerLayerGate, gemma4State.wrapPerLayerOut,
+ gemma4State.positionHolder);
+ }
+ return unifiedLayer;
+ }
+
+ // ═══════════════════════════════════════════════════════════════════════════════════
+ // GRID SCHEDULER
+ // ═══════════════════════════════════════════════════════════════════════════════════
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler gridScheduler) {
+ WorkerGrid rmsNormWorker = WorkerGridFactory.createRmsNormWorker(dim, gemma4State.localSize);
+ WorkerGrid dimElementWiseWorker = WorkerGridFactory.genericWorker(dim, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid woProjWorker = WorkerGridFactory.genericWorker(dim * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid pleGateProjWorker = WorkerGridFactory.genericWorker(nEmbdPerLayer * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid pleGateGeluWorker = WorkerGridFactory.genericWorker(nEmbdPerLayer, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ // === Layer-0 PLE setup ===
+ gridScheduler.addWorkerGrid("layer_0.scale_embedding", dimElementWiseWorker);
+ gridScheduler.addWorkerGrid("layer_0.ple_model_proj", WorkerGridFactory.genericWorker(perLayerTotal * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC));
+ gridScheduler.addWorkerGrid("layer_0.ple_proj_scale_norm", WorkerGridFactory.genericWorker(config.numberOfLayers() * HEAD_NORM_LOCAL_SIZE, HEAD_NORM_LOCAL_SIZE));
+ gridScheduler.addWorkerGrid("layer_0.ple_merge", WorkerGridFactory.genericWorker(perLayerTotal, LOCAL_WORK_GROUP_SIZE_ALLOC));
+
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String prefix = "layer_" + i + ".";
+ int headDim = config.headDim(i);
+ boolean hasOwnKv = config.hasOwnKv(i);
+ int qDim = nHead * headDim;
+ int kvDim = nHeadKv * headDim;
+ int ffnLen = config.feedForwardLength(i);
+
+ WorkerGrid headNormWorker = WorkerGridFactory.genericWorker(nHead * HEAD_NORM_LOCAL_SIZE, HEAD_NORM_LOCAL_SIZE);
+ WorkerGrid kvHeadNormWorker = WorkerGridFactory.genericWorker(nHeadKv * HEAD_NORM_LOCAL_SIZE, HEAD_NORM_LOCAL_SIZE);
+ WorkerGrid ropeWorker = WorkerGridFactory.createRoPEWorker(nHead, headDim);
+ WorkerGrid attentionWorker = WorkerGridFactory.createAttentionWorker(nHead, headDim);
+ WorkerGrid qProjWorker = WorkerGridFactory.genericWorker(qDim * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid kvProjWorker = WorkerGridFactory.genericWorker(kvDim * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC);
+ WorkerGrid ffnGateUpWorker = WorkerGridFactory.genericWorker(ffnLen * LOCAL_WORK_GROUP_SIZE_ALLOC, LOCAL_WORK_GROUP_SIZE_ALLOC);
+
+ gridScheduler.addWorkerGrid(prefix + "attn_norm_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "attn_norm_apply", dimElementWiseWorker);
+ gridScheduler.addWorkerGrid(prefix + "q_proj", qProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "q_norm", headNormWorker);
+ if (hasOwnKv) {
+ gridScheduler.addWorkerGrid(prefix + "k_proj", kvProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "k_norm", kvHeadNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "v_proj", kvProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "v_norm", kvHeadNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "rope_and_cache", ropeWorker);
+ } else {
+ gridScheduler.addWorkerGrid(prefix + "rope_q_only", ropeWorker);
+ }
+ gridScheduler.addWorkerGrid(prefix + "attention", attentionWorker);
+ gridScheduler.addWorkerGrid(prefix + "wo_proj", woProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "post_attn_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "post_attn_apply", dimElementWiseWorker);
+
+ gridScheduler.addWorkerGrid(prefix + "ffn_norm_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "ffn_norm_apply", dimElementWiseWorker);
+ gridScheduler.addWorkerGrid(prefix + "ffn_gate_up", ffnGateUpWorker);
+ gridScheduler.addWorkerGrid(prefix + "ffn_down_proj", woProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "post_ffn_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "post_ffn_apply", dimElementWiseWorker);
+
+ gridScheduler.addWorkerGrid(prefix + "ple_gate_proj", pleGateProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "ple_gate_gelu_mul", pleGateGeluWorker);
+ gridScheduler.addWorkerGrid(prefix + "ple_proj", woProjWorker);
+ gridScheduler.addWorkerGrid(prefix + "ple_post_reduce", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "ple_post_apply", dimElementWiseWorker);
+
+ if (shouldUseFinalNormalization()) {
+ gridScheduler.addWorkerGrid(prefix + "attn_norm_finalize", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "post_attn_finalize", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "ffn_norm_finalize", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "post_ffn_finalize", rmsNormWorker);
+ gridScheduler.addWorkerGrid(prefix + "ple_post_finalize", rmsNormWorker);
+ }
+ if (weights.layerOutputScale[i] != null) {
+ gridScheduler.addWorkerGrid(prefix + "layer_output_scale", dimElementWiseWorker);
+ }
+ }
+ return gridScheduler;
+ }
+}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java
new file mode 100644
index 00000000..baa76009
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/layers/type/fp16/Gemma4LogitsFP16Layer.java
@@ -0,0 +1,114 @@
+package org.beehive.gpullama3.tornadovm.layers.type.fp16;
+
+import org.beehive.gpullama3.inference.state.State;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
+import org.beehive.gpullama3.tornadovm.kernels.Gemma4Kernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernels;
+import org.beehive.gpullama3.tornadovm.kernels.TransformerComputeKernelsLayered;
+import org.beehive.gpullama3.tornadovm.layerplanner.WorkerGridFactory;
+import org.beehive.gpullama3.tornadovm.layerplanner.strategy.SchedulerType;
+import uk.ac.manchester.tornado.api.GridScheduler;
+import uk.ac.manchester.tornado.api.TaskGraph;
+import uk.ac.manchester.tornado.api.enums.DataTransferMode;
+
+/**
+ * Gemma4-specific FP16 logits layer.
+ *
+ * Identical to {@link LogitsFP16Layer} except for one addition: Gemma4 applies a final
+ * logit soft-cap, {@code logits = softcap * tanh(logits / softcap)}, after the vocabulary
+ * projection (see {@code gemma4.final_logit_softcapping} and
+ * {@link org.beehive.gpullama3.inference.InferenceCore#forwardJavaGemma4}).
+ */
+public class Gemma4LogitsFP16Layer extends LogitsFP16Layer {
+
+ private static final String SOFTCAP_TASK = "logit_softcap";
+
+ public Gemma4LogitsFP16Layer(String name, State state, Weights weights, Configuration config,
+ String lastTaskGraphID, SchedulerType schedulerType) {
+ super(name, state, weights, config, lastTaskGraphID, schedulerType);
+ }
+
+ private float softcap() {
+ return ((Gemma4Configuration) config).finalLogitSoftcapping();
+ }
+
+ // @formatter:off
+ @Override
+ protected TaskGraph setupLogitsTaskGraph(TornadoWeights weights, Configuration config) {
+ var logits = new TaskGraph("logits");
+ // === Data Setup ===
+ logits.consumeFromDevice(lastTaskGraphID, state.wrapX);
+ logits.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
+ logits.transferToDevice(DataTransferMode.FIRST_EXECUTION,
+ context,
+ state.wrapLogits,
+ state.wrapXbFP16,
+ weights.wclsByteArray.asHalfFloatArray(),
+ weights.rms_final_weight_as_floatArray.asFloatArray());
+
+ // === Final RMS Normalization ===
+ logits.task("rms_reduce",
+ TransformerComputeKernels::reductionOneBlockWithLayer,
+ context,
+ state.tempLogits,
+ state.wrapX,
+ config.dim(),
+ config.rmsNormEps(),
+ state.localSize);
+
+ if (schedulerType == SchedulerType.NON_NVIDIA) {
+ logits.task("rms_finalize",
+ TransformerComputeKernelsLayered::reductionFinalNormalization,
+ context,
+ state.tempLogits,
+ config.dim(),
+ config.rmsNormEps());
+ }
+
+ logits.task("rms_apply_fp16",
+ TransformerComputeKernels::mapContextWithQuantizeLogits,
+ context,
+ state.wrapXbFP16,
+ state.wrapX,
+ weights.rms_final_weight_as_floatArray.asFloatArray(),
+ state.tempLogits);
+
+ // === Vocabulary Projection ===
+ logits.task("vocab_proj",
+ TransformerComputeKernelsLayered::matrixVectorGeneric,
+ context,
+ state.wrapXbFP16,
+ state.wrapLogits,
+ weights.wclsByteArray.asHalfFloatArray(),
+ config.dim(),
+ config.vocabularySize(),
+ LOCAL_WORK_GROUP_SIZE_ALLOC * THREAD_SCALE_FOR_LOGITS);
+
+ // === Final logit soft-capping (Gemma4-specific) ===
+ if (softcap() != 0.0f) {
+ logits.task(SOFTCAP_TASK,
+ Gemma4Kernels::applyLogitSoftcap,
+ context,
+ state.wrapLogits,
+ softcap(),
+ config.vocabularySize());
+ }
+
+ // === Transfer Results to Host ===
+ logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
+ return logits;
+ }
+ // @formatter:on
+
+ @Override
+ public GridScheduler updateGridScheduler(GridScheduler tornadoForwardScheduler) {
+ var scheduler = super.updateGridScheduler(tornadoForwardScheduler);
+ if (softcap() != 0.0f) {
+ scheduler.addWorkerGrid("logits." + SOFTCAP_TASK, WorkerGridFactory.genericWorker(config.vocabularySize(), LOCAL_WORK_GROUP_SIZE_ALLOC));
+ }
+ return scheduler;
+ }
+}