@@ -61,8 +61,9 @@ Llama::Llama() :
6161 _max_tokens(0 ),
6262 _log_level(GGML_LOG_LEVEL_CONT),
6363 _n_gpu_layers(0 ),
64- _n_past (0 ),
64+ _n_system_tokens (0 ),
6565 _is_gemma4(false ),
66+ _sampler_dirty(false ),
6667 _seed(LLAMA_DEFAULT_SEED) {
6768 llama_log_set ([](enum ggml_log_level level, const char *text, void *user_data) {
6869 Llama *llama = (Llama *)user_data;
@@ -99,8 +100,9 @@ Llama::Llama(Llama &&other) noexcept
99100 , _max_tokens(other._max_tokens)
100101 , _log_level(other._log_level)
101102 , _n_gpu_layers(other._n_gpu_layers)
102- , _n_past (other._n_past )
103+ , _n_system_tokens (other._n_system_tokens )
103104 , _is_gemma4(other._is_gemma4)
105+ , _sampler_dirty(other._sampler_dirty)
104106 , _seed(other._seed) {
105107}
106108
@@ -129,8 +131,9 @@ void Llama::reset() {
129131 _top_p = 1 .0f ;
130132 _min_p = 0 .0f ;
131133 _max_tokens = 150 ;
132- _n_past = 0 ;
134+ _n_system_tokens = 0 ;
133135 _seed = LLAMA_DEFAULT_SEED;
136+ _sampler_dirty = true ;
134137 if (_ctx) {
135138 llama_memory_clear (llama_get_memory (_ctx), true );
136139 }
@@ -210,13 +213,13 @@ bool Llama::load_embedding_model(string model_path) {
210213void Llama::set_grammar (const string &src, const string &root) {
211214 _grammar_src = src;
212215 _grammar_root = root;
216+ dirty ();
213217}
214218
215219bool Llama::add_message (LlamaIter &iter, const string &role, const string &content) {
216220 llama_chat_message message = {role.c_str (), content.c_str ()};
217221 int buf_size = 2 * (int )(role.size () + content.size () + 64 );
218222 vector<char > buf (buf_size);
219- bool add_ass = (role == " user" || role == " tool" );
220223 int32_t n = 0 ;
221224
222225 if (_template.empty ()) {
@@ -225,14 +228,18 @@ bool Llama::add_message(LlamaIter &iter, const string &role, const string &conte
225228 }
226229
227230 if (_is_gemma4) {
228- string str = " <|turn>" + role + " \n " + content + " <turn|>\n " ;
229- if (add_ass) {
230- str += " <|turn>model\n " ;
231+ // see: https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4
232+ string str;
233+ if (role == " system" ) {
234+ str = " <|turn>system\n <|think|>" + content + " <turn|>\n " ;
235+ } else {
236+ str = " <|turn>" + role + " \n " + content + " <turn|>\n " ;
231237 }
232238 n = str.size ();
233239 buf.assign (str.begin (), str.end ());
234240 buf.push_back (' \0 ' );
235241 } else {
242+ bool add_ass = (role == " user" || role == " tool" || role == " tool_result" );
236243 n = llama_chat_apply_template (_template.c_str (), &message, 1 , add_ass, buf.data (), buf_size);
237244 if (n < 0 ) {
238245 _last_error = " No chat template no supported" ;
@@ -244,16 +251,25 @@ bool Llama::add_message(LlamaIter &iter, const string &role, const string &conte
244251 }
245252 string prompt (buf.data (), n);
246253
247- if (!configure_sampler ()) {
248- return false ;
254+ if (_sampler_dirty) {
255+ // avoid wasteful rebuild
256+ if (!configure_sampler ()) {
257+ return false ;
258+ }
259+ _sampler_dirty = false ;
249260 }
250261
251262 vector<llama_token> prompt_tokens = tokenize (prompt);
252263 if (prompt_tokens.size () == 0 ) {
253264 return false ;
254265 }
255266
256- if (!make_space_for_tokens (prompt_tokens.size (), _n_past)) {
267+ if (role == " system" ) {
268+ // always retain system tokens
269+ _n_system_tokens = prompt_tokens.size ();
270+ }
271+
272+ if (!make_space_for_tokens (prompt_tokens.size ())) {
257273 return false ;
258274 }
259275
@@ -278,7 +294,6 @@ bool Llama::add_message(LlamaIter &iter, const string &role, const string &conte
278294 }
279295 }
280296
281- _n_past += prompt_tokens.size ();
282297 iter._t_start = std::chrono::high_resolution_clock::now ();
283298 iter._llama = this ;
284299 iter._has_next = true ;
@@ -326,7 +341,6 @@ string Llama::all(LlamaIter &iter) {
326341
327342 // end-of-generation check
328343 if (llama_vocab_is_eog (_vocab, tok)) {
329- iter._has_next = false ;
330344 break ;
331345 }
332346
@@ -342,6 +356,9 @@ string Llama::all(LlamaIter &iter) {
342356 }
343357 }
344358
359+ // tokens exhausted - call add_message to continue
360+ iter._has_next = false ;
361+
345362 // detokenize sequentially
346363 if (!decoded.empty ()) {
347364 for (llama_token tok : decoded) {
@@ -407,8 +424,8 @@ bool Llama::batch_decode_tokens(vector<llama_token> &tokens) {
407424 llama_batch batch = llama_batch_get_one (tokens.data () + i, batch_size);
408425 int result = llama_decode (_ctx, batch);
409426 if (result != 0 ) {
410- _last_error = std::format (" Failed to decode batch. position:{} error:{} [size:{}, past:{} ]" ,
411- i, result, tokens.size (), _n_past );
427+ _last_error = std::format (" Failed to decode batch. position:{} error:{} [size:{}]" ,
428+ i, result, tokens.size ());
412429 return false ;
413430 }
414431 }
@@ -507,9 +524,8 @@ bool Llama::ends_with_sentence_boundary(const string &text) {
507524//
508525// Parameters:
509526// n_tokens - Number of tokens we need space for
510- // keep_min - Minimum tokens to keep (e.g., system prompt), default 0
511527//
512- bool Llama::make_space_for_tokens (int n_tokens, int keep_min ) {
528+ bool Llama::make_space_for_tokens (int n_tokens) {
513529 int n_ctx = llama_n_ctx (_ctx);
514530 if (n_tokens > n_ctx) {
515531 _last_error = " Too many tokens, increase context size (n_ctx)" ;
@@ -539,10 +555,10 @@ bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
539555 // Calculate how many tokens to remove
540556 int tokens_to_remove = space_needed - space_available;
541557
542- // Can't remove more than we have (minus keep_min )
543- int removable = current_used - keep_min ;
558+ // Can't remove more than we have (minus _n_system_tokens )
559+ int removable = current_used - _n_system_tokens ;
544560 if (tokens_to_remove > removable) {
545- _last_error = " Can't make enough space while keeping keep_min tokens" ;
561+ _last_error = " Can't make enough space while keeping num_system_tokens tokens" ;
546562 return false ;
547563 }
548564
0 commit comments