diff --git a/bindings/python/quant.h b/bindings/python/quant.h index 49fea77..36cbbb2 100644 --- a/bindings/python/quant.h +++ b/bindings/python/quant.h @@ -62,6 +62,13 @@ int quant_generate(quant_ctx* ctx, const char* prompt, void (*on_token)(const char* text, void* user_data), void* user_data); +// Multi-turn chat with KV cache reuse (O(delta) per turn instead of O(n^2)). +// Subsequent calls only re-prefill the suffix that diverges from history. +// Pass prompt = NULL to reset the chat session. Returns tokens generated. +int quant_chat(quant_ctx* ctx, const char* prompt, + void (*on_token)(const char* text, void* user_data), + void* user_data); + // Generate and return full response as string. Caller must free(). char* quant_ask(quant_ctx* ctx, const char* prompt); @@ -202,8 +209,6 @@ static inline int clock_gettime(int id, struct timespec* ts) { // Section 1: Types and Specs (from tq_types.h, tq_spec.h) // ============================================================================ - - /* Cross-language static assert: works in both C11 and C++11/17 */ #ifdef __cplusplus #define TQ_STATIC_ASSERT(cond, msg) static_assert(cond, msg) @@ -219,8 +224,6 @@ static inline int clock_gettime(int id, struct timespec* ts) { #define TQ_PI_2 1.5707963267948966f #endif - - /* ============================================================ * Constants * ============================================================ */ @@ -398,8 +401,6 @@ typedef struct { int enable_recompression;/* Tier 1 → Tier 2 re-compression */ } tq_progressive_config_t; - - /* TurboQuant KV cache block: RHT + Lloyd-Max codebook + QJL residual * 3-bit variant: 2-bit codebook (4 levels) + 1-bit QJL sign hash * Block covers TQ_BK elements (128). @@ -469,12 +470,6 @@ TQ_CHECK_SIZE(block_tq_turbo_kv_4b, 8 + TQ_BK * 3 / 8 + TQ_BK / 8); TQ_CHECK_SIZE(block_tq_turbo_kv_1b, 8 + TQ_BK / 8); TQ_CHECK_SIZE(block_tq_turbo_kv_2b, 8 + TQ_BK / 8 + TQ_BK / 8); - - - - - - /* Format specification — version-aware, ONNX-inspired */ #define TQ_SPEC_VERSION 1 @@ -500,18 +495,10 @@ typedef struct { uint8_t flags; /* TQ_FLAG_* bitmask */ } tq_format_spec_t; - - - - // ============================================================================ // Section 2: Engine Types (from tq_engine.h) // ============================================================================ - - - - /* ============================================================ * Model configuration * ============================================================ */ @@ -886,6 +873,7 @@ typedef struct { int n_threads; float rep_penalty; /* repetition penalty (default: 1.1, 1.0 = disabled) */ int rep_window; /* how many recent tokens to penalize (default: 32) */ + unsigned long long rng_seed; /* sampling seed (default: 42, 0 = use 42 for back-compat) */ /* Callback for streaming output */ void (*on_token)(const char* text, void* user_data); void* user_data; @@ -1123,9 +1111,6 @@ void tq_tp_run(void* (*fn)(void*), void** args, int n_tasks); /* Max threads supported by thread pool */ #define TQ_TP_MAX 16 - - - // ============================================================================ // Section 3: GGUF Types (from tq_gguf.h) // ============================================================================ @@ -1143,10 +1128,6 @@ void tq_tp_run(void* (*fn)(void*), void** args, int n_tasks); * directly into TurboQuant inference engine. */ - - - - /* ============================================================ * GGUF format constants * ============================================================ */ @@ -1462,14 +1443,10 @@ int tq_metal_moe_forward( const int* up_types, /* per-expert up quant types, NULL = use weight_type */ const int* down_types); /* per-expert down quant types, NULL = use weight_type */ - - - // ============================================================================ // Section 4: Internal API (from turboquant.h) // ============================================================================ - /** * TurboQuant.cpp — Cross-platform KV cache compression library * @@ -1477,9 +1454,6 @@ int tq_metal_moe_forward( * Zero external dependencies (libc/libm only). */ - - - /* ============================================================ * Version * ============================================================ */ @@ -1753,21 +1727,28 @@ void tq_progressive_free(tq_progressive_t* p); tq_progressive_config_t tq_progressive_default_config(void); - - - - // ============================================================================ // Section 5: quant_ctx struct definition // ============================================================================ - struct quant_ctx { tq_model_t* model; tq_state_t* state; tq_tokenizer_t* tokenizer; tq_gen_config_t config; - int n_ctx_tokens; /* number of tokens currently in KV cache */ + int n_ctx_tokens; /* number of tokens currently in KV cache */ + /* Prefix-match cache for chat history reuse: + * stores the actual token IDs that are committed to the KV cache, + * so the next quant_generate() can skip the matching prefix and + * only prefill the diverging suffix. Critical for chat mode where + * each turn re-sends the entire conversation history. */ + int* cached_tokens; + int n_cached; + int cached_capacity; + /* Text-prefix cache: stores the entire prompt + generated response + * text from the last call, allowing the next call to bypass BPE + * re-tokenization issues by matching at the byte level. */ + char* cached_text; }; // ============================================================================ @@ -1788,7 +1769,6 @@ struct quant_ctx { * - Random signs decorrelate channels across different blocks */ - #ifdef __ARM_NEON #include #endif @@ -1902,7 +1882,6 @@ void tq_rht_inverse(float* data, int n, uint32_t seed) { */ /* Generic reference — no compiler-specific pragmas */ - /* ---------- FP16 helpers ---------- */ static uint16_t uni_fp32_to_fp16(float v) { @@ -2285,7 +2264,6 @@ void tq_uniform_3b_attention_ref(const float* query, const void* kv, // Section 8: Type Traits (from tq_traits.c) // ============================================================================ - /* Stub implementations for excluded quantization types (polar, qjl, turbo, mixed) */ static void tq_stub_quantize(const float* src, void* dst, int n) { (void)src; (void)dst; (void)n; @@ -2583,7 +2561,6 @@ tq_type tq_type_from_name(const char* name) { * No external dependencies — libc/libm only. */ - #ifdef __ARM_NEON #include #endif @@ -2617,7 +2594,6 @@ static struct { static int g_n_threads = 1; - static void* tp_worker(void* arg) { int id = (int)(intptr_t)arg; int my_gen = 0; @@ -4173,6 +4149,7 @@ tq_gen_config_t tq_default_gen_config(void) { config.n_threads = 1; config.rep_penalty = 1.1f; config.rep_window = 32; + config.rng_seed = 42ULL; config.on_token = NULL; config.user_data = NULL; return config; @@ -4388,8 +4365,6 @@ void tq_matmul_1bit(float* out, const float* x, * SPDX-License-Identifier: MIT */ - - #ifdef _WIN32 #else #endif @@ -5098,8 +5073,6 @@ const tq_gguf_tensor_t* tq_gguf_find_tensor(const tq_gguf_ctx_t* ctx, const char * Pure C11, no external dependencies. */ - - #if defined(__ARM_NEON) || defined(__ARM_NEON__) #include #define TQ_HAS_NEON 1 @@ -7174,7 +7147,6 @@ void tq_metal_batch_end_if_available(void) { * Also supports the legacy llama2.c binary tokenizer format as fallback. */ - /* Global for qsort comparator (vocab index sorting) */ static char** g_vocab_for_sort; static int cmp_vocab_idx(const void* a, const void* b) { @@ -8033,32 +8005,75 @@ tq_tokenizer_t* tq_load_tokenizer_from_gguf(const void* gguf_ctx_ptr) { } } - /* Load merges if available */ + /* Build sorted indices BEFORE merge parsing so str_lookup() can use + * binary search instead of O(n) linear scan. For 248K vocab with + * ~50K merges (3 lookups each), this turns a ~10 s init into ~100 ms. */ + tok->sorted_indices = (int*)malloc(vocab_size * sizeof(int)); + if (tok->sorted_indices) { + for (int i = 0; i < (int)vocab_size; i++) tok->sorted_indices[i] = i; + g_vocab_for_sort = tok->vocab; + qsort(tok->sorted_indices, vocab_size, sizeof(int), cmp_vocab_idx); + } + + /* Load and parse merges if available. + * GGUF stores merges as a string array of "tok_a tok_b" pairs. + * We need to look up token IDs and build (id_a, id_b, id_merged) triples + * so the BPE encoder can use them. */ int64_t merges_idx = tq_gguf_find_key(gguf, "tokenizer.ggml.merges"); if (merges_idx >= 0) { const tq_gguf_kv_t* mkv = &gguf->kv[merges_idx]; if (mkv->type == TQ_GGUF_TYPE_ARRAY && mkv->value.array.elem_type == TQ_GGUF_TYPE_STRING) { - /* Parse merge rules: "token_a token_b" -> find IDs, store as merge pairs */ - uint64_t n_merges = mkv->value.array.count; - tok->n_merges = (int)n_merges; - tok->merge_pairs = (int*)malloc(n_merges * 3 * sizeof(int)); + uint64_t n_merges_total = mkv->value.array.count; + tok->merge_pairs = (int*)malloc(n_merges_total * 3 * sizeof(int)); + tok->n_merges = 0; if (tok->merge_pairs) { - memset(tok->merge_pairs, 0, n_merges * 3 * sizeof(int)); + tq_gguf_string_t* merge_strings = (tq_gguf_string_t*)mkv->value.array.data; + for (uint64_t mi = 0; mi < n_merges_total; mi++) { + if (!merge_strings[mi].str || merge_strings[mi].len == 0) continue; + + /* Copy merge string and split on space: "tok_a tok_b" */ + char buf[2048]; + int slen = (int)merge_strings[mi].len; + if (slen >= (int)sizeof(buf)) continue; + memcpy(buf, merge_strings[mi].str, (size_t)slen); + buf[slen] = '\0'; + + char* sep = strchr(buf, ' '); + if (!sep) continue; + *sep = '\0'; + const char* str_a = buf; + const char* str_b = sep + 1; + + /* Build merged string: concatenation of tok_a + tok_b */ + char merged[2048]; + int la = (int)strlen(str_a); + int lb = (int)strlen(str_b); + if (la + lb >= (int)sizeof(merged)) continue; + memcpy(merged, str_a, (size_t)la); + memcpy(merged + la, str_b, (size_t)lb); + merged[la + lb] = '\0'; + + /* Look up token IDs via binary search (sorted_indices built above) */ + int id_a = str_lookup(tok, str_a); + int id_b = str_lookup(tok, str_b); + int id_merged = str_lookup(tok, merged); + + if (id_a >= 0 && id_b >= 0 && id_merged >= 0) { + tok->merge_pairs[tok->n_merges * 3 + 0] = id_a; + tok->merge_pairs[tok->n_merges * 3 + 1] = id_b; + tok->merge_pairs[tok->n_merges * 3 + 2] = id_merged; + /* Priority: earlier merges in GGUF = higher priority */ + tok->scores[id_merged] = (float)(n_merges_total - mi); + tok->n_merges++; + } + } + fprintf(stderr, "tq_load_tokenizer_from_gguf: parsed %d/%d merges\n", + tok->n_merges, (int)n_merges_total); } } } - /* Build sorted indices for encoding (binary search by string). - * Use qsort for O(n log n) instead of insertion sort O(n²) — critical - * for 248K vocab where insertion sort would take minutes. */ - tok->sorted_indices = (int*)malloc(vocab_size * sizeof(int)); - if (tok->sorted_indices) { - for (int i = 0; i < (int)vocab_size; i++) tok->sorted_indices[i] = i; - g_vocab_for_sort = tok->vocab; - qsort(tok->sorted_indices, vocab_size, sizeof(int), cmp_vocab_idx); - } - fprintf(stderr, "tq_load_tokenizer_from_gguf: loaded %d tokens (max_len=%d)\n", tok->vocab_size, tok->max_token_len); return tok; @@ -8476,7 +8491,6 @@ const char* tq_decode(const tq_tokenizer_t* tok, int prev_token, int token) { * Supports hybrid architectures (e.g., Qwen3.5 DeltaNet + self_attn). */ - #ifdef _WIN32 #else #endif @@ -9939,18 +9953,11 @@ static tq_model_t* tq_load_safetensors(const char* path) { free(tensors); - /* Qwen3.5 RMSNorm adjustment: Qwen3_5RMSNorm computes - * output = norm(x) * (1.0 + weight), NOT norm(x) * weight. - * We bake the "+1" into the weight so tq_rmsnorm can stay as - * out = x * rsqrt * weight. - * - * This applies to: input_layernorm, post_attention_layernorm, - * model.norm, q_norm, k_norm. - * It does NOT apply to: linear_attn.norm (Qwen3_5RMSNormGated - * uses plain weight without +1). - * - * We detect Qwen3.5 by the presence of DeltaNet layers. */ - if (model->config.delta_n_heads > 0) { + /* Qwen3.5 (DeltaNet hybrid) RMSNorm adjustment. + * Only for non-GGUF models (raw checkpoints). GGUF files from + * llama.cpp already have +1 baked in by the converter. + * Qwen2/Qwen3 use standard RMSNorm and never need +1. */ + if (model->config.delta_n_heads > 0 && !model->gguf_ctx) { int dim_h = model->config.hidden_dim; int head_dim_h = model->config.head_dim; @@ -9979,7 +9986,7 @@ static tq_model_t* tq_load_safetensors(const char* path) { for (int i = 0; i < dim_h; i++) model->output_norm[i] += 1.0f; } - fprintf(stderr, "tq_load_model: applied Qwen3.5 RMSNorm +1 weight adjustment\n"); + fprintf(stderr, "tq_load_model: applied Qwen RMSNorm +1 weight adjustment\n"); } /* Gemma3 RMSNorm adjustment: same (1+w) scaling as Qwen3.5 */ @@ -12143,8 +12150,13 @@ tq_model_t* tq_load_gguf(const char* path) { } const size_t MAX_FP32_BYTES = (size_t)16 * 1024 * 1024 * 1024ULL; /* 16 GB */ - /* TQ_NO_Q4=1 disables Q4 recompression → use direct GGUF dequant for better quality */ + /* TQ_NO_Q4=1 disables Q4 recompression → use direct GGUF dequant for better quality. + * Can be set via environment variable or compile-time define (useful for WASM). */ +#ifdef TQ_NO_Q4 + if (1) { +#else if (getenv("TQ_NO_Q4")) { +#endif fprintf(stderr, "tq_load_gguf: TQ_NO_Q4 set — skipping Q4 conversion, using GGUF on-the-fly dequant\n"); goto skip_q4_conversion; } @@ -12893,7 +12905,6 @@ void tq_quantize_weights_1bit(tq_model_t* model) { * -> residual add */ - /* Unified Q2/1-bit matmul dispatch. * When model->use_1bit_weights, Q2 fields contain sign bits + norms, * dispatched to tq_matmul_1bit (FP32 input required). @@ -15153,7 +15164,6 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) { } } - /* Increment profile token count if profiling is active */ if (s->profile_kv) { s->profile_kv_count++; @@ -15204,7 +15214,6 @@ float* tq_forward(tq_model_t* model, tq_state_t* s, int token, int pos) { * - Full generation loop with streaming callback */ - /* ============================================================ * Argmax sampling: return token with highest logit * ============================================================ */ @@ -15425,7 +15434,12 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, fprintf(stderr, "\n"); } - /* Prefill: process all prompt tokens */ + /* Prefill: process all prompt tokens. + * NOTE: No emscripten_sleep() here — the call stack during tq_forward() + * is too deep for ASYNCIFY to unwind (matmul → SIMD kernels). Adding + * sleep here breaks ASYNCIFY for the entire generate call, including + * the token streaming callback. The browser shows "Thinking..." via + * requestAnimationFrame before entering this blocking prefill. */ for (int i = 0; i < n_prompt; i++) { tq_forward(model, state, prompt_tokens[i], i); } @@ -15460,9 +15474,11 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, } } - /* Sample first generated token */ + /* Sample first generated token. The seed is configurable via + * config->rng_seed (default 42); 0 falls back to 42 so existing + * callers that never set rng_seed get bit-identical behaviour. */ int pos = n_prompt; - unsigned long long rng_state = 42; + unsigned long long rng_state = config->rng_seed ? config->rng_seed : 42ULL; int next_token = tq_sample_topp(state->logits, vocab_size, config->temperature, config->top_p, &rng_state); @@ -15627,6 +15643,498 @@ int tq_generate(tq_model_t* model, tq_tokenizer_t* tokenizer, return generated; } +/* ============================================================================ + * tq_generate_continue — reuse an existing tq_state_t across calls. + * + * Unlike tq_generate (which allocates and frees its own state on every call), + * this function takes a caller-managed state plus a record of which tokens + * are currently committed to the KV cache. It computes the longest common + * prefix between the cached tokens and the new prompt, then prefills only + * the diverging suffix. After generation, *cached_tokens_out and + * *n_cached_out are updated to reflect the new cache contents. + * + * This turns chat mode from O(n^2) (full re-prefill every turn) into + * O(delta) (only the new tokens of each turn). + * + * Returns the number of tokens generated, or -1 on error. + * ============================================================================ */ +static int tq_lcp_int(const int* a, int na, const int* b, int nb) { + int lim = na < nb ? na : nb; + int i = 0; + while (i < lim && a[i] == b[i]) i++; + return i; +} + +int tq_generate_continue(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + int** cached_tokens_io, /* in/out: cached prefix tokens */ + int* n_cached_io, /* in/out: cached count */ + int* cached_capacity_io, /* in/out: allocated capacity */ + char* output, int output_size) { + if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io) { + return -1; + } + + /* Heap-allocated prompt token buffer (was a 4096-stack array, which + * silently truncated after ~10 turns of accumulating chat history). + * Cap at the model's max_seq_len so we never exceed KV bounds. */ + int max_prompt = model->config.max_seq_len > 0 + ? model->config.max_seq_len : 4096; + int* new_tokens = (int*)malloc((size_t)max_prompt * sizeof(int)); + if (!new_tokens) return -1; + int n_new = 0; + if (tokenizer && prompt) { + int add_bos = (model->config.model_type == 1) ? 1 : 0; + n_new = tq_encode(tokenizer, prompt, new_tokens, max_prompt, add_bos); + } + if (n_new <= 0) { + new_tokens[0] = (model->config.model_type == 1) ? 2 : 1; + n_new = 1; + } + + /* Overflow check: reject prompts that won't fit. The previous + * behavior was to silently drop oldest tokens via a sliding window, + * but that desynced any cached_text the higher-level wrapper held + * (cached_text claimed the full prompt, while cached_tokens only + * had the truncated tail — next turn's text-prefix match would + * map text bytes to the wrong KV positions). Returning -2 lets the + * caller decide (reset chat, show error). */ + int reserve = config->max_tokens > 0 ? config->max_tokens : 256; + int budget = max_prompt - reserve - 32; + if (budget < 64) budget = 64; + if (n_new > budget) { + free(new_tokens); + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat] OVERFLOW n_new=%d budget=%d max=%d\n", + n_new, budget, max_prompt); + } + return -2; + } + + /* Find longest common prefix with the cached tokens. + * If the new prompt is just an extension of the cached one, we skip + * everything up to the LCP and only prefill the suffix. */ + int n_cached = *n_cached_io; + int* cached_tokens = *cached_tokens_io; + + int lcp = tq_lcp_int(cached_tokens, n_cached, new_tokens, n_new); + + /* Prefill the new suffix [lcp, n_new) */ + for (int i = lcp; i < n_new; i++) { + tq_forward(model, state, new_tokens[i], i); + } + int pos = n_new; + int prefill_tokens = n_new - lcp; + int prefix_hit = lcp; + + /* Save the n_new prompt into the cache buffer (will append generated + * tokens below). Grow the buffer if needed. */ + int needed_cap = n_new + config->max_tokens + 16; + if (*cached_capacity_io < needed_cap) { + int new_cap = needed_cap < 4096 ? 4096 : needed_cap; + int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); + if (!nb) { free(new_tokens); return -1; } + *cached_tokens_io = nb; + *cached_capacity_io = new_cap; + cached_tokens = nb; + } + memcpy(cached_tokens, new_tokens, (size_t)n_new * sizeof(int)); + *n_cached_io = n_new; + n_cached = n_new; + + /* --- generation loop (mirrors tq_generate's loop) --- */ + int vocab_size = model->config.vocab_size; + float rep_penalty = config->rep_penalty; + int rep_window = config->rep_window; + if (rep_window > 64) rep_window = 64; + int recent_tokens[64]; + int recent_count = 0; + for (int i = (n_new > rep_window ? n_new - rep_window : 0); i < n_new; i++) { + recent_tokens[recent_count % 64] = new_tokens[i]; + recent_count++; + } + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size && state->logits) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + uint64_t rng_state = config->rng_seed ? (uint64_t)config->rng_seed + : (uint64_t)time(NULL); + int next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + + int generated = 0; + int output_pos = 0; + int prev_token = new_tokens[n_new - 1]; + + int eos_tokens[] = { + 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046, + }; + int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]); + + while (generated < config->max_tokens) { + int is_eos = 0; + for (int e = 0; e < n_eos; e++) { + if (next_token == eos_tokens[e]) { is_eos = 1; break; } + } + if (is_eos) break; + + if (pos >= model->config.max_seq_len) break; /* simple stop, no shift */ + + /* Decode + stream */ + if (tokenizer) { + const char* piece = tq_decode(tokenizer, prev_token, next_token); + int should_stop = 0; + if (piece) { + if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") || + strstr(piece, "<|start_header_id|>")) { + should_stop = 1; piece = ""; + } + } + if (should_stop) break; + int piece_len = (int)strlen(piece ? piece : ""); + if (config->on_token && piece) config->on_token(piece, config->user_data); + if (output && piece && output_pos + piece_len < output_size - 1) { + memcpy(output + output_pos, piece, piece_len); + output_pos += piece_len; + } + } + + /* Append generated token to cache record */ + if (n_cached < *cached_capacity_io) { + cached_tokens[n_cached++] = next_token; + *n_cached_io = n_cached; + } + + prev_token = next_token; + tq_forward(model, state, next_token, pos); + pos++; + generated++; + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + recent_tokens[recent_count % 64] = next_token; + recent_count++; + } + + if (output && output_size > 0) { + output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; + } + + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, + "[chat] prefix_hit=%d prefill=%d generated=%d cached=%d\n", + prefix_hit, prefill_tokens, generated, *n_cached_io); + } + + free(new_tokens); + return generated; +} + +/* ============================================================================ + * tq_generate_chat_text — text-prefix matching for chat reuse + * + * Solves the BPE re-tokenization issue: when the model generates response + * tokens via sample_topp, those token IDs may not match what tq_encode() + * produces from the same response text in the next turn's prompt. The + * token-level LCP in tq_generate_continue truncates at that boundary. + * + * This function tracks the *text* of the last prompt+response. On the next + * call, if the new prompt starts with cached_text byte-for-byte, the entire + * cached state is valid — tokenize ONLY the new SUFFIX text and prefill + * those tokens at positions [n_cached..]. No LCP, no truncation. + * + * Pass cached_text_io == NULL to disable text-prefix tracking. + * ============================================================================ */ + +typedef struct { + char* buf; + size_t len; + size_t cap; + int tainted; /* 1 if accumulation ever failed → buf is incomplete */ + void (*user_cb)(const char*, void*); + void* user_data; +} chat_accum_t; + +static void chat_accum_callback(const char* tok, void* u) { + chat_accum_t* ctx = (chat_accum_t*)u; + if (!tok) return; + /* Always pass through to the user's callback first — losing tokens + * from the user's stream because of an INTERNAL realloc failure is + * far worse than a stale cached_text on the next turn. */ + if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data); + if (ctx->tainted) return; + size_t tlen = strlen(tok); + if (ctx->len + tlen + 1 > ctx->cap) { + size_t new_cap = (ctx->cap + tlen + 64) * 2; + char* nb = (char*)realloc(ctx->buf, new_cap); + if (!nb) { ctx->tainted = 1; return; } + ctx->buf = nb; + ctx->cap = new_cap; + } + memcpy(ctx->buf + ctx->len, tok, tlen); + ctx->len += tlen; + ctx->buf[ctx->len] = '\0'; +} + +int tq_generate_chat_text(tq_model_t* model, + tq_tokenizer_t* tokenizer, + tq_state_t* state, + const char* prompt, + tq_gen_config_t* config, + char** cached_text_io, + int** cached_tokens_io, + int* n_cached_io, + int* cached_capacity_io, + char* output, int output_size) { + if (!model || !state || !config || !cached_tokens_io || !n_cached_io || !cached_capacity_io || !prompt) { + return -1; + } + + int matched_text_len = 0; + int prefix_pos = 0; + + if (cached_text_io && *cached_text_io && *n_cached_io > 0) { + size_t cached_len = strlen(*cached_text_io); + if (cached_len > 0 && strncmp(*cached_text_io, prompt, cached_len) == 0) { + matched_text_len = (int)cached_len; + prefix_pos = *n_cached_io; + } + } + + chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, .tainted = 0, + .user_cb = config->on_token, + .user_data = config->user_data }; + void (*orig_cb)(const char*, void*) = config->on_token; + void* orig_ud = config->user_data; + config->on_token = chat_accum_callback; + config->user_data = &accum; + + int generated = 0; + + if (matched_text_len > 0) { + const char* suffix = prompt + matched_text_len; + int max_prompt = model->config.max_seq_len > 0 + ? model->config.max_seq_len : 4096; + int* suffix_toks = (int*)malloc((size_t)max_prompt * sizeof(int)); + if (!suffix_toks) { + config->on_token = orig_cb; config->user_data = orig_ud; + return -1; + } + int n_suffix = 0; + if (*suffix != '\0') { + n_suffix = tq_encode(tokenizer, suffix, suffix_toks, max_prompt, 0); + if (n_suffix < 0) n_suffix = 0; + } + + /* Context overflow: return -2 instead of falling back to a + * dangerous full reprefill. The state still has stale KV at + * positions [n_new..prefix_pos) that would corrupt later tokens. + * Caller should reset the chat and retry. */ + int reserve = config->max_tokens > 0 ? config->max_tokens : 256; + if (prefix_pos + n_suffix + reserve + 32 > max_prompt) { + free(suffix_toks); + config->on_token = orig_cb; config->user_data = orig_ud; + if (accum.buf) free(accum.buf); + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, + "[chat-text] OVERFLOW prefix_pos=%d n_suffix=%d reserve=%d max=%d\n", + prefix_pos, n_suffix, reserve, max_prompt); + } + return -2; + } + + int needed = prefix_pos + n_suffix + reserve + 16; + if (*cached_capacity_io < needed) { + int new_cap = needed < 4096 ? 4096 : needed; + int* nb = (int*)realloc(*cached_tokens_io, (size_t)new_cap * sizeof(int)); + if (!nb) { free(suffix_toks); config->on_token = orig_cb; config->user_data = orig_ud; return -1; } + *cached_tokens_io = nb; + *cached_capacity_io = new_cap; + } + + int* cached = *cached_tokens_io; + for (int i = 0; i < n_suffix; i++) { + cached[prefix_pos + i] = suffix_toks[i]; + tq_forward(model, state, suffix_toks[i], prefix_pos + i); + } + *n_cached_io = prefix_pos + n_suffix; + free(suffix_toks); + + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat-text] FAST text_match=%d new_suffix_tokens=%d\n", + matched_text_len, n_suffix); + } + + /* Generation loop — mirrors tq_generate_continue including + * rep_penalty (which the fast path was silently dropping). */ + int vocab_size = model->config.vocab_size; + int n_cached = *n_cached_io; + int pos = n_cached; + int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1; + + float rep_penalty = config->rep_penalty; + int rep_window = config->rep_window; + if (rep_window > 64) rep_window = 64; + int recent_tokens[64]; + int recent_count = 0; + for (int i = (n_cached > rep_window ? n_cached - rep_window : 0); i < n_cached; i++) { + recent_tokens[recent_count % 64] = cached[i]; + recent_count++; + } + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size && state->logits) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + uint64_t rng_state = config->rng_seed + ? (uint64_t)config->rng_seed : (uint64_t)time(NULL); + int next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + + int output_pos = 0; + int eos_tokens[] = { 1, 2, 106, 128001, 128006, 128007, 128008, 128009, 248044, 248046 }; + int n_eos = sizeof(eos_tokens) / sizeof(eos_tokens[0]); + + while (generated < config->max_tokens) { + int is_eos = 0; + for (int e = 0; e < n_eos; e++) { + if (next_token == eos_tokens[e]) { is_eos = 1; break; } + } + if (is_eos) break; + if (pos >= model->config.max_seq_len) break; + + const char* piece = tokenizer ? tq_decode(tokenizer, prev_token, next_token) : ""; + int should_stop = 0; + if (piece) { + if (strstr(piece, "<|im_end|>") || strstr(piece, "<|eot_id|>") || + strstr(piece, "<|start_header_id|>")) { + should_stop = 1; piece = ""; + } + } + if (should_stop) break; + + int piece_len = (int)strlen(piece ? piece : ""); + if (config->on_token && piece) config->on_token(piece, config->user_data); + if (output && piece && output_pos + piece_len < output_size - 1) { + memcpy(output + output_pos, piece, piece_len); + output_pos += piece_len; + } + + if (n_cached < *cached_capacity_io) { + cached[n_cached++] = next_token; + *n_cached_io = n_cached; + } + + prev_token = next_token; + tq_forward(model, state, next_token, pos); + pos++; + generated++; + + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + + next_token = tq_sample_topp(state->logits, vocab_size, + config->temperature, config->top_p, + &rng_state); + recent_tokens[recent_count % 64] = next_token; + recent_count++; + } + + if (output && output_size > 0) { + output[output_pos < output_size ? output_pos : output_size - 1] = '\0'; + } + } else { + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat-text] SLOW no text-prefix match, full tokenize\n"); + } + generated = tq_generate_continue( + model, tokenizer, state, prompt, config, + cached_tokens_io, n_cached_io, cached_capacity_io, + output, output_size); + } + + config->on_token = orig_cb; + config->user_data = orig_ud; + + /* Update cached_text only if we know the KV state corresponds + * EXACTLY to (prompt + accum.buf): + * - generated >= 0: generation didn't error out + * - !accum.tainted: every generated token was captured + * On any failure, clear cached_text so the next call falls through + * to the slow path with a clean slate instead of trusting bytes + * that don't match the KV cache. */ + if (cached_text_io) { + if (generated < 0 || accum.tainted) { + if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; } + } else { + size_t plen = strlen(prompt); + size_t glen = accum.len; + size_t new_len = plen + glen; + char* nt = (char*)malloc(new_len + 1); + if (nt) { + memcpy(nt, prompt, plen); + if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen); + nt[new_len] = '\0'; + if (*cached_text_io) free(*cached_text_io); + *cached_text_io = nt; + } else { + /* malloc failed → can't refresh cached_text. Clearing it + * is safer than leaving the previous (now stale) value. */ + if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; } + } + } + } + if (accum.buf) free(accum.buf); + + return generated; +} // ============================================================================ @@ -15862,22 +16370,7 @@ void quant_free_string(char* str) { if (str) free(str); } -/* ================================================================ - * Context persistence — save/load KV cache to disk - * - * File format (binary, little-endian): - * magic: 4 bytes "QKVC" - * version: uint32 (1) - * n_layers: uint32 - * kv_dim: uint32 (n_kv_heads * head_dim) - * max_seq: uint32 - * n_tokens: uint32 (number of filled positions) - * kv_type: uint32 (TQ_TYPE_* enum or TQ_TYPE_COUNT for fp32) - * has_fp16v: uint32 (1 if value_cache_fp16 is used) - * reserved: 32 bytes (future use) - * data: raw KV cache bytes - * ================================================================ */ - +/* Context persistence: QKVC format (64-byte header + raw KV data) */ int quant_save_context(quant_ctx* ctx, const char* path) { if (!ctx || !ctx->state || !path) return -1; FILE* fp = fopen(path, "wb"); @@ -15886,29 +16379,17 @@ int quant_save_context(quant_ctx* ctx, const char* path) { tq_state_t* s = ctx->state; tq_model_config_t* c = &ctx->model->config; int kv_dim = c->n_kv_heads * c->head_dim; - - /* Header */ fwrite("QKVC", 1, 4, fp); - uint32_t version = 1; - uint32_t nl = (uint32_t)c->n_layers; - uint32_t kd = (uint32_t)kv_dim; - uint32_t ms = (uint32_t)c->max_seq_len; - uint32_t nt = (uint32_t)ctx->n_ctx_tokens; - uint32_t kt = (uint32_t)s->kv_quant_type; - uint32_t hfp16 = s->value_cache_fp16 ? 1 : 0; - fwrite(&version, 4, 1, fp); - fwrite(&nl, 4, 1, fp); - fwrite(&kd, 4, 1, fp); - fwrite(&ms, 4, 1, fp); - fwrite(&nt, 4, 1, fp); - fwrite(&kt, 4, 1, fp); - fwrite(&hfp16, 4, 1, fp); - char reserved[32] = {0}; - fwrite(reserved, 1, 32, fp); + uint32_t hdr[7] = { 1, (uint32_t)c->n_layers, (uint32_t)kv_dim, + (uint32_t)c->max_seq_len, (uint32_t)ctx->n_ctx_tokens, + (uint32_t)s->kv_quant_type, s->value_cache_fp16 ? 1u : 0u }; + fwrite(hdr, 4, 7, fp); + char reserved[32] = {0}; fwrite(reserved, 1, 32, fp); + uint32_t nl = hdr[1], nt = hdr[4], kt = hdr[5]; /* KV data: write only the filled portion (nt tokens) */ for (uint32_t l = 0; l < nl; l++) { - size_t layer_stride = (size_t)ms * kv_dim; + size_t layer_stride = (size_t)c->max_seq_len * kv_dim; /* Key cache: FP32 or quantized */ if (s->key_cache) { fwrite(s->key_cache + l * layer_stride, sizeof(float), @@ -15916,7 +16397,7 @@ int quant_save_context(quant_ctx* ctx, const char* path) { } if (s->quant_key_cache && kt < TQ_TYPE_COUNT) { size_t blk_sz = tq_type_type_size(kt); - uint8_t* qbase = (uint8_t*)s->quant_key_cache + l * (size_t)ms * blk_sz; + uint8_t* qbase = (uint8_t*)s->quant_key_cache + l * (size_t)c->max_seq_len * blk_sz; fwrite(qbase, blk_sz, nt, fp); } /* Value cache: FP32 or FP16 */ @@ -15925,7 +16406,7 @@ int quant_save_context(quant_ctx* ctx, const char* path) { (size_t)nt * kv_dim, fp); } if (s->value_cache_fp16) { - size_t layer_stride16 = (size_t)ms * kv_dim; + size_t layer_stride16 = (size_t)c->max_seq_len * kv_dim; fwrite(s->value_cache_fp16 + l * layer_stride16, sizeof(uint16_t), (size_t)nt * kv_dim, fp); } @@ -15942,37 +16423,16 @@ int quant_load_context(quant_ctx* ctx, const char* path) { FILE* fp = fopen(path, "rb"); if (!fp) return -1; - /* Read and validate header */ char magic[4]; - if (fread(magic, 1, 4, fp) != 4 || memcmp(magic, "QKVC", 4) != 0) { - fclose(fp); return -1; - } - uint32_t version, nl, kd, ms, nt, kt, hfp16; - fread(&version, 4, 1, fp); - fread(&nl, 4, 1, fp); - fread(&kd, 4, 1, fp); - fread(&ms, 4, 1, fp); - fread(&nt, 4, 1, fp); - fread(&kt, 4, 1, fp); - fread(&hfp16, 4, 1, fp); - char reserved[32]; - fread(reserved, 1, 32, fp); - + if (fread(magic, 1, 4, fp) != 4 || memcmp(magic, "QKVC", 4) != 0) { fclose(fp); return -1; } + uint32_t hdr[7]; fread(hdr, 4, 7, fp); + char reserved[32]; fread(reserved, 1, 32, fp); + uint32_t nl = hdr[1], nt = hdr[4], kt = hdr[5]; tq_state_t* s = ctx->state; tq_model_config_t* c = &ctx->model->config; int kv_dim = c->n_kv_heads * c->head_dim; - - /* Validate compatibility */ - if (nl != (uint32_t)c->n_layers || kd != (uint32_t)kv_dim) { - fprintf(stderr, "quant_load_context: model mismatch (layers %u vs %d, kv_dim %u vs %d)\n", - nl, c->n_layers, kd, kv_dim); - fclose(fp); return -1; - } - if (nt > (uint32_t)c->max_seq_len) { - fprintf(stderr, "quant_load_context: saved %u tokens > max_seq_len %d\n", - nt, c->max_seq_len); - fclose(fp); return -1; - } + if (nl != (uint32_t)c->n_layers || hdr[2] != (uint32_t)kv_dim) { fclose(fp); return -1; } + if (nt > (uint32_t)c->max_seq_len) { fclose(fp); return -1; } /* Read KV data */ for (uint32_t l = 0; l < nl; l++) { @@ -16009,9 +16469,70 @@ void quant_free_ctx(quant_ctx* ctx) { if (!ctx) return; tq_free_state(ctx->state); tq_free_tokenizer(ctx->tokenizer); + if (ctx->cached_tokens) free(ctx->cached_tokens); + if (ctx->cached_text) free(ctx->cached_text); free(ctx); } +/* ---------------------------------------------------------------------- + * quant_chat — chat-mode generate that reuses the KV cache across calls. + * + * Unlike quant_generate (which resets the state on every call and so makes + * each turn O(history_length)), quant_chat keeps the state alive between + * calls. The first call to quant_chat() prefills and generates as normal. + * Subsequent calls compute the longest common prefix between the new prompt + * and the previously processed tokens, skip the matched prefix, and only + * prefill the diverging suffix. + * + * Result: turn N's prefill cost is O(new tokens this turn), not + * O(total history). Chat experience matches what users expect from ollama. + * + * Reset behavior: pass NULL prompt to wipe the cache (start a new chat). + * Returns the number of tokens generated, or -1 on error. + * ---------------------------------------------------------------------- */ +int quant_chat(quant_ctx* ctx, const char* prompt, + void (*on_token)(const char* text, void* user_data), + void* user_data) { + if (!ctx || !ctx->model) return -1; + + /* NULL prompt = reset the chat (clear cache + state) */ + if (!prompt) { + tq_free_state(ctx->state); + ctx->state = tq_create_state_ex(&ctx->model->config, + ctx->config.kv_type, + ctx->config.value_quant_bits); + if (ctx->cached_tokens) free(ctx->cached_tokens); + ctx->cached_tokens = NULL; + ctx->n_cached = 0; + ctx->cached_capacity = 0; + ctx->n_ctx_tokens = 0; + if (ctx->cached_text) { free(ctx->cached_text); ctx->cached_text = NULL; } + return 0; + } + + if (!ctx->state) { + ctx->state = tq_create_state_ex(&ctx->model->config, + ctx->config.kv_type, + ctx->config.value_quant_bits); + if (!ctx->state) return -1; + } + + ctx->config.on_token = on_token; + ctx->config.user_data = user_data; + + char output[65536]; + /* Use the text-prefix path so chat replays bypass BPE re-tokenization + * issues. Falls back to token-LCP path if text prefix doesn't match. */ + int n = tq_generate_chat_text( + ctx->model, ctx->tokenizer, ctx->state, prompt, &ctx->config, + &ctx->cached_text, + &ctx->cached_tokens, &ctx->n_cached, &ctx->cached_capacity, + output, sizeof(output)); + + if (n > 0) ctx->n_ctx_tokens = ctx->n_cached; + return n; +} + void quant_free_model(quant_model* model) { tq_free_model((tq_model_t*)model); } diff --git a/bindings/python/quantcpp/__init__.py b/bindings/python/quantcpp/__init__.py index 88bdb6e..e9559ef 100644 --- a/bindings/python/quantcpp/__init__.py +++ b/bindings/python/quantcpp/__init__.py @@ -35,6 +35,16 @@ ) +class ChatContextOverflow(RuntimeError): + """Raised when chat history exceeds the model's context window. + + The C side has already auto-reset the session by the time this is + raised — the caller must trim its conversation history (drop the + oldest turns) and retry. Catching this is the supported way to + detect "we hit max_seq_len" without parsing log output. + """ + + # ----------------------------------------------------------------------- # Model registry — small GGUF models auto-downloaded from HuggingFace # ----------------------------------------------------------------------- @@ -394,6 +404,15 @@ def chat(self, prompt: str) -> Iterator[str]: Falls back to ``generate()`` on older library builds without ``quant_chat`` symbol. + + Raises + ------ + ChatContextOverflow + When the conversation history exceeds the model's context + window. The session has been auto-reset; the caller should + trim history and retry. + RuntimeError + On other generation failures (allocation, invalid state). """ self._ensure_open() lib = get_lib() @@ -414,6 +433,7 @@ def chat(self, prompt: str) -> Iterator[str]: tokens = [] done = threading.Event() error_box = [None] + rc_box = [0] def _on_token(text_ptr, _user_data): if text_ptr: @@ -424,7 +444,8 @@ def _on_token(text_ptr, _user_data): def _run(): try: with self._lock: - lib.quant_chat(self._ctx, prompt.encode("utf-8"), cb, None) + rc_box[0] = lib.quant_chat( + self._ctx, prompt.encode("utf-8"), cb, None) except Exception as e: error_box[0] = e finally: @@ -448,6 +469,19 @@ def _run(): if error_box[0] is not None: raise error_box[0] + # Surface generation failures from the C side. Previously these + # were silently swallowed: -2 (context overflow) and -1 (alloc + # failure) both produced empty token streams that callers could + # not distinguish from "the model decided to say nothing". + rc = rc_box[0] + if rc == -2: + raise ChatContextOverflow( + "conversation history exceeds the model's context window — " + "session has been reset, retry with shorter history" + ) + if rc < 0: + raise RuntimeError(f"quant_chat failed with rc={rc}") + def reset_chat(self) -> None: """Reset the chat KV cache. Next chat() call starts fresh.""" self._ensure_open() @@ -528,4 +562,4 @@ def load(path: str, **kwargs) -> Model: return Model(path, **kwargs) -__all__ = ["Model", "load", "download", "__version__"] +__all__ = ["Model", "load", "download", "ChatContextOverflow", "__version__"] diff --git a/bindings/python/quantcpp/cli.py b/bindings/python/quantcpp/cli.py index 954d7fc..830204f 100644 --- a/bindings/python/quantcpp/cli.py +++ b/bindings/python/quantcpp/cli.py @@ -151,24 +151,63 @@ def cmd_run(args): print(tok, end="", flush=True) print() else: + from quantcpp import ChatContextOverflow print("quantcpp \u2014 type your message, Ctrl+C to exit", file=sys.stderr) # Multi-turn chat: accumulate history as ChatML so the model sees # prior turns. m.chat() reuses the KV cache via prefix-match, so # repeating the history is cheap (O(new tokens), not O(n^2)). - history = "" + # turns is a list of (user_msg, assistant_msg) pairs so we can + # trim from the front when we hit context overflow. + turns = [] + def _build_history(extra_user=None): + parts = [] + for u, a in turns: + parts.append(f"<|im_start|>user\n{u}<|im_end|>\n<|im_start|>assistant\n{a}<|im_end|>\n") + if extra_user is not None: + parts.append(f"<|im_start|>user\n{extra_user}<|im_end|>\n<|im_start|>assistant\n") + return "".join(parts) + try: while True: question = input("\nYou: ") if not question.strip(): continue - history += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n" print("AI: ", end="", flush=True) reply_buf = [] - for tok in m.chat(history): - print(tok, end="", flush=True) - reply_buf.append(tok) + # Retry loop: on context overflow, drop the oldest turn + # and try again. Without this, the C side resets the KV + # cache but Python's history still has the bloat, so + # every subsequent turn would loop back into overflow. + attempt = 0 + while True: + history = _build_history(extra_user=question) + try: + for tok in m.chat(history): + print(tok, end="", flush=True) + reply_buf.append(tok) + break + except ChatContextOverflow: + if not turns: + print("\n[chat] message alone exceeds context window — try a shorter question.", + file=sys.stderr) + reply_buf = [] # nothing was emitted + break + dropped = turns.pop(0) + attempt += 1 + print(f"\n[chat] context full \u2014 dropped oldest turn ({len(dropped[0])+len(dropped[1])} chars), retrying...", + file=sys.stderr) + # The session was already reset by the C side; + # retrying with the trimmed history will hit + # the slow path on this turn and the fast path + # again from the next turn onward. + if attempt > 8: + print("[chat] too many overflow retries, giving up on this turn.", + file=sys.stderr) + reply_buf = [] + break print() - history += "".join(reply_buf) + "<|im_end|>\n" + if reply_buf: + turns.append((question, "".join(reply_buf))) except (KeyboardInterrupt, EOFError): print("\nBye!", file=sys.stderr) diff --git a/quant.h b/quant.h index 9a2691c..36cbbb2 100644 --- a/quant.h +++ b/quant.h @@ -15695,17 +15695,23 @@ int tq_generate_continue(tq_model_t* model, n_new = 1; } - /* Sliding window: drop oldest prompt tokens if the new prompt would - * leave no room for max_tokens of generation. Keeps the most recent - * tokens. Forces full reprefill since the prefix shifted. */ + /* Overflow check: reject prompts that won't fit. The previous + * behavior was to silently drop oldest tokens via a sliding window, + * but that desynced any cached_text the higher-level wrapper held + * (cached_text claimed the full prompt, while cached_tokens only + * had the truncated tail — next turn's text-prefix match would + * map text bytes to the wrong KV positions). Returning -2 lets the + * caller decide (reset chat, show error). */ int reserve = config->max_tokens > 0 ? config->max_tokens : 256; int budget = max_prompt - reserve - 32; if (budget < 64) budget = 64; if (n_new > budget) { - int drop = n_new - budget; - memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int)); - n_new = budget; - *n_cached_io = 0; + free(new_tokens); + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat] OVERFLOW n_new=%d budget=%d max=%d\n", + n_new, budget, max_prompt); + } + return -2; } /* Find longest common prefix with the cached tokens. @@ -15872,6 +15878,7 @@ typedef struct { char* buf; size_t len; size_t cap; + int tainted; /* 1 if accumulation ever failed → buf is incomplete */ void (*user_cb)(const char*, void*); void* user_data; } chat_accum_t; @@ -15879,18 +15886,22 @@ typedef struct { static void chat_accum_callback(const char* tok, void* u) { chat_accum_t* ctx = (chat_accum_t*)u; if (!tok) return; + /* Always pass through to the user's callback first — losing tokens + * from the user's stream because of an INTERNAL realloc failure is + * far worse than a stale cached_text on the next turn. */ + if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data); + if (ctx->tainted) return; size_t tlen = strlen(tok); if (ctx->len + tlen + 1 > ctx->cap) { size_t new_cap = (ctx->cap + tlen + 64) * 2; char* nb = (char*)realloc(ctx->buf, new_cap); - if (!nb) return; + if (!nb) { ctx->tainted = 1; return; } ctx->buf = nb; ctx->cap = new_cap; } memcpy(ctx->buf + ctx->len, tok, tlen); ctx->len += tlen; ctx->buf[ctx->len] = '\0'; - if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data); } int tq_generate_chat_text(tq_model_t* model, @@ -15918,7 +15929,7 @@ int tq_generate_chat_text(tq_model_t* model, } } - chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, + chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, .tainted = 0, .user_cb = config->on_token, .user_data = config->user_data }; void (*orig_cb)(const char*, void*) = config->on_token; @@ -15982,12 +15993,35 @@ int tq_generate_chat_text(tq_model_t* model, matched_text_len, n_suffix); } - /* Generation loop */ + /* Generation loop — mirrors tq_generate_continue including + * rep_penalty (which the fast path was silently dropping). */ int vocab_size = model->config.vocab_size; int n_cached = *n_cached_io; int pos = n_cached; int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1; + float rep_penalty = config->rep_penalty; + int rep_window = config->rep_window; + if (rep_window > 64) rep_window = 64; + int recent_tokens[64]; + int recent_count = 0; + for (int i = (n_cached > rep_window ? n_cached - rep_window : 0); i < n_cached; i++) { + recent_tokens[recent_count % 64] = cached[i]; + recent_count++; + } + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size && state->logits) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + uint64_t rng_state = config->rng_seed ? (uint64_t)config->rng_seed : (uint64_t)time(NULL); int next_token = tq_sample_topp(state->logits, vocab_size, @@ -16033,9 +16067,24 @@ int tq_generate_chat_text(tq_model_t* model, pos++; generated++; + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + next_token = tq_sample_topp(state->logits, vocab_size, config->temperature, config->top_p, &rng_state); + recent_tokens[recent_count % 64] = next_token; + recent_count++; } if (output && output_size > 0) { @@ -16051,21 +16100,35 @@ int tq_generate_chat_text(tq_model_t* model, output, output_size); } -update_cache: config->on_token = orig_cb; config->user_data = orig_ud; + /* Update cached_text only if we know the KV state corresponds + * EXACTLY to (prompt + accum.buf): + * - generated >= 0: generation didn't error out + * - !accum.tainted: every generated token was captured + * On any failure, clear cached_text so the next call falls through + * to the slow path with a clean slate instead of trusting bytes + * that don't match the KV cache. */ if (cached_text_io) { - size_t plen = strlen(prompt); - size_t glen = accum.len; - size_t new_len = plen + glen; - char* nt = (char*)malloc(new_len + 1); - if (nt) { - memcpy(nt, prompt, plen); - if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen); - nt[new_len] = '\0'; - if (*cached_text_io) free(*cached_text_io); - *cached_text_io = nt; + if (generated < 0 || accum.tainted) { + if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; } + } else { + size_t plen = strlen(prompt); + size_t glen = accum.len; + size_t new_len = plen + glen; + char* nt = (char*)malloc(new_len + 1); + if (nt) { + memcpy(nt, prompt, plen); + if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen); + nt[new_len] = '\0'; + if (*cached_text_io) free(*cached_text_io); + *cached_text_io = nt; + } else { + /* malloc failed → can't refresh cached_text. Clearing it + * is safer than leaving the previous (now stale) value. */ + if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; } + } } } if (accum.buf) free(accum.buf); diff --git a/src/engine/tq_generate.c b/src/engine/tq_generate.c index 1f45a35..0211a83 100644 --- a/src/engine/tq_generate.c +++ b/src/engine/tq_generate.c @@ -653,20 +653,23 @@ int tq_generate_continue(tq_model_t* model, n_new = 1; } - /* Sliding window: if the new prompt + reserved generation room would - * exceed max_seq_len, drop the oldest tokens from the front of the - * prompt. We keep the most recent (max_seq_len - max_tokens - 32) tokens. - * Note: this discards conversation history; ideally callers send - * pre-trimmed prompts, but this prevents catastrophic failure. */ + /* Overflow check: reject prompts that won't fit. The previous + * behavior was to silently drop oldest tokens via a sliding window, + * but that desynced any cached_text the higher-level wrapper held + * (cached_text claimed the full prompt, while cached_tokens only + * had the truncated tail — next turn's text-prefix match would + * map text bytes to the wrong KV positions). Returning -2 lets the + * caller decide (reset chat, show error). */ int reserve = config->max_tokens > 0 ? config->max_tokens : 256; int budget = max_prompt - reserve - 32; if (budget < 64) budget = 64; if (n_new > budget) { - int drop = n_new - budget; - memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int)); - n_new = budget; - /* Force full reprefill since the prefix shifted */ - *n_cached_io = 0; + free(new_tokens); + if (getenv("TQ_CHAT_DEBUG")) { + fprintf(stderr, "[chat] OVERFLOW n_new=%d budget=%d max=%d\n", + n_new, budget, max_prompt); + } + return -2; } int n_cached = *n_cached_io; @@ -835,6 +838,7 @@ typedef struct { char* buf; size_t len; size_t cap; + int tainted; /* 1 if accumulation ever failed → buf is incomplete */ void (*user_cb)(const char*, void*); void* user_data; } chat_accum_t; @@ -842,18 +846,22 @@ typedef struct { static void chat_accum_callback(const char* tok, void* u) { chat_accum_t* ctx = (chat_accum_t*)u; if (!tok) return; + /* Always pass through to the user's callback first — losing tokens + * from the user's stream because of an INTERNAL realloc failure is + * far worse than a stale cached_text on the next turn. */ + if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data); + if (ctx->tainted) return; size_t tlen = strlen(tok); if (ctx->len + tlen + 1 > ctx->cap) { size_t new_cap = (ctx->cap + tlen + 64) * 2; char* nb = (char*)realloc(ctx->buf, new_cap); - if (!nb) return; + if (!nb) { ctx->tainted = 1; return; } ctx->buf = nb; ctx->cap = new_cap; } memcpy(ctx->buf + ctx->len, tok, tlen); ctx->len += tlen; ctx->buf[ctx->len] = '\0'; - if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data); } int tq_generate_chat_text(tq_model_t* model, @@ -897,7 +905,7 @@ int tq_generate_chat_text(tq_model_t* model, /* Wrap user callback to capture generated text into a buffer for the * next call's cached_text update. */ - chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, + chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, .tainted = 0, .user_cb = config->on_token, .user_data = config->user_data }; void (*orig_cb)(const char*, void*) = config->on_token; @@ -971,12 +979,36 @@ int tq_generate_chat_text(tq_model_t* model, matched_text_len, n_suffix); } - /* --- Run generation loop directly --- */ + /* --- Run generation loop directly. Mirrors tq_generate_continue + * including rep_penalty (the fast path was silently dropping + * it before, leaving rep_penalty inconsistent across turns). */ int vocab_size = model->config.vocab_size; int n_cached = *n_cached_io; int pos = n_cached; int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1; + float rep_penalty = config->rep_penalty; + int rep_window = config->rep_window; + if (rep_window > 64) rep_window = 64; + int recent_tokens[64]; + int recent_count = 0; + for (int i = (n_cached > rep_window ? n_cached - rep_window : 0); i < n_cached; i++) { + recent_tokens[recent_count % 64] = cached[i]; + recent_count++; + } + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size && state->logits) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + unsigned long long rng_state = config->rng_seed ? (unsigned long long)config->rng_seed : (unsigned long long)time(NULL); int next_token = tq_sample_topp(state->logits, vocab_size, @@ -1022,9 +1054,24 @@ int tq_generate_chat_text(tq_model_t* model, pos++; generated++; + if (rep_penalty > 1.0f) { + int window = recent_count < rep_window ? recent_count : rep_window; + for (int r = 0; r < window; r++) { + int idx = (recent_count - 1 - r) % 64; + if (idx < 0) idx += 64; + int tok = recent_tokens[idx]; + if (tok >= 0 && tok < vocab_size) { + if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty; + else state->logits[tok] *= rep_penalty; + } + } + } + next_token = tq_sample_topp(state->logits, vocab_size, config->temperature, config->top_p, &rng_state); + recent_tokens[recent_count % 64] = next_token; + recent_count++; } if (output && output_size > 0) { @@ -1041,24 +1088,36 @@ int tq_generate_chat_text(tq_model_t* model, output, output_size); } -update_cache: /* Restore the original callback before returning to caller */ config->on_token = orig_cb; config->user_data = orig_ud; - /* Update cached_text = prompt + generated text. The next call can - * fast-path against this if its prompt starts with this string. */ + /* Update cached_text only if we know the KV state corresponds + * EXACTLY to (prompt + accum.buf): + * - generated >= 0: generation didn't error out + * - !accum.tainted: every generated token was captured + * On any failure, clear cached_text so the next call falls through + * to the slow path with a clean slate instead of trusting bytes + * that don't match the KV cache. */ if (cached_text_io) { - size_t plen = strlen(prompt); - size_t glen = accum.len; - size_t new_len = plen + glen; - char* nt = (char*)malloc(new_len + 1); - if (nt) { - memcpy(nt, prompt, plen); - if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen); - nt[new_len] = '\0'; - if (*cached_text_io) free(*cached_text_io); - *cached_text_io = nt; + if (generated < 0 || accum.tainted) { + if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; } + } else { + size_t plen = strlen(prompt); + size_t glen = accum.len; + size_t new_len = plen + glen; + char* nt = (char*)malloc(new_len + 1); + if (nt) { + memcpy(nt, prompt, plen); + if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen); + nt[new_len] = '\0'; + if (*cached_text_io) free(*cached_text_io); + *cached_text_io = nt; + } else { + /* malloc failed → can't refresh cached_text. Clearing it + * is safer than leaving the previous (now stale) value. */ + if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; } + } } } if (accum.buf) free(accum.buf); diff --git a/src/server/tq_server.c b/src/server/tq_server.c index 81db519..711557b 100644 --- a/src/server/tq_server.c +++ b/src/server/tq_server.c @@ -109,6 +109,12 @@ typedef struct { int cached_capacity; char* cached_text; /* prompt + generated, for text-prefix matching */ long last_used; /* monotonic counter for LRU */ + /* Track the kv_type / value_quant_bits used to allocate kv_state. + * If a later request reuses this session id with different params, + * we must rebuild the state — the cached KV blocks are formatted + * for the original config and would be misinterpreted otherwise. */ + tq_type kv_type; + int value_quant_bits; } kv_session_t; struct tq_server { @@ -123,7 +129,8 @@ struct tq_server { }; /* Find or allocate a session by id. Caller holds inference_mutex. - * Returns a pointer into server->sessions. Never NULL (LRU evicts). */ + * Returns a pointer into server->sessions, or NULL on allocation failure + * (caller must check and respond with HTTP 500). */ static kv_session_t* get_or_create_session(tq_server_t* server, const char* sid, tq_type kv_type, @@ -141,8 +148,36 @@ static kv_session_t* get_or_create_session(tq_server_t* server, continue; } if (strncmp(server->sessions[i].id, sid, SESSION_ID_MAX) == 0) { - server->sessions[i].last_used = server->session_clock; - return &server->sessions[i]; + kv_session_t* hit = &server->sessions[i]; + hit->last_used = server->session_clock; + /* If the client switched kv_type / value_quant_bits between + * turns, the cached KV blocks are formatted for the OLD + * config. We must rebuild — reusing the state would + * misinterpret quantized blocks and produce garbage. */ + if (hit->kv_type != kv_type || + hit->value_quant_bits != value_quant_bits) { + fprintf(stderr, "[server] session %s: kv_type/vq_bits changed, rebuilding state\n", hit->id); + if (hit->kv_state) tq_free_state(hit->kv_state); + if (hit->cached_tokens) free(hit->cached_tokens); + if (hit->cached_text) free(hit->cached_text); + hit->kv_state = tq_create_state_ex( + &server->config.model->config, kv_type, value_quant_bits); + if (!hit->kv_state) { + /* Free state failed → mark slot empty so we don't + * leave a half-baked entry that future calls would + * NULL-deref. */ + fprintf(stderr, "[server] tq_create_state_ex failed (rebuild) for session %s\n", hit->id); + memset(hit, 0, sizeof(*hit)); + return NULL; + } + hit->cached_tokens = NULL; + hit->n_cached = 0; + hit->cached_capacity = 0; + hit->cached_text = NULL; + hit->kv_type = kv_type; + hit->value_quant_bits = value_quant_bits; + } + return hit; } if (server->sessions[i].last_used < lru_time) { lru_time = server->sessions[i].last_used; @@ -163,6 +198,17 @@ static kv_session_t* get_or_create_session(tq_server_t* server, strncpy(s->id, sid, SESSION_ID_MAX - 1); s->kv_state = tq_create_state_ex( &server->config.model->config, kv_type, value_quant_bits); + if (!s->kv_state) { + /* tq_create_state_ex returned NULL (OOM, bad config). Clear the + * slot id so the slot looks empty again, otherwise the next + * call with the same sid would find this entry and dereference + * a NULL kv_state. */ + fprintf(stderr, "[server] tq_create_state_ex failed for session %s\n", sid); + memset(s, 0, sizeof(*s)); + return NULL; + } + s->kv_type = kv_type; + s->value_quant_bits = value_quant_bits; s->last_used = server->session_clock; return s; } @@ -779,13 +825,22 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod kv_session_t* sess = get_or_create_session(server, req.session_id, gen_cfg.kv_type, gen_cfg.value_quant_bits); - int gen_rc = tq_generate_chat_text(server->config.model, server->config.tokenizer, - sess->kv_state, req.prompt, &gen_cfg, - &sess->cached_text, - &sess->cached_tokens, &sess->n_cached, - &sess->cached_capacity, - output, sizeof(output)); - if (gen_rc == -2) { + int gen_rc; + if (!sess) { + /* tq_create_state_ex failed inside get_or_create_session. + * Synthesize an error event in the SSE stream so the client + * doesn't see a happy "stop" with empty content. */ + gen_rc = -1; + LOG_ERROR("Session allocation failed"); + } else { + gen_rc = tq_generate_chat_text(server->config.model, server->config.tokenizer, + sess->kv_state, req.prompt, &gen_cfg, + &sess->cached_text, + &sess->cached_tokens, &sess->n_cached, + &sess->cached_capacity, + output, sizeof(output)); + } + if (gen_rc == -2 && sess) { /* Context overflow — auto-reset session and surface error. * Client should retry with a shorter conversation history. */ LOG_ERROR("Session %s: context overflow, auto-reset", sess->id); @@ -797,7 +852,37 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod if (sess->cached_text) { free(sess->cached_text); sess->cached_text = NULL; } } - /* Send final chunk with finish_reason */ + /* Send final chunk. finish_reason: "stop" on success, "error" + * on -1, "length" on -2 (overflow). The previous code always + * sent "stop" even when generation errored, leaving clients + * thinking the model decided to produce zero tokens. */ + const char* finish_reason = "stop"; + if (gen_rc == -2) finish_reason = "length"; + else if (gen_rc < 0) finish_reason = "error"; + + if (gen_rc < 0) { + /* Emit an error delta so OpenAI-compatible clients can see + * what went wrong. Most clients surface the delta content. */ + char err_chunk[SSE_CHUNK_SIZE]; + const char* msg = (gen_rc == -2) + ? "context overflow — session reset, retry with shorter history" + : "internal error during generation"; + snprintf(err_chunk, sizeof(err_chunk), + "{" + "\"id\":\"%s\"," + "\"object\":\"chat.completion.chunk\"," + "\"created\":%ld," + "\"model\":\"%s\"," + "\"choices\":[{" + "\"index\":0," + "\"delta\":{\"content\":\"[%s]\"}," + "\"finish_reason\":null" + "}]" + "}", + completion_id, (long)time(NULL), model_id, msg); + send_sse_event(fd, err_chunk); + } + char final_chunk[SSE_CHUNK_SIZE]; snprintf(final_chunk, sizeof(final_chunk), "{" @@ -808,14 +893,15 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod "\"choices\":[{" "\"index\":0," "\"delta\":{}," - "\"finish_reason\":\"stop\"" + "\"finish_reason\":\"%s\"" "}]" "}", - completion_id, (long)time(NULL), model_id); + completion_id, (long)time(NULL), model_id, finish_reason); send_sse_event(fd, final_chunk); send_sse_event(fd, "[DONE]"); - LOG_INFO("Streaming complete: %d tokens", sse_ctx.token_count); + LOG_INFO("Streaming complete: %d tokens (rc=%d)", + sse_ctx.token_count, gen_rc); } else { /* --- Non-streaming --- */ @@ -828,6 +914,16 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod kv_session_t* sess = get_or_create_session(server, req.session_id, gen_cfg.kv_type, gen_cfg.value_quant_bits); + if (!sess) { + LOG_ERROR("Session allocation failed"); + free(collect.buf); + pthread_mutex_unlock(&server->inference_mutex); + free_chat_request(&req); + send_json(fd, 500, "Internal Server Error", + "{\"error\":{\"message\":\"Failed to allocate KV state for session\"," + "\"type\":\"server_error\",\"code\":\"session_alloc_failed\"}}"); + return; + } int gen_rc = tq_generate_chat_text(server->config.model, server->config.tokenizer, sess->kv_state, req.prompt, &gen_cfg, &sess->cached_text, @@ -852,6 +948,20 @@ static void handle_chat_completions(tq_server_t* server, int fd, const char* bod "\"type\":\"context_overflow\",\"code\":\"context_full\"}}"); return; } + if (gen_rc < 0) { + /* Other error (-1: invalid args, OOM during prefill, etc.). + * The previous code fell through and sent HTTP 200 with an + * empty content string, which is indistinguishable from a + * deliberate empty completion. Return 500 instead. */ + LOG_ERROR("Session %s: generation failed (rc=%d)", sess->id, gen_rc); + free(collect.buf); + pthread_mutex_unlock(&server->inference_mutex); + free_chat_request(&req); + send_json(fd, 500, "Internal Server Error", + "{\"error\":{\"message\":\"Generation failed (allocation error or invalid state)\"," + "\"type\":\"server_error\",\"code\":\"generation_failed\"}}"); + return; + } const char* content = collect.buf ? collect.buf : ""; diff --git a/wasm/quant.wasm b/wasm/quant.wasm index 061f952..f018484 100755 Binary files a/wasm/quant.wasm and b/wasm/quant.wasm differ diff --git a/wasm/quant_wasm.c b/wasm/quant_wasm.c index 281fd31..98a0316 100644 --- a/wasm/quant_wasm.c +++ b/wasm/quant_wasm.c @@ -61,6 +61,14 @@ static void on_token_sync(const char* text, void* ud) { EMSCRIPTEN_KEEPALIVE int wasm_load_model(const char* path) { js_on_status("Loading model..."); + /* Reset generation state on load — if a previous run was interrupted + * (page reload mid-stream, JS error in the token callback), the + * busy flag would otherwise be stuck at 1 and every subsequent + * generate call would early-return -1 forever. */ + g_generating = 0; + g_output_pos = 0; + g_output[0] = '\0'; + g_stream_count = 0; if (g_model) { quant_free_model(g_model); g_model = NULL; } if (g_ctx) { quant_free_ctx(g_ctx); g_ctx = NULL; }