Skip to content

Commit 4737d15

Browse files
author
Chris Warren-Smith
committed
LLAMA: fix nitro tool handling
1 parent 17a6e5f commit 4737d15

3 files changed

Lines changed: 250 additions & 197 deletions

File tree

llama/llama-sb.cpp

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,9 @@ Llama::Llama() :
6161
_max_tokens(0),
6262
_log_level(GGML_LOG_LEVEL_CONT),
6363
_n_gpu_layers(0),
64-
_n_past(0),
64+
_n_system_tokens(0),
6565
_is_gemma4(false),
66+
_sampler_dirty(false),
6667
_seed(LLAMA_DEFAULT_SEED) {
6768
llama_log_set([](enum ggml_log_level level, const char *text, void *user_data) {
6869
Llama *llama = (Llama *)user_data;
@@ -99,8 +100,9 @@ Llama::Llama(Llama &&other) noexcept
99100
, _max_tokens(other._max_tokens)
100101
, _log_level(other._log_level)
101102
, _n_gpu_layers(other._n_gpu_layers)
102-
, _n_past(other._n_past)
103+
, _n_system_tokens(other._n_system_tokens)
103104
, _is_gemma4(other._is_gemma4)
105+
, _sampler_dirty(other._sampler_dirty)
104106
, _seed(other._seed) {
105107
}
106108

@@ -129,8 +131,9 @@ void Llama::reset() {
129131
_top_p = 1.0f;
130132
_min_p = 0.0f;
131133
_max_tokens = 150;
132-
_n_past = 0;
134+
_n_system_tokens = 0;
133135
_seed = LLAMA_DEFAULT_SEED;
136+
_sampler_dirty = true;
134137
if (_ctx) {
135138
llama_memory_clear(llama_get_memory(_ctx), true);
136139
}
@@ -210,13 +213,13 @@ bool Llama::load_embedding_model(string model_path) {
210213
void Llama::set_grammar(const string &src, const string &root) {
211214
_grammar_src = src;
212215
_grammar_root = root;
216+
dirty();
213217
}
214218

215219
bool Llama::add_message(LlamaIter &iter, const string &role, const string &content) {
216220
llama_chat_message message = {role.c_str(), content.c_str()};
217221
int buf_size = 2 * (int)(role.size() + content.size() + 64);
218222
vector<char> buf(buf_size);
219-
bool add_ass = (role == "user" || role == "tool");
220223
int32_t n = 0;
221224

222225
if (_template.empty()) {
@@ -225,14 +228,18 @@ bool Llama::add_message(LlamaIter &iter, const string &role, const string &conte
225228
}
226229

227230
if (_is_gemma4) {
228-
string str = "<|turn>" + role + "\n" + content + "<turn|>\n";
229-
if (add_ass) {
230-
str += "<|turn>model\n";
231+
// see: https://ai.google.dev/gemma/docs/core/prompt-formatting-gemma4
232+
string str;
233+
if (role == "system") {
234+
str = "<|turn>system\n<|think|>" + content + "<turn|>\n";
235+
} else {
236+
str = "<|turn>" + role + "\n" + content + "<turn|>\n";
231237
}
232238
n = str.size();
233239
buf.assign(str.begin(), str.end());
234240
buf.push_back('\0');
235241
} else {
242+
bool add_ass = (role == "user" || role == "tool" || role == "tool_result");
236243
n = llama_chat_apply_template(_template.c_str(), &message, 1, add_ass, buf.data(), buf_size);
237244
if (n < 0) {
238245
_last_error = "No chat template no supported";
@@ -244,16 +251,25 @@ bool Llama::add_message(LlamaIter &iter, const string &role, const string &conte
244251
}
245252
string prompt(buf.data(), n);
246253

247-
if (!configure_sampler()) {
248-
return false;
254+
if (_sampler_dirty) {
255+
// avoid wasteful rebuild
256+
if (!configure_sampler()) {
257+
return false;
258+
}
259+
_sampler_dirty = false;
249260
}
250261

251262
vector<llama_token> prompt_tokens = tokenize(prompt);
252263
if (prompt_tokens.size() == 0) {
253264
return false;
254265
}
255266

256-
if (!make_space_for_tokens(prompt_tokens.size(), _n_past)) {
267+
if (role == "system") {
268+
// always retain system tokens
269+
_n_system_tokens = prompt_tokens.size();
270+
}
271+
272+
if (!make_space_for_tokens(prompt_tokens.size())) {
257273
return false;
258274
}
259275

@@ -278,7 +294,6 @@ bool Llama::add_message(LlamaIter &iter, const string &role, const string &conte
278294
}
279295
}
280296

281-
_n_past += prompt_tokens.size();
282297
iter._t_start = std::chrono::high_resolution_clock::now();
283298
iter._llama = this;
284299
iter._has_next = true;
@@ -326,7 +341,6 @@ string Llama::all(LlamaIter &iter) {
326341

327342
// end-of-generation check
328343
if (llama_vocab_is_eog(_vocab, tok)) {
329-
iter._has_next = false;
330344
break;
331345
}
332346

@@ -342,6 +356,9 @@ string Llama::all(LlamaIter &iter) {
342356
}
343357
}
344358

359+
// tokens exhausted - call add_message to continue
360+
iter._has_next = false;
361+
345362
// detokenize sequentially
346363
if (!decoded.empty()) {
347364
for (llama_token tok : decoded) {
@@ -407,8 +424,8 @@ bool Llama::batch_decode_tokens(vector<llama_token> &tokens) {
407424
llama_batch batch = llama_batch_get_one(tokens.data() + i, batch_size);
408425
int result = llama_decode(_ctx, batch);
409426
if (result != 0) {
410-
_last_error = std::format("Failed to decode batch. position:{} error:{} [size:{}, past:{}]",
411-
i, result, tokens.size(), _n_past);
427+
_last_error = std::format("Failed to decode batch. position:{} error:{} [size:{}]",
428+
i, result, tokens.size());
412429
return false;
413430
}
414431
}
@@ -507,9 +524,8 @@ bool Llama::ends_with_sentence_boundary(const string &text) {
507524
//
508525
// Parameters:
509526
// n_tokens - Number of tokens we need space for
510-
// keep_min - Minimum tokens to keep (e.g., system prompt), default 0
511527
//
512-
bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
528+
bool Llama::make_space_for_tokens(int n_tokens) {
513529
int n_ctx = llama_n_ctx(_ctx);
514530
if (n_tokens > n_ctx) {
515531
_last_error = "Too many tokens, increase context size (n_ctx)";
@@ -539,10 +555,10 @@ bool Llama::make_space_for_tokens(int n_tokens, int keep_min) {
539555
// Calculate how many tokens to remove
540556
int tokens_to_remove = space_needed - space_available;
541557

542-
// Can't remove more than we have (minus keep_min)
543-
int removable = current_used - keep_min;
558+
// Can't remove more than we have (minus _n_system_tokens)
559+
int removable = current_used - _n_system_tokens;
544560
if (tokens_to_remove > removable) {
545-
_last_error = "Can't make enough space while keeping keep_min tokens";
561+
_last_error = "Can't make enough space while keeping num_system_tokens tokens";
546562
return false;
547563
}
548564

llama/llama-sb.h

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,17 @@ struct Llama {
8181
// generation parameters
8282
void add_stop(const char *stop) { _stop_sequences.push_back(stop); }
8383
void clear_stops() { _stop_sequences.clear(); }
84-
void set_penalty_last_n(int32_t penalty_last_n) { _penalty_last_n = penalty_last_n; }
85-
void set_penalty_repeat(float penalty_repeat) { _penalty_repeat = penalty_repeat; }
86-
void set_penalty_freq(float penalty_freq) { _penalty_freq = penalty_freq; }
87-
void set_penalty_present(float penalty_present) { _penalty_present = penalty_present; }
88-
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; }
89-
void set_min_p(float min_p) { _min_p = min_p; }
90-
void set_temperature(float temperature) { _temperature = temperature; }
91-
void set_top_k(int top_k) { _top_k = top_k; }
92-
void set_top_p(float top_p) { _top_p = top_p; }
84+
void set_penalty_last_n(int32_t penalty_last_n) { _penalty_last_n = penalty_last_n; dirty(); }
85+
void set_penalty_repeat(float penalty_repeat) { _penalty_repeat = penalty_repeat; dirty(); }
86+
void set_penalty_freq(float penalty_freq) { _penalty_freq = penalty_freq; dirty(); }
87+
void set_penalty_present(float penalty_present) { _penalty_present = penalty_present; dirty(); }
88+
void set_max_tokens(int max_tokens) { _max_tokens = max_tokens; dirty(); }
89+
void set_min_p(float min_p) { _min_p = min_p; dirty(); }
90+
void set_temperature(float temperature) { _temperature = temperature; dirty(); }
91+
void set_top_k(int top_k) { _top_k = top_k; dirty(); }
92+
void set_top_p(float top_p) { _top_p = top_p; dirty(); }
9393
void set_grammar(const string &src, const string &root);
94-
void set_seed(unsigned int seed) { _seed = seed; }
94+
void set_seed(unsigned int seed) { _seed = seed; dirty(); }
9595

9696
// error handling
9797
const char *last_error() { return _last_error.c_str(); }
@@ -110,8 +110,9 @@ struct Llama {
110110
private:
111111
bool batch_decode_tokens(vector<llama_token> &tokens);
112112
bool configure_sampler();
113+
void dirty() {_sampler_dirty = true; }
113114
bool ends_with_sentence_boundary(const string &out);
114-
bool make_space_for_tokens(int n_tokens, int keep_min);
115+
bool make_space_for_tokens(int n_tokens);
115116
vector<llama_token> tokenize(const string &prompt);
116117
string token_to_string(LlamaIter &iter, llama_token tok);
117118
void set_last_error(const char *message);
@@ -136,7 +137,8 @@ struct Llama {
136137
int _max_tokens;
137138
int _log_level;
138139
int _n_gpu_layers;
139-
int _n_past;
140+
int _n_system_tokens;
140141
bool _is_gemma4;
142+
bool _sampler_dirty;
141143
unsigned int _seed;
142144
};

0 commit comments

Comments
 (0)