Skip to content

Commit 3576e10

Browse files
committed
built-in support of translategemma models.
1 parent 27f0764 commit 3576e10

7 files changed

Lines changed: 209 additions & 2 deletions

File tree

models/gemma.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,174 @@ namespace chatllm::gemma::siglip
491491
}
492492
}
493493

494+
namespace chatllm::gemma::translation
495+
{
496+
const static json::JSON language_code_dict({
497+
"aa", "Afar", "aa-DJ", "Afar", "aa-ER", "Afar", "ab", "Abkhazian", "af", "Afrikaans", "af-NA", "Afrikaans", "ak", "Akan", "am", "Amharic", "an", "Aragonese", "ar", "Arabic",
498+
"ar-AE", "Arabic", "ar-BH", "Arabic", "ar-DJ", "Arabic", "ar-DZ", "Arabic", "ar-EG", "Arabic", "ar-EH", "Arabic", "ar-ER", "Arabic", "ar-IL", "Arabic", "ar-IQ", "Arabic", "ar-JO", "Arabic",
499+
"ar-KM", "Arabic", "ar-KW", "Arabic", "ar-LB", "Arabic", "ar-LY", "Arabic", "ar-MA", "Arabic", "ar-MR", "Arabic", "ar-OM", "Arabic", "ar-PS", "Arabic", "ar-QA", "Arabic", "ar-SA", "Arabic",
500+
"ar-SD", "Arabic", "ar-SO", "Arabic", "ar-SS", "Arabic", "ar-SY", "Arabic", "ar-TD", "Arabic", "ar-TN", "Arabic", "ar-YE", "Arabic", "as", "Assamese", "az", "Azerbaijani", "az-Arab", "Azerbaijani",
501+
"az-Arab-IQ", "Azerbaijani", "az-Arab-TR", "Azerbaijani", "az-Cyrl", "Azerbaijani", "az-Latn", "Azerbaijani", "ba", "Bashkir", "be", "Belarusian", "be-tarask", "Belarusian", "bg", "Bulgarian", "bg-BG", "Bulgarian", "bm", "Bambara",
502+
"bm-Nkoo", "Bambara", "bn", "Bengali", "bn-IN", "Bengali", "bo", "Tibetan", "bo-IN", "Tibetan", "br", "Breton", "bs", "Bosnian", "bs-Cyrl", "Bosnian", "bs-Latn", "Bosnian", "ca", "Catalan",
503+
"ca-AD", "Catalan", "ca-ES", "Catalan", "ca-FR", "Catalan", "ca-IT", "Catalan", "ce", "Chechen", "co", "Corsican", "cs", "Czech", "cs-CZ", "Czech", "cv", "Chuvash", "cy", "Welsh",
504+
"da", "Danish", "da-DK", "Danish", "da-GL", "Danish", "de", "German", "de-AT", "German", "de-BE", "German", "de-CH", "German", "de-DE", "German", "de-IT", "German", "de-LI", "German",
505+
"de-LU", "German", "dv", "Divehi", "dz", "Dzongkha", "ee", "Ewe", "ee-TG", "Ewe", "el", "Greek", "el-CY", "Greek", "el-GR", "Greek", "el-polyton", "Greek", "en", "English",
506+
"en-AE", "English", "en-AG", "English", "en-AI", "English", "en-AS", "English", "en-AT", "English", "en-AU", "English", "en-BB", "English", "en-BE", "English", "en-BI", "English", "en-BM", "English",
507+
"en-BS", "English", "en-BW", "English", "en-BZ", "English", "en-CA", "English", "en-CC", "English", "en-CH", "English", "en-CK", "English", "en-CM", "English", "en-CX", "English", "en-CY", "English",
508+
"en-CZ", "English", "en-DE", "English", "en-DG", "English", "en-DK", "English", "en-DM", "English", "en-ER", "English", "en-ES", "English", "en-FI", "English", "en-FJ", "English", "en-FK", "English",
509+
"en-FM", "English", "en-FR", "English", "en-GB", "English", "en-GD", "English", "en-GG", "English", "en-GH", "English", "en-GI", "English", "en-GM", "English", "en-GS", "English", "en-GU", "English",
510+
"en-GY", "English", "en-HK", "English", "en-HU", "English", "en-ID", "English", "en-IE", "English", "en-IL", "English", "en-IM", "English", "en-IN", "English", "en-IO", "English", "en-IT", "English",
511+
"en-JE", "English", "en-JM", "English", "en-KE", "English", "en-KI", "English", "en-KN", "English", "en-KY", "English", "en-LC", "English", "en-LR", "English", "en-LS", "English", "en-MG", "English",
512+
"en-MH", "English", "en-MO", "English", "en-MP", "English", "en-MS", "English", "en-MT", "English", "en-MU", "English", "en-MV", "English", "en-MW", "English", "en-MY", "English", "en-NA", "English",
513+
"en-NF", "English", "en-NG", "English", "en-NL", "English", "en-NO", "English", "en-NR", "English", "en-NU", "English", "en-NZ", "English", "en-PG", "English", "en-PH", "English", "en-PK", "English",
514+
"en-PL", "English", "en-PN", "English", "en-PR", "English", "en-PT", "English", "en-PW", "English", "en-RO", "English", "en-RW", "English", "en-SB", "English", "en-SC", "English", "en-SD", "English",
515+
"en-SE", "English", "en-SG", "English", "en-SH", "English", "en-SI", "English", "en-SK", "English", "en-SL", "English", "en-SS", "English", "en-SX", "English", "en-SZ", "English", "en-TC", "English",
516+
"en-TK", "English", "en-TO", "English", "en-TT", "English", "en-TV", "English", "en-TZ", "English", "en-UG", "English", "en-UM", "English", "en-VC", "English", "en-VG", "English", "en-VI", "English",
517+
"en-VU", "English", "en-WS", "English", "en-ZA", "English", "en-ZM", "English", "en-ZW", "English", "eo", "Esperanto", "es", "Spanish", "es-AR", "Spanish", "es-BO", "Spanish", "es-BR", "Spanish",
518+
"es-BZ", "Spanish", "es-CL", "Spanish", "es-CO", "Spanish", "es-CR", "Spanish", "es-CU", "Spanish", "es-DO", "Spanish", "es-EA", "Spanish", "es-EC", "Spanish", "es-ES", "Spanish", "es-GQ", "Spanish",
519+
"es-GT", "Spanish", "es-HN", "Spanish", "es-IC", "Spanish", "es-MX", "Spanish", "es-NI", "Spanish", "es-PA", "Spanish", "es-PE", "Spanish", "es-PH", "Spanish", "es-PR", "Spanish", "es-PY", "Spanish",
520+
"es-SV", "Spanish", "es-US", "Spanish", "es-UY", "Spanish", "es-VE", "Spanish", "et", "Estonian", "et-EE", "Estonian", "eu", "Basque", "fa", "Persian", "fa-AF", "Persian", "fa-IR", "Persian",
521+
"ff", "Fulah", "ff-Adlm", "Fulah", "ff-Adlm-BF", "Fulah", "ff-Adlm-CM", "Fulah", "ff-Adlm-GH", "Fulah", "ff-Adlm-GM", "Fulah", "ff-Adlm-GW", "Fulah", "ff-Adlm-LR", "Fulah", "ff-Adlm-MR", "Fulah", "ff-Adlm-NE", "Fulah",
522+
"ff-Adlm-NG", "Fulah", "ff-Adlm-SL", "Fulah", "ff-Adlm-SN", "Fulah", "ff-Latn", "Fulah", "ff-Latn-BF", "Fulah", "ff-Latn-CM", "Fulah", "ff-Latn-GH", "Fulah", "ff-Latn-GM", "Fulah", "ff-Latn-GN", "Fulah", "ff-Latn-GW", "Fulah",
523+
"ff-Latn-LR", "Fulah", "ff-Latn-MR", "Fulah", "ff-Latn-NE", "Fulah", "ff-Latn-NG", "Fulah", "ff-Latn-SL", "Fulah", "fi", "Finnish", "fi-FI", "Finnish", "fil-PH", "Filipino", "fo", "Faroese", "fo-DK", "Faroese",
524+
"fr", "French", "fr-BE", "French", "fr-BF", "French", "fr-BI", "French", "fr-BJ", "French", "fr-BL", "French", "fr-CA", "French", "fr-CD", "French", "fr-CF", "French", "fr-CG", "French",
525+
"fr-CH", "French", "fr-CI", "French", "fr-CM", "French", "fr-DJ", "French", "fr-DZ", "French", "fr-FR", "French", "fr-GA", "French", "fr-GF", "French", "fr-GN", "French", "fr-GP", "French",
526+
"fr-GQ", "French", "fr-HT", "French", "fr-KM", "French", "fr-LU", "French", "fr-MA", "French", "fr-MC", "French", "fr-MF", "French", "fr-MG", "French", "fr-ML", "French", "fr-MQ", "French",
527+
"fr-MR", "French", "fr-MU", "French", "fr-NC", "French", "fr-NE", "French", "fr-PF", "French", "fr-PM", "French", "fr-RE", "French", "fr-RW", "French", "fr-SC", "French", "fr-SN", "French",
528+
"fr-SY", "French", "fr-TD", "French", "fr-TG", "French", "fr-TN", "French", "fr-VU", "French", "fr-WF", "French", "fr-YT", "French", "fy", "Western Frisian", "ga", "Irish", "ga-GB", "Irish",
529+
"gd", "Scottish Gaelic", "gl", "Galician", "gn", "Guarani", "gu", "Gujarati", "gu-IN", "Gujarati", "gv", "Manx", "ha", "Hausa", "ha-Arab", "Hausa", "ha-Arab-SD", "Hausa", "ha-GH", "Hausa",
530+
"ha-NE", "Hausa", "he", "Hebrew", "he-IL", "Hebrew", "hi", "Hindi", "hi-IN", "Hindi", "hi-Latn", "Hindi", "hr", "Croatian", "hr-BA", "Croatian", "hr-HR", "Croatian", "ht", "Haitian",
531+
"hu", "Hungarian", "hu-HU", "Hungarian", "hy", "Armenian", "ia", "Interlingua", "id", "Indonesian", "id-ID", "Indonesian", "ie", "Interlingue", "ig", "Igbo", "ii", "Sichuan Yi", "ik", "Inupiaq",
532+
"io", "Ido", "is", "Icelandic", "it", "Italian", "it-CH", "Italian", "it-IT", "Italian", "it-SM", "Italian", "it-VA", "Italian", "iu", "Inuktitut", "iu-Latn", "Inuktitut", "ja", "Japanese",
533+
"ja-JP", "Japanese", "jv", "Javanese", "ka", "Georgian", "ki", "Kikuyu", "kk", "Kazakh", "kk-Arab", "Kazakh", "kk-Cyrl", "Kazakh", "kk-KZ", "Kazakh", "kl", "Kalaallisut", "km", "Central Khmer",
534+
"kn", "Kannada", "kn-IN", "Kannada", "ko", "Korean", "ko-CN", "Korean", "ko-KP", "Korean", "ko-KR", "Korean", "ks", "Kashmiri", "ks-Arab", "Kashmiri", "ks-Deva", "Kashmiri", "ku", "Kurdish",
535+
"kw", "Cornish", "ky", "Kyrgyz", "la", "Latin", "lb", "Luxembourgish", "lg", "Ganda", "ln", "Lingala", "ln-AO", "Lingala", "ln-CF", "Lingala", "ln-CG", "Lingala", "lo", "Lao",
536+
"lt", "Lithuanian", "lt-LT", "Lithuanian", "lu", "Luba-Katanga", "lv", "Latvian", "lv-LV", "Latvian", "mg", "Malagasy", "mi", "Maori", "mk", "Macedonian", "ml", "Malayalam", "ml-IN", "Malayalam",
537+
"mn", "Mongolian", "mn-Mong", "Mongolian", "mn-Mong-MN", "Mongolian", "mr", "Marathi", "mr-IN", "Marathi", "ms", "Malay", "ms-Arab", "Malay", "ms-Arab-BN", "Malay", "ms-BN", "Malay", "ms-ID", "Malay",
538+
"ms-SG", "Malay", "mt", "Maltese", "my", "Burmese", "nb", "Norwegian Bokmål", "nb-SJ", "Norwegian Bokmål", "nd", "North Ndebele", "ne", "Nepali", "ne-IN", "Nepali", "nl", "Dutch", "nl-AW", "Dutch",
539+
"nl-BE", "Dutch", "nl-BQ", "Dutch", "nl-CW", "Dutch", "nl-NL", "Dutch", "nl-SR", "Dutch", "nl-SX", "Dutch", "nn", "Norwegian Nynorsk", "no", "Norwegian", "no-NO", "Norwegian", "nr", "South Ndebele",
540+
"nv", "Navajo", "ny", "Chichewa", "oc", "Occitan", "oc-ES", "Occitan", "om", "Oromo", "om-KE", "Oromo", "or", "Oriya", "os", "Ossetian", "os-RU", "Ossetian", "pa", "Punjabi",
541+
"pa-IN", "Punjabi", "pa-Arab", "Punjabi", "pa-Guru", "Punjabi", "pl", "Polish", "pl-PL", "Polish", "ps", "Pashto", "ps-PK", "Pashto", "pt", "Portuguese", "pt-AO", "Portuguese", "pt-BR", "Portuguese",
542+
"pt-CH", "Portuguese", "pt-CV", "Portuguese", "pt-GQ", "Portuguese", "pt-GW", "Portuguese", "pt-LU", "Portuguese", "pt-MO", "Portuguese", "pt-MZ", "Portuguese", "pt-PT", "Portuguese", "pt-ST", "Portuguese", "pt-TL", "Portuguese",
543+
"qu", "Quechua", "qu-BO", "Quechua", "qu-EC", "Quechua", "rm", "Romansh", "rn", "Rundi", "ro", "Romanian", "ro-MD", "Romanian", "ro-RO", "Romanian", "ru", "Russian", "ru-BY", "Russian",
544+
"ru-KG", "Russian", "ru-KZ", "Russian", "ru-MD", "Russian", "ru-RU", "Russian", "ru-UA", "Russian", "rw", "Kinyarwanda", "sa", "Sanskrit", "sc", "Sardinian", "sd", "Sindhi", "sd-Arab", "Sindhi",
545+
"sd-Deva", "Sindhi", "se", "Northern Sami", "se-FI", "Northern Sami", "se-SE", "Northern Sami", "sg", "Sango", "si", "Sinhala", "sk", "Slovak", "sk-SK", "Slovak", "sl", "Slovenian", "sl-SI", "Slovenian",
546+
"sn", "Shona", "so", "Somali", "so-DJ", "Somali", "so-ET", "Somali", "so-KE", "Somali", "sq", "Albanian", "sq-MK", "Albanian", "sq-XK", "Albanian", "sr", "Serbian", "sr-RS", "Serbian",
547+
"sr-Cyrl", "Serbian", "sr-Cyrl-BA", "Serbian", "sr-Cyrl-ME", "Serbian", "sr-Cyrl-XK", "Serbian", "sr-Latn", "Serbian", "sr-Latn-BA", "Serbian", "sr-Latn-ME", "Serbian", "sr-Latn-XK", "Serbian", "ss", "Swati", "ss-SZ", "Swati",
548+
"st", "Southern Sotho", "st-LS", "Southern Sotho", "su", "Sundanese", "su-Latn", "Sundanese", "sv", "Swedish", "sv-AX", "Swedish", "sv-FI", "Swedish", "sv-SE", "Swedish", "sw", "Swahili", "sw-CD", "Swahili",
549+
"sw-KE", "Swahili", "sw-TZ", "Swahili", "sw-UG", "Swahili", "ta", "Tamil", "ta-IN", "Tamil", "ta-LK", "Tamil", "ta-MY", "Tamil", "ta-SG", "Tamil", "te", "Telugu", "te-IN", "Telugu",
550+
"tg", "Tajik", "th", "Thai", "th-TH", "Thai", "ti", "Tigrinya", "ti-ER", "Tigrinya", "tk", "Turkmen", "tl", "Tagalog", "tn", "Tswana", "tn-BW", "Tswana", "to", "Tonga",
551+
"tr", "Turkish", "tr-CY", "Turkish", "tr-TR", "Turkish", "ts", "Tsonga", "tt", "Tatar", "ug", "Uyghur", "uk", "Ukrainian", "uk-UA", "Ukrainian", "ur", "Urdu", "ur-IN", "Urdu",
552+
"ur-PK", "Urdu", "uz", "Uzbek", "uz-Arab", "Uzbek", "uz-Cyrl", "Uzbek", "uz-Latn", "Uzbek", "ve", "Venda", "vi", "Vietnamese", "vi-VN", "Vietnamese", "vo", "Volapük", "wa", "Walloon",
553+
"wo", "Wolof", "xh", "Xhosa", "yi", "Yiddish", "yo", "Yoruba", "yo-BJ", "Yoruba", "za", "Zhuang", "zh", "Chinese", "zh-CH", "Chinese", "zh-TW", "Chinese", "zh-Hans", "Chinese",
554+
"zh-Hans-HK", "Chinese", "zh-Hans-MO", "Chinese", "zh-Hans-MY", "Chinese", "zh-Hans-SG", "Chinese", "zh-Hant", "Chinese", "zh-Hant-HK", "Chinese", "zh-Hant-MO", "Chinese", "zh-Hant-MY", "Chinese", "zh-Latn", "Chinese", "zu", "Zulu",
555+
"zu-ZA", "Zulu"});
556+
557+
class TranslateGemmaHistoryEncoder : public v3::ChatHistoryEncoder
558+
{
559+
public:
560+
typedef v3::ChatHistoryEncoder Base;
561+
void append_user(int round_idx, const std::string &user, std::vector<int> &ids) const override;
562+
void append_user(int round_idx, const Content &user, std::vector<int> &ids) const override;
563+
public:
564+
void parse_command(std::string &user) const;
565+
public:
566+
std::string def_source_lang_code = "zh";
567+
std::string def_target_lang_code = "en";
568+
std::string source_lang_code;
569+
std::string target_lang_code;
570+
};
571+
572+
static TranslateGemmaHistoryEncoder _translate_encoder;
573+
574+
void TranslateGemmaHistoryEncoder::parse_command(std::string &user) const
575+
{
576+
user = utils::trim(user);
577+
if (user.size() < 1) return;
578+
if (user[0] != '/') return;
579+
auto pos = user.find(' ');
580+
std::string command = user.substr(1, pos - 1);
581+
if (pos != std::string::npos)
582+
user = utils::trim(user.substr(pos + 1));
583+
else
584+
user = "";
585+
std::vector<std::string> items;
586+
utils::split(command, "->", items);
587+
if (items.size() !=2) return;
588+
589+
auto obj = const_cast<TranslateGemmaHistoryEncoder *>(this);
590+
591+
if (language_code_dict.hasKey(items[0]))
592+
obj->source_lang_code = items[0];
593+
if (language_code_dict.hasKey(items[1]))
594+
obj->target_lang_code = items[1];
595+
}
596+
597+
void TranslateGemmaHistoryEncoder::append_user(int round_idx, const std::string &user, std::vector<int> &ids) const
598+
{
599+
Content c(nullptr, user);
600+
append_user(round_idx, c, ids);
601+
}
602+
603+
void TranslateGemmaHistoryEncoder::append_user(int round_idx, const Content &user, std::vector<int> &ids) const
604+
{
605+
std::string text = "";
606+
std::string image = "";
607+
608+
{
609+
auto obj = const_cast<TranslateGemmaHistoryEncoder *>(this);
610+
obj->source_lang_code = def_source_lang_code;
611+
obj->target_lang_code = def_target_lang_code;
612+
}
613+
614+
for (auto &piece : user.pieces)
615+
{
616+
switch (piece.type)
617+
{
618+
case ContentPiece::Type::Text:
619+
{
620+
std::string s = piece.content;
621+
parse_command(s);
622+
if (s.size() < 1) break;
623+
CHATLLM_CHECK(text.size() < 1) << "only one text input is allowed";
624+
text = s;
625+
}
626+
break;
627+
case ContentPiece::Type::Image:
628+
CHATLLM_CHECK(image.size() < 1) << "only one image input is allowed";
629+
image = piece.content;
630+
break;
631+
default:
632+
CHATLLM_CHECK(false) << "only text/image input are allowed";
633+
break;
634+
}
635+
}
636+
637+
const std::string source_lang = language_code_dict[source_lang_code].ToString();
638+
const std::string target_lang = language_code_dict[target_lang_code].ToString();
639+
640+
const std::string s1 = "You are a professional " + source_lang + " (" + source_lang_code + ") to " +
641+
target_lang + " (" + target_lang_code + ") translator. Your goal is to accurately convey the meaning and "
642+
"nuances of the original " + source_lang + " text while adhering to " + target_lang + " grammar, "
643+
"vocabulary, and cultural sensitivities.\n";
644+
645+
const std::string s2 = text.size() > 0 ?
646+
"Produce only the " + target_lang + " translation, without any additional explanations or " +
647+
"commentary. Please translate the following " + source_lang + " text into " + target_lang + ":\n\n\n" + text
648+
:
649+
"Please translate the " + source_lang + " text in the provided image into " + target_lang + ". " +
650+
"Produce only the " + target_lang + " translation, without any additional explanations, " +
651+
"alternatives or commentary. Focus only on the text, do not output where the text is located, " +
652+
"surrounding objects or any other explanation about the picture. Ignore symbols, pictogram, and " +
653+
"arrows!\n\n\n";
654+
655+
Content c(nullptr, s1 + s2);
656+
if (image.size() > 0)
657+
c.push_back(image, ContentPiece::Type::Image);
658+
Base::append_user(round_idx, c, ids);
659+
}
660+
}
661+
494662
namespace chatllm::gemma::v3
495663
{
496664
static ChatHistoryEncoder _chat_encoder;
@@ -601,6 +769,17 @@ namespace chatllm::gemma::v3
601769
_chat_encoder.vit_loaded = visual.load(loader);
602770
}
603771

