diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 3e3db3c5e35..e1c89896ba6 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -5527,7 +5527,7 @@ struct server_context_impl { SLT_WRN(slot, "%s\n", st1.str().c_str()); } - if (pos_min >= pos_min_thold) { + if (n_swa > 0 && pos_min >= pos_min_thold) { // search for a context checkpoint const auto it = std::find_if( slot.prompt.checkpoints.rbegin(), @@ -5660,9 +5660,17 @@ struct server_context_impl { SLT_TRC(slot, "cached n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - common_context_seq_rm(ctx_tgt, slot.id, p0, -1); + if (!llama_memory_seq_rm(llama_get_memory(ctx_tgt), slot.id, p0, -1)) { + SLT_WRN(slot, "partial seq_rm at p0=%d failed (recurrent backend cannot roll into cache); clearing slot and re-evaluating from scratch\n", p0); + slot.n_prompt_tokens_cache = 0; + slot.prompt.tokens.keep_first(0); + common_context_seq_rm(ctx_tgt, slot.id, 0, -1); + } + if (ctx_dft) { - common_context_seq_rm(ctx_dft.get(), slot.id, p0, -1); + if (!llama_memory_seq_rm(llama_get_memory(ctx_dft.get()), slot.id, p0, -1)) { + common_context_seq_rm(ctx_dft.get(), slot.id, 0, -1); + } } // If using an alora, there may be uncached tokens that come