Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions src/main/java/org/beehive/gpullama3/inference/InferenceCore.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

import org.beehive.gpullama3.auxiliary.Parallel;
import org.beehive.gpullama3.tensor.standard.FloatTensor;
import org.beehive.gpullama3.inference.state.Gemma4State;
import org.beehive.gpullama3.inference.state.Phi3State;
import org.beehive.gpullama3.inference.state.State;
import org.beehive.gpullama3.inference.weights.standard.Gemma4StandardWeights;
import org.beehive.gpullama3.model.loader.ModelLoader;
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.gemma4.Gemma4Configuration;
import org.beehive.gpullama3.model.granite.GraniteConfiguration;
import org.beehive.gpullama3.model.devstral.DevstralConfiguration;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
Expand Down Expand Up @@ -534,6 +538,186 @@ public static FloatTensor forwardJavaQwen3(Model model, State state, int token,
return state.logits;
}

/** RMS-normalizes without applying a learned scale (Gemma4 normalizes V with a plain, weight-less RMSNorm). */
private static void rmsnormNoWeight(FloatTensor out, FloatTensor x, int offset, int size, float rmsNormEps) {
float ss = x.reduce(offset, size, 0f, (acc, xi) -> acc + xi * xi);
ss /= size;
ss += rmsNormEps;
ss = (float) (1.0 / Math.sqrt(ss));
final float finalss = ss;
out.mapWithIndexInPlace(offset, size, (value, index) -> finalss * x.getFloat(index));
}

/** Tanh-approximation GELU, matching ggml's {@code ggml_gelu_f32} (used by Gemma4's GeGLU FFN and PLE gate). */
private static float gelu(float x) {
return 0.5f * x * (1.0f + (float) Math.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)));
}

/** NeoX-style RoPE: rotates pairs (i, i + headDim/2) within each head using precomputed cos/sin tables. */
private static void ropeRotateNeox(FloatTensor vec, int nHeads, int headDim, int position, FloatTensor freqCisReal, FloatTensor freqCisImag) {
int nComplHead = headDim / 2;
for (int h = 0; h < nHeads; h++) {
int base = h * headDim;
for (int ic = 0; ic < nComplHead; ic++) {
float fcr = freqCisReal.getFloat(position * nComplHead + ic);
float fci = freqCisImag.getFloat(position * nComplHead + ic);
float v0 = vec.getFloat(base + ic);
float v1 = vec.getFloat(base + ic + nComplHead);
vec.setFloat(base + ic, v0 * fcr - v1 * fci);
vec.setFloat(base + ic + nComplHead, v0 * fci + v1 * fcr);
}
}
}