772+
void ConditionalGeneration::set_tokenizer(BaseTokenizer *tokenizer)
773+
{
774+
BaseModelForConditionalGeneration::set_tokenizer(tokenizer);
775+
if ("TranslateGemma" == name_)
776+
{
777+
Tokenizer *tok = dynamic_cast<Tokenizer *>(tokenizer);
778+
tok->set_chat_encoder(&translation::_translate_encoder);
779+
multi_turn = false;
780+
}
781+
}
782+
604783
bool ConditionalGeneration::load_more(const json::JSON &config)
605784
{
606785
BaseModelForConditionalGeneration::load_more(config);

models/gemma.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,7 @@ namespace chatllm::gemma::v3
331331
ConditionalGeneration(const Config &config, const RuntimeConfig &runtime_config, ModelType type = MODEL_TYPE_GEMMA3);
332332
void load(ModelLoader &loader) override;
333333

334+
void set_tokenizer(BaseTokenizer *tokenizer) override;
334335
bool load_more(const json::JSON &config) override;
335336
void set_additional_args(const std::map<std::string, std::string> &args) override;
336337

src/basics.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ namespace utils
6565
std::string trim(const std::string& str);
6666

6767
std::string join(const std::vector<std::string>& vec, const std::string& sep);
68+
void split(const std::string &str, const std::string &delimiter, std::vector<std::string> &items);
6869

6970
std::string replace_all(const std::string& str, const std::string& from, const std::string& to);
7071

src/chat.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1991,6 +1991,11 @@ namespace chatllm
19911991
return "";
19921992
}
19931993

