Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
823 changes: 672 additions & 151 deletions bindings/python/quant.h

Large diffs are not rendered by default.

38 changes: 36 additions & 2 deletions bindings/python/quantcpp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@
)


class ChatContextOverflow(RuntimeError):
"""Raised when chat history exceeds the model's context window.

The C side has already auto-reset the session by the time this is
raised — the caller must trim its conversation history (drop the
oldest turns) and retry. Catching this is the supported way to
detect "we hit max_seq_len" without parsing log output.
"""


# -----------------------------------------------------------------------
# Model registry — small GGUF models auto-downloaded from HuggingFace
# -----------------------------------------------------------------------
Expand Down Expand Up @@ -394,6 +404,15 @@ def chat(self, prompt: str) -> Iterator[str]:

Falls back to ``generate()`` on older library builds without
``quant_chat`` symbol.

Raises
------
ChatContextOverflow
When the conversation history exceeds the model's context
window. The session has been auto-reset; the caller should
trim history and retry.
RuntimeError
On other generation failures (allocation, invalid state).
"""
self._ensure_open()
lib = get_lib()
Expand All @@ -414,6 +433,7 @@ def chat(self, prompt: str) -> Iterator[str]:
tokens = []
done = threading.Event()
error_box = [None]
rc_box = [0]

def _on_token(text_ptr, _user_data):
if text_ptr:
Expand All @@ -424,7 +444,8 @@ def _on_token(text_ptr, _user_data):
def _run():
try:
with self._lock:
lib.quant_chat(self._ctx, prompt.encode("utf-8"), cb, None)
rc_box[0] = lib.quant_chat(
self._ctx, prompt.encode("utf-8"), cb, None)
except Exception as e:
error_box[0] = e
finally:
Expand All @@ -448,6 +469,19 @@ def _run():
if error_box[0] is not None:
raise error_box[0]

# Surface generation failures from the C side. Previously these
# were silently swallowed: -2 (context overflow) and -1 (alloc
# failure) both produced empty token streams that callers could
# not distinguish from "the model decided to say nothing".
rc = rc_box[0]
if rc == -2:
raise ChatContextOverflow(
"conversation history exceeds the model's context window — "
"session has been reset, retry with shorter history"
)
if rc < 0:
raise RuntimeError(f"quant_chat failed with rc={rc}")

def reset_chat(self) -> None:
"""Reset the chat KV cache. Next chat() call starts fresh."""
self._ensure_open()
Expand Down Expand Up @@ -528,4 +562,4 @@ def load(path: str, **kwargs) -> Model:
return Model(path, **kwargs)


__all__ = ["Model", "load", "download", "__version__"]
__all__ = ["Model", "load", "download", "ChatContextOverflow", "__version__"]
51 changes: 45 additions & 6 deletions bindings/python/quantcpp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,24 +151,63 @@ def cmd_run(args):
print(tok, end="", flush=True)
print()
else:
from quantcpp import ChatContextOverflow
print("quantcpp \u2014 type your message, Ctrl+C to exit", file=sys.stderr)
# Multi-turn chat: accumulate history as ChatML so the model sees
# prior turns. m.chat() reuses the KV cache via prefix-match, so
# repeating the history is cheap (O(new tokens), not O(n^2)).
history = ""
# turns is a list of (user_msg, assistant_msg) pairs so we can
# trim from the front when we hit context overflow.
turns = []
def _build_history(extra_user=None):
parts = []
for u, a in turns:
parts.append(f"<|im_start|>user\n{u}<|im_end|>\n<|im_start|>assistant\n{a}<|im_end|>\n")
if extra_user is not None:
parts.append(f"<|im_start|>user\n{extra_user}<|im_end|>\n<|im_start|>assistant\n")
return "".join(parts)

try:
while True:
question = input("\nYou: ")
if not question.strip():
continue
history += f"<|im_start|>user\n{question}<|im_end|>\n<|im_start|>assistant\n"
print("AI: ", end="", flush=True)
reply_buf = []
for tok in m.chat(history):
print(tok, end="", flush=True)
reply_buf.append(tok)
# Retry loop: on context overflow, drop the oldest turn
# and try again. Without this, the C side resets the KV
# cache but Python's history still has the bloat, so
# every subsequent turn would loop back into overflow.
attempt = 0
while True:
history = _build_history(extra_user=question)
try:
for tok in m.chat(history):
print(tok, end="", flush=True)
reply_buf.append(tok)
break
except ChatContextOverflow:
if not turns:
print("\n[chat] message alone exceeds context window — try a shorter question.",
file=sys.stderr)
reply_buf = [] # nothing was emitted
break
dropped = turns.pop(0)
attempt += 1
print(f"\n[chat] context full \u2014 dropped oldest turn ({len(dropped[0])+len(dropped[1])} chars), retrying...",
file=sys.stderr)
# The session was already reset by the C side;
# retrying with the trimmed history will hit
# the slow path on this turn and the fast path
# again from the next turn onward.
if attempt > 8:
print("[chat] too many overflow retries, giving up on this turn.",
file=sys.stderr)
reply_buf = []
break
print()
history += "".join(reply_buf) + "<|im_end|>\n"
if reply_buf:
turns.append((question, "".join(reply_buf)))
except (KeyboardInterrupt, EOFError):
print("\nBye!", file=sys.stderr)

Expand Down
107 changes: 85 additions & 22 deletions quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -15695,17 +15695,23 @@ int tq_generate_continue(tq_model_t* model,
n_new = 1;
}

/* Sliding window: drop oldest prompt tokens if the new prompt would
* leave no room for max_tokens of generation. Keeps the most recent
* tokens. Forces full reprefill since the prefix shifted. */
/* Overflow check: reject prompts that won't fit. The previous
* behavior was to silently drop oldest tokens via a sliding window,
* but that desynced any cached_text the higher-level wrapper held
* (cached_text claimed the full prompt, while cached_tokens only
* had the truncated tail — next turn's text-prefix match would
* map text bytes to the wrong KV positions). Returning -2 lets the
* caller decide (reset chat, show error). */
int reserve = config->max_tokens > 0 ? config->max_tokens : 256;
int budget = max_prompt - reserve - 32;
if (budget < 64) budget = 64;
if (n_new > budget) {
int drop = n_new - budget;
memmove(new_tokens, new_tokens + drop, (size_t)budget * sizeof(int));
n_new = budget;
*n_cached_io = 0;
free(new_tokens);
if (getenv("TQ_CHAT_DEBUG")) {
fprintf(stderr, "[chat] OVERFLOW n_new=%d budget=%d max=%d\n",
n_new, budget, max_prompt);
}
return -2;
}

/* Find longest common prefix with the cached tokens.
Expand Down Expand Up @@ -15872,25 +15878,30 @@ typedef struct {
char* buf;
size_t len;
size_t cap;
int tainted; /* 1 if accumulation ever failed → buf is incomplete */
void (*user_cb)(const char*, void*);
void* user_data;
} chat_accum_t;

static void chat_accum_callback(const char* tok, void* u) {
chat_accum_t* ctx = (chat_accum_t*)u;
if (!tok) return;
/* Always pass through to the user's callback first — losing tokens
* from the user's stream because of an INTERNAL realloc failure is
* far worse than a stale cached_text on the next turn. */
if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data);
if (ctx->tainted) return;
size_t tlen = strlen(tok);
if (ctx->len + tlen + 1 > ctx->cap) {
size_t new_cap = (ctx->cap + tlen + 64) * 2;
char* nb = (char*)realloc(ctx->buf, new_cap);
if (!nb) return;
if (!nb) { ctx->tainted = 1; return; }
ctx->buf = nb;
ctx->cap = new_cap;
}
memcpy(ctx->buf + ctx->len, tok, tlen);
ctx->len += tlen;
ctx->buf[ctx->len] = '\0';
if (ctx->user_cb) ctx->user_cb(tok, ctx->user_data);
}

int tq_generate_chat_text(tq_model_t* model,
Expand Down Expand Up @@ -15918,7 +15929,7 @@ int tq_generate_chat_text(tq_model_t* model,
}
}

chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0,
chat_accum_t accum = { .buf = NULL, .len = 0, .cap = 0, .tainted = 0,
.user_cb = config->on_token,
.user_data = config->user_data };
void (*orig_cb)(const char*, void*) = config->on_token;
Expand Down Expand Up @@ -15982,12 +15993,35 @@ int tq_generate_chat_text(tq_model_t* model,
matched_text_len, n_suffix);
}

/* Generation loop */
/* Generation loop — mirrors tq_generate_continue including
* rep_penalty (which the fast path was silently dropping). */
int vocab_size = model->config.vocab_size;
int n_cached = *n_cached_io;
int pos = n_cached;
int prev_token = n_cached > 0 ? cached[n_cached - 1] : 1;

float rep_penalty = config->rep_penalty;
int rep_window = config->rep_window;
if (rep_window > 64) rep_window = 64;
int recent_tokens[64];
int recent_count = 0;
for (int i = (n_cached > rep_window ? n_cached - rep_window : 0); i < n_cached; i++) {
recent_tokens[recent_count % 64] = cached[i];
recent_count++;
}
if (rep_penalty > 1.0f) {
int window = recent_count < rep_window ? recent_count : rep_window;
for (int r = 0; r < window; r++) {
int idx = (recent_count - 1 - r) % 64;
if (idx < 0) idx += 64;
int tok = recent_tokens[idx];
if (tok >= 0 && tok < vocab_size && state->logits) {
if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty;
else state->logits[tok] *= rep_penalty;
}
}
}

uint64_t rng_state = config->rng_seed
? (uint64_t)config->rng_seed : (uint64_t)time(NULL);
int next_token = tq_sample_topp(state->logits, vocab_size,
Expand Down Expand Up @@ -16033,9 +16067,24 @@ int tq_generate_chat_text(tq_model_t* model,
pos++;
generated++;

if (rep_penalty > 1.0f) {
int window = recent_count < rep_window ? recent_count : rep_window;
for (int r = 0; r < window; r++) {
int idx = (recent_count - 1 - r) % 64;
if (idx < 0) idx += 64;
int tok = recent_tokens[idx];
if (tok >= 0 && tok < vocab_size) {
if (state->logits[tok] > 0) state->logits[tok] /= rep_penalty;
else state->logits[tok] *= rep_penalty;
}
}
}

next_token = tq_sample_topp(state->logits, vocab_size,
config->temperature, config->top_p,
&rng_state);
recent_tokens[recent_count % 64] = next_token;
recent_count++;
}

if (output && output_size > 0) {
Expand All @@ -16051,21 +16100,35 @@ int tq_generate_chat_text(tq_model_t* model,
output, output_size);
}

update_cache:
config->on_token = orig_cb;
config->user_data = orig_ud;

/* Update cached_text only if we know the KV state corresponds
* EXACTLY to (prompt + accum.buf):
* - generated >= 0: generation didn't error out
* - !accum.tainted: every generated token was captured
* On any failure, clear cached_text so the next call falls through
* to the slow path with a clean slate instead of trusting bytes
* that don't match the KV cache. */
if (cached_text_io) {
size_t plen = strlen(prompt);
size_t glen = accum.len;
size_t new_len = plen + glen;
char* nt = (char*)malloc(new_len + 1);
if (nt) {
memcpy(nt, prompt, plen);
if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen);
nt[new_len] = '\0';
if (*cached_text_io) free(*cached_text_io);
*cached_text_io = nt;
if (generated < 0 || accum.tainted) {
if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; }
} else {
size_t plen = strlen(prompt);
size_t glen = accum.len;
size_t new_len = plen + glen;
char* nt = (char*)malloc(new_len + 1);
if (nt) {
memcpy(nt, prompt, plen);
if (glen > 0 && accum.buf) memcpy(nt + plen, accum.buf, glen);
nt[new_len] = '\0';
if (*cached_text_io) free(*cached_text_io);
*cached_text_io = nt;
} else {
/* malloc failed → can't refresh cached_text. Clearing it
* is safer than leaving the previous (now stale) value. */
if (*cached_text_io) { free(*cached_text_io); *cached_text_io = NULL; }
}
}
}
if (accum.buf) free(accum.buf);
Expand Down
Loading
Loading