public static FloatTensor forwardJavaGemma4(Model model, State state, int token, int position) {
final Gemma4Configuration config = (Gemma4Configuration) model.configuration();
final Gemma4StandardWeights weights = (Gemma4StandardWeights) model.weights();
final Gemma4State gs = (Gemma4State) state;

final int dim = config.dim();
final int nHead = config.numberOfHeads();
final int nHeadKv = config.numberOfKeyValueHeads();
final int kvMul = config.kvMul();
final int nLayers = config.numberOfLayers();
final int nEmbdPerLayer = config.embeddingLengthPerLayer();
final int perLayerTotal = nLayers * nEmbdPerLayer;
final float attentionScale = 1.0f; // Gemma4 attention uses scaling = 1.0 (no 1/sqrt(headDim))

// 1. token embedding, scaled by sqrt(dim)
weights.tokenEmbeddingTable.copyTo(token * dim, state.x, 0, dim);
final float embedScale = (float) Math.sqrt(dim);
state.x.mapInPlace(v -> v * embedScale);

// 2. per-layer embeddings (PLE): inp_per_layer[l] = (rmsnorm(proj(x) / sqrt(dim)) + tokEmbd[l]*sqrt(nEmbdPerLayer)) / sqrt(2)
// per_layer_token_embd is ~2.35B elements (too large for the int-indexed FloatTensor API), so it
// is addressed one embedding row at a time directly from its raw tensor entry.
ModelLoader.copyEmbeddingRow(weights.perLayerTokenEmbd, token, perLayerTotal, gs.perLayerInputs, 0);
final float perLayerTokEmbedScale = (float) Math.sqrt(nEmbdPerLayer);
gs.perLayerInputs.mapInPlace(v -> v * perLayerTokEmbedScale);

weights.perLayerModelProj.matmul(state.x, gs.perLayerProjScratch, perLayerTotal, dim);
final float perLayerProjScale = (float) (1.0 / Math.sqrt(dim));
gs.perLayerProjScratch.mapInPlace(v -> v * perLayerProjScale);
for (int l = 0; l < nLayers; l++) {
rmsnorm(gs.perLayerProjScratch, gs.perLayerProjScratch, weights.perLayerProjNorm, l * nEmbdPerLayer, nEmbdPerLayer, config.rmsNormEps());
}
final float perLayerInputScale = (float) (1.0 / Math.sqrt(2.0));
for (int i = 0; i < perLayerTotal; i++) {
float v = (gs.perLayerProjScratch.getFloat(i) + gs.perLayerInputs.getFloat(i)) * perLayerInputScale;
gs.perLayerInputs.setFloat(i, v);
}

// 3. transformer layers
for (int l = 0; l < nLayers; l++) {
final int curLayer = l;
final int headDim = config.headDim(l);
final boolean isSwa = config.isSwa(l);
final int qDim = nHead * headDim;
final int kvDim = nHeadKv * headDim;

FloatTensor freqCisReal = isSwa ? weights.freqCisRealSwa : weights.freqCisRealFull;
FloatTensor freqCisImag = isSwa ? weights.freqCisImagSwa : weights.freqCisImagFull;

// attn_norm
rmsnorm(state.xb, state.x, weights.attnNorm[l], 0, dim, config.rmsNormEps());

// Q projection, per-head Q-norm, RoPE
weights.wq[l].matmul(state.xb, state.q, qDim, dim);
for (int h = 0; h < nHead; h++) {
rmsnorm(state.q, state.q, weights.attnQNorm[l], h * headDim, headDim, config.rmsNormEps());
}
ropeRotateNeox(state.q, nHead, headDim, position, freqCisReal, freqCisImag);

// K/V: either compute and cache them here, or reuse an earlier layer's cache ("shared KV layers")
final int kvSrcLayer;
if (config.hasOwnKv(l)) {
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
for (int h = 0; h < nHeadKv; h++) {
rmsnorm(state.k, state.k, weights.attnKNorm[l], h * headDim, headDim, config.rmsNormEps());
rmsnormNoWeight(state.v, state.v, h * headDim, headDim, config.rmsNormEps());
}
ropeRotateNeox(state.k, nHeadKv, headDim, position, freqCisReal, freqCisImag);

state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
kvSrcLayer = l;
} else {
kvSrcLayer = config.kvReuseLayer(l);
}

// self-attention (causal; sliding-window layers additionally restrict to a local window)
final int windowStart = isSwa ? Math.max(0, position - config.slidingWindowSize() + 1) : 0;
Parallel.parallelFor(0, nHead, h -> {
int qOffset = h * headDim;
int attOffset = h * config.contextLength();
int kvHeadOffset = (h / kvMul) * headDim;

for (int t = windowStart; t <= position; t++) {
int kvOffset = t * kvDim + kvHeadOffset;
float score = state.q.dot(qOffset, state.keyCache[kvSrcLayer], kvOffset, headDim);
score *= attentionScale;
state.att.setFloat(attOffset + t, score);
}

state.att.softmaxInPlace(attOffset + windowStart, position - windowStart + 1);

int xbOffset = h * headDim;
state.xb.fillInPlace(xbOffset, headDim, 0f);
for (int t = windowStart; t <= position; t++) {
int kvOffset = t * kvDim + kvHeadOffset;
float a = state.att.getFloat(attOffset + t);
state.xb.saxpyInPlace(xbOffset, state.valueCache[kvSrcLayer], kvOffset, headDim, a);
}
});

// wo projection, post-attention norm, residual -> attn_out (kept in state.x)
weights.wo[curLayer].matmul(state.xb, state.xb2, dim, qDim);
rmsnorm(state.xb2, state.xb2, weights.attnPostNorm[curLayer], 0, dim, config.rmsNormEps());
state.x.addInPlace(state.xb2); // state.x now holds attn_out = inpL + post_attn_norm(attn(...))

// FFN (GeGLU: down(gelu(gate(x)) * up(x))), post-FFN norm, residual -> cur (kept in state.x)
rmsnorm(state.xb, state.x, weights.ffnNorm[curLayer], 0, dim, config.rmsNormEps());
weights.ffnGate[curLayer].matmul(state.xb, state.hb, config.feedForwardLength(curLayer), dim);
weights.ffnUp[curLayer].matmul(state.xb, state.hb2, config.feedForwardLength(curLayer), dim);
state.hb.mapInPlace(InferenceCore::gelu);
state.hb.multiplyInPlace(state.hb2);
weights.ffnDown[curLayer].matmul(state.hb, state.xb2, dim, config.feedForwardLength(curLayer));
rmsnorm(state.xb2, state.xb2, weights.ffnPostNorm[curLayer], 0, dim, config.rmsNormEps());
state.x.addInPlace(state.xb2); // state.x now holds cur = attn_out + post_ffn_norm(ffn(...))

// per-layer embedding (PLE): cur += per_layer_post_norm(proj(gelu(inp_gate(cur)) * inp_per_layer[l]))
weights.perLayerInpGate[curLayer].matmul(state.x, gs.perLayerGate, nEmbdPerLayer, dim);
gs.perLayerGate.mapInPlace(InferenceCore::gelu);
int peOffset = curLayer * nEmbdPerLayer;
for (int j = 0; j < nEmbdPerLayer; j++) {
gs.perLayerGate.setFloat(j, gs.perLayerGate.getFloat(j) * gs.perLayerInputs.getFloat(peOffset + j));
}
weights.perLayerProj[curLayer].matmul(gs.perLayerGate, gs.perLayerOut, dim, nEmbdPerLayer);
rmsnorm(gs.perLayerOut, gs.perLayerOut, weights.perLayerPostNorm[curLayer], 0, dim, config.rmsNormEps());
state.x.addInPlace(gs.perLayerOut);

// optional learned per-layer output scale
FloatTensor outScale = weights.layerOutputScale[curLayer];
if (outScale != null) {
final float scale = outScale.getFloat(0);
state.x.mapInPlace(v -> v * scale);
}
}

// final norm, classifier, and logit soft-capping: logits = softcap * tanh(logits / softcap)
rmsnorm(state.x, state.x, weights.outputNorm, 0, dim, config.rmsNormEps());
weights.outputWeight.matmul(state.x, state.logits, config.vocabularySize(), dim);

final float softcap = config.finalLogitSoftcapping();
if (softcap != 0.0f) {
final float invSoftcap = 1.0f / softcap;
state.logits.mapInPlace(v -> (float) Math.tanh(v * invSoftcap) * softcap);
}

return state.logits;
}

public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int token, int position) {
Phi3Configuration config = (Phi3Configuration) model.configuration();
Phi3StandardWeights weights = (Phi3StandardWeights) model.weights();
Expand Down
Loading
Loading