1994+
bool Pipeline::support_multi_turn(void) const
1995+
{
1996+
return is_loaded() ? modelobj.model->support_multi_turn() : false;
1997+
}
1998+
19941999
void Pipeline::restart(void)
19952000
{
19962001
initializing = true;

src/chat.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,8 @@ namespace chatllm
10201020
virtual void set_additional_args(const std::map<std::string, std::string> &args) {}
10211021

10221022
virtual LayerAllocatorManager *get_alloc_manager(void) = 0;
1023+
1024+
virtual bool support_multi_turn(void) const { return false; }
10231025
};
10241026

10251027
class ModelProxy : public AbstractModel
@@ -1293,6 +1295,8 @@ namespace chatllm
12931295
int save_session(ModelSessionMemory &session) const override;
12941296
int load_session(ModelSessionMemory &session) override;
12951297

1298+
bool support_multi_turn(void) const override { return multi_turn; }
1299+
12961300
private:
12971301
struct state
12981302
{
@@ -1309,6 +1313,7 @@ namespace chatllm
13091313
BaseTokenizer *tokenizer;
13101314
ModelPurpose purpose;
13111315
bool aborted;
1316+
bool multi_turn = true;
13121317
private:
13131318
int _seed;
13141319
};
@@ -1406,6 +1411,7 @@ namespace chatllm
14061411
virtual std::string get_additional_description(void) const;
14071412

14081413
bool is_loaded(void) const { return modelobj.loaded; }
1414+
virtual bool support_multi_turn(void) const;
14091415

14101416
virtual void restart(void);
14111417
virtual void rewind(int n_past);

src/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,7 +1075,7 @@ void chat(Args &args, chatllm::Pipeline &pipeline, TextStreamer &streamer)
10751075
std::string output = pipeline.chat(history, gen_config, &streamer);
10761076
history.push_back(output, chatllm::MsgRole::User);
10771077

1078-
if (args.single_turn)
1078+
if (args.single_turn || !pipeline.support_multi_turn())
10791079
{
10801080
history.clear();
10811081
pipeline.restart();
@@ -1107,7 +1107,7 @@ void chat(Args &args, chatllm::Pipeline &pipeline, TextStreamer &streamer)
11071107
std::string output = pipeline.chat(history, gen_config, &streamer);
11081108
history.push_back(output, chatllm::MsgRole::Assistant);
11091109

1110-
if (args.single_turn)
1110+
if (args.single_turn || !pipeline.support_multi_turn())
11111111
{
11121112
history.clear();
11131113
pipeline.restart();

0 commit comments

Comments
 (0)