From 87744d2abc27be9f65bb538375e3468e7aad1cbb Mon Sep 17 00:00:00 2001 From: dianjixz <18637716021@163.com> Date: Mon, 12 May 2025 10:20:54 +0800 Subject: [PATCH 01/79] [update] README.md add model list --- README_zh.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/README_zh.md b/README_zh.md index dbeff436..4216b0aa 100644 --- a/README_zh.md +++ b/README_zh.md @@ -14,6 +14,7 @@ * [特性](#特性) * [Demo](#demo) +* [模型列表](#模型列表) * [环境要求](#环境要求) * [编译](#编译) * [安装](#安装) @@ -54,6 +55,11 @@ StackFlow 语音助手的主要工作模式: - [StackFlow yolo 视觉检测](https://github.com/Abandon-ht/ModuleLLM_Development_Guide/tree/dev/ESP32/cpp) - [StackFlow VLM 图片描述](https://github.com/Abandon-ht/ModuleLLM_Development_Guide/tree/dev/ESP32/cpp) +## 模型列表 +| 模型名 | 模型类型 | 模型大小 | 模型能力 | 模型配置文件 | 计算单元 | +| :----: | :----: | :----: | :----: | :----: | :----: | +| [sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01](https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz2) | KWS | 6.4M | 关键词识别 | [mode_sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.json](projects/llm_framework/main_kws/mode_sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.json) | CPU | + ## 环境要求 ## 当前 StackFlow 的 AI 单元是建立在 AXERA 加速平台之上的,主要的芯片平台为 ax630c、ax650n。系统要求为 ubuntu。 From 10e4bdf828912595765ef4031b95435073054ea7 Mon Sep 17 00:00:00 2001 From: yuyun2000 <15515722313yxw@gmail.com> Date: Thu, 15 May 2025 15:17:04 +0800 Subject: [PATCH 02/79] Refactor SOLA component code Streamline and simplify code in the SOLA module for improved readability and maintenance --- .../llm_framework/main_melotts/src/main.cpp | 119 ++---------------- 1 file changed, 12 insertions(+), 107 deletions(-) diff --git a/projects/llm_framework/main_melotts/src/main.cpp b/projects/llm_framework/main_melotts/src/main.cpp index 4fab699f..5c371d97 100644 --- a/projects/llm_framework/main_melotts/src/main.cpp +++ b/projects/llm_framework/main_melotts/src/main.cpp @@ -243,7 +243,6 @@ class llm_task { try { std::vector wav_pcm_data; if (msg_str.empty()) { - SLOGI("empty"); if (out_callback_) { std::string output = wav_pcm_data.empty() ? std::string() : std::string((char *)wav_pcm_data.data(), @@ -252,9 +251,7 @@ class llm_task { } return false; } - SLOGI("Processing text: %s", msg_str.c_str()); - // Convert text to phonemes and tones std::vector phones_bef, tones_bef; lexicon_->convert(msg_str, phones_bef, tones_bef); auto phones = intersperse(phones_bef, 0); @@ -262,9 +259,6 @@ class llm_task { int phone_len = phones.size(); std::vector langids(phone_len, 3); - SLOGI("Phoneme conversion completed, length: %d", phone_len); - - // Run the encoder to generate hidden representations auto encoder_output = encoder_->Run(phones, tones, langids, g_matrix, mode_config_.noise_scale, mode_config_.noise_scale_w, mode_config_.get_length_scale(), mode_config_.sdp_ratio); @@ -273,33 +267,22 @@ class llm_task { auto zp_info = encoder_output.at(0).GetTensorTypeAndShapeInfo(); auto zp_shape = zp_info.GetShape(); - SLOGI("Encoder output completed, shape: [%ld, %ld, %ld], expected audio length: %d", zp_shape[0], - zp_shape[1], zp_shape[2], audio_len); - - // Calculate decoder parameters int zp_size = decoder_->GetInputSize(0) / sizeof(float); int dec_len = zp_size / zp_shape[1]; int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float); - const int pad_frames = 16; + const int pad_frames = 24; const int samples_per_frame = 512; - SLOGI("Decoder configuration: frame length=%d, audio slice length=%d, pad length=%d, samples per frame=%d", - dec_len, audio_slice_len, pad_frames, samples_per_frame); - const int effective_frames = dec_len - 2 * pad_frames; int dec_slice_num = static_cast(std::ceil(static_cast(zp_shape[2]) / static_cast(effective_frames))); - SLOGI("Will perform %d inferences, each with effective frames: %d", dec_slice_num, effective_frames); + const int sola_buffer_frame = pad_frames * samples_per_frame; + const int sola_search_frame = pad_frames * samples_per_frame; + const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; - // SOLA parameters setup - const int sola_buffer_frame = pad_frames * samples_per_frame; // Overlap buffer length - const int sola_search_frame = pad_frames * samples_per_frame; // Search window length - const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; // Effective block length - - // Create fade-in/fade-out windows for smooth transitions std::vector fade_in_window(sola_buffer_frame); std::vector fade_out_window(sola_buffer_frame); @@ -308,50 +291,35 @@ class llm_task { fade_out_window[i] = 1.0f - fade_in_window[i]; } - // Initialize SOLA buffer std::vector sola_buffer(sola_buffer_frame, 0.0f); bool first_frame = true; std::vector pcmlist; - // Main decoding loop - process each slice for (int i = 0; i < dec_slice_num; i++) { - // Calculate start position for current batch input int input_start = i * effective_frames; - // Consider forward padding, but ensure non-negative if (i > 0) { input_start -= pad_frames; } input_start = std::max(0, input_start); - // Actual input length int actual_len = std::min(dec_len, static_cast(zp_shape[2] - input_start)); - // Calculate effective output range (frame level) int output_start_frame, output_end_frame; if (i == 0) { - // First frame: skip padding at beginning output_start_frame = 0; output_end_frame = effective_frames - 1; } else if (i == dec_slice_num - 1) { - // Last frame: calculate from current segment start output_start_frame = i * effective_frames; - // Last frame extends to encoder's maximum output length - output_end_frame = static_cast(zp_shape[2]) - 1; + output_end_frame = static_cast(zp_shape[2]) - 1; } else { - // Middle frames: standard calculation output_start_frame = i * effective_frames; output_end_frame = (i + 1) * effective_frames - 1; } - SLOGI("Inference #%d: input frame range=[%d-%d], actual length=%d, output frame range=[%d-%d]", i + 1, - input_start, input_start + actual_len - 1, actual_len, output_start_frame, output_end_frame); - - // Prepare decoder input, initialize all to zero std::vector zp(zp_size, 0); - // Copy data to decoder input for (int n = 0; n < zp_shape[1]; n++) { int copy_size = std::min(actual_len, static_cast(zp_shape[2] - input_start)); if (copy_size > 0) { @@ -360,76 +328,49 @@ class llm_task { } } - // Run decoder std::vector decoder_output(audio_slice_len); decoder_->SetInput(zp.data(), 0); decoder_->SetInput(g_matrix.data(), 1); - SLOGI("Inference #%d: starting decoding...", i + 1); - if (0 != decoder_->Run()) { - SLOGI("Inference #%d: decoding failed", i + 1); throw std::string("decoder_ RunSync error"); } decoder_->GetOutput(decoder_output.data(), 0); - // === SOLA Processing Logic === if (first_frame) { - // Special handling for first frame - should not skip initial content - // First frame starts directly from decoder output without skipping - int audio_start = 0; // Start from beginning, don't skip pad_frames + int audio_start = 0; + int audio_len = decoder_output.size() - sola_buffer_frame; + audio_len = std::max(0, audio_len); - // Calculate data length for first frame - // First frame should preserve complete decoder output, only reserving sola_buffer_frame at the end - // for next frame alignment - int audio_len = decoder_output.size() - sola_buffer_frame; - - // Boundary check - audio_len = std::max(0, audio_len); // Ensure non-negative - - // Add first frame data if (audio_len > 0) { pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start, decoder_output.begin() + audio_start + audio_len); } - // Save sola_buffer_frame length from the end to SOLA buffer for next frame alignment int buffer_start = audio_len; - // Ensure sufficient data is available for copying if (buffer_start + sola_buffer_frame <= decoder_output.size()) { std::copy(decoder_output.begin() + buffer_start, decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin()); } else { - // Possible case: first frame data is shorter than sola_buffer_frame int available = static_cast(decoder_output.size() - buffer_start); if (available > 0) { std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin()); - // Fill with zeros std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f); } else { - // Completely insufficient data, fill all with zeros std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f); } } first_frame = false; - - SLOGI( - "Inference #%d: First frame processing, added %d samples from position %d to output, saved %d " - "samples to SOLA buffer", - i + 1, audio_len, audio_start, sola_buffer_frame); } else { - // Non-first frame: SOLA alignment required int audio_start = pad_frames * samples_per_frame; - // 1. Prepare search window - beginning portion of current frame std::vector search_window(sola_buffer_frame + sola_search_frame); std::copy(decoder_output.begin() + audio_start, decoder_output.begin() + audio_start + search_window.size(), search_window.begin()); - // 2. Find best alignment point (calculate cross-correlation) int best_offset = 0; float best_correlation = -1.0; @@ -442,7 +383,6 @@ class llm_task { energy += search_window[j + offset] * search_window[j + offset]; } - // Normalize correlation (avoid division by zero) float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f; if (normalized_correlation > best_correlation) { @@ -451,40 +391,28 @@ class llm_task { } } - SLOGI("Inference #%d: SOLA found best alignment offset %d with correlation coefficient %f", i + 1, - best_offset, best_correlation); - - // 3. Apply alignment offset int aligned_start = audio_start + best_offset; - // 4. Smooth transition processing (crossfade in alignment region) std::vector crossfade_region(sola_buffer_frame); for (int j = 0; j < sola_buffer_frame; j++) { - // Apply fade-in/fade-out window functions crossfade_region[j] = decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j]; } - // 5. Add crossfade region to output pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end()); int remaining_start = aligned_start + sola_buffer_frame; if (i == dec_slice_num - 1) { int total_expected_samples = audio_len * samples_per_frame / 512; - - int processed_samples = static_cast(pcmlist.size()); - - int remaining_needed = total_expected_samples - processed_samples; - remaining_needed = std::max(0, remaining_needed); + int processed_samples = static_cast(pcmlist.size()); + int remaining_needed = total_expected_samples - processed_samples; + remaining_needed = std::max(0, remaining_needed); int remaining_len = std::min(remaining_needed, static_cast(decoder_output.size() - remaining_start)); - SLOGI("Inference #%d (final): Expected total=%d, processed=%d, needed=%d, available=%d", i + 1, - total_expected_samples, processed_samples, remaining_needed, remaining_len); - if (remaining_len > 0) { pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start, decoder_output.begin() + remaining_start + remaining_len); @@ -492,7 +420,6 @@ class llm_task { } else { int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame; - remaining_len = std::min(remaining_len, static_cast(decoder_output.size() - remaining_start)); @@ -514,55 +441,33 @@ class llm_task { } std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f); } - - SLOGI("Inference #%d: Added %d + %d samples to output, cumulative length: %zu", i + 1, - sola_buffer_frame, remaining_len, pcmlist.size()); } } } - SLOGI("All inference completed, raw generated PCM length: %zu", pcmlist.size()); - if (pcmlist.size() > audio_len) { - SLOGI("Truncating output from %zu to %d samples as per encoder prediction", pcmlist.size(), audio_len); pcmlist.resize(audio_len); } - SLOGI("Final PCM length after truncation: %zu", pcmlist.size()); - - // Post-processing: resample and convert to int16 double src_ratio = static_cast(mode_config_.audio_rate) / static_cast(mode_config_.mode_rate); std::vector tmp_pcm((pcmlist.size() * src_ratio + 1)); int len; - SLOGI("Starting audio resampling, source rate: %f, target rate: %f, ratio: %f", - static_cast(mode_config_.mode_rate), static_cast(mode_config_.audio_rate), src_ratio); - resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio); - SLOGI("Resampling completed, length after resampling: %d", len); - - // Convert to 16-bit PCM wav_pcm_data.reserve(len); std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data), [](const auto val) { return static_cast(val * INT16_MAX); }); - SLOGI("Final audio length: %zu samples", wav_pcm_data.size()); - - // Call the output callback function with the result if (out_callback_) { out_callback_( std::string(reinterpret_cast(wav_pcm_data.data()), wav_pcm_data.size() * sizeof(int16_t)), finish); } - - SLOGI("TTS processing completed, output callback invoked"); } catch (const std::exception &e) { - SLOGI("TTS processing exception: %s", e.what()); return true; } catch (...) { - SLOGI("TTS processing encountered an unknown exception"); return true; } return false; @@ -975,4 +880,4 @@ int main(int argc, char *argv[]) } llm.llm_firework_exit(); return 0; -} \ No newline at end of file +} From 74c41a3e4479d883c527fffef052a1ce4f70fb10 Mon Sep 17 00:00:00 2001 From: yuyun2000 <15515722313yxw@gmail.com> Date: Fri, 16 May 2025 11:10:58 +0800 Subject: [PATCH 03/79] Add text normalization for Chinese, Japanese, and English Implement regex-based text normalization functionality to support trilingual (CJE) content processing --- .../llm_framework/include/fst/accumulator.h | 903 +++++++++ projects/llm_framework/include/fst/add-on.h | 248 +++ .../llm_framework/include/fst/arc-arena.h | 232 +++ projects/llm_framework/include/fst/arc-map.h | 1285 +++++++++++++ projects/llm_framework/include/fst/arc.h | 317 ++++ .../llm_framework/include/fst/arcfilter.h | 93 + projects/llm_framework/include/fst/arcsort.h | 211 +++ projects/llm_framework/include/fst/bi-table.h | 480 +++++ projects/llm_framework/include/fst/cache.h | 1327 +++++++++++++ projects/llm_framework/include/fst/closure.h | 134 ++ .../llm_framework/include/fst/compact-fst.h | 1564 ++++++++++++++++ projects/llm_framework/include/fst/compat.h | 130 ++ .../llm_framework/include/fst/complement.h | 277 +++ .../include/fst/compose-filter.h | 571 ++++++ projects/llm_framework/include/fst/compose.h | 1035 ++++++++++ projects/llm_framework/include/fst/concat.h | 220 +++ projects/llm_framework/include/fst/config.h | 3 + .../llm_framework/include/fst/config.h.in | 11 + projects/llm_framework/include/fst/connect.h | 323 ++++ .../llm_framework/include/fst/const-fst.h | 485 +++++ .../llm_framework/include/fst/determinize.h | 1093 +++++++++++ .../llm_framework/include/fst/dfs-visit.h | 202 ++ .../llm_framework/include/fst/difference.h | 205 ++ .../llm_framework/include/fst/disambiguate.h | 564 ++++++ projects/llm_framework/include/fst/edit-fst.h | 702 +++++++ projects/llm_framework/include/fst/encode.h | 556 ++++++ .../llm_framework/include/fst/epsnormalize.h | 61 + projects/llm_framework/include/fst/equal.h | 169 ++ .../llm_framework/include/fst/equivalent.h | 230 +++ .../llm_framework/include/fst/expanded-fst.h | 179 ++ .../include/fst/expectation-weight.h | 134 ++ .../fst/extensions/compress/compress-script.h | 53 + .../fst/extensions/compress/compress.h | 906 +++++++++ .../include/fst/extensions/compress/elias.h | 97 + .../include/fst/extensions/compress/gzfile.h | 127 ++ .../include/fst/extensions/compress/randmod.h | 102 + .../fst/extensions/far/compile-strings.h | 260 +++ .../include/fst/extensions/far/create.h | 46 + .../include/fst/extensions/far/equal.h | 69 + .../include/fst/extensions/far/extract.h | 118 ++ .../include/fst/extensions/far/far-class.h | 258 +++ .../include/fst/extensions/far/far.h | 481 +++++ .../include/fst/extensions/far/farlib.h | 19 + .../include/fst/extensions/far/farscript.h | 269 +++ .../include/fst/extensions/far/getters.h | 30 + .../include/fst/extensions/far/info.h | 147 ++ .../include/fst/extensions/far/isomorphic.h | 69 + .../fst/extensions/far/print-strings.h | 105 ++ .../include/fst/extensions/far/script-impl.h | 23 + .../include/fst/extensions/far/stlist.h | 273 +++ .../include/fst/extensions/far/sttable.h | 353 ++++ .../linear/linear-fst-data-builder.h | 1074 +++++++++++ .../fst/extensions/linear/linear-fst-data.h | 526 ++++++ .../fst/extensions/linear/linear-fst.h | 1173 ++++++++++++ .../fst/extensions/linear/linearscript.h | 391 ++++ .../fst/extensions/linear/loglinear-apply.h | 77 + .../include/fst/extensions/linear/trie.h | 444 +++++ .../include/fst/extensions/mpdt/compose.h | 267 +++ .../include/fst/extensions/mpdt/expand.h | 335 ++++ .../include/fst/extensions/mpdt/info.h | 190 ++ .../include/fst/extensions/mpdt/mpdt.h | 357 ++++ .../include/fst/extensions/mpdt/mpdtlib.h | 18 + .../include/fst/extensions/mpdt/mpdtscript.h | 156 ++ .../fst/extensions/mpdt/read_write_utils.h | 86 + .../include/fst/extensions/mpdt/reverse.h | 54 + .../fst/extensions/ngram/bitmap-index.h | 168 ++ .../include/fst/extensions/ngram/ngram-fst.h | 1027 ++++++++++ .../include/fst/extensions/ngram/nthbit.h | 49 + .../include/fst/extensions/pdt/collection.h | 107 ++ .../include/fst/extensions/pdt/compose.h | 493 +++++ .../include/fst/extensions/pdt/expand.h | 933 +++++++++ .../include/fst/extensions/pdt/getters.h | 22 + .../include/fst/extensions/pdt/info.h | 152 ++ .../include/fst/extensions/pdt/paren.h | 440 +++++ .../include/fst/extensions/pdt/pdt.h | 165 ++ .../include/fst/extensions/pdt/pdtlib.h | 19 + .../include/fst/extensions/pdt/pdtscript.h | 244 +++ .../include/fst/extensions/pdt/replace.h | 827 ++++++++ .../include/fst/extensions/pdt/reverse.h | 38 + .../fst/extensions/pdt/shortest-path.h | 715 +++++++ .../include/fst/extensions/special/phi-fst.h | 183 ++ .../include/fst/extensions/special/rho-fst.h | 172 ++ .../fst/extensions/special/sigma-fst.h | 176 ++ .../llm_framework/include/fst/factor-weight.h | 496 +++++ .../llm_framework/include/fst/filter-state.h | 199 ++ projects/llm_framework/include/fst/flags.h | 219 +++ .../llm_framework/include/fst/float-weight.h | 858 +++++++++ projects/llm_framework/include/fst/fst-decl.h | 254 +++ projects/llm_framework/include/fst/fst.h | 1007 ++++++++++ projects/llm_framework/include/fst/fstlib.h | 130 ++ .../include/fst/generic-register.h | 126 ++ projects/llm_framework/include/fst/heap.h | 168 ++ projects/llm_framework/include/fst/icu.h | 129 ++ .../llm_framework/include/fst/intersect.h | 181 ++ .../llm_framework/include/fst/interval-set.h | 398 ++++ projects/llm_framework/include/fst/invert.h | 139 ++ .../llm_framework/include/fst/isomorphic.h | 183 ++ .../include/fst/label-reachable.h | 511 +++++ .../include/fst/lexicographic-weight.h | 173 ++ projects/llm_framework/include/fst/lock.h | 62 + projects/llm_framework/include/fst/log.h | 78 + .../include/fst/lookahead-filter.h | 623 ++++++ .../include/fst/lookahead-matcher.h | 841 +++++++++ projects/llm_framework/include/fst/map.h | 110 ++ .../llm_framework/include/fst/mapped-file.h | 81 + .../llm_framework/include/fst/matcher-fst.h | 347 ++++ projects/llm_framework/include/fst/matcher.h | 1575 ++++++++++++++++ projects/llm_framework/include/fst/memory.h | 443 +++++ projects/llm_framework/include/fst/minimize.h | 568 ++++++ .../llm_framework/include/fst/mutable-fst.h | 398 ++++ .../llm_framework/include/fst/pair-weight.h | 155 ++ .../llm_framework/include/fst/partition.h | 305 +++ .../llm_framework/include/fst/power-weight.h | 168 ++ .../include/fst/product-weight.h | 107 ++ projects/llm_framework/include/fst/project.h | 159 ++ .../llm_framework/include/fst/properties.h | 468 +++++ projects/llm_framework/include/fst/prune.h | 341 ++++ projects/llm_framework/include/fst/push.h | 155 ++ projects/llm_framework/include/fst/queue.h | 948 ++++++++++ .../include/fst/randequivalent.h | 114 ++ projects/llm_framework/include/fst/randgen.h | 756 ++++++++ projects/llm_framework/include/fst/rational.h | 307 +++ projects/llm_framework/include/fst/register.h | 115 ++ projects/llm_framework/include/fst/relabel.h | 472 +++++ .../llm_framework/include/fst/replace-util.h | 629 +++++++ projects/llm_framework/include/fst/replace.h | 1492 +++++++++++++++ projects/llm_framework/include/fst/reverse.h | 116 ++ projects/llm_framework/include/fst/reweight.h | 127 ++ .../llm_framework/include/fst/rmepsilon.h | 548 ++++++ .../include/fst/rmfinalepsilon.h | 80 + .../include/fst/script/arc-class.h | 40 + .../include/fst/script/arciterator-class.h | 212 +++ .../include/fst/script/arcsort.h | 44 + .../include/fst/script/arg-packs.h | 37 + .../include/fst/script/closure.h | 28 + .../include/fst/script/compile-impl.h | 217 +++ .../include/fst/script/compile.h | 98 + .../include/fst/script/compose.h | 34 + .../llm_framework/include/fst/script/concat.h | 40 + .../include/fst/script/connect.h | 23 + .../include/fst/script/convert.h | 35 + .../llm_framework/include/fst/script/decode.h | 49 + .../include/fst/script/determinize.h | 59 + .../include/fst/script/difference.h | 35 + .../include/fst/script/disambiguate.h | 54 + .../include/fst/script/draw-impl.h | 227 +++ .../llm_framework/include/fst/script/draw.h | 85 + .../llm_framework/include/fst/script/encode.h | 51 + .../include/fst/script/encodemapper-class.h | 169 ++ .../include/fst/script/epsnormalize.h | 31 + .../llm_framework/include/fst/script/equal.h | 32 + .../include/fst/script/equivalent.h | 34 + .../include/fst/script/fst-class.h | 530 ++++++ .../include/fst/script/fstscript-decl.h | 32 + .../include/fst/script/fstscript.h | 155 ++ .../include/fst/script/getters.h | 76 + .../include/fst/script/info-impl.h | 314 ++++ .../llm_framework/include/fst/script/info.h | 50 + .../include/fst/script/intersect.h | 35 + .../llm_framework/include/fst/script/invert.h | 23 + .../include/fst/script/isomorphic.h | 34 + .../llm_framework/include/fst/script/map.h | 158 ++ .../include/fst/script/minimize.h | 33 + .../include/fst/script/print-impl.h | 132 ++ .../llm_framework/include/fst/script/print.h | 79 + .../include/fst/script/project.h | 28 + .../llm_framework/include/fst/script/prune.h | 51 + .../llm_framework/include/fst/script/push.h | 53 + .../include/fst/script/randequivalent.h | 67 + .../include/fst/script/randgen.h | 63 + .../include/fst/script/register.h | 99 + .../include/fst/script/relabel.h | 64 + .../include/fst/script/replace.h | 72 + .../include/fst/script/reverse.h | 30 + .../include/fst/script/reweight.h | 37 + .../include/fst/script/rmepsilon.h | 109 ++ .../include/fst/script/script-impl.h | 211 +++ .../include/fst/script/shortest-distance.h | 214 +++ .../include/fst/script/shortest-path.h | 116 ++ .../include/fst/script/stateiterator-class.h | 85 + .../include/fst/script/synchronize.h | 29 + .../include/fst/script/text-io.h | 28 + .../include/fst/script/topsort.h | 26 + .../llm_framework/include/fst/script/union.h | 29 + .../llm_framework/include/fst/script/verify.h | 27 + .../include/fst/script/weight-class.h | 235 +++ .../llm_framework/include/fst/set-weight.h | 618 ++++++ .../include/fst/shortest-distance.h | 351 ++++ .../llm_framework/include/fst/shortest-path.h | 549 ++++++ .../include/fst/signed-log-weight.h | 440 +++++ .../include/fst/sparse-power-weight.h | 209 +++ .../include/fst/sparse-tuple-weight.h | 422 +++++ .../llm_framework/include/fst/state-map.h | 613 ++++++ .../include/fst/state-reachable.h | 224 +++ .../llm_framework/include/fst/state-table.h | 494 +++++ .../llm_framework/include/fst/statesort.h | 74 + .../llm_framework/include/fst/string-weight.h | 807 ++++++++ projects/llm_framework/include/fst/string.h | 286 +++ .../include/fst/symbol-table-ops.h | 76 + .../llm_framework/include/fst/symbol-table.h | 445 +++++ .../llm_framework/include/fst/synchronize.h | 414 ++++ .../include/fst/test-properties.h | 246 +++ .../include/fst/test/algo_test.h | 1414 ++++++++++++++ .../llm_framework/include/fst/test/fst_test.h | 318 ++++ .../llm_framework/include/fst/test/rand-fst.h | 90 + .../include/fst/test/weight-tester.h | 207 ++ projects/llm_framework/include/fst/topsort.h | 95 + .../llm_framework/include/fst/tuple-weight.h | 163 ++ projects/llm_framework/include/fst/types.h | 41 + .../llm_framework/include/fst/union-find.h | 84 + .../llm_framework/include/fst/union-weight.h | 505 +++++ projects/llm_framework/include/fst/union.h | 157 ++ projects/llm_framework/include/fst/util.h | 400 ++++ .../llm_framework/include/fst/vector-fst.h | 796 ++++++++ projects/llm_framework/include/fst/verify.h | 100 + projects/llm_framework/include/fst/visit.h | 321 ++++ projects/llm_framework/include/fst/weight.h | 389 ++++ .../llm_framework/include/gflags/defines.h | 48 + .../llm_framework/include/gflags/gflags.h | 626 +++++++ .../include/gflags/gflags_completions.h | 121 ++ .../include/gflags/gflags_declare.h | 156 ++ .../llm_framework/include/glog/log_severity.h | 92 + projects/llm_framework/include/glog/logging.h | 1662 +++++++++++++++++ .../llm_framework/include/glog/raw_logging.h | 180 ++ .../llm_framework/include/glog/stl_logging.h | 220 +++ .../llm_framework/include/glog/vlog_is_on.h | 129 ++ .../llm_framework/main_melotts/SConstruct | 7 +- .../main_melotts/mode_melotts-en-default.json | 2 + .../main_melotts/mode_melotts-en-us.json | 8 +- .../main_melotts/mode_melotts-ja-jp.json | 2 + .../main_melotts/mode_melotts-zh-cn.json | 2 + .../llm_framework/main_melotts/src/main.cpp | 20 +- .../main_melotts/src/runner/Lexicon.hpp | 25 +- .../main_melotts/src/runner/base64.cpp | 100 +- .../main_melotts/src/runner/base64.h | 10 +- .../src/runner/processor/CMakeLists.txt | 13 + .../src/runner/processor/wetext_processor.cc | 86 + .../src/runner/processor/wetext_processor.h | 51 + .../runner/processor/wetext_token_parser.cc | 161 ++ .../runner/processor/wetext_token_parser.h | 94 + .../src/runner/utils/CMakeLists.txt | 3 + .../src/runner/utils/wetext_flags.h | 23 + .../src/runner/utils/wetext_log.h | 23 + .../src/runner/utils/wetext_string.cc | 89 + .../src/runner/utils/wetext_string.h | 42 + 245 files changed, 66914 insertions(+), 73 deletions(-) create mode 100644 projects/llm_framework/include/fst/accumulator.h create mode 100644 projects/llm_framework/include/fst/add-on.h create mode 100644 projects/llm_framework/include/fst/arc-arena.h create mode 100644 projects/llm_framework/include/fst/arc-map.h create mode 100644 projects/llm_framework/include/fst/arc.h create mode 100644 projects/llm_framework/include/fst/arcfilter.h create mode 100644 projects/llm_framework/include/fst/arcsort.h create mode 100644 projects/llm_framework/include/fst/bi-table.h create mode 100644 projects/llm_framework/include/fst/cache.h create mode 100644 projects/llm_framework/include/fst/closure.h create mode 100644 projects/llm_framework/include/fst/compact-fst.h create mode 100644 projects/llm_framework/include/fst/compat.h create mode 100644 projects/llm_framework/include/fst/complement.h create mode 100644 projects/llm_framework/include/fst/compose-filter.h create mode 100644 projects/llm_framework/include/fst/compose.h create mode 100644 projects/llm_framework/include/fst/concat.h create mode 100644 projects/llm_framework/include/fst/config.h create mode 100644 projects/llm_framework/include/fst/config.h.in create mode 100644 projects/llm_framework/include/fst/connect.h create mode 100644 projects/llm_framework/include/fst/const-fst.h create mode 100644 projects/llm_framework/include/fst/determinize.h create mode 100644 projects/llm_framework/include/fst/dfs-visit.h create mode 100644 projects/llm_framework/include/fst/difference.h create mode 100644 projects/llm_framework/include/fst/disambiguate.h create mode 100644 projects/llm_framework/include/fst/edit-fst.h create mode 100644 projects/llm_framework/include/fst/encode.h create mode 100644 projects/llm_framework/include/fst/epsnormalize.h create mode 100644 projects/llm_framework/include/fst/equal.h create mode 100644 projects/llm_framework/include/fst/equivalent.h create mode 100644 projects/llm_framework/include/fst/expanded-fst.h create mode 100644 projects/llm_framework/include/fst/expectation-weight.h create mode 100644 projects/llm_framework/include/fst/extensions/compress/compress-script.h create mode 100644 projects/llm_framework/include/fst/extensions/compress/compress.h create mode 100644 projects/llm_framework/include/fst/extensions/compress/elias.h create mode 100644 projects/llm_framework/include/fst/extensions/compress/gzfile.h create mode 100644 projects/llm_framework/include/fst/extensions/compress/randmod.h create mode 100644 projects/llm_framework/include/fst/extensions/far/compile-strings.h create mode 100644 projects/llm_framework/include/fst/extensions/far/create.h create mode 100644 projects/llm_framework/include/fst/extensions/far/equal.h create mode 100644 projects/llm_framework/include/fst/extensions/far/extract.h create mode 100644 projects/llm_framework/include/fst/extensions/far/far-class.h create mode 100644 projects/llm_framework/include/fst/extensions/far/far.h create mode 100644 projects/llm_framework/include/fst/extensions/far/farlib.h create mode 100644 projects/llm_framework/include/fst/extensions/far/farscript.h create mode 100644 projects/llm_framework/include/fst/extensions/far/getters.h create mode 100644 projects/llm_framework/include/fst/extensions/far/info.h create mode 100644 projects/llm_framework/include/fst/extensions/far/isomorphic.h create mode 100644 projects/llm_framework/include/fst/extensions/far/print-strings.h create mode 100644 projects/llm_framework/include/fst/extensions/far/script-impl.h create mode 100644 projects/llm_framework/include/fst/extensions/far/stlist.h create mode 100644 projects/llm_framework/include/fst/extensions/far/sttable.h create mode 100644 projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h create mode 100644 projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h create mode 100644 projects/llm_framework/include/fst/extensions/linear/linear-fst.h create mode 100644 projects/llm_framework/include/fst/extensions/linear/linearscript.h create mode 100644 projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h create mode 100644 projects/llm_framework/include/fst/extensions/linear/trie.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/compose.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/expand.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/info.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/mpdt.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/mpdtlib.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/mpdtscript.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/read_write_utils.h create mode 100644 projects/llm_framework/include/fst/extensions/mpdt/reverse.h create mode 100644 projects/llm_framework/include/fst/extensions/ngram/bitmap-index.h create mode 100644 projects/llm_framework/include/fst/extensions/ngram/ngram-fst.h create mode 100644 projects/llm_framework/include/fst/extensions/ngram/nthbit.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/collection.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/compose.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/expand.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/getters.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/info.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/paren.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/pdt.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/pdtlib.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/pdtscript.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/replace.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/reverse.h create mode 100644 projects/llm_framework/include/fst/extensions/pdt/shortest-path.h create mode 100644 projects/llm_framework/include/fst/extensions/special/phi-fst.h create mode 100644 projects/llm_framework/include/fst/extensions/special/rho-fst.h create mode 100644 projects/llm_framework/include/fst/extensions/special/sigma-fst.h create mode 100644 projects/llm_framework/include/fst/factor-weight.h create mode 100644 projects/llm_framework/include/fst/filter-state.h create mode 100644 projects/llm_framework/include/fst/flags.h create mode 100644 projects/llm_framework/include/fst/float-weight.h create mode 100644 projects/llm_framework/include/fst/fst-decl.h create mode 100644 projects/llm_framework/include/fst/fst.h create mode 100644 projects/llm_framework/include/fst/fstlib.h create mode 100644 projects/llm_framework/include/fst/generic-register.h create mode 100644 projects/llm_framework/include/fst/heap.h create mode 100644 projects/llm_framework/include/fst/icu.h create mode 100644 projects/llm_framework/include/fst/intersect.h create mode 100644 projects/llm_framework/include/fst/interval-set.h create mode 100644 projects/llm_framework/include/fst/invert.h create mode 100644 projects/llm_framework/include/fst/isomorphic.h create mode 100644 projects/llm_framework/include/fst/label-reachable.h create mode 100644 projects/llm_framework/include/fst/lexicographic-weight.h create mode 100644 projects/llm_framework/include/fst/lock.h create mode 100644 projects/llm_framework/include/fst/log.h create mode 100644 projects/llm_framework/include/fst/lookahead-filter.h create mode 100644 projects/llm_framework/include/fst/lookahead-matcher.h create mode 100644 projects/llm_framework/include/fst/map.h create mode 100644 projects/llm_framework/include/fst/mapped-file.h create mode 100644 projects/llm_framework/include/fst/matcher-fst.h create mode 100644 projects/llm_framework/include/fst/matcher.h create mode 100644 projects/llm_framework/include/fst/memory.h create mode 100644 projects/llm_framework/include/fst/minimize.h create mode 100644 projects/llm_framework/include/fst/mutable-fst.h create mode 100644 projects/llm_framework/include/fst/pair-weight.h create mode 100644 projects/llm_framework/include/fst/partition.h create mode 100644 projects/llm_framework/include/fst/power-weight.h create mode 100644 projects/llm_framework/include/fst/product-weight.h create mode 100644 projects/llm_framework/include/fst/project.h create mode 100644 projects/llm_framework/include/fst/properties.h create mode 100644 projects/llm_framework/include/fst/prune.h create mode 100644 projects/llm_framework/include/fst/push.h create mode 100644 projects/llm_framework/include/fst/queue.h create mode 100644 projects/llm_framework/include/fst/randequivalent.h create mode 100644 projects/llm_framework/include/fst/randgen.h create mode 100644 projects/llm_framework/include/fst/rational.h create mode 100644 projects/llm_framework/include/fst/register.h create mode 100644 projects/llm_framework/include/fst/relabel.h create mode 100644 projects/llm_framework/include/fst/replace-util.h create mode 100644 projects/llm_framework/include/fst/replace.h create mode 100644 projects/llm_framework/include/fst/reverse.h create mode 100644 projects/llm_framework/include/fst/reweight.h create mode 100644 projects/llm_framework/include/fst/rmepsilon.h create mode 100644 projects/llm_framework/include/fst/rmfinalepsilon.h create mode 100644 projects/llm_framework/include/fst/script/arc-class.h create mode 100644 projects/llm_framework/include/fst/script/arciterator-class.h create mode 100644 projects/llm_framework/include/fst/script/arcsort.h create mode 100644 projects/llm_framework/include/fst/script/arg-packs.h create mode 100644 projects/llm_framework/include/fst/script/closure.h create mode 100644 projects/llm_framework/include/fst/script/compile-impl.h create mode 100644 projects/llm_framework/include/fst/script/compile.h create mode 100644 projects/llm_framework/include/fst/script/compose.h create mode 100644 projects/llm_framework/include/fst/script/concat.h create mode 100644 projects/llm_framework/include/fst/script/connect.h create mode 100644 projects/llm_framework/include/fst/script/convert.h create mode 100644 projects/llm_framework/include/fst/script/decode.h create mode 100644 projects/llm_framework/include/fst/script/determinize.h create mode 100644 projects/llm_framework/include/fst/script/difference.h create mode 100644 projects/llm_framework/include/fst/script/disambiguate.h create mode 100644 projects/llm_framework/include/fst/script/draw-impl.h create mode 100644 projects/llm_framework/include/fst/script/draw.h create mode 100644 projects/llm_framework/include/fst/script/encode.h create mode 100644 projects/llm_framework/include/fst/script/encodemapper-class.h create mode 100644 projects/llm_framework/include/fst/script/epsnormalize.h create mode 100644 projects/llm_framework/include/fst/script/equal.h create mode 100644 projects/llm_framework/include/fst/script/equivalent.h create mode 100644 projects/llm_framework/include/fst/script/fst-class.h create mode 100644 projects/llm_framework/include/fst/script/fstscript-decl.h create mode 100644 projects/llm_framework/include/fst/script/fstscript.h create mode 100644 projects/llm_framework/include/fst/script/getters.h create mode 100644 projects/llm_framework/include/fst/script/info-impl.h create mode 100644 projects/llm_framework/include/fst/script/info.h create mode 100644 projects/llm_framework/include/fst/script/intersect.h create mode 100644 projects/llm_framework/include/fst/script/invert.h create mode 100644 projects/llm_framework/include/fst/script/isomorphic.h create mode 100644 projects/llm_framework/include/fst/script/map.h create mode 100644 projects/llm_framework/include/fst/script/minimize.h create mode 100644 projects/llm_framework/include/fst/script/print-impl.h create mode 100644 projects/llm_framework/include/fst/script/print.h create mode 100644 projects/llm_framework/include/fst/script/project.h create mode 100644 projects/llm_framework/include/fst/script/prune.h create mode 100644 projects/llm_framework/include/fst/script/push.h create mode 100644 projects/llm_framework/include/fst/script/randequivalent.h create mode 100644 projects/llm_framework/include/fst/script/randgen.h create mode 100644 projects/llm_framework/include/fst/script/register.h create mode 100644 projects/llm_framework/include/fst/script/relabel.h create mode 100644 projects/llm_framework/include/fst/script/replace.h create mode 100644 projects/llm_framework/include/fst/script/reverse.h create mode 100644 projects/llm_framework/include/fst/script/reweight.h create mode 100644 projects/llm_framework/include/fst/script/rmepsilon.h create mode 100644 projects/llm_framework/include/fst/script/script-impl.h create mode 100644 projects/llm_framework/include/fst/script/shortest-distance.h create mode 100644 projects/llm_framework/include/fst/script/shortest-path.h create mode 100644 projects/llm_framework/include/fst/script/stateiterator-class.h create mode 100644 projects/llm_framework/include/fst/script/synchronize.h create mode 100644 projects/llm_framework/include/fst/script/text-io.h create mode 100644 projects/llm_framework/include/fst/script/topsort.h create mode 100644 projects/llm_framework/include/fst/script/union.h create mode 100644 projects/llm_framework/include/fst/script/verify.h create mode 100644 projects/llm_framework/include/fst/script/weight-class.h create mode 100644 projects/llm_framework/include/fst/set-weight.h create mode 100644 projects/llm_framework/include/fst/shortest-distance.h create mode 100644 projects/llm_framework/include/fst/shortest-path.h create mode 100644 projects/llm_framework/include/fst/signed-log-weight.h create mode 100644 projects/llm_framework/include/fst/sparse-power-weight.h create mode 100644 projects/llm_framework/include/fst/sparse-tuple-weight.h create mode 100644 projects/llm_framework/include/fst/state-map.h create mode 100644 projects/llm_framework/include/fst/state-reachable.h create mode 100644 projects/llm_framework/include/fst/state-table.h create mode 100644 projects/llm_framework/include/fst/statesort.h create mode 100644 projects/llm_framework/include/fst/string-weight.h create mode 100644 projects/llm_framework/include/fst/string.h create mode 100644 projects/llm_framework/include/fst/symbol-table-ops.h create mode 100644 projects/llm_framework/include/fst/symbol-table.h create mode 100644 projects/llm_framework/include/fst/synchronize.h create mode 100644 projects/llm_framework/include/fst/test-properties.h create mode 100644 projects/llm_framework/include/fst/test/algo_test.h create mode 100644 projects/llm_framework/include/fst/test/fst_test.h create mode 100644 projects/llm_framework/include/fst/test/rand-fst.h create mode 100644 projects/llm_framework/include/fst/test/weight-tester.h create mode 100644 projects/llm_framework/include/fst/topsort.h create mode 100644 projects/llm_framework/include/fst/tuple-weight.h create mode 100644 projects/llm_framework/include/fst/types.h create mode 100644 projects/llm_framework/include/fst/union-find.h create mode 100644 projects/llm_framework/include/fst/union-weight.h create mode 100644 projects/llm_framework/include/fst/union.h create mode 100644 projects/llm_framework/include/fst/util.h create mode 100644 projects/llm_framework/include/fst/vector-fst.h create mode 100644 projects/llm_framework/include/fst/verify.h create mode 100644 projects/llm_framework/include/fst/visit.h create mode 100644 projects/llm_framework/include/fst/weight.h create mode 100644 projects/llm_framework/include/gflags/defines.h create mode 100644 projects/llm_framework/include/gflags/gflags.h create mode 100644 projects/llm_framework/include/gflags/gflags_completions.h create mode 100644 projects/llm_framework/include/gflags/gflags_declare.h create mode 100644 projects/llm_framework/include/glog/log_severity.h create mode 100644 projects/llm_framework/include/glog/logging.h create mode 100644 projects/llm_framework/include/glog/raw_logging.h create mode 100644 projects/llm_framework/include/glog/stl_logging.h create mode 100644 projects/llm_framework/include/glog/vlog_is_on.h create mode 100644 projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt create mode 100644 projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc create mode 100644 projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h create mode 100644 projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc create mode 100644 projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h create mode 100644 projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt create mode 100644 projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h create mode 100644 projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h create mode 100644 projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc create mode 100644 projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h diff --git a/projects/llm_framework/include/fst/accumulator.h b/projects/llm_framework/include/fst/accumulator.h new file mode 100644 index 00000000..5ae19247 --- /dev/null +++ b/projects/llm_framework/include/fst/accumulator.h @@ -0,0 +1,903 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes to accumulate arc weights. Useful for weight lookahead. + +#ifndef FST_ACCUMULATOR_H_ +#define FST_ACCUMULATOR_H_ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + +namespace fst { + +// This class accumulates arc weights using the semiring Plus(). +template +class DefaultAccumulator { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + DefaultAccumulator() {} + + DefaultAccumulator(const DefaultAccumulator &acc, bool safe = false) {} + + void Init(const Fst &fst, bool copy = false) {} + + void SetState(StateId state) {} + + Weight Sum(Weight w, Weight v) { return Plus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + Adder adder(w); // maintains cumulative sum accurately + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) + adder.Add(aiter->Value().weight); + return adder.Sum(); + } + + constexpr bool Error() const { return false; } + + private: + DefaultAccumulator &operator=(const DefaultAccumulator &) = delete; +}; + +// This class accumulates arc weights using the log semiring Plus() assuming an +// arc weight has a WeightConvert specialization to and from log64 weights. +template +class LogAccumulator { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + LogAccumulator() {} + + LogAccumulator(const LogAccumulator &acc, bool safe = false) {} + + void Init(const Fst &fst, bool copy = false) {} + + void SetState(StateId s) {} + + Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + auto sum = w; + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + } + + constexpr bool Error() const { return false; } + + private: + Weight LogPlus(Weight w, Weight v) { + if (w == Weight::Zero()) { + return v; + } + const auto f1 = to_log_weight_(w).Value(); + const auto f2 = to_log_weight_(v).Value(); + if (f1 > f2) { + return to_weight_(Log64Weight(f2 - internal::LogPosExp(f1 - f2))); + } else { + return to_weight_(Log64Weight(f1 - internal::LogPosExp(f2 - f1))); + } + } + + const WeightConvert to_log_weight_{}; + const WeightConvert to_weight_{}; + + LogAccumulator &operator=(const LogAccumulator &) = delete; +}; + +// Interface for shareable data for fast log accumulator copies. Holds pointers +// to data only, storage is provided by derived classes. +class FastLogAccumulatorData { + public: + FastLogAccumulatorData(int arc_limit, int arc_period) + : arc_limit_(arc_limit), + arc_period_(arc_period), + weights_ptr_(nullptr), + num_weights_(0), + weight_positions_ptr_(nullptr), + num_positions_(0) {} + + virtual ~FastLogAccumulatorData() {} + + // Cummulative weight per state for all states s.t. # of arcs > arc_limit_ + // with arcs in order. The first element per state is Log64Weight::Zero(). + const double *Weights() const { return weights_ptr_; } + + int NumWeights() const { return num_weights_; } + + // Maps from state to corresponding beginning weight position in weights_. + // osition -1 means no pre-computed weights for that state. + const int *WeightPositions() const { return weight_positions_ptr_; } + + int NumPositions() const { return num_positions_; } + + int ArcLimit() const { return arc_limit_; } + + int ArcPeriod() const { return arc_period_; } + + // Returns true if the data object is mutable and supports SetData(). + virtual bool IsMutable() const = 0; + + // Does not take ownership but may invalidate the contents of weights and + // weight_positions. + virtual void SetData(std::vector *weights, + std::vector *weight_positions) = 0; + + protected: + void Init(int num_weights, const double *weights, int num_positions, + const int *weight_positions) { + weights_ptr_ = weights; + num_weights_ = num_weights; + weight_positions_ptr_ = weight_positions; + num_positions_ = num_positions; + } + + private: + const int arc_limit_; + const int arc_period_; + const double *weights_ptr_; + int num_weights_; + const int *weight_positions_ptr_; + int num_positions_; + + FastLogAccumulatorData(const FastLogAccumulatorData &) = delete; + FastLogAccumulatorData &operator=(const FastLogAccumulatorData &) = delete; +}; + +// FastLogAccumulatorData with mutable storage; filled by +// FastLogAccumulator::Init. +class MutableFastLogAccumulatorData : public FastLogAccumulatorData { + public: + MutableFastLogAccumulatorData(int arc_limit, int arc_period) + : FastLogAccumulatorData(arc_limit, arc_period) {} + + bool IsMutable() const override { return true; } + + void SetData(std::vector *weights, + std::vector *weight_positions) override { + weights_.swap(*weights); + weight_positions_.swap(*weight_positions); + Init(weights_.size(), weights_.data(), weight_positions_.size(), + weight_positions_.data()); + } + + private: + std::vector weights_; + std::vector weight_positions_; + + MutableFastLogAccumulatorData(const MutableFastLogAccumulatorData &) = delete; + MutableFastLogAccumulatorData &operator=( + const MutableFastLogAccumulatorData &) = delete; +}; + +// This class accumulates arc weights using the log semiring Plus() assuming an +// arc weight has a WeightConvert specialization to and from log64 weights. The +// member function Init(fst) has to be called to setup pre-computed weight +// information. +template +class FastLogAccumulator { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit FastLogAccumulator(ssize_t arc_limit = 20, ssize_t arc_period = 10) + : to_log_weight_(), + to_weight_(), + arc_limit_(arc_limit), + arc_period_(arc_period), + data_(std::make_shared(arc_limit, + arc_period)), + state_weights_(nullptr), + error_(false) {} + + explicit FastLogAccumulator(std::shared_ptr data) + : to_log_weight_(), + to_weight_(), + arc_limit_(data->ArcLimit()), + arc_period_(data->ArcPeriod()), + data_(data), + state_weights_(nullptr), + error_(false) {} + + FastLogAccumulator(const FastLogAccumulator &acc, bool safe = false) + : to_log_weight_(), + to_weight_(), + arc_limit_(acc.arc_limit_), + arc_period_(acc.arc_period_), + data_(acc.data_), + state_weights_(nullptr), + error_(acc.error_) {} + + void SetState(StateId s) { + const auto *weights = data_->Weights(); + const auto *weight_positions = data_->WeightPositions(); + state_weights_ = nullptr; + if (s < data_->NumPositions()) { + const auto pos = weight_positions[s]; + if (pos >= 0) state_weights_ = &(weights[pos]); + } + } + + Weight Sum(Weight w, Weight v) const { return LogPlus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) const { + if (error_) return Weight::NoWeight(); + auto sum = w; + // Finds begin and end of pre-stored weights. + ssize_t index_begin = -1; + ssize_t index_end = -1; + ssize_t stored_begin = end; + ssize_t stored_end = end; + if (state_weights_) { + index_begin = begin > 0 ? (begin - 1) / arc_period_ + 1 : 0; + index_end = end / arc_period_; + stored_begin = index_begin * arc_period_; + stored_end = index_end * arc_period_; + } + // Computes sum before pre-stored weights. + if (begin < stored_begin) { + const auto pos_end = std::min(stored_begin, end); + aiter->Seek(begin); + for (auto pos = begin; pos < pos_end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + } + // Computes sum between pre-stored weights. + if (stored_begin < stored_end) { + const auto f1 = state_weights_[index_end]; + const auto f2 = state_weights_[index_begin]; + if (f1 < f2) sum = LogPlus(sum, LogMinus(f1, f2)); + // Commented out for efficiency; adds Zero(). + /* + else { + // explicitly computes if cumulative sum lacks precision + aiter->Seek(stored_begin); + for (auto pos = stored_begin; pos < stored_end; aiter->Next(), ++pos) + sum = LogPlus(sum, aiter->Value().weight); + } + */ + } + // Computes sum after pre-stored weights. + if (stored_end < end) { + const auto pos_start = std::max(stored_begin, stored_end); + aiter->Seek(pos_start); + for (auto pos = pos_start; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + } + return sum; + } + + template + void Init(const FST &fst, bool copy = false) { + if (copy || !data_->IsMutable()) return; + if (data_->NumPositions() != 0 || arc_limit_ < arc_period_) { + FSTERROR() << "FastLogAccumulator: Initialization error"; + error_ = true; + return; + } + std::vector weights; + std::vector weight_positions; + weight_positions.reserve(CountStates(fst)); + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (fst.NumArcs(s) >= arc_limit_) { + auto sum = FloatLimits::PosInfinity(); + if (weight_positions.size() <= s) weight_positions.resize(s + 1, -1); + weight_positions[s] = weights.size(); + weights.push_back(sum); + size_t narcs = 0; + ArcIterator aiter(fst, s); + aiter.SetFlags(kArcWeightValue | kArcNoCache, kArcFlags); + for (; !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + sum = LogPlus(sum, arc.weight); + // Stores cumulative weight distribution per arc_period_. + if (++narcs % arc_period_ == 0) weights.push_back(sum); + } + } + } + data_->SetData(&weights, &weight_positions); + } + + bool Error() const { return error_; } + + std::shared_ptr GetData() const { return data_; } + + private: + static double LogPosExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F + exp(-x)); + } + + static double LogMinusExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F - exp(-x)); + } + + Weight LogPlus(Weight w, Weight v) const { + if (w == Weight::Zero()) { + return v; + } + const auto f1 = to_log_weight_(w).Value(); + const auto f2 = to_log_weight_(v).Value(); + if (f1 > f2) { + return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2))); + } else { + return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1))); + } + } + + double LogPlus(double f1, Weight v) const { + const auto f2 = to_log_weight_(v).Value(); + if (f1 == FloatLimits::PosInfinity()) { + return f2; + } else if (f1 > f2) { + return f2 - LogPosExp(f1 - f2); + } else { + return f1 - LogPosExp(f2 - f1); + } + } + + // Assumes f1 < f2. + Weight LogMinus(double f1, double f2) const { + if (f2 == FloatLimits::PosInfinity()) { + return to_weight_(Log64Weight(f1)); + } else { + return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1))); + } + } + + const WeightConvert to_log_weight_{}; + const WeightConvert to_weight_{}; + const ssize_t arc_limit_; // Minimum number of arcs to pre-compute state. + const ssize_t arc_period_; // Saves cumulative weights per arc_period_. + std::shared_ptr data_; + const double *state_weights_; + bool error_; + + FastLogAccumulator &operator=(const FastLogAccumulator &) = delete; +}; + +// Stores shareable data for cache log accumulator copies. All copies share the +// same cache. +template +class CacheLogAccumulatorData { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + CacheLogAccumulatorData(bool gc, size_t gc_limit) + : cache_gc_(gc), cache_limit_(gc_limit), cache_size_(0) {} + + CacheLogAccumulatorData(const CacheLogAccumulatorData &data) + : cache_gc_(data.cache_gc_), + cache_limit_(data.cache_limit_), + cache_size_(0) {} + + bool CacheDisabled() const { return cache_gc_ && cache_limit_ == 0; } + + std::vector *GetWeights(StateId s) { + auto it = cache_.find(s); + if (it != cache_.end()) { + it->second.recent = true; + return it->second.weights.get(); + } else { + return nullptr; + } + } + + void AddWeights(StateId s, std::vector *weights) { + if (cache_gc_ && cache_size_ >= cache_limit_) GC(false); + cache_.insert(std::make_pair(s, CacheState(weights, true))); + if (cache_gc_) cache_size_ += weights->capacity() * sizeof(double); + } + + private: + // Cached information for a given state. + struct CacheState { + std::unique_ptr> weights; // Accumulated weights. + bool recent; // Has this state been accessed since last GC? + + CacheState(std::vector *weights, bool recent) + : weights(weights), recent(recent) {} + }; + + // Garbage collect: Deletes from cache states that have not been accessed + // since the last GC ('free_recent = false') until 'cache_size_' is 2/3 of + // 'cache_limit_'. If it does not free enough memory, start deleting + // recently accessed states. + void GC(bool free_recent) { + auto cache_target = (2 * cache_limit_) / 3 + 1; + auto it = cache_.begin(); + while (it != cache_.end() && cache_size_ > cache_target) { + auto &cs = it->second; + if (free_recent || !cs.recent) { + cache_size_ -= cs.weights->capacity() * sizeof(double); + cache_.erase(it++); + } else { + cs.recent = false; + ++it; + } + } + if (!free_recent && cache_size_ > cache_target) GC(true); + } + + std::unordered_map cache_; // Cache. + bool cache_gc_; // Enables garbage collection. + size_t cache_limit_; // # of bytes cached. + size_t cache_size_; // # of bytes allowed before GC. + + CacheLogAccumulatorData &operator=(const CacheLogAccumulatorData &) = delete; +}; + +// This class accumulates arc weights using the log semiring Plus() has a +// WeightConvert specialization to and from log64 weights. It is similar to the +// FastLogAccumator. However here, the accumulated weights are pre-computed and +// stored only for the states that are visited. The member function Init(fst) +// has to be called to setup this accumulator. +template +class CacheLogAccumulator { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit CacheLogAccumulator(ssize_t arc_limit = 10, bool gc = false, + size_t gc_limit = 10 * 1024 * 1024) + : arc_limit_(arc_limit), + data_(std::make_shared>(gc, gc_limit)), + s_(kNoStateId), + error_(false) {} + + CacheLogAccumulator(const CacheLogAccumulator &acc, bool safe = false) + : arc_limit_(acc.arc_limit_), + fst_(acc.fst_ ? acc.fst_->Copy() : nullptr), + data_(safe ? std::make_shared>(*acc.data_) + : acc.data_), + s_(kNoStateId), + error_(acc.error_) {} + + // Argument arc_limit specifies the minimum number of arcs to pre-compute. + void Init(const Fst &fst, bool copy = false) { + if (!copy && fst_) { + FSTERROR() << "CacheLogAccumulator: Initialization error"; + error_ = true; + return; + } + fst_.reset(fst.Copy()); + } + + void SetState(StateId s, int depth = 0) { + if (s == s_) return; + s_ = s; + if (data_->CacheDisabled() || error_) { + weights_ = nullptr; + return; + } + if (!fst_) { + FSTERROR() << "CacheLogAccumulator::SetState: Incorrectly initialized"; + error_ = true; + weights_ = nullptr; + return; + } + weights_ = data_->GetWeights(s); + if ((weights_ == nullptr) && (fst_->NumArcs(s) >= arc_limit_)) { + weights_ = new std::vector; + weights_->reserve(fst_->NumArcs(s) + 1); + weights_->push_back(FloatLimits::PosInfinity()); + data_->AddWeights(s, weights_); + } + } + + Weight Sum(Weight w, Weight v) { return LogPlus(w, v); } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + if (weights_ == nullptr) { + auto sum = w; + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + } else { + Extend(end, aiter); + const auto &f1 = (*weights_)[end]; + const auto &f2 = (*weights_)[begin]; + if (f1 < f2) { + return LogPlus(w, LogMinus(f1, f2)); + } else { + // Commented out for efficiency; adds Zero(). + /* + auto sum = w; + // Explicitly computes if cumulative sum lacks precision. + aiter->Seek(begin); + for (auto pos = begin; pos < end; aiter->Next(), ++pos) { + sum = LogPlus(sum, aiter->Value().weight); + } + return sum; + */ + return w; + } + } + } + + // Returns first position from aiter->Position() whose accumulated + // value is greater or equal to w (w.r.t. Zero() < One()). The + // iterator may be repositioned. + template + size_t LowerBound(Weight w, ArcIter *aiter) { + const auto f = to_log_weight_(w).Value(); + auto pos = aiter->Position(); + if (weights_) { + Extend(fst_->NumArcs(s_), aiter); + return std::lower_bound(weights_->begin() + pos + 1, weights_->end(), + f, std::greater()) - + weights_->begin() - 1; + } else { + size_t n = 0; + auto x = FloatLimits::PosInfinity(); + for (aiter->Reset(); !aiter->Done(); aiter->Next(), ++n) { + x = LogPlus(x, aiter->Value().weight); + if (n >= pos && x <= f) break; + } + return n; + } + } + + bool Error() const { return error_; } + + private: + double LogPosExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F + exp(-x)); + } + + double LogMinusExp(double x) { + return x == FloatLimits::PosInfinity() ? 0.0 + : log(1.0F - exp(-x)); + } + + Weight LogPlus(Weight w, Weight v) { + if (w == Weight::Zero()) { + return v; + } + const auto f1 = to_log_weight_(w).Value(); + const auto f2 = to_log_weight_(v).Value(); + if (f1 > f2) { + return to_weight_(Log64Weight(f2 - LogPosExp(f1 - f2))); + } else { + return to_weight_(Log64Weight(f1 - LogPosExp(f2 - f1))); + } + } + + double LogPlus(double f1, Weight v) { + const auto f2 = to_log_weight_(v).Value(); + if (f1 == FloatLimits::PosInfinity()) { + return f2; + } else if (f1 > f2) { + return f2 - LogPosExp(f1 - f2); + } else { + return f1 - LogPosExp(f2 - f1); + } + } + + // Assumes f1 < f2. + Weight LogMinus(double f1, double f2) { + if (f2 == FloatLimits::PosInfinity()) { + return to_weight_(Log64Weight(f1)); + } else { + return to_weight_(Log64Weight(f1 - LogMinusExp(f2 - f1))); + } + } + + // Extends weights up to index 'end'. + template + void Extend(ssize_t end, ArcIter *aiter) { + if (weights_->size() <= end) { + for (aiter->Seek(weights_->size() - 1); weights_->size() <= end; + aiter->Next()) { + weights_->push_back(LogPlus(weights_->back(), aiter->Value().weight)); + } + } + } + + + const WeightConvert to_log_weight_{}; + const WeightConvert to_weight_{}; + ssize_t arc_limit_; // Minimum # of arcs to cache a state. + std::vector *weights_; // Accumulated weights for cur. state. + std::unique_ptr> fst_; // Input FST. + std::shared_ptr> data_; // Cache data. + StateId s_; // Current state. + bool error_; +}; + +// Stores shareable data for replace accumulator copies. +template +class ReplaceAccumulatorData { + public: + using Arc = typename Accumulator::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using StateTable = T; + using StateTuple = typename StateTable::StateTuple; + + ReplaceAccumulatorData() : state_table_(nullptr) {} + + explicit ReplaceAccumulatorData( + const std::vector &accumulators) + : state_table_(nullptr) { + accumulators_.reserve(accumulators.size()); + for (const auto accumulator : accumulators) { + accumulators_.emplace_back(accumulator); + } + } + + void Init(const std::vector *>> &fst_tuples, + const StateTable *state_table) { + state_table_ = state_table; + accumulators_.resize(fst_tuples.size()); + for (Label i = 0; i < accumulators_.size(); ++i) { + if (!accumulators_[i]) { + accumulators_[i].reset(new Accumulator()); + accumulators_[i]->Init(*(fst_tuples[i].second)); + } + fst_array_.emplace_back(fst_tuples[i].second->Copy()); + } + } + + const StateTuple &GetTuple(StateId s) const { return state_table_->Tuple(s); } + + Accumulator *GetAccumulator(size_t i) { return accumulators_[i].get(); } + + const Fst *GetFst(size_t i) const { return fst_array_[i].get(); } + + private: + const StateTable *state_table_; + std::vector> accumulators_; + std::vector>> fst_array_; +}; + +// This class accumulates weights in a ReplaceFst. The 'Init' method takes as +// input the argument used to build the ReplaceFst and the ReplaceFst state +// table. It uses accumulators of type 'Accumulator' in the underlying FSTs. +template > +class ReplaceAccumulator { + public: + using Arc = typename Accumulator::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using StateTable = T; + using StateTuple = typename StateTable::StateTuple; + using Weight = typename Arc::Weight; + + ReplaceAccumulator() + : init_(false), + data_(std::make_shared< + ReplaceAccumulatorData>()), + error_(false) {} + + explicit ReplaceAccumulator(const std::vector &accumulators) + : init_(false), + data_(std::make_shared>( + accumulators)), + error_(false) {} + + ReplaceAccumulator(const ReplaceAccumulator &acc, + bool safe = false) + : init_(acc.init_), data_(acc.data_), error_(acc.error_) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator: Can't copy unintialized accumulator"; + } + if (safe) FSTERROR() << "ReplaceAccumulator: Safe copy not supported"; + } + + // Does not take ownership of the state table, the state table is owned by + // the ReplaceFst. + void Init(const std::vector *>> &fst_tuples, + const StateTable *state_table) { + init_ = true; + data_->Init(fst_tuples, state_table); + } + + // Method required by LookAheadMatcher. However, ReplaceAccumulator needs to + // be initialized by calling the Init method above before being passed to + // LookAheadMatcher. + // + // TODO(allauzen): Revisit this. Consider creating a method + // Init(const ReplaceFst&, bool) and using friendship to get access + // to the innards of ReplaceFst. + void Init(const Fst &fst, bool copy = false) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator::Init: Accumulator needs to be" + << " initialized before being passed to LookAheadMatcher"; + error_ = true; + } + } + + void SetState(StateId s) { + if (!init_) { + FSTERROR() << "ReplaceAccumulator::SetState: Incorrectly initialized"; + error_ = true; + return; + } + auto tuple = data_->GetTuple(s); + fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based. + data_->GetAccumulator(fst_id_)->SetState(tuple.fst_state); + if ((tuple.prefix_id != 0) && + (data_->GetFst(fst_id_)->Final(tuple.fst_state) != Weight::Zero())) { + offset_ = 1; + offset_weight_ = data_->GetFst(fst_id_)->Final(tuple.fst_state); + } else { + offset_ = 0; + offset_weight_ = Weight::Zero(); + } + aiter_.reset( + new ArcIterator>(*data_->GetFst(fst_id_), tuple.fst_state)); + } + + Weight Sum(Weight w, Weight v) { + if (error_) return Weight::NoWeight(); + return data_->GetAccumulator(fst_id_)->Sum(w, v); + } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + if (error_) return Weight::NoWeight(); + auto sum = begin == end ? Weight::Zero() + : data_->GetAccumulator(fst_id_)->Sum( + w, aiter_.get(), begin ? begin - offset_ : 0, + end - offset_); + if (begin == 0 && end != 0 && offset_ > 0) sum = Sum(offset_weight_, sum); + return sum; + } + + bool Error() const { return error_; } + + private: + bool init_; + std::shared_ptr> data_; + Label fst_id_; + size_t offset_; + Weight offset_weight_; + std::unique_ptr>> aiter_; + bool error_; +}; + +// SafeReplaceAccumulator accumulates weights in a ReplaceFst and copies of it +// are always thread-safe copies. +template +class SafeReplaceAccumulator { + public: + using Arc = typename Accumulator::Arc; + using StateId = typename Arc::StateId; + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + using StateTable = T; + using StateTuple = typename StateTable::StateTuple; + + SafeReplaceAccumulator() {} + + SafeReplaceAccumulator(const SafeReplaceAccumulator ©, bool safe) + : SafeReplaceAccumulator(copy) {} + + explicit SafeReplaceAccumulator( + const std::vector &accumulators) { + for (const auto &accumulator : accumulators) { + accumulators_.emplace_back(accumulator, true); + } + } + + void Init(const std::vector *>> &fst_tuples, + const StateTable *state_table) { + state_table_ = state_table; + for (Label i = 0; i < fst_tuples.size(); ++i) { + if (i == accumulators_.size()) { + accumulators_.resize(accumulators_.size() + 1); + accumulators_[i].Init(*(fst_tuples[i].second)); + } + fst_array_.emplace_back(fst_tuples[i].second->Copy(true)); + } + init_ = true; + } + + void Init(const Fst &fst, bool copy = false) { + if (!init_) { + FSTERROR() << "SafeReplaceAccumulator::Init: Accumulator needs to be" + << " initialized before being passed to LookAheadMatcher"; + error_ = true; + } + } + + void SetState(StateId s) { + auto tuple = state_table_->Tuple(s); + fst_id_ = tuple.fst_id - 1; // Replace FST ID is 1-based + GetAccumulator(fst_id_)->SetState(tuple.fst_state); + offset_ = 0; + offset_weight_ = Weight::Zero(); + const auto final_weight = GetFst(fst_id_)->Final(tuple.fst_state); + if ((tuple.prefix_id != 0) && (final_weight != Weight::Zero())) { + offset_ = 1; + offset_weight_ = final_weight; + } + aiter_.Set(*GetFst(fst_id_), tuple.fst_state); + } + + Weight Sum(Weight w, Weight v) { + if (error_) return Weight::NoWeight(); + return GetAccumulator(fst_id_)->Sum(w, v); + } + + template + Weight Sum(Weight w, ArcIter *aiter, ssize_t begin, ssize_t end) { + if (error_) return Weight::NoWeight(); + if (begin == end) return Weight::Zero(); + auto sum = GetAccumulator(fst_id_)->Sum( + w, aiter_.get(), begin ? begin - offset_ : 0, end - offset_); + if (begin == 0 && end != 0 && offset_ > 0) { + sum = Sum(offset_weight_, sum); + } + return sum; + } + + bool Error() const { return error_; } + + private: + class ArcIteratorPtr { + public: + ArcIteratorPtr() {} + + ArcIteratorPtr(const ArcIteratorPtr ©) {} + + void Set(const Fst &fst, StateId state_id) { + ptr_.reset(new ArcIterator>(fst, state_id)); + } + + ArcIterator> *get() { return ptr_.get(); } + + private: + std::unique_ptr>> ptr_; + }; + + Accumulator *GetAccumulator(size_t i) { return &accumulators_[i]; } + + const Fst *GetFst(size_t i) const { return fst_array_[i].get(); } + + const StateTable *state_table_; + std::vector accumulators_; + std::vector>> fst_array_; + ArcIteratorPtr aiter_; + bool init_ = false; + bool error_ = false; + Label fst_id_; + size_t offset_; + Weight offset_weight_; +}; + +} // namespace fst + +#endif // FST_ACCUMULATOR_H_ diff --git a/projects/llm_framework/include/fst/add-on.h b/projects/llm_framework/include/fst/add-on.h new file mode 100644 index 00000000..4a95111f --- /dev/null +++ b/projects/llm_framework/include/fst/add-on.h @@ -0,0 +1,248 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST implementation class to attach an arbitrary object with a read/write +// method to an FST and its file representation. The FST is given a new type +// name. + +#ifndef FST_ADD_ON_H_ +#define FST_ADD_ON_H_ + +#include +#include +#include +#include + +#include + +#include + + +namespace fst { + +// Identifies stream data as an add-on FST. +static constexpr int32 kAddOnMagicNumber = 446681434; + +// Nothing to save. +class NullAddOn { + public: + NullAddOn() {} + + static NullAddOn *Read(std::istream &strm, const FstReadOptions &opts) { + return new NullAddOn(); + } + + bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const { + return true; + } +}; + +// Create a new add-on from a pair of add-ons. +template +class AddOnPair { + public: + // Argument reference count incremented. + AddOnPair(std::shared_ptr a1, std::shared_ptr a2) + : a1_(std::move(a1)), a2_(std::move(a2)) {} + + const A1 *First() const { return a1_.get(); } + + const A2 *Second() const { return a2_.get(); } + + std::shared_ptr SharedFirst() const { return a1_; } + + std::shared_ptr SharedSecond() const { return a2_; } + + static AddOnPair *Read(std::istream &istrm, + const FstReadOptions &opts) { + A1 *a1 = nullptr; + bool have_addon1 = false; + ReadType(istrm, &have_addon1); + if (have_addon1) a1 = A1::Read(istrm, opts); + + A2 *a2 = nullptr; + bool have_addon2 = false; + ReadType(istrm, &have_addon2); + if (have_addon2) a2 = A2::Read(istrm, opts); + + return new AddOnPair(std::shared_ptr(a1), + std::shared_ptr(a2)); + } + + bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const { + bool have_addon1 = a1_ != nullptr; + WriteType(ostrm, have_addon1); + if (have_addon1) a1_->Write(ostrm, opts); + bool have_addon2 = a2_ != nullptr; + WriteType(ostrm, have_addon2); + if (have_addon2) a2_->Write(ostrm, opts); + return true; + } + + private: + std::shared_ptr a1_; + std::shared_ptr a2_; +}; + +namespace internal { + +// Adds an object of type T to an FST. T must support: +// +// T* Read(std::istream &); +// bool Write(std::ostream &); +// +// The resulting type is a new FST implementation. +template +class AddOnImpl : public FstImpl { + public: + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + + // We make a thread-safe copy of the FST by default since an FST + // implementation is expected to not share mutable data between objects. + AddOnImpl(const FST &fst, const string &type, + std::shared_ptr t = std::shared_ptr()) + : fst_(fst, true), t_(std::move(t)) { + SetType(type); + SetProperties(fst_.Properties(kFstProperties, false)); + SetInputSymbols(fst_.InputSymbols()); + SetOutputSymbols(fst_.OutputSymbols()); + } + + // Conversion from const Fst & to F always copies the underlying + // implementation. + AddOnImpl(const Fst &fst, const string &type, + std::shared_ptr t = std::shared_ptr()) + : fst_(fst), t_(std::move(t)) { + SetType(type); + SetProperties(fst_.Properties(kFstProperties, false)); + SetInputSymbols(fst_.InputSymbols()); + SetOutputSymbols(fst_.OutputSymbols()); + } + + // We make a thread-safe copy of the FST by default since an FST + // implementation is expected to not share mutable data between objects. + AddOnImpl(const AddOnImpl &impl) + : fst_(impl.fst_, true), t_(impl.t_) { + SetType(impl.Type()); + SetProperties(fst_.Properties(kCopyProperties, false)); + SetInputSymbols(fst_.InputSymbols()); + SetOutputSymbols(fst_.OutputSymbols()); + } + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId s) const { return fst_.Final(s); } + + size_t NumArcs(StateId s) const { return fst_.NumArcs(s); } + + size_t NumInputEpsilons(StateId s) const { return fst_.NumInputEpsilons(s); } + + size_t NumOutputEpsilons(StateId s) const { + return fst_.NumOutputEpsilons(s); + } + + size_t NumStates() const { return fst_.NumStates(); } + + static AddOnImpl *Read(std::istream &strm, + const FstReadOptions &opts) { + FstReadOptions nopts(opts); + FstHeader hdr; + if (!nopts.header) { + hdr.Read(strm, nopts.source); + nopts.header = &hdr; + } + std::unique_ptr> impl( + new AddOnImpl(nopts.header->FstType())); + if (!impl->ReadHeader(strm, nopts, kMinFileVersion, &hdr)) return nullptr; + impl.reset(); + int32 magic_number = 0; + ReadType(strm, &magic_number); // Ensures this is an add-on FST. + if (magic_number != kAddOnMagicNumber) { + LOG(ERROR) << "AddOnImpl::Read: Bad add-on header: " << nopts.source; + return nullptr; + } + FstReadOptions fopts(opts); + fopts.header = nullptr; // Contained header was written out. + std::unique_ptr fst(FST::Read(strm, fopts)); + if (!fst) return nullptr; + std::shared_ptr t; + bool have_addon = false; + ReadType(strm, &have_addon); + if (have_addon) { // Reads add-on object if present. + t = std::shared_ptr(T::Read(strm, fopts)); + if (!t) return nullptr; + } + return new AddOnImpl(*fst, nopts.header->FstType(), t); + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + FstWriteOptions nopts(opts); + nopts.write_isymbols = false; // Allows contained FST to hold any symbols. + nopts.write_osymbols = false; + WriteHeader(strm, nopts, kFileVersion, &hdr); + WriteType(strm, kAddOnMagicNumber); // Ensures this is an add-on FST. + FstWriteOptions fopts(opts); + fopts.write_header = true; // Forces writing contained header. + if (!fst_.Write(strm, fopts)) return false; + bool have_addon = !!t_; + WriteType(strm, have_addon); + // Writes add-on object if present. + if (have_addon) t_->Write(strm, opts); + return true; + } + + void InitStateIterator(StateIteratorData *data) const { + fst_.InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const { + fst_.InitArcIterator(s, data); + } + + FST &GetFst() { return fst_; } + + const FST &GetFst() const { return fst_; } + + const T *GetAddOn() const { return t_.get(); } + + std::shared_ptr GetSharedAddOn() const { return t_; } + + void SetAddOn(std::shared_ptr t) { t_ = t; } + + private: + explicit AddOnImpl(const string &type) : t_() { + SetType(type); + SetProperties(kExpanded); + } + + // Current file format version. + static constexpr int kFileVersion = 1; + // Minimum file format version supported. + static constexpr int kMinFileVersion = 1; + + FST fst_; + std::shared_ptr t_; + + AddOnImpl &operator=(const AddOnImpl &) = delete; +}; + +template +constexpr int AddOnImpl::kFileVersion; + +template +constexpr int AddOnImpl::kMinFileVersion; + +} // namespace internal +} // namespace fst + +#endif // FST_ADD_ON_H_ diff --git a/projects/llm_framework/include/fst/arc-arena.h b/projects/llm_framework/include/fst/arc-arena.h new file mode 100644 index 00000000..13fe918a --- /dev/null +++ b/projects/llm_framework/include/fst/arc-arena.h @@ -0,0 +1,232 @@ +#ifndef FST_ARC_ARENA_H_ +#define FST_ARC_ARENA_H_ + +#include +#include +#include +#include +#include +#include + +namespace fst { + +// ArcArena is used for fast allocation of contiguous arrays of arcs. +// +// To create an arc array: +// for each state: +// for each arc: +// arena.PushArc(); +// // Commits these arcs and returns pointer to them. +// Arc *arcs = arena.GetArcs(); +// +// OR +// +// arena.DropArcs(); // Throws away current arcs, reuse the space. +// +// The arcs returned are guaranteed to be contiguous and the pointer returned +// will never be invalidated until the arena is cleared for reuse. +// +// The contents of the arena can be released with a call to arena.Clear() after +// which the arena will restart with an initial allocation capable of holding at +// least all of the arcs requested in the last usage before Clear() making +// subsequent uses of the Arena more efficient. +// +// The max_retained_size option can limit the amount of arc space requested on +// Clear() to avoid excess growth from intermittent high usage. +template +class ArcArena { + public: + explicit ArcArena(size_t block_size = 256, + size_t max_retained_size = 1e6) + : block_size_(block_size), + max_retained_size_(max_retained_size) { + blocks_.emplace_back(MakeSharedBlock(block_size_)); + first_block_size_ = block_size_; + total_size_ = block_size_; + arcs_ = blocks_.back().get(); + end_ = arcs_ + block_size_; + next_ = arcs_; + } + + ArcArena(const ArcArena& copy) + : arcs_(copy.arcs_), next_(copy.next_), end_(copy.end_), + block_size_(copy.block_size_), + first_block_size_(copy.first_block_size_), + total_size_(copy.total_size_), + max_retained_size_(copy.max_retained_size_), + blocks_(copy.blocks_) { + NewBlock(block_size_); + } + + void ReserveArcs(size_t n) { + if (next_ + n < end_) return; + NewBlock(n); + } + + void PushArc(const Arc& arc) { + if (next_ == end_) { + size_t length = next_ - arcs_; + NewBlock(length * 2); + } + *next_ = arc; + ++next_; + } + + const Arc* GetArcs() { + const auto *arcs = arcs_; + arcs_ = next_; + return arcs; + } + + void DropArcs() { next_ = arcs_; } + + size_t Size() { return total_size_; } + + void Clear() { + blocks_.resize(1); + if (total_size_ > first_block_size_) { + first_block_size_ = std::min(max_retained_size_, total_size_); + blocks_.back() = MakeSharedBlock(first_block_size_); + } + total_size_ = first_block_size_; + arcs_ = blocks_.back().get(); + end_ = arcs_ + first_block_size_; + next_ = arcs_; + } + + private: + // Allocates a new block with capacity of at least n or block_size, + // copying incomplete arc sequence from old block to new block. + void NewBlock(size_t n) { + const auto length = next_ - arcs_; + const auto new_block_size = std::max(n, block_size_); + total_size_ += new_block_size; + blocks_.emplace_back(MakeSharedBlock(new_block_size)); + std::copy(arcs_, next_, blocks_.back().get()); + arcs_ = blocks_.back().get(); + next_ = arcs_ + length; + end_ = arcs_ + new_block_size; + } + + std::shared_ptr MakeSharedBlock(size_t size) { + return std::shared_ptr(new Arc[size], std::default_delete()); + } + + Arc *arcs_; + Arc *next_; + const Arc *end_; + size_t block_size_; + size_t first_block_size_; + size_t total_size_; + size_t max_retained_size_; + std::list> blocks_; +}; + +// ArcArenaStateStore uses a resusable ArcArena to store arc arrays and does not +// require that the Expander call ReserveArcs first. +// +// TODO(tombagby): Make cache type configurable. +// TODO(tombagby): Provide ThreadLocal/Concurrent configuration. +template +class ArcArenaStateStore { + public: + using Arc = A; + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + + ArcArenaStateStore() : arena_(64 * 1024) { + } + + class State { + public: + Weight Final() const { return final_; } + + size_t NumInputEpsilons() const { return niepsilons_; } + + size_t NumOutputEpsilons() const { return noepsilons_; } + + size_t NumArcs() const { return narcs_; } + + const Arc &GetArc(size_t n) const { return arcs_[n]; } + + const Arc *Arcs() const { return arcs_; } + + int* MutableRefCount() const { return nullptr; } + + private: + State(Weight weight, int32 niepsilons, int32 noepsilons, int32 narcs, + const Arc *arcs) + : final_(std::move(weight)), + niepsilons_(niepsilons), + noepsilons_(noepsilons), + narcs_(narcs), + arcs_(arcs) {} + + Weight final_; + size_t niepsilons_; + size_t noepsilons_; + size_t narcs_; + const Arc *arcs_; + + friend class ArcArenaStateStore; + }; + + template + State *FindOrExpand(Expander &expander, StateId state_id) { // NOLINT + auto it = cache_.insert(std::pair(state_id, nullptr)); + if (!it.second) return it.first->second; + // Needs a new state. + StateBuilder builder(&arena_); + expander.Expand(state_id, &builder); + const auto arcs = arena_.GetArcs(); + size_t narcs = builder.narcs_; + size_t niepsilons = 0; + size_t noepsilons = 0; + for (size_t i = 0; i < narcs; ++i) { + if (arcs[i].ilabel == 0) ++niepsilons; + if (arcs[i].olabel == 0) ++noepsilons; + } + states_.emplace_back( + State(builder.final_, niepsilons, noepsilons, narcs, arcs)); + // Places it in the cache. + auto state = &states_.back(); + it.first->second = state; + return state; + } + + State *Find(StateId state_id) const { + auto it = cache_.find(state_id); + return (it == cache_.end()) ? nullptr : it->second; + } + + private: + class StateBuilder { + public: + explicit StateBuilder(ArcArena* arena) + : arena_(arena), final_(Weight::Zero()), narcs_(0) {} + + void SetFinal(Weight weight) { final_ = std::move(weight); } + + void ReserveArcs(size_t n) { arena_->ReserveArcs(n); } + + void AddArc(const Arc &arc) { + ++narcs_; + arena_->PushArc(arc); + } + + private: + friend class ArcArenaStateStore; + + ArcArena *arena_; + Weight final_; + size_t narcs_; + }; + + std::unordered_map cache_; + std::deque states_; + ArcArena arena_; +}; + +} // namespace fst + +#endif // FST_ARC_ARENA_H_ diff --git a/projects/llm_framework/include/fst/arc-map.h b/projects/llm_framework/include/fst/arc-map.h new file mode 100644 index 00000000..24db4911 --- /dev/null +++ b/projects/llm_framework/include/fst/arc-map.h @@ -0,0 +1,1285 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to map over/transform arcs e.g., change semirings or +// implement project/invert. Consider using when operation does +// not change the number of arcs (except possibly superfinal arcs). + +#ifndef FST_ARC_MAP_H_ +#define FST_ARC_MAP_H_ + +#include +#include +#include + +#include + +#include +#include + + +namespace fst { + +// Determines how final weights are mapped. +enum MapFinalAction { + // A final weight is mapped into a final weight. An error is raised if this + // is not possible. + MAP_NO_SUPERFINAL, + // A final weight is mapped to an arc to the superfinal state when the result + // cannot be represented as a final weight. The superfinal state will be + // added only if it is needed. + MAP_ALLOW_SUPERFINAL, + // A final weight is mapped to an arc to the superfinal state unless the + // result can be represented as a final weight of weight Zero(). The + // superfinal state is always added (if the input is not the empty FST). + MAP_REQUIRE_SUPERFINAL +}; + +// Determines how symbol tables are mapped. +enum MapSymbolsAction { + // Symbols should be cleared in the result by the map. + MAP_CLEAR_SYMBOLS, + // Symbols should be copied from the input FST by the map. + MAP_COPY_SYMBOLS, + // Symbols should not be modified in the result by the map itself. + // (They may set by the mapper). + MAP_NOOP_SYMBOLS +}; + +// The ArcMapper interfaces defines how arcs and final weights are mapped. +// This is useful for implementing operations that do not change the number of +// arcs (except possibly superfinal arcs). +// +// template +// class ArcMapper { +// public: +// using FromArc = A; +// using ToArc = B; +// +// // Maps an arc type FromArc to arc type ToArc. +// ToArc operator()(const FromArc &arc); +// +// // Specifies final action the mapper requires (see above). +// // The mapper will be passed final weights as arcs of the form +// // Arc(0, 0, weight, kNoStateId). +// MapFinalAction FinalAction() const; +// +// // Specifies input symbol table action the mapper requires (see above). +// MapSymbolsAction InputSymbolsAction() const; +// +// // Specifies output symbol table action the mapper requires (see above). +// MapSymbolsAction OutputSymbolsAction() const; +// +// // This specifies the known properties of an FST mapped by this mapper. It +// takes as argument the input FSTs's known properties. +// uint64 Properties(uint64 props) const; +// }; +// +// The ArcMap functions and classes below will use the FinalAction() +// method of the mapper to determine how to treat final weights, e.g., whether +// to add a superfinal state. They will use the Properties() method to set the +// result FST properties. +// +// We include a various map versions below. One dimension of variation is +// whether the mapping mutates its input, writes to a new result FST, or is an +// on-the-fly FST. Another dimension is how we pass the mapper. We allow passing +// the mapper by pointer for cases that we need to change the state of the +// user's mapper. This is the case with the EncodeMapper, which is reused +// during decoding. We also include map versions that pass the mapper by value +// or const reference when this suffices. + +// Maps an arc type A using a mapper function object C, passed +// by pointer. This version modifies its Fst input. +template +void ArcMap(MutableFst *fst, C *mapper) { + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetOutputSymbols(nullptr); + } + if (fst->Start() == kNoStateId) return; + const auto props = fst->Properties(kFstProperties, false); + const auto final_action = mapper->FinalAction(); + auto superfinal = kNoStateId; + if (final_action == MAP_REQUIRE_SUPERFINAL) { + superfinal = fst->AddState(); + fst->SetFinal(superfinal, Weight::One()); + } + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + const auto state = siter.Value(); + for (MutableArcIterator> aiter(fst, state); + !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + aiter.SetValue((*mapper)(arc)); + } + switch (final_action) { + case MAP_NO_SUPERFINAL: + default: { + const FromArc arc(0, 0, fst->Final(state), kNoStateId); + const auto final_arc = (*mapper)(arc); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc"; + fst->SetProperties(kError, kError); + } + fst->SetFinal(state, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + if (state != superfinal) { + const FromArc arc(0, 0, fst->Final(state), kNoStateId); + auto final_arc = (*mapper)(arc); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + // Add a superfinal state if not already done. + if (superfinal == kNoStateId) { + superfinal = fst->AddState(); + fst->SetFinal(superfinal, Weight::One()); + } + final_arc.nextstate = superfinal; + fst->AddArc(state, std::move(final_arc)); + fst->SetFinal(state, Weight::Zero()); + } else { + fst->SetFinal(state, final_arc.weight); + } + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + if (state != superfinal) { + const FromArc arc(0, 0, fst->Final(state), kNoStateId); + const auto final_arc = (*mapper)(arc); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != Weight::Zero()) { + fst->AddArc(state, ToArc(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal)); + } + fst->SetFinal(state, Weight::Zero()); + } + break; + } + } + } + fst->SetProperties(mapper->Properties(props), kFstProperties); +} + +// Maps an arc type A using a mapper function object C, passed by value. This +// version modifies its FST input. +template +void ArcMap(MutableFst *fst, C mapper) { + ArcMap(fst, &mapper); +} + +// Maps an arc type A to an arc type B using mapper function object C, +// passed by pointer. This version writes the mapped input FST to an +// output MutableFst. +template +void ArcMap(const Fst &ifst, MutableFst *ofst, C *mapper) { + using FromArc = A; + using StateId = typename FromArc::StateId; + ofst->DeleteStates(); + if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetInputSymbols(ifst.InputSymbols()); + } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetOutputSymbols(nullptr); + } + const auto iprops = ifst.Properties(kCopyProperties, false); + if (ifst.Start() == kNoStateId) { + if (iprops & kError) ofst->SetProperties(kError, kError); + return; + } + const auto final_action = mapper->FinalAction(); + if (ifst.Properties(kExpanded, false)) { + ofst->ReserveStates( + CountStates(ifst) + (final_action == MAP_NO_SUPERFINAL ? 0 : 1)); + } + // Adds all states. + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + ofst->AddState(); + } + StateId superfinal = kNoStateId; + if (final_action == MAP_REQUIRE_SUPERFINAL) { + superfinal = ofst->AddState(); + ofst->SetFinal(superfinal, B::Weight::One()); + } + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (s == ifst.Start()) ofst->SetStart(s); + ofst->ReserveArcs( + s, ifst.NumArcs(s) + (final_action != MAP_NO_SUPERFINAL ? 1 : 0)); + for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { + ofst->AddArc(s, (*mapper)(aiter.Value())); + } + switch (final_action) { + case MAP_NO_SUPERFINAL: + default: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMap: Non-zero arc labels for superfinal arc"; + ofst->SetProperties(kError, kError); + } + ofst->SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + // Add a superfinal state if not already done. + if (superfinal == kNoStateId) { + superfinal = ofst->AddState(); + ofst->SetFinal(superfinal, B::Weight::One()); + } + final_arc.nextstate = superfinal; + ofst->AddArc(s, std::move(final_arc)); + ofst->SetFinal(s, B::Weight::Zero()); + } else { + ofst->SetFinal(s, final_arc.weight); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + B final_arc = (*mapper)(A(0, 0, ifst.Final(s), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != B::Weight::Zero()) { + ofst->AddArc(s, B(final_arc.ilabel, final_arc.olabel, + final_arc.weight, superfinal)); + } + ofst->SetFinal(s, B::Weight::Zero()); + break; + } + } + } + const auto oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); +} + +// Maps an arc type A to an arc type B using mapper function +// object C, passed by value. This version writes the mapped input +// Fst to an output MutableFst. +template +void ArcMap(const Fst &ifst, MutableFst *ofst, C mapper) { + ArcMap(ifst, ofst, &mapper); +} + +struct ArcMapFstOptions : public CacheOptions { + // ArcMapFst default caching behaviour is to do no caching. Most mappers are + // cheap and therefore we save memory by not doing caching. + ArcMapFstOptions() : CacheOptions(true, 0) {} + + explicit ArcMapFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} +}; + +template +class ArcMapFst; + +namespace internal { + +// Implementation of delayed ArcMapFst. +template +class ArcMapFstImpl : public CacheImpl { + public: + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::EmplaceArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::PushArc; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + friend class StateIterator>; + + ArcMapFstImpl(const Fst &fst, const C &mapper, + const ArcMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(new C(mapper)), + own_mapper_(true), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ArcMapFstImpl(const Fst &fst, C *mapper, const ArcMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(mapper), + own_mapper_(false), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ArcMapFstImpl(const ArcMapFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + mapper_(new C(*impl.mapper_)), + own_mapper_(true), + superfinal_(kNoStateId), + nstates_(0) { + Init(); + } + + ~ArcMapFstImpl() override { + if (own_mapper_) delete mapper_; + } + + StateId Start() { + if (!HasStart()) SetStart(FindOState(fst_->Start())); + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + switch (final_action_) { + case MAP_NO_SUPERFINAL: + default: { + const auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + FSTERROR() << "ArcMapFst: Non-zero arc labels for superfinal arc"; + SetProperties(kError, kError); + } + SetFinal(s, final_arc.weight); + break; + } + case MAP_ALLOW_SUPERFINAL: { + if (s == superfinal_) { + SetFinal(s, Weight::One()); + } else { + const auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel == 0 && final_arc.olabel == 0) { + SetFinal(s, final_arc.weight); + } else { + SetFinal(s, Weight::Zero()); + } + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + SetFinal(s, s == superfinal_ ? Weight::One() : Weight::Zero()); + break; + } + } + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (fst_->Properties(kError, false) || + (mapper_->Properties(0) & kError))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + // Add exiting arcs. + if (s == superfinal_) { + SetArcs(s); + return; + } + for (ArcIterator> aiter(*fst_, FindIState(s)); !aiter.Done(); + aiter.Next()) { + auto aarc = aiter.Value(); + aarc.nextstate = FindOState(aarc.nextstate); + PushArc(s, (*mapper_)(aarc)); + } + + // Check for superfinal arcs. + if (!HasFinal(s) || Final(s) == Weight::Zero()) { + switch (final_action_) { + case MAP_NO_SUPERFINAL: + default: + break; + case MAP_ALLOW_SUPERFINAL: { + auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) { + if (superfinal_ == kNoStateId) superfinal_ = nstates_++; + final_arc.nextstate = superfinal_; + PushArc(s, std::move(final_arc)); + } + break; + } + case MAP_REQUIRE_SUPERFINAL: { + const auto final_arc = + (*mapper_)(A(0, 0, fst_->Final(FindIState(s)), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0 || + final_arc.weight != B::Weight::Zero()) { + EmplaceArc(s, final_arc.ilabel, final_arc.olabel, final_arc.weight, + superfinal_); + } + break; + } + } + } + SetArcs(s); + } + + private: + void Init() { + SetType("map"); + if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetInputSymbols(fst_->InputSymbols()); + } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetInputSymbols(nullptr); + } + if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetOutputSymbols(fst_->OutputSymbols()); + } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetOutputSymbols(nullptr); + } + if (fst_->Start() == kNoStateId) { + final_action_ = MAP_NO_SUPERFINAL; + SetProperties(kNullProperties); + } else { + final_action_ = mapper_->FinalAction(); + uint64 props = fst_->Properties(kCopyProperties, false); + SetProperties(mapper_->Properties(props)); + if (final_action_ == MAP_REQUIRE_SUPERFINAL) superfinal_ = 0; + } + } + + // Maps from output state to input state. + StateId FindIState(StateId s) { + if (superfinal_ == kNoStateId || s < superfinal_) { + return s; + } else { + return s - 1; + } + } + + // Maps from input state to output state. + StateId FindOState(StateId is) { + auto os = is; + if (!(superfinal_ == kNoStateId || is < superfinal_)) ++os; + if (os >= nstates_) nstates_ = os + 1; + return os; + } + + std::unique_ptr> fst_; + C *mapper_; + const bool own_mapper_; + MapFinalAction final_action_; + StateId superfinal_; + StateId nstates_; +}; + +} // namespace internal + +// Maps an arc type A to an arc type B using Mapper function object +// C. This version is a delayed FST. +template +class ArcMapFst : public ImplToFst> { + public: + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::ArcMapFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + ArcMapFst(const Fst &fst, const C &mapper, const ArcMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + ArcMapFst(const Fst &fst, C *mapper, const ArcMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + ArcMapFst(const Fst &fst, const C &mapper) + : ImplToFst( + std::make_shared(fst, mapper, ArcMapFstOptions())) {} + + ArcMapFst(const Fst &fst, C *mapper) + : ImplToFst( + std::make_shared(fst, mapper, ArcMapFstOptions())) {} + + // See Fst<>::Copy() for doc. + ArcMapFst(const ArcMapFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this ArcMapFst. See Fst<>::Copy() for further doc. + ArcMapFst *Copy(bool safe = false) const override { + return new ArcMapFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + protected: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + private: + ArcMapFst &operator=(const ArcMapFst &) = delete; +}; + +// Specialization for ArcMapFst. +// +// This may be derived from. +template +class StateIterator> : public StateIteratorBase { + public: + using StateId = typename B::StateId; + + explicit StateIterator(const ArcMapFst &fst) + : impl_(fst.GetImpl()), + siter_(*impl_->fst_), + s_(0), + superfinal_(impl_->final_action_ == MAP_REQUIRE_SUPERFINAL) { + CheckSuperfinal(); + } + + bool Done() const final { return siter_.Done() && !superfinal_; } + + StateId Value() const final { return s_; } + + void Next() final { + ++s_; + if (!siter_.Done()) { + siter_.Next(); + CheckSuperfinal(); + } else if (superfinal_) { + superfinal_ = false; + } + } + + void Reset() final { + s_ = 0; + siter_.Reset(); + superfinal_ = impl_->final_action_ == MAP_REQUIRE_SUPERFINAL; + CheckSuperfinal(); + } + + private: + void CheckSuperfinal() { + if (impl_->final_action_ != MAP_ALLOW_SUPERFINAL || superfinal_) return; + if (!siter_.Done()) { + const auto final_arc = + (*impl_->mapper_)(A(0, 0, impl_->fst_->Final(s_), kNoStateId)); + if (final_arc.ilabel != 0 || final_arc.olabel != 0) superfinal_ = true; + } + } + + const internal::ArcMapFstImpl *impl_; + StateIterator> siter_; + StateId s_; + bool superfinal_; // True if there is a superfinal state and not done. +}; + +// Specialization for ArcMapFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename A::StateId; + + ArcIterator(const ArcMapFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void ArcMapFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Utility Mappers. + +// Mapper that returns its input. +template +class IdentityArcMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr ToArc operator()(const FromArc &arc) const { return arc; } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { return props; } +}; + +// Mapper that converts all input symbols to epsilon. +template +class InputEpsilonMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(0, arc.olabel, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kSetArcProperties) | kIEpsilons | kILabelSorted; + } +}; + +// Mapper that converts all output symbols to epsilon. +template +class OutputEpsilonMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, 0, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kSetArcProperties) | kOEpsilons | kOLabelSorted; + } +}; + +// Mapper that returns its input with final states redirected to a single +// super-final state. +template +class SuperFinalMapper { + public: + using FromArc = A; + using ToArc = A; + using Label = typename FromArc::Label; + using Weight = typename FromArc::Weight;; + + // Arg allows setting super-final label. + explicit SuperFinalMapper(Label final_label = 0) + : final_label_(final_label) {} + + ToArc operator()(const FromArc &arc) const { + // Super-final arc. + if (arc.nextstate == kNoStateId && arc.weight != Weight::Zero()) { + return ToArc(final_label_, final_label_, arc.weight, kNoStateId); + } else { + return arc; + } + } + + constexpr MapFinalAction FinalAction() const { + return MAP_REQUIRE_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + if (final_label_ == 0) { + return props & kAddSuperFinalProperties; + } else { + return props & kAddSuperFinalProperties & + kILabelInvariantProperties & kOLabelInvariantProperties; + } + } + + private: + Label final_label_; +}; + +// Mapper that leaves labels and nextstate unchanged and constructs a new weight +// from the underlying value of the arc weight. If no weight converter is +// explictly specified, requires that there is a WeightConvert class +// specialization that converts the weights. +template > +class WeightConvertMapper { + public: + using FromArc = A; + using ToArc = B; + using Converter = C; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + + constexpr explicit WeightConvertMapper(const Converter &c = Converter()) + : convert_weight_(c) {} + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, convert_weight_(arc.weight), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { return props; } + + private: + const Converter convert_weight_; +}; + +// Non-precision-changing weight conversions; consider using more efficient +// Cast method instead. + +using StdToLogMapper = WeightConvertMapper; + +using LogToStdMapper = WeightConvertMapper; + +// Precision-changing weight conversions. + +using StdToLog64Mapper = WeightConvertMapper; + +using LogToLog64Mapper = WeightConvertMapper; + +using Log64ToStdMapper = WeightConvertMapper; + +using Log64ToLogMapper = WeightConvertMapper; + +// Mapper from A to GallicArc. +template +class ToGallicMapper { + public: + using FromArc = A; + using ToArc = GallicArc; + + using SW = StringWeight; + using AW = typename FromArc::Weight; + using GW = typename ToArc::Weight; + + ToArc operator()(const FromArc &arc) const { + // Super-final arc. + if (arc.nextstate == kNoStateId && arc.weight != AW::Zero()) { + return ToArc(0, 0, GW(SW::One(), arc.weight), kNoStateId); + // Super-non-final arc. + } else if (arc.nextstate == kNoStateId) { + return ToArc(0, 0, GW::Zero(), kNoStateId); + // Epsilon label. + } else if (arc.olabel == 0) { + return ToArc(arc.ilabel, arc.ilabel, GW(SW::One(), arc.weight), + arc.nextstate); + // Regular label. + } else { + return ToArc(arc.ilabel, arc.ilabel, GW(SW(arc.olabel), arc.weight), + arc.nextstate); + } + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return ProjectProperties(props, true) & kWeightInvariantProperties; + } +}; + +// Mapper from GallicArc to A. +template +class FromGallicMapper { + public: + using FromArc = GallicArc; + using ToArc = A; + + using Label = typename ToArc::Label; + using AW = typename ToArc::Weight; + using GW = typename FromArc::Weight; + + explicit FromGallicMapper(Label superfinal_label = 0) + : superfinal_label_(superfinal_label), error_(false) {} + + ToArc operator()(const FromArc &arc) const { + // 'Super-non-final' arc. + if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) { + return A(arc.ilabel, 0, AW::Zero(), kNoStateId); + } + Label l = kNoLabel; + AW weight; + if (!Extract(arc.weight, &weight, &l) || arc.ilabel != arc.olabel) { + FSTERROR() << "FromGallicMapper: Unrepresentable weight: " << arc.weight + << " for arc with ilabel = " << arc.ilabel + << ", olabel = " << arc.olabel + << ", nextstate = " << arc.nextstate; + error_ = true; + } + if (arc.ilabel == 0 && l != 0 && arc.nextstate == kNoStateId) { + return ToArc(superfinal_label_, l, weight, arc.nextstate); + } else { + return ToArc(arc.ilabel, l, weight, arc.nextstate); + } + } + + constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops & kOLabelInvariantProperties & + kWeightInvariantProperties & kAddSuperFinalProperties; + if (error_) outprops |= kError; + return outprops; + } + + private: + template + static bool Extract(const GallicWeight &gallic_weight, + typename A::Weight *weight, typename A::Label *label) { + using GWT = StringWeight; + const GWT &w1 = gallic_weight.Value1(); + const AW &w2 = gallic_weight.Value2(); + typename GWT::Iterator iter1(w1); + const Label l = w1.Size() == 1 ? iter1.Value() : 0; + if (l == kStringInfinity || l == kStringBad || w1.Size() > 1) return false; + *label = l; + *weight = w2; + return true; + } + + static bool Extract(const GallicWeight &gallic_weight, + typename A::Weight *weight, typename A::Label *label) { + if (gallic_weight.Size() > 1) return false; + if (gallic_weight.Size() == 0) { + *label = 0; + *weight = A::Weight::Zero(); + return true; + } + return Extract(gallic_weight.Back(), weight, label); + } + + const Label superfinal_label_; + mutable bool error_; +}; + +// Mapper from GallicArc to A. +template +class GallicToNewSymbolsMapper { + public: + using FromArc = GallicArc; + using ToArc = A; + + using Label = typename ToArc::Label; + using StateId = typename ToArc::StateId; + using AW = typename ToArc::Weight; + using GW = typename FromArc::Weight; + using SW = StringWeight; + + explicit GallicToNewSymbolsMapper(MutableFst *fst) + : fst_(fst), + lmax_(0), + osymbols_(fst->OutputSymbols()), + isymbols_(nullptr), + error_(false) { + fst_->DeleteStates(); + state_ = fst_->AddState(); + fst_->SetStart(state_); + fst_->SetFinal(state_, AW::One()); + if (osymbols_) { + string name = osymbols_->Name() + "_from_gallic"; + fst_->SetInputSymbols(new SymbolTable(name)); + isymbols_ = fst_->MutableInputSymbols(); + const int64 zero = 0; + isymbols_->AddSymbol(osymbols_->Find(zero), 0); + } else { + fst_->SetInputSymbols(nullptr); + } + } + + ToArc operator()(const FromArc &arc) { + // Super-non-final arc. + if (arc.nextstate == kNoStateId && arc.weight == GW::Zero()) { + return ToArc(arc.ilabel, 0, AW::Zero(), kNoStateId); + } + SW w1 = arc.weight.Value1(); + AW w2 = arc.weight.Value2(); + Label l; + if (w1.Size() == 0) { + l = 0; + } else { + auto insert_result = map_.insert(std::make_pair(w1, kNoLabel)); + if (!insert_result.second) { + l = insert_result.first->second; + } else { + l = ++lmax_; + insert_result.first->second = l; + StringWeightIterator iter1(w1); + StateId n; + string s; + for (size_t i = 0, p = state_; i < w1.Size(); + ++i, iter1.Next(), p = n) { + n = i == w1.Size() - 1 ? state_ : fst_->AddState(); + fst_->AddArc(p, ToArc(i ? 0 : l, iter1.Value(), AW::One(), n)); + if (isymbols_) { + if (i) s = s + "_"; + s = s + osymbols_->Find(iter1.Value()); + } + } + if (isymbols_) isymbols_->AddSymbol(s, l); + } + } + if (l == kStringInfinity || l == kStringBad || arc.ilabel != arc.olabel) { + FSTERROR() << "GallicToNewSymbolMapper: Unrepresentable weight: " << l; + error_ = true; + } + return ToArc(arc.ilabel, l, w2, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_ALLOW_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 inprops) const { + uint64 outprops = inprops & kOLabelInvariantProperties & + kWeightInvariantProperties & kAddSuperFinalProperties; + if (error_) outprops |= kError; + return outprops; + } + + private: + class StringKey { + public: + size_t operator()(const SW &x) const { return x.Hash(); } + }; + + using Map = std::unordered_map; + + MutableFst *fst_; + Map map_; + Label lmax_; + StateId state_; + const SymbolTable *osymbols_; + SymbolTable *isymbols_; + mutable bool error_; +}; + +// TODO(kbg): Add common base class for those mappers which do nothing except +// mutate their weights. + +// Mapper to add a constant to all weights. +template +class PlusMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + constexpr explicit PlusMapper(Weight weight) : weight_(std::move(weight)) {} + + ToArc operator()(const FromArc &arc) const { + if (arc.weight == Weight::Zero()) return arc; + return ToArc(arc.ilabel, arc.olabel, Plus(arc.weight, weight_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const Weight weight_; +}; + +// Mapper to (right) multiply a constant to all weights. +template +class TimesMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + constexpr explicit TimesMapper(Weight weight) : weight_(std::move(weight)) {} + + ToArc operator()(const FromArc &arc) const { + if (arc.weight == Weight::Zero()) return arc; + return ToArc(arc.ilabel, arc.olabel, Times(arc.weight, weight_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const Weight weight_; +}; + +// Mapper to take all weights to a constant power. The power argument is stored +// as a double, so if there is a floating-point power implementation for this +// weight type, it will take precedence. Otherwise, the power argument's 53 bits +// of integer precision will be implicitly converted to a size_t and the default +// power implementation (iterated multiplication) will be used instead. +template +class PowerMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + explicit PowerMapper(double power) : power_(power) {} + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, Power(arc.weight, power_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const double power_; +}; + +// Mapper to reciprocate all non-Zero() weights. +template +class InvertWeightMapper { + public: + using FromArc = A; + using ToArc = A; + using Weight = typename FromArc::Weight; + + ToArc operator()(const FromArc &arc) const { + if (arc.weight == Weight::Zero()) return arc; + return ToArc(arc.ilabel, arc.olabel, Divide(Weight::One(), arc.weight), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } +}; + +// Mapper to map all non-Zero() weights to One(). +template +class RmWeightMapper { + public: + using FromArc = A; + using ToArc = B; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, + arc.weight != FromWeight::Zero() ? + ToWeight::One() : ToWeight::Zero(), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kWeightInvariantProperties) | kUnweighted; + } +}; + +// Mapper to quantize all weights. +template +class QuantizeMapper { + public: + using FromArc = A; + using ToArc = B; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + + QuantizeMapper() : delta_(kDelta) {} + + explicit QuantizeMapper(float d) : delta_(d) {} + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, arc.weight.Quantize(delta_), + arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return props & kWeightInvariantProperties; + } + + private: + const float delta_; +}; + +// Mapper from A to B under the assumption: +// +// B::Weight = A::Weight::ReverseWeight +// B::Label == A::Label +// B::StateId == A::StateId +// +// The weight is reversed, while the label and nextstate are preserved. +template +class ReverseWeightMapper { + public: + using FromArc = A; + using ToArc = B; + + constexpr ToArc operator()(const FromArc &arc) const { + return ToArc(arc.ilabel, arc.olabel, arc.weight.Reverse(), arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { return props; } +}; + +} // namespace fst + +#endif // FST_ARC_MAP_H_ diff --git a/projects/llm_framework/include/fst/arc.h b/projects/llm_framework/include/fst/arc.h new file mode 100644 index 00000000..651b11df --- /dev/null +++ b/projects/llm_framework/include/fst/arc.h @@ -0,0 +1,317 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Commonly used FST arc types. + +#ifndef FST_ARC_H_ +#define FST_ARC_H_ + +#include +#include +#include +#include + + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +template +struct ArcTpl { + public: + using Weight = W; + using Label = int; + using StateId = int; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ArcTpl() noexcept(std::is_nothrow_default_constructible::value) {} + + template + ArcTpl(Label ilabel, Label olabel, T &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = + new string(Weight::Type() == "tropical" ? "standard" : Weight::Type()); + return *type; + } +}; + +using StdArc = ArcTpl; +using LogArc = ArcTpl; +using Log64Arc = ArcTpl; +using SignedLogArc = ArcTpl; +using SignedLog64Arc = ArcTpl; +using MinMaxArc = ArcTpl; + +// Arc with integer labels and state IDs and string weights. +template +struct StringArc { + public: + using Label = int; + using Weight = StringWeight; + using StateId = int; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + StringArc() = default; + + template + StringArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = new string( + S == STRING_LEFT ? "left_standard_string" + : (S == STRING_RIGHT ? "right_standard_string" + : "restricted_standard_string")); + return *type; + } +}; + +// Arc with label and state Id type the same as template arg and with +// weights over the Gallic semiring w.r.t the output labels and weights of A. +template +struct GallicArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = GallicWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + GallicArc() = default; + + template + GallicArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + explicit GallicArc(const Arc &arc) + : ilabel(arc.ilabel), olabel(arc.ilabel), weight(arc.olabel, arc.weight), + nextstate(arc.nextstate) {} + + static const string &Type() { + static const auto *const type = new string( + (G == GALLIC_LEFT + ? "left_gallic_" + : (G == GALLIC_RIGHT + ? "right_gallic_" + : (G == GALLIC_RESTRICT + ? "restricted_gallic_" + : (G == GALLIC_MIN ? "min_gallic_" : "gallic_")))) + + Arc::Type()); + return *type; + } +}; + +// Arc with the reverse of the weight found in its template arg. +template +struct ReverseArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using AWeight = typename Arc::Weight; + using Weight = typename AWeight::ReverseWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ReverseArc() = default; + + template + ReverseArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = new string("reverse_" + Arc::Type()); + return *type; + } +}; + +// Arc with integer labels and state IDs and lexicographic weights. +template +struct LexicographicArc { + using Label = int; + using StateId = int; + using Weight = LexicographicWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + LexicographicArc() = default; + + template + LexicographicArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const string *const type = new string(Weight::Type()); + return *type; + } +}; + +// Arc with integer labels and state IDs and product weights. +template +struct ProductArc { + using Label = int; + using StateId = int; + using Weight = ProductWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ProductArc() = default; + + template + ProductArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = new string(Weight::Type()); + return *type; + } +}; + +// Arc with label and state ID type the same as first template argument and with +// weights over the n-th Cartesian power of the weight type of the template +// argument. +template +struct PowerArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = PowerWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + PowerArc() = default; + + template + PowerArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = + new string(Arc::Type() + "_^" + std::to_string(n)); + return *type; + } +}; + +// Arc with label and state ID type the same as first template argument and with +// weights over the arbitrary Cartesian power of the weight type. +template +struct SparsePowerArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::Label; + using Weight = SparsePowerWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + SparsePowerArc() = default; + + template + SparsePowerArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const string *const type = [] { + string type = Arc::Type() + "_^n"; + if (sizeof(K) != sizeof(uint32)) { + type += "_" + std::to_string(CHAR_BIT * sizeof(K)); + } + return new string(type); + }(); + return *type; + } +}; + +// Arc with label and state ID type the same as first template argument and with +// expectation weight over the first template argument's weight type and the +// second template argument. +template +struct ExpectationArc { + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using X1 = typename Arc::Weight; + using Weight = ExpectationWeight; + + Label ilabel; + Label olabel; + Weight weight; + StateId nextstate; + + ExpectationArc() = default; + + template + ExpectationArc(Label ilabel, Label olabel, W &&weight, StateId nextstate) + : ilabel(ilabel), + olabel(olabel), + weight(std::forward(weight)), + nextstate(nextstate) {} + + static const string &Type() { + static const auto *const type = + new string("expectation_" + Arc::Type() + "_" + X2::Type()); + return *type; + } +}; + +} // namespace fst + +#endif // FST_ARC_H_ diff --git a/projects/llm_framework/include/fst/arcfilter.h b/projects/llm_framework/include/fst/arcfilter.h new file mode 100644 index 00000000..598e543f --- /dev/null +++ b/projects/llm_framework/include/fst/arcfilter.h @@ -0,0 +1,93 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function objects to restrict which arcs are traversed in an FST. + +#ifndef FST_ARCFILTER_H_ +#define FST_ARCFILTER_H_ + + +#include +#include + + +namespace fst { + +// True for all arcs. +template +class AnyArcFilter { + public: + bool operator()(const Arc &arc) const { return true; } +}; + +// True for (input/output) epsilon arcs. +template +class EpsilonArcFilter { + public: + bool operator()(const Arc &arc) const { + return arc.ilabel == 0 && arc.olabel == 0; + } +}; + +// True for input epsilon arcs. +template +class InputEpsilonArcFilter { + public: + bool operator()(const Arc &arc) const { return arc.ilabel == 0; } +}; + +// True for output epsilon arcs. +template +class OutputEpsilonArcFilter { + public: + bool operator()(const Arc &arc) const { return arc.olabel == 0; } +}; + +// True if specified label matches (doesn't match) when keep_match is +// true (false). +template +class LabelArcFilter { + public: + using Label = typename Arc::Label; + + explicit LabelArcFilter(Label label, bool match_input = true, + bool keep_match = true) + : label_(label), match_input_(match_input), keep_match_(keep_match) {} + + bool operator()(const Arc &arc) const { + const bool match = (match_input_ ? arc.ilabel : arc.olabel) == label_; + return keep_match_ ? match : !match; + } + + private: + const Label label_; + const bool match_input_; + const bool keep_match_; +}; + +// True if specified labels match (don't match) when keep_match is true (false). +template +class MultiLabelArcFilter { + public: + using Label = typename Arc::Label; + + explicit MultiLabelArcFilter(bool match_input = true, bool keep_match = true) + : match_input_(match_input), keep_match_(keep_match) {} + + bool operator()(const Arc &arc) const { + const Label label = match_input_ ? arc.ilabel : arc.olabel; + const bool match = labels_.Find(label) != labels_.End(); + return keep_match_ ? match : !match; + } + + void AddLabel(Label label) { labels_.Insert(label); } + + private: + CompactSet labels_; + const bool match_input_; + const bool keep_match_; +}; + +} // namespace fst + +#endif // FST_ARCFILTER_H_ diff --git a/projects/llm_framework/include/fst/arcsort.h b/projects/llm_framework/include/fst/arcsort.h new file mode 100644 index 00000000..b5ab50e0 --- /dev/null +++ b/projects/llm_framework/include/fst/arcsort.h @@ -0,0 +1,211 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to sort arcs in an FST. + +#ifndef FST_ARCSORT_H_ +#define FST_ARCSORT_H_ + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +template +class ArcSortMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + constexpr ArcSortMapper(const Fst &fst, const Compare &comp) + : fst_(fst), comp_(comp), i_(0) {} + + // Allows updating Fst argument; pass only if changed. + ArcSortMapper(const ArcSortMapper &mapper, + const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_), comp_(mapper.comp_), i_(0) {} + + StateId Start() { return fst_.Start(); } + + Weight Final(StateId s) const { return fst_.Final(s); } + + void SetState(StateId s) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(s)); + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + arcs_.push_back(aiter.Value()); + } + std::sort(arcs_.begin(), arcs_.end(), comp_); + } + + bool Done() const { return i_ >= arcs_.size(); } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + MapSymbolsAction InputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + MapSymbolsAction OutputSymbolsAction() const { return MAP_COPY_SYMBOLS; } + + uint64 Properties(uint64 props) const { return comp_.Properties(props); } + + private: + const Fst &fst_; + const Compare &comp_; + std::vector arcs_; + ssize_t i_; // current arc position + + ArcSortMapper &operator=(const ArcSortMapper &) = delete; +}; + +// Sorts the arcs in an FST according to function object 'comp' of type Compare. +// This version modifies its input. Comparison function objects ILabelCompare +// and OLabelCompare are provided by the library. In general, Compare must meet +// the requirements for a comparison function object (e.g., similar to those +// used by std::sort). It must also have a member Properties(uint64) that +// specifies the known properties of the sorted FST; it takes as argument the +// input FST's known properties before the sort. +// +// Complexity: +// +// - Time: O(v d log d) +// - Space: O(d) +// +// where v = # of states and d = maximum out-degree. +template +void ArcSort(MutableFst *fst, Compare comp) { + ArcSortMapper mapper(*fst, comp); + StateMap(fst, mapper); +} + +using ArcSortFstOptions = CacheOptions; + +// Sorts the arcs in an FST according to function object 'comp' of type Compare. +// This version is a delayed FST. Comparsion function objects ILabelCompare and +// OLabelCompare are provided by the library. In general, Compare must meet the +// requirements for a comparision function object (e.g., similar to those +// used by std::sort). It must also have a member Properties(uint64) that +// specifies the known properties of the sorted FST; it takes as argument the +// input FST's known properties. +// +// Complexity: +// +// - Time: O(v d log d) +// - Space: O(d) +// +// where v = # of states visited, d = maximum out-degree of states visited. +// Constant time and space to visit an input state is assumed and exclusive of +// caching. +template +class ArcSortFst : public StateMapFst> { + using StateMapFst>::GetImpl; + + public: + using StateId = typename Arc::StateId; + using Mapper = ArcSortMapper; + + ArcSortFst(const Fst &fst, const Compare &comp) + : StateMapFst(fst, + ArcSortMapper(fst, comp)) {} + + ArcSortFst(const Fst &fst, const Compare &comp, + const ArcSortFstOptions &opts) + : StateMapFst(fst, Mapper(fst, comp), opts) {} + + // See Fst<>::Copy() for doc. + ArcSortFst(const ArcSortFst &fst, bool safe = false) + : StateMapFst(fst, safe) {} + + // Gets a copy of this ArcSortFst. See Fst<>::Copy() for further doc. + ArcSortFst *Copy(bool safe = false) const override { + return new ArcSortFst(*this, safe); + } + + size_t NumArcs(StateId s) const override { + return GetImpl()->GetFst()->NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) const override { + return GetImpl()->GetFst()->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) const override { + return GetImpl()->GetFst()->NumOutputEpsilons(s); + } +}; + +// Specialization for ArcSortFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const ArcSortFst &fst) + : StateIterator>>(fst) { + } +}; + +// Specialization for ArcSortFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + ArcIterator(const ArcSortFst &fst, typename Arc::StateId s) + : ArcIterator>>(fst, + s) {} +}; + +// Compare class for comparing input labels of arcs. +template +class ILabelCompare { + public: + constexpr ILabelCompare() {} + + constexpr bool operator()(const Arc &arc1, const Arc &arc2) const { + return arc1.ilabel < arc2.ilabel; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kILabelSorted | + (props & kAcceptor ? kOLabelSorted : 0); + } +}; + +// Compare class for comparing output labels of arcs. +template +class OLabelCompare { + public: + constexpr OLabelCompare() {} + + constexpr bool operator()(const Arc &arc1, const Arc &arc2) const { + return arc1.olabel < arc2.olabel; + } + + constexpr uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kOLabelSorted | + (props & kAcceptor ? kILabelSorted : 0); + } +}; + +// Useful aliases when using StdArc. + +template +using StdArcSortFst = ArcSortFst; + +using StdILabelCompare = ILabelCompare; + +using StdOLabelCompare = OLabelCompare; + +} // namespace fst + +#endif // FST_ARCSORT_H_ diff --git a/projects/llm_framework/include/fst/bi-table.h b/projects/llm_framework/include/fst/bi-table.h new file mode 100644 index 00000000..9651cfe8 --- /dev/null +++ b/projects/llm_framework/include/fst/bi-table.h @@ -0,0 +1,480 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for representing a bijective mapping between an arbitrary entry +// of type T and a signed integral ID. + +#ifndef FST_BI_TABLE_H_ +#define FST_BI_TABLE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace fst { + +// Bitables model bijective mappings between entries of an arbitrary type T and +// an signed integral ID of type I. The IDs are allocated starting from 0 in +// order. +// +// template +// class BiTable { +// public: +// +// // Required constructors. +// BiTable(); +// +// // Looks up integer ID from entry. If it doesn't exist and insert +// / is true, adds it; otherwise, returns -1. +// I FindId(const T &entry, bool insert = true); +// +// // Looks up entry from integer ID. +// const T &FindEntry(I) const; +// +// // Returns number of stored entries. +// I Size() const; +// }; + +// An implementation using a hash map for the entry to ID mapping. H is the +// hash function and E is the equality function. If passed to the constructor, +// ownership is given to this class. +template > +class HashBiTable { + public: + // Reserves space for table_size elements. If passing H and E to the + // constructor, this class owns them. + explicit HashBiTable(size_t table_size = 0, H *h = nullptr, E *e = nullptr) : + hash_func_(h ? h : new H()), hash_equal_(e ? e : new E()), + entry2id_(table_size, *hash_func_, *hash_equal_) { + if (table_size) id2entry_.reserve(table_size); + } + + HashBiTable(const HashBiTable &table) + : hash_func_(new H(*table.hash_func_)), + hash_equal_(new E(*table.hash_equal_)), + entry2id_(table.entry2id_.begin(), table.entry2id_.end(), + table.entry2id_.size(), *hash_func_, *hash_equal_), + id2entry_(table.id2entry_) {} + + I FindId(const T &entry, bool insert = true) { + if (!insert) { + const auto it = entry2id_.find(entry); + return it == entry2id_.end() ? -1 : it->second - 1; + } + I &id_ref = entry2id_[entry]; + if (id_ref == 0) { // T not found; stores and assigns a new ID. + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } + return id_ref - 1; // NB: id_ref = ID + 1. + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + // TODO(riley): Add fancy clear-to-size, as in CompactHashBiTable. + void Clear() { + entry2id_.clear(); + id2entry_.clear(); + } + + private: + std::unique_ptr hash_func_; + std::unique_ptr hash_equal_; + std::unordered_map entry2id_; + std::vector id2entry_; +}; + +// Enables alternative hash set representations below. +enum HSType { HS_STL = 0, HS_DENSE = 1, HS_SPARSE = 2, HS_FLAT = 3 }; + +// Default hash set is STL hash_set. +template +struct HashSet : public std::unordered_set> { + explicit HashSet(size_t n = 0, const H &h = H(), const E &e = E()) + : std::unordered_set>(n, h, e) {} + + void rehash(size_t n) {} +}; + +// An implementation using a hash set for the entry to ID mapping. The hash set +// holds keys which are either the ID or kCurrentKey. These keys can be mapped +// to entries either by looking up in the entry vector or, if kCurrentKey, in +// current_entry_. The hash and key equality functions map to entries first. H +// is the hash function and E is the equality function. If passed to the +// constructor, ownership is given to this class. +// TODO(rybach): remove support for (deprecated and unused) HS_DENSE, HS_SPARSE. +template , + HSType HS = HS_FLAT> +class CompactHashBiTable { + public: + friend class HashFunc; + friend class HashEqual; + + // Reserves space for table_size elements. If passing H and E to the + // constructor, this class owns them. + explicit CompactHashBiTable(size_t table_size = 0, H *h = nullptr, + E *e = nullptr) : + hash_func_(h ? h : new H()), hash_equal_(e ? e : new E()), + compact_hash_func_(*this), compact_hash_equal_(*this), + keys_(table_size, compact_hash_func_, compact_hash_equal_) { + if (table_size) id2entry_.reserve(table_size); + } + + CompactHashBiTable(const CompactHashBiTable &table) + : hash_func_(new H(*table.hash_func_)), + hash_equal_(new E(*table.hash_equal_)), + compact_hash_func_(*this), compact_hash_equal_(*this), + keys_(table.keys_.size(), compact_hash_func_, compact_hash_equal_), + id2entry_(table.id2entry_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + I FindId(const T &entry, bool insert = true) { + current_entry_ = &entry; + if (insert) { + auto result = keys_.insert(kCurrentKey); + if (!result.second) return *result.first; // Already exists. + // Overwrites kCurrentKey with a new key value; this is safe because it + // doesn't affect hashing or equality testing. + I key = id2entry_.size(); + const_cast(*result.first) = key; + id2entry_.push_back(entry); + return key; + } + const auto it = keys_.find(kCurrentKey); + return it == keys_.end() ? -1 : *it; + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + // Clears content; with argument, erases last n IDs. + void Clear(ssize_t n = -1) { + if (n < 0 || n >= id2entry_.size()) { // Clears completely. + keys_.clear(); + id2entry_.clear(); + } else if (n == id2entry_.size() - 1) { // Leaves only key 0. + const T entry = FindEntry(0); + keys_.clear(); + id2entry_.clear(); + FindId(entry, true); + } else { + while (n-- > 0) { + I key = id2entry_.size() - 1; + keys_.erase(key); + id2entry_.pop_back(); + } + keys_.rehash(0); + } + } + + private: + static_assert(std::is_signed::value, "I must be a signed type"); + // ... otherwise >= kCurrentKey comparisons as used below don't work. + // TODO(rybach): (1) remove kEmptyKey, kDeletedKey, (2) don't use >= for key + // comparison, (3) allow unsigned key types. + static constexpr I kCurrentKey = -1; + static constexpr I kEmptyKey = -2; + static constexpr I kDeletedKey = -3; + + class HashFunc { + public: + explicit HashFunc(const CompactHashBiTable &ht) : ht_(&ht) {} + + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*ht_->hash_func_)(ht_->Key2Entry(k)); + } else { + return 0; + } + } + + private: + const CompactHashBiTable *ht_; + }; + + class HashEqual { + public: + explicit HashEqual(const CompactHashBiTable &ht) : ht_(&ht) {} + + bool operator()(I k1, I k2) const { + if (k1 == k2) { + return true; + } else if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return (*ht_->hash_equal_)(ht_->Key2Entry(k1), ht_->Key2Entry(k2)); + } else { + return false; + } + } + + private: + const CompactHashBiTable *ht_; + }; + + using KeyHashSet = HashSet; + + const T &Key2Entry(I k) const { + if (k == kCurrentKey) { + return *current_entry_; + } else { + return id2entry_[k]; + } + } + + std::unique_ptr hash_func_; + std::unique_ptr hash_equal_; + HashFunc compact_hash_func_; + HashEqual compact_hash_equal_; + KeyHashSet keys_; + std::vector id2entry_; + const T *current_entry_; +}; + +template +constexpr I CompactHashBiTable::kCurrentKey; + +template +constexpr I CompactHashBiTable::kEmptyKey; + +template +constexpr I CompactHashBiTable::kDeletedKey; + +// An implementation using a vector for the entry to ID mapping. It is passed a +// function object FP that should fingerprint entries uniquely to an integer +// that can used as a vector index. Normally, VectorBiTable constructs the FP +// object. The user can instead pass in this object; in that case, VectorBiTable +// takes its ownership. +template +class VectorBiTable { + public: + // Reserves table_size cells of space. If passing FP argument to the + // constructor, this class owns it. + explicit VectorBiTable(FP *fp = nullptr, size_t table_size = 0) : + fp_(fp ? fp : new FP()) { + if (table_size) id2entry_.reserve(table_size); + } + + VectorBiTable(const VectorBiTable &table) + : fp_(new FP(*table.fp_)), fp2id_(table.fp2id_), + id2entry_(table.id2entry_) {} + + I FindId(const T &entry, bool insert = true) { + ssize_t fp = (*fp_)(entry); + if (fp >= fp2id_.size()) fp2id_.resize(fp + 1); + I &id_ref = fp2id_[fp]; + if (id_ref == 0) { // T not found. + if (insert) { // Stores and assigns a new ID. + id2entry_.push_back(entry); + id_ref = id2entry_.size(); + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1. + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + const FP &Fingerprint() const { return *fp_; } + + private: + std::unique_ptr fp_; + std::vector fp2id_; + std::vector id2entry_; +}; + +// An implementation using a vector and a compact hash table. The selecting +// functor S returns true for entries to be hashed in the vector. The +// fingerprinting functor FP returns a unique fingerprint for each entry to be +// hashed in the vector (these need to be suitable for indexing in a vector). +// The hash functor H is used when hashing entry into the compact hash table. +// If passed to the constructor, ownership is given to this class. +template +class VectorHashBiTable { + public: + friend class HashFunc; + friend class HashEqual; + + explicit VectorHashBiTable(S *s, FP *fp, H *h, size_t vector_size = 0, + size_t entry_size = 0) + : selector_(s), fp_(fp), h_(h), hash_func_(*this), hash_equal_(*this), + keys_(0, hash_func_, hash_equal_) { + if (vector_size) fp2id_.reserve(vector_size); + if (entry_size) id2entry_.reserve(entry_size); + } + + VectorHashBiTable(const VectorHashBiTable &table) + : selector_(new S(table.s_)), fp_(new FP(*table.fp_)), + h_(new H(*table.h_)), id2entry_(table.id2entry_), + fp2id_(table.fp2id_), hash_func_(*this), hash_equal_(*this), + keys_(table.keys_.size(), hash_func_, hash_equal_) { + keys_.insert(table.keys_.begin(), table.keys_.end()); + } + + I FindId(const T &entry, bool insert = true) { + if ((*selector_)(entry)) { // Uses the vector if selector_(entry) == true. + uint64 fp = (*fp_)(entry); + if (fp2id_.size() <= fp) fp2id_.resize(fp + 1, 0); + if (fp2id_[fp] == 0) { // T not found. + if (insert) { // Stores and assigns a new ID. + id2entry_.push_back(entry); + fp2id_[fp] = id2entry_.size(); + } else { + return -1; + } + } + return fp2id_[fp] - 1; // NB: assoc_value = ID + 1. + } else { // Uses the hash table otherwise. + current_entry_ = &entry; + const auto it = keys_.find(kCurrentKey); + if (it == keys_.end()) { + if (insert) { + I key = id2entry_.size(); + id2entry_.push_back(entry); + keys_.insert(key); + return key; + } else { + return -1; + } + } else { + return *it; + } + } + } + + const T &FindEntry(I s) const { return id2entry_[s]; } + + I Size() const { return id2entry_.size(); } + + const S &Selector() const { return *selector_; } + + const FP &Fingerprint() const { return *fp_; } + + const H &Hash() const { return *h_; } + + private: + static constexpr I kCurrentKey = -1; + static constexpr I kEmptyKey = -2; + + class HashFunc { + public: + explicit HashFunc(const VectorHashBiTable &ht) : ht_(&ht) {} + + size_t operator()(I k) const { + if (k >= kCurrentKey) { + return (*(ht_->h_))(ht_->Key2Entry(k)); + } else { + return 0; + } + } + + private: + const VectorHashBiTable *ht_; + }; + + class HashEqual { + public: + explicit HashEqual(const VectorHashBiTable &ht) : ht_(&ht) {} + + bool operator()(I k1, I k2) const { + if (k1 >= kCurrentKey && k2 >= kCurrentKey) { + return ht_->Key2Entry(k1) == ht_->Key2Entry(k2); + } else { + return k1 == k2; + } + } + + private: + const VectorHashBiTable *ht_; + }; + + using KeyHashSet = HashSet; + + const T &Key2Entry(I k) const { + if (k == kCurrentKey) { + return *current_entry_; + } else { + return id2entry_[k]; + } + } + + std::unique_ptr selector_; // True if entry hashed into vector. + std::unique_ptr fp_; // Fingerprint used for hashing into vector. + std::unique_ptr h_; // Hash funcion used for hashing into hash_set. + + std::vector id2entry_; // Maps state IDs to entry. + std::vector fp2id_; // Maps entry fingerprints to IDs. + + // Compact implementation of the hash table mapping entries to state IDs + // using the hash function h_. + HashFunc hash_func_; + HashEqual hash_equal_; + KeyHashSet keys_; + const T *current_entry_; +}; + +template +constexpr I VectorHashBiTable::kCurrentKey; + +template +constexpr I VectorHashBiTable::kEmptyKey; + +// An implementation using a hash map for the entry to ID mapping. This version +// permits erasing of arbitrary states. The entry T must have == defined and +// its default constructor must produce a entry that will never be seen. F is +// the hash function. +template +class ErasableBiTable { + public: + ErasableBiTable() : first_(0) {} + + I FindId(const T &entry, bool insert = true) { + I &id_ref = entry2id_[entry]; + if (id_ref == 0) { // T not found. + if (insert) { // Stores and assigns a new ID. + id2entry_.push_back(entry); + id_ref = id2entry_.size() + first_; + } else { + return -1; + } + } + return id_ref - 1; // NB: id_ref = ID + 1. + } + + const T &FindEntry(I s) const { return id2entry_[s - first_]; } + + I Size() const { return id2entry_.size(); } + + void Erase(I s) { + auto &ref = id2entry_[s - first_]; + entry2id_.erase(ref); + ref = empty_entry_; + while (!id2entry_.empty() && id2entry_.front() == empty_entry_) { + id2entry_.pop_front(); + ++first_; + } + } + + private: + std::unordered_map entry2id_; + std::deque id2entry_; + const T empty_entry_; + I first_; // I of first element in the deque. +}; + +} // namespace fst + +#endif // FST_BI_TABLE_H_ diff --git a/projects/llm_framework/include/fst/cache.h b/projects/llm_framework/include/fst/cache.h new file mode 100644 index 00000000..13b7cf81 --- /dev/null +++ b/projects/llm_framework/include/fst/cache.h @@ -0,0 +1,1327 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// An FST implementation that caches FST elements of a delayed computation. + +#ifndef FST_CACHE_H_ +#define FST_CACHE_H_ + +#include +#include +#include +#include + +#include +#include + +#include + +#include + +DECLARE_bool(fst_default_cache_gc); +DECLARE_int64(fst_default_cache_gc_limit); + +namespace fst { + +// Options for controlling caching behavior; higher level than CacheImplOptions. +struct CacheOptions { + bool gc; // Enables GC. + size_t gc_limit; // Number of bytes allowed before GC. + + explicit CacheOptions(bool gc = FLAGS_fst_default_cache_gc, + size_t gc_limit = FLAGS_fst_default_cache_gc_limit) + : gc(gc), gc_limit(gc_limit) {} +}; + +// Options for controlling caching behavior, at a lower level than +// CacheOptions; templated on the cache store and allows passing the store. +template +struct CacheImplOptions { + bool gc; // Enables GC. + size_t gc_limit; // Number of bytes allowed before GC. + CacheStore *store; // Cache store. + bool own_store; // Should CacheImpl takes ownership of the store? + + explicit CacheImplOptions(bool gc = FLAGS_fst_default_cache_gc, + size_t gc_limit = FLAGS_fst_default_cache_gc_limit, + CacheStore *store = nullptr) + : gc(gc), gc_limit(gc_limit), store(store), own_store(true) {} + + explicit CacheImplOptions(const CacheOptions &opts) + : gc(opts.gc), gc_limit(opts.gc_limit), store(nullptr), own_store(true) {} +}; + +// Cache flags. +constexpr uint32 kCacheFinal = 0x0001; // Final weight has been cached. +constexpr uint32 kCacheArcs = 0x0002; // Arcs have been cached. +constexpr uint32 kCacheInit = 0x0004; // Initialized by GC. +constexpr uint32 kCacheRecent = 0x0008; // Visited since GC. +constexpr uint32 kCacheFlags = + kCacheFinal | kCacheArcs | kCacheInit | kCacheRecent; + +// Cache state, with arcs stored in a per-state std::vector. +template > +class CacheState { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ArcAllocator = M; + using StateAllocator = + typename ArcAllocator::template rebind>::other; + + // Provides STL allocator for arcs. + explicit CacheState(const ArcAllocator &alloc) + : final_(Weight::Zero()), + niepsilons_(0), + noepsilons_(0), + arcs_(alloc), + flags_(0), + ref_count_(0) {} + + CacheState(const CacheState &state, const ArcAllocator &alloc) + : final_(state.Final()), + niepsilons_(state.NumInputEpsilons()), + noepsilons_(state.NumOutputEpsilons()), + arcs_(state.arcs_.begin(), state.arcs_.end(), alloc), + flags_(state.Flags()), + ref_count_(0) {} + + void Reset() { + final_ = Weight::Zero(); + niepsilons_ = 0; + noepsilons_ = 0; + ref_count_ = 0; + flags_ = 0; + arcs_.clear(); + } + + Weight Final() const { return final_; } + + size_t NumInputEpsilons() const { return niepsilons_; } + + size_t NumOutputEpsilons() const { return noepsilons_; } + + size_t NumArcs() const { return arcs_.size(); } + + const Arc &GetArc(size_t n) const { return arcs_[n]; } + + // Used by the ArcIterator> efficient implementation. + const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; } + + // Accesses flags; used by the caller. + uint32 Flags() const { return flags_; } + + // Accesses ref count; used by the caller. + int RefCount() const { return ref_count_; } + + void SetFinal(Weight weight) { final_ = std::move(weight); } + + void ReserveArcs(size_t n) { arcs_.reserve(n); } + + // Adds one arc at a time with all needed book-keeping; use PushArc and + // SetArcs for a more efficient alternative. + void AddArc(const Arc &arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(arc); + } + + void AddArc(Arc &&arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(std::move(arc)); + } + + // Adds one arc at a time with delayed book-keeping; finalize with SetArcs(). + void PushArc(const Arc &arc) { arcs_.push_back(arc); } + + void PushArc(Arc &&arc) { arcs_.push_back(std::move(arc)); } + + // Adds one arc at a time with delayed book-keeping; finalize with SetArcs(). + template + void EmplaceArc(T &&... ctor_args) { + arcs_.emplace_back(std::forward(ctor_args)...); + } + + // Finalizes arcs book-keeping; call only once. + void SetArcs() { + for (const auto &arc : arcs_) { + IncrementNumEpsilons(arc); + } + } + + // Modifies nth arc. + void SetArc(const Arc &arc, size_t n) { + if (arcs_[n].ilabel == 0) --niepsilons_; + if (arcs_[n].olabel == 0) --noepsilons_; + IncrementNumEpsilons(arc); + arcs_[n] = arc; + } + + // Deletes all arcs. + void DeleteArcs() { + niepsilons_ = 0; + noepsilons_ = 0; + arcs_.clear(); + } + + void DeleteArcs(size_t n) { + for (size_t i = 0; i < n; ++i) { + if (arcs_.back().ilabel == 0) --niepsilons_; + if (arcs_.back().olabel == 0) --noepsilons_; + arcs_.pop_back(); + } + } + + // Sets status flags; used by the caller. + void SetFlags(uint32 flags, uint32 mask) const { + flags_ &= ~mask; + flags_ |= flags; + } + + // Mutates reference counts; used by the caller. + + int IncrRefCount() const { return ++ref_count_; } + + int DecrRefCount() const { return --ref_count_; } + + // Used by the ArcIterator> efficient implementation. + int *MutableRefCount() const { return &ref_count_; } + + // Used for state class allocation. + void *operator new(size_t size, StateAllocator *alloc) { + return alloc->allocate(1); + } + + // For state destruction and memory freeing. + static void Destroy(CacheState *state, StateAllocator *alloc) { + if (state) { + state->~CacheState(); + alloc->deallocate(state, 1); + } + } + + private: + // Update the number of epsilons as a result of having added an arc. + void IncrementNumEpsilons(const Arc &arc) { + if (arc.ilabel == 0) ++niepsilons_; + if (arc.olabel == 0) ++noepsilons_; + } + + Weight final_; // Final weight. + size_t niepsilons_; // # of input epsilons. + size_t noepsilons_; // # of output epsilons. + std::vector arcs_; // Arcs representation. + mutable uint32 flags_; + mutable int ref_count_; // If 0, available for GC. +}; + +// Cache store, allocating and storing states, providing a mapping from state +// IDs to cached states, and an iterator over these states. The state template +// argument must implement the CacheState interface. The state for a StateId s +// is constructed when requested by GetMutableState(s) if it is not yet stored. +// Initially, a state has a reference count of zero, but the user may increment +// or decrement this to control the time of destruction. In particular, a state +// is destroyed when: +// +// 1. This instance is destroyed, or +// 2. Clear() or Delete() is called, or +// 3. Possibly (implementation-dependently) when: +// - Garbage collection is enabled (as defined by opts.gc), +// - The cache store size exceeds the limits (as defined by opts.gc_limits), +// - The state's reference count is zero, and +// - The state is not the most recently requested state. +// +// template +// class CacheStore { +// public: +// using State = S; +// using Arc = typename State::Arc; +// using StateId = typename Arc::StateId; +// +// // Required constructors/assignment operators. +// explicit CacheStore(const CacheOptions &opts); +// +// // Returns nullptr if state is not stored. +// const State *GetState(StateId s); +// +// // Creates state if state is not stored. +// State *GetMutableState(StateId s); +// +// // Similar to State::AddArc() but updates cache store book-keeping. +// void AddArc(State *state, const Arc &arc); +// +// // Similar to State::SetArcs() but updates cache store book-keeping; call +// // only once. +// void SetArcs(State *state); +// +// // Similar to State::DeleteArcs() but updates cache store book-keeping. +// +// void DeleteArcs(State *state); +// +// void DeleteArcs(State *state, size_t n); +// +// // Deletes all cached states. +// void Clear(); +// +// // Number of cached states. +// StateId CountStates(); +// +// // Iterates over cached states (in an arbitrary order); only needed if +// // opts.gc is true. +// bool Done() const; // End of iteration. +// StateId Value() const; // Current state. +// void Next(); // Advances to next state (when !Done). +// void Reset(); // Returns to initial condition. +// void Delete(); // Deletes current state and advances to next. +// }; + +// Container cache stores. + +// This class uses a vector of pointers to states to store cached states. +template +class VectorCacheStore { + public: + using State = S; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + using StateList = std::list>; + + // Required constructors/assignment operators. + explicit VectorCacheStore(const CacheOptions &opts) : cache_gc_(opts.gc) { + Clear(); + Reset(); + } + + VectorCacheStore(const VectorCacheStore &store) + : cache_gc_(store.cache_gc_) { + CopyStates(store); + Reset(); + } + + ~VectorCacheStore() { Clear(); } + + VectorCacheStore &operator=(const VectorCacheStore &store) { + if (this != &store) { + CopyStates(store); + Reset(); + } + return *this; + } + + bool InBounds(StateId s) const { + return s < static_cast(state_vec_.size()); + } + + // Returns nullptr if state is not stored. + const State *GetState(StateId s) const { + return InBounds(s) ? state_vec_[s] : nullptr; + } + + // Creates state if state is not stored. + State *GetMutableState(StateId s) { + State *state = nullptr; + if (InBounds(s)) { + state = state_vec_[s]; + } else { + state_vec_.resize(s + 1, nullptr); + } + if (!state) { + state = new (&state_alloc_) State(arc_alloc_); + state_vec_[s] = state; + if (cache_gc_) state_list_.push_back(s); + } + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping + void AddArc(State *state, const Arc &arc) { state->AddArc(arc); } + + // Similar to State::SetArcs() but updates cache store book-keeping; call + // only once. + void SetArcs(State *state) { state->SetArcs(); } + + // Deletes all arcs. + void DeleteArcs(State *state) { state->DeleteArcs(); } + + // Deletes some arcs. + void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); } + + // Deletes all cached states. + void Clear() { + for (State *s : state_vec_) { + State::Destroy(s, &state_alloc_); + } + state_vec_.clear(); + state_list_.clear(); + } + + StateId CountStates() const { + return std::count_if(state_vec_.begin(), state_vec_.end(), + [](const State *s) { return s != nullptr; }); + } + + // Iterates over cached states (in an arbitrary order); only works if GC is + // enabled (o.w. avoiding state_list_ overhead). + bool Done() const { return iter_ == state_list_.end(); } + + StateId Value() const { return *iter_; } + + void Next() { ++iter_; } + + void Reset() { iter_ = state_list_.begin(); } + + // Deletes current state and advances to next. + void Delete() { + State::Destroy(state_vec_[*iter_], &state_alloc_); + state_vec_[*iter_] = nullptr; + state_list_.erase(iter_++); + } + + private: + void CopyStates(const VectorCacheStore &store) { + Clear(); + state_vec_.reserve(store.state_vec_.size()); + for (size_t s = 0; s < store.state_vec_.size(); ++s) { + State *state = nullptr; + const auto *store_state = store.state_vec_[s]; + if (store_state) { + state = new (&state_alloc_) State(*store_state, arc_alloc_); + if (cache_gc_) state_list_.push_back(s); + } + state_vec_.push_back(state); + } + } + + bool cache_gc_; // Supports iteration when true. + std::vector state_vec_; // Vector of states (or null). + StateList state_list_; // List of states. + typename StateList::iterator iter_; // State list iterator. + typename State::StateAllocator state_alloc_; // For state allocation. + typename State::ArcAllocator arc_alloc_; // For arc allocation. +}; + +// This class uses a hash map from state IDs to pointers to cached states. +template +class HashCacheStore { + public: + using State = S; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + + using StateMap = + std::unordered_map, + std::equal_to, + PoolAllocator>>; + + // Required constructors/assignment operators. + explicit HashCacheStore(const CacheOptions &opts) { + Clear(); + Reset(); + } + + HashCacheStore(const HashCacheStore &store) { + CopyStates(store); + Reset(); + } + + ~HashCacheStore() { Clear(); } + + HashCacheStore &operator=(const HashCacheStore &store) { + if (this != &store) { + CopyStates(store); + Reset(); + } + return *this; + } + + // Returns nullptr if state is not stored. + const State *GetState(StateId s) const { + const auto it = state_map_.find(s); + return it != state_map_.end() ? it->second : nullptr; + } + + // Creates state if state is not stored. + State *GetMutableState(StateId s) { + auto *&state = state_map_[s]; + if (!state) state = new (&state_alloc_) State(arc_alloc_); + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping. + void AddArc(State *state, const Arc &arc) { state->AddArc(arc); } + + // Similar to State::SetArcs() but updates internal cache size; call only + // once. + void SetArcs(State *state) { state->SetArcs(); } + + // Deletes all arcs. + void DeleteArcs(State *state) { state->DeleteArcs(); } + + // Deletes some arcs. + void DeleteArcs(State *state, size_t n) { state->DeleteArcs(n); } + + // Deletes all cached states. + void Clear() { + for (auto it = state_map_.begin(); it != state_map_.end(); ++it) { + State::Destroy(it->second, &state_alloc_); + } + state_map_.clear(); + } + + StateId CountStates() const { return state_map_.size(); } + + // Iterates over cached states (in an arbitrary order). + bool Done() const { return iter_ == state_map_.end(); } + + StateId Value() const { return iter_->first; } + + void Next() { ++iter_; } + + void Reset() { iter_ = state_map_.begin(); } + + // Deletes current state and advances to next. + void Delete() { + State::Destroy(iter_->second, &state_alloc_); + state_map_.erase(iter_++); + } + + private: + void CopyStates(const HashCacheStore &store) { + Clear(); + for (auto it = store.state_map_.begin(); it != store.state_map_.end(); + ++it) { + state_map_[it->first] = + new (&state_alloc_) State(*it->second, arc_alloc_); + } + } + + StateMap state_map_; // Map from state ID to state. + typename StateMap::iterator iter_; // State map iterator. + typename State::StateAllocator state_alloc_; // For state allocation. + typename State::ArcAllocator arc_alloc_; // For arc allocation. +}; + +// Garbage-colllection cache stores. + +// This class implements a simple garbage collection scheme when +// 'opts.gc_limit = 0'. In particular, the first cached state is reused for each +// new state so long as the reference count is zero on the to-be-reused state. +// Otherwise, the full underlying store is used. The caller can increment the +// reference count to inhibit the GC of in-use states (e.g., in an ArcIterator). +// +// The typical use case for this optimization is when a single pass over a +// cached +// FST is performed with only one-state expanded at a time. +template +class FirstCacheStore { + public: + using State = typename CacheStore::State; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + + // Required constructors/assignment operators. + explicit FirstCacheStore(const CacheOptions &opts) + : store_(opts), + cache_gc_(opts.gc_limit == 0), // opts.gc ignored historically. + cache_first_state_id_(kNoStateId), + cache_first_state_(nullptr) {} + + FirstCacheStore(const FirstCacheStore &store) + : store_(store.store_), + cache_gc_(store.cache_gc_), + cache_first_state_id_(store.cache_first_state_id_), + cache_first_state_(store.cache_first_state_id_ != kNoStateId + ? store_.GetMutableState(0) + : nullptr) {} + + FirstCacheStore &operator=( + const FirstCacheStore &store) { + if (this != &store) { + store_ = store.store_; + cache_gc_ = store.cache_gc_; + cache_first_state_id_ = store.cache_first_state_id_; + cache_first_state_ = store.cache_first_state_id_ != kNoStateId + ? store_.GetMutableState(0) + : nullptr; + } + return *this; + } + + // Returns nullptr if state is not stored. + const State *GetState(StateId s) const { + // store_ state 0 may hold first cached state; the rest are shifted by 1. + return s == cache_first_state_id_ ? cache_first_state_ + : store_.GetState(s + 1); + } + + // Creates state if state is not stored. + State *GetMutableState(StateId s) { + // store_ state 0 used to hold first cached state; the rest are shifted by + // 1. + if (cache_first_state_id_ == s) { + return cache_first_state_; // Request for first cached state. + } + if (cache_gc_) { + if (cache_first_state_id_ == kNoStateId) { + cache_first_state_id_ = s; // Sets first cached state. + cache_first_state_ = store_.GetMutableState(0); + cache_first_state_->SetFlags(kCacheInit, kCacheInit); + cache_first_state_->ReserveArcs(2 * kAllocSize); + return cache_first_state_; + } else if (cache_first_state_->RefCount() == 0) { + cache_first_state_id_ = s; // Updates first cached state. + cache_first_state_->Reset(); + cache_first_state_->SetFlags(kCacheInit, kCacheInit); + return cache_first_state_; + } else { // Keeps first cached state. + cache_first_state_->SetFlags(0, kCacheInit); // Clears initialized bit. + cache_gc_ = false; // Disables GC. + } + } + auto *state = store_.GetMutableState(s + 1); + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping. + void AddArc(State *state, const Arc &arc) { store_.AddArc(state, arc); } + + // Similar to State::SetArcs() but updates internal cache size; call only + // once. + void SetArcs(State *state) { store_.SetArcs(state); } + + // Deletes all arcs + void DeleteArcs(State *state) { store_.DeleteArcs(state); } + + // Deletes some arcs + void DeleteArcs(State *state, size_t n) { store_.DeleteArcs(state, n); } + + // Deletes all cached states + void Clear() { + store_.Clear(); + cache_first_state_id_ = kNoStateId; + cache_first_state_ = nullptr; + } + + StateId CountStates() const { return store_.CountStates(); } + + // Iterates over cached states (in an arbitrary order). Only needed if GC is + // enabled. + bool Done() const { return store_.Done(); } + + StateId Value() const { + // store_ state 0 may hold first cached state; rest shifted + 1. + const auto s = store_.Value(); + return s ? s - 1 : cache_first_state_id_; + } + + void Next() { store_.Next(); } + + void Reset() { store_.Reset(); } + + // Deletes current state and advances to next. + void Delete() { + if (Value() == cache_first_state_id_) { + cache_first_state_id_ = kNoStateId; + cache_first_state_ = nullptr; + } + store_.Delete(); + } + + private: + CacheStore store_; // Underlying store. + bool cache_gc_; // GC enabled. + StateId cache_first_state_id_; // First cached state ID. + State *cache_first_state_; // First cached state. +}; + +// This class implements mark-sweep garbage collection on an underlying cache +// store. If GC is enabled, garbage collection of states is performed in a +// rough approximation of LRU order once when 'gc_limit' bytes is reached. The +// caller can increment the reference count to inhibit the GC of in-use state +// (e.g., in an ArcIterator). With GC enabled, the 'gc_limit' parameter allows +// the caller to trade-off time vs. space. +template +class GCCacheStore { + public: + using State = typename CacheStore::State; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + + // Required constructors/assignment operators. + explicit GCCacheStore(const CacheOptions &opts) + : store_(opts), + cache_gc_request_(opts.gc), + cache_limit_(opts.gc_limit > kMinCacheLimit ? opts.gc_limit + : kMinCacheLimit), + cache_gc_(false), + cache_size_(0) {} + + // Returns 0 if state is not stored. + const State *GetState(StateId s) const { return store_.GetState(s); } + + // Creates state if state is not stored + State *GetMutableState(StateId s) { + auto *state = store_.GetMutableState(s); + if (cache_gc_request_ && !(state->Flags() & kCacheInit)) { + state->SetFlags(kCacheInit, kCacheInit); + cache_size_ += sizeof(State) + state->NumArcs() * sizeof(Arc); + // GC is enabled once an uninited state (from underlying store) is seen. + cache_gc_ = true; + if (cache_size_ > cache_limit_) GC(state, false); + } + return state; + } + + // Similar to State::AddArc() but updates cache store book-keeping. + void AddArc(State *state, const Arc &arc) { + store_.AddArc(state, arc); + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ += sizeof(Arc); + if (cache_size_ > cache_limit_) GC(state, false); + } + } + + // Similar to State::SetArcs() but updates internal cache size; call only + // once. + void SetArcs(State *state) { + store_.SetArcs(state); + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ += state->NumArcs() * sizeof(Arc); + if (cache_size_ > cache_limit_) GC(state, false); + } + } + + // Deletes all arcs. + void DeleteArcs(State *state) { + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ -= state->NumArcs() * sizeof(Arc); + } + store_.DeleteArcs(state); + } + + // Deletes some arcs. + void DeleteArcs(State *state, size_t n) { + if (cache_gc_ && (state->Flags() & kCacheInit)) { + cache_size_ -= n * sizeof(Arc); + } + store_.DeleteArcs(state, n); + } + + // Deletes all cached states. + void Clear() { + store_.Clear(); + cache_size_ = 0; + } + + StateId CountStates() const { return store_.CountStates(); } + + // Iterates over cached states (in an arbitrary order); only needed if GC is + // enabled. + bool Done() const { return store_.Done(); } + + StateId Value() const { return store_.Value(); } + + void Next() { store_.Next(); } + + void Reset() { store_.Reset(); } + + // Deletes current state and advances to next. + void Delete() { + if (cache_gc_) { + const auto *state = store_.GetState(Value()); + if (state->Flags() & kCacheInit) { + cache_size_ -= sizeof(State) + state->NumArcs() * sizeof(Arc); + } + } + store_.Delete(); + } + + // Removes from the cache store (not referenced-counted and not the current) + // states that have not been accessed since the last GC until at most + // cache_fraction * cache_limit_ bytes are cached. If that fails to free + // enough, attempts to uncaching recently visited states as well. If still + // unable to free enough memory, then widens cache_limit_. + void GC(const State *current, bool free_recent, float cache_fraction = 0.666); + + // Returns the current cache size in bytes or 0 if GC is disabled. + size_t CacheSize() const { return cache_size_; } + + // Returns the cache limit in bytes. + size_t CacheLimit() const { return cache_limit_; } + + private: + static constexpr size_t kMinCacheLimit = 8096; // Minimum cache limit. + + CacheStore store_; // Underlying store. + bool cache_gc_request_; // GC requested but possibly not yet enabled. + size_t cache_limit_; // Number of bytes allowed before GC. + bool cache_gc_; // GC enabled + size_t cache_size_; // Number of bytes cached. +}; + +template +void GCCacheStore::GC(const State *current, bool free_recent, + float cache_fraction) { + if (!cache_gc_) return; + VLOG(2) << "GCCacheStore: Enter GC: object = " + << "(" << this << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; + size_t cache_target = cache_fraction * cache_limit_; + store_.Reset(); + while (!store_.Done()) { + auto *state = store_.GetMutableState(store_.Value()); + if (cache_size_ > cache_target && state->RefCount() == 0 && + (free_recent || !(state->Flags() & kCacheRecent)) && state != current) { + if (state->Flags() & kCacheInit) { + size_t size = sizeof(State) + state->NumArcs() * sizeof(Arc); + if (size < cache_size_) { + cache_size_ -= size; + } + } + store_.Delete(); + } else { + state->SetFlags(0, kCacheRecent); + store_.Next(); + } + } + if (!free_recent && cache_size_ > cache_target) { // Recurses on recent. + GC(current, true, cache_fraction); + } else if (cache_target > 0) { // Widens cache limit. + while (cache_size_ > cache_target) { + cache_limit_ *= 2; + cache_target *= 2; + } + } else if (cache_size_ > 0) { + FSTERROR() << "GCCacheStore:GC: Unable to free all cached states"; + } + VLOG(2) << "GCCacheStore: Exit GC: object = " + << "(" << this << "), free recently cached = " << free_recent + << ", cache size = " << cache_size_ + << ", cache frac = " << cache_fraction + << ", cache limit = " << cache_limit_ << "\n"; +} + +template +constexpr size_t GCCacheStore::kMinCacheLimit; + +// This class is the default cache state and store used by CacheBaseImpl. +// It uses VectorCacheStore for storage decorated by FirstCacheStore +// and GCCacheStore to do (optional) garbage collection. +template +class DefaultCacheStore + : public GCCacheStore>>> { + public: + explicit DefaultCacheStore(const CacheOptions &opts) + : GCCacheStore>>>(opts) { + } +}; + +namespace internal { + +// This class is used to cache FST elements stored in states of type State +// (see CacheState) with the flags used to indicate what has been cached. Use +// HasStart(), HasFinal(), and HasArcs() to determine if cached and SetStart(), +// SetFinal(), AddArc(), (or PushArc() and SetArcs()) to cache. Note that you +// must set the final weight even if the state is non-final to mark it as +// cached. The state storage method and any garbage collection policy are +// determined by the cache store. If the store is passed in with the options, +// CacheBaseImpl takes ownership. +template > +class CacheBaseImpl : public FstImpl { + public: + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = CacheStore; + + using FstImpl::Type; + using FstImpl::Properties; + + explicit CacheBaseImpl(const CacheOptions &opts = CacheOptions()) + : has_start_(false), + cache_start_(kNoStateId), + nknown_states_(0), + min_unexpanded_state_id_(0), + max_expanded_state_id_(-1), + cache_gc_(opts.gc), + cache_limit_(opts.gc_limit), + cache_store_(new CacheStore(opts)), + new_cache_store_(true), + own_cache_store_(true) {} + + explicit CacheBaseImpl(const CacheImplOptions &opts) + : has_start_(false), + cache_start_(kNoStateId), + nknown_states_(0), + min_unexpanded_state_id_(0), + max_expanded_state_id_(-1), + cache_gc_(opts.gc), + cache_limit_(opts.gc_limit), + cache_store_(opts.store ? opts.store : new CacheStore(CacheOptions( + opts.gc, opts.gc_limit))), + new_cache_store_(!opts.store), + own_cache_store_(opts.store ? opts.own_store : true) {} + + // Preserve gc parameters. If preserve_cache is true, also preserves + // cache data. + CacheBaseImpl(const CacheBaseImpl &impl, + bool preserve_cache = false) + : FstImpl(), + has_start_(false), + cache_start_(kNoStateId), + nknown_states_(0), + min_unexpanded_state_id_(0), + max_expanded_state_id_(-1), + cache_gc_(impl.cache_gc_), + cache_limit_(impl.cache_limit_), + cache_store_(new CacheStore(CacheOptions(cache_gc_, cache_limit_))), + new_cache_store_(impl.new_cache_store_ || !preserve_cache), + own_cache_store_(true) { + if (preserve_cache) { + *cache_store_ = *impl.cache_store_; + has_start_ = impl.has_start_; + cache_start_ = impl.cache_start_; + nknown_states_ = impl.nknown_states_; + expanded_states_ = impl.expanded_states_; + min_unexpanded_state_id_ = impl.min_unexpanded_state_id_; + max_expanded_state_id_ = impl.max_expanded_state_id_; + } + } + + ~CacheBaseImpl() override { if (own_cache_store_) delete cache_store_; } + + void SetStart(StateId s) { + cache_start_ = s; + has_start_ = true; + if (s >= nknown_states_) nknown_states_ = s + 1; + } + + void SetFinal(StateId s, Weight weight) { + auto *state = cache_store_->GetMutableState(s); + state->SetFinal(std::move(weight)); + static constexpr auto flags = kCacheFinal | kCacheRecent; + state->SetFlags(flags, flags); + } + +// Disabled to ensure PushArc not AddArc is used in existing code +// TODO(sorenj): re-enable for backing store +#if 0 + // AddArc adds a single arc to a state and does incremental cache + // book-keeping. For efficiency, prefer PushArc and SetArcs below + // when possible. + void AddArc(StateId s, const Arc &arc) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->AddArc(state, arc); + if (arc.nextstate >= nknown_states_) + nknown_states_ = arc.nextstate + 1; + SetExpandedState(s); + static constexpr auto flags = kCacheArcs | kCacheRecent; + state->SetFlags(flags, flags); + } +#endif + + // Adds a single arc to a state but delays cache book-keeping. SetArcs must + // be called when all PushArc and EmplaceArc calls at a state are complete. + // Do not mix with calls to AddArc. + void PushArc(StateId s, const Arc &arc) { + auto *state = cache_store_->GetMutableState(s); + state->PushArc(arc); + } + + void PushArc(StateId s, Arc &&arc) { + auto *state = cache_store_->GetMutableState(s); + state->PushArc(std::move(arc)); + } + + // Adds a single arc to a state but delays cache book-keeping. SetArcs must + // be called when all PushArc and EmplaceArc calls at a state are complete. + // Do not mix with calls to AddArc. + template + void EmplaceArc(StateId s, T &&... ctor_args) { + auto *state = cache_store_->GetMutableState(s); + state->EmplaceArc(std::forward(ctor_args)...); + } + + // Marks arcs of a state as cached and does cache book-keeping after all + // calls to PushArc have been completed. Do not mix with calls to AddArc. + void SetArcs(StateId s) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->SetArcs(state); + const auto narcs = state->NumArcs(); + for (size_t a = 0; a < narcs; ++a) { + const auto &arc = state->GetArc(a); + if (arc.nextstate >= nknown_states_) nknown_states_ = arc.nextstate + 1; + } + SetExpandedState(s); + static constexpr auto flags = kCacheArcs | kCacheRecent; + state->SetFlags(flags, flags); + } + + void ReserveArcs(StateId s, size_t n) { + auto *state = cache_store_->GetMutableState(s); + state->ReserveArcs(n); + } + + void DeleteArcs(StateId s) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->DeleteArcs(state); + } + + void DeleteArcs(StateId s, size_t n) { + auto *state = cache_store_->GetMutableState(s); + cache_store_->DeleteArcs(state, n); + } + + void Clear() { + nknown_states_ = 0; + min_unexpanded_state_id_ = 0; + max_expanded_state_id_ = -1; + has_start_ = false; + cache_start_ = kNoStateId; + cache_store_->Clear(); + } + + // Is the start state cached? + bool HasStart() const { + if (!has_start_ && Properties(kError)) has_start_ = true; + return has_start_; + } + + // Is the final weight of the state cached? + bool HasFinal(StateId s) const { + const auto *state = cache_store_->GetState(s); + if (state && state->Flags() & kCacheFinal) { + state->SetFlags(kCacheRecent, kCacheRecent); + return true; + } else { + return false; + } + } + + // Are arcs of the state cached? + bool HasArcs(StateId s) const { + const auto *state = cache_store_->GetState(s); + if (state && state->Flags() & kCacheArcs) { + state->SetFlags(kCacheRecent, kCacheRecent); + return true; + } else { + return false; + } + } + + StateId Start() const { return cache_start_; } + + Weight Final(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->Final(); + } + + size_t NumArcs(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->NumArcs(); + } + + size_t NumInputEpsilons(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->NumInputEpsilons(); + } + + size_t NumOutputEpsilons(StateId s) const { + const auto *state = cache_store_->GetState(s); + return state->NumOutputEpsilons(); + } + + // Provides information needed for generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + const auto *state = cache_store_->GetState(s); + data->base = nullptr; + data->narcs = state->NumArcs(); + data->arcs = state->Arcs(); + data->ref_count = state->MutableRefCount(); + state->IncrRefCount(); + } + + // Number of known states. + StateId NumKnownStates() const { return nknown_states_; } + + // Updates number of known states, taking into account the passed state ID. + void UpdateNumKnownStates(StateId s) { + if (s >= nknown_states_) nknown_states_ = s + 1; + } + + // Finds the mininum never-expanded state ID. + StateId MinUnexpandedState() const { + while (min_unexpanded_state_id_ <= max_expanded_state_id_ && + ExpandedState(min_unexpanded_state_id_)) { + ++min_unexpanded_state_id_; + } + return min_unexpanded_state_id_; + } + + // Returns maximum ever-expanded state ID. + StateId MaxExpandedState() const { return max_expanded_state_id_; } + + void SetExpandedState(StateId s) { + if (s > max_expanded_state_id_) max_expanded_state_id_ = s; + if (s < min_unexpanded_state_id_) return; + if (s == min_unexpanded_state_id_) ++min_unexpanded_state_id_; + if (cache_gc_ || cache_limit_ == 0) { + if (expanded_states_.size() <= static_cast(s)) + expanded_states_.resize(s + 1, false); + expanded_states_[s] = true; + } + } + + bool ExpandedState(StateId s) const { + if (cache_gc_ || cache_limit_ == 0) { + return expanded_states_[s]; + } else if (new_cache_store_) { + return cache_store_->GetState(s) != nullptr; + } else { + // If the cache was not created by this class, then the cached state needs + // to be inspected to update nknown_states_. + return false; + } + } + + const CacheStore *GetCacheStore() const { return cache_store_; } + + CacheStore *GetCacheStore() { return cache_store_; } + + // Caching on/off switch, limit and size accessors. + + bool GetCacheGc() const { return cache_gc_; } + + size_t GetCacheLimit() const { return cache_limit_; } + + private: + mutable bool has_start_; // Is the start state cached? + StateId cache_start_; // ID of start state. + StateId nknown_states_; // Number of known states. + std::vector expanded_states_; // States that have been expanded. + mutable StateId min_unexpanded_state_id_; // Minimum never-expanded state ID + mutable StateId max_expanded_state_id_; // Maximum ever-expanded state ID + bool cache_gc_; // GC enabled. + size_t cache_limit_; // Number of bytes allowed before GC. + CacheStore *cache_store_; // The store of cached states. + bool new_cache_store_; // Was the store was created by class? + bool own_cache_store_; // Is the store owned by class? + + CacheBaseImpl &operator=(const CacheBaseImpl &impl) = delete; +}; + +// A CacheBaseImpl with the default cache state type. +template +class CacheImpl : public CacheBaseImpl> { + public: + using State = CacheState; + + CacheImpl() {} + + explicit CacheImpl(const CacheOptions &opts) + : CacheBaseImpl>(opts) {} + + CacheImpl(const CacheImpl &impl, bool preserve_cache = false) + : CacheBaseImpl(impl, preserve_cache) {} + + private: + CacheImpl &operator=(const CacheImpl &impl) = delete; +}; + +} // namespace internal + +// Use this to make a state iterator for a CacheBaseImpl-derived FST, which must +// have Arc and Store types defined. Note this iterator only returns those +// states reachable from the initial state, so consider implementing a +// class-specific one. +// +// This class may be derived from. +template +class CacheStateIterator : public StateIteratorBase { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = typename FST::Store; + using State = typename Store::State; + using Impl = internal::CacheBaseImpl; + + CacheStateIterator(const FST &fst, Impl *impl) + : fst_(fst), impl_(impl), s_(0) { + fst_.Start(); // Forces start state. + } + + bool Done() const final { + if (s_ < impl_->NumKnownStates()) return false; + for (StateId u = impl_->MinUnexpandedState(); u < impl_->NumKnownStates(); + u = impl_->MinUnexpandedState()) { + // Forces state expansion. + ArcIterator aiter(fst_, u); + aiter.SetFlags(kArcValueFlags, kArcValueFlags | kArcNoCache); + for (; !aiter.Done(); aiter.Next()) { + impl_->UpdateNumKnownStates(aiter.Value().nextstate); + } + impl_->SetExpandedState(u); + if (s_ < impl_->NumKnownStates()) return false; + } + return true; + } + + StateId Value() const final { return s_; } + + void Next() final { ++s_; } + + void Reset() final { s_ = 0; } + + private: + const FST &fst_; + Impl *impl_; + StateId s_; +}; + +// Used to make an arc iterator for a CacheBaseImpl-derived FST, which must +// have Arc and State types defined. +template +class CacheArcIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = typename FST::Store; + using State = typename Store::State; + using Impl = internal::CacheBaseImpl; + + CacheArcIterator(Impl *impl, StateId s) : i_(0) { + state_ = impl->GetCacheStore()->GetMutableState(s); + state_->IncrRefCount(); + } + + ~CacheArcIterator() { state_->DecrRefCount(); } + + bool Done() const { return i_ >= state_->NumArcs(); } + + const Arc &Value() const { return state_->GetArc(i_); } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + constexpr uint32 Flags() const { return kArcValueFlags; } + + void SetFlags(uint32 flags, uint32 mask) {} + + private: + const State *state_; + size_t i_; + + CacheArcIterator(const CacheArcIterator &) = delete; + CacheArcIterator &operator=(const CacheArcIterator &) = delete; +}; + +// Use this to make a mutable arc iterator for a CacheBaseImpl-derived FST, +// which must have types Arc and Store defined. +template +class CacheMutableArcIterator + : public MutableArcIteratorBase { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = typename FST::Store; + using State = typename Store::State; + using Impl = internal::CacheBaseImpl; + + // User must call MutateCheck() in the constructor. + CacheMutableArcIterator(Impl *impl, StateId s) : i_(0), s_(s), impl_(impl) { + state_ = impl_->GetCacheStore()->GetMutableState(s_); + state_->IncrRefCount(); + } + + ~CacheMutableArcIterator() override { state_->DecrRefCount(); } + + bool Done() const final { return i_ >= state_->NumArcs(); } + + const Arc &Value() const final { return state_->GetArc(i_); } + + void Next() final { ++i_; } + + size_t Position() const final { return i_; } + + void Reset() final { i_ = 0; } + + void Seek(size_t a) final { i_ = a; } + + void SetValue(const Arc &arc) final { state_->SetArc(arc, i_); } + + uint32 Flags() const final { return kArcValueFlags; } + + void SetFlags(uint32, uint32) final {} + + private: + size_t i_; + StateId s_; + Impl *impl_; + State *state_; + + CacheMutableArcIterator(const CacheMutableArcIterator &) = delete; + CacheMutableArcIterator &operator=(const CacheMutableArcIterator &) = delete; +}; + +// Wrap existing CacheStore implementation to use with ExpanderFst. +template +class ExpanderCacheStore { + public: + using State = typename CacheStore::State; + using Arc = typename CacheStore::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit ExpanderCacheStore(const CacheOptions &opts = CacheOptions()) + : store_(opts) {} + + template + State *FindOrExpand(Expander &expander, StateId s) { // NOLINT + auto *state = store_.GetMutableState(s); + if (state->Flags()) { + state->SetFlags(kCacheRecent, kCacheRecent); + } else { + StateBuilder builder(state); + expander.Expand(s, &builder); + state->SetFlags(kCacheFlags, kCacheFlags); + store_.SetArcs(state); + } + return state; + } + + private: + CacheStore store_; + + struct StateBuilder { + State *state; + + explicit StateBuilder(State *state_) : state(state_) {} + + void AddArc(const Arc &arc) { state->PushArc(arc); } + + void AddArc(Arc &&arc) { state->PushArc(std::move(arc)); } + + void SetFinal(Weight weight) { state->SetFinal(std::move(weight)); } + }; +}; + +} // namespace fst + +#endif // FST_CACHE_H_ diff --git a/projects/llm_framework/include/fst/closure.h b/projects/llm_framework/include/fst/closure.h new file mode 100644 index 00000000..13beea9c --- /dev/null +++ b/projects/llm_framework/include/fst/closure.h @@ -0,0 +1,134 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to compute the concatenative closure of an FST. + +#ifndef FST_CLOSURE_H_ +#define FST_CLOSURE_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Computes the concatenative closure. This version modifies its +// MutableFst input. If an FST transduces string x to y with weight a, +// then its closure transduces x to y with weight a, xx to yy with +// weight Times(a, a), xxx to yyy with with Times(Times(a, a), a), +// etc. If closure_type == CLOSURE_STAR, then the empty string is +// transduced to itself with weight Weight::One() as well. +// +// Complexity: +// +// Time: O(V) +// Space: O(V) +// +// where V is the number of states. +template +void Closure(MutableFst *fst, ClosureType closure_type) { + using Weight = typename Arc::Weight; + const auto props = fst->Properties(kFstProperties, false); + const auto start = fst->Start(); + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + const auto s = siter.Value(); + const auto weight = fst->Final(s); + if (weight != Weight::Zero()) fst->AddArc(s, Arc(0, 0, weight, start)); + } + if (closure_type == CLOSURE_STAR) { + fst->ReserveStates(fst->NumStates() + 1); + const auto nstart = fst->AddState(); + fst->SetStart(nstart); + fst->SetFinal(nstart, Weight::One()); + if (start != kNoLabel) fst->AddArc(nstart, Arc(0, 0, Weight::One(), start)); + } + fst->SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR), + kFstProperties); +} + +// Computes the concatenative closure. This version modifies its +// RationalFst input. +template +void Closure(RationalFst *fst, ClosureType closure_type) { + fst->GetMutableImpl()->AddClosure(closure_type); +} + +struct ClosureFstOptions : RationalFstOptions { + ClosureType type; + + ClosureFstOptions(const RationalFstOptions &opts, + ClosureType type = CLOSURE_STAR) + : RationalFstOptions(opts), type(type) {} + + explicit ClosureFstOptions(ClosureType type = CLOSURE_STAR) : type(type) {} +}; + +// Computes the concatenative closure. This version is a delayed FST. If an FST +// transduces string x to y with weight a, then its closure transduces x to y +// with weight a, xx to yy with weight Times(a, a), xxx to yyy with weight +// Times(Times(a, a), a), etc. If closure_type == CLOSURE_STAR, then the empty +// string is transduced to itself with weight Weight::One() as well. +// +// Complexity: +// +// Time: O(v) +// Space: O(v) +// +// where v is the number of states visited. Constant time and space to visit an +// input state or arc is assumed and exclusive of caching. +template +class ClosureFst : public RationalFst { + public: + using Arc = A; + + ClosureFst(const Fst &fst, ClosureType closure_type) { + GetMutableImpl()->InitClosure(fst, closure_type); + } + + ClosureFst(const Fst &fst, const ClosureFstOptions &opts) + : RationalFst(opts) { + GetMutableImpl()->InitClosure(fst, opts.type); + } + + // See Fst<>::Copy() for doc. + ClosureFst(const ClosureFst &fst, bool safe = false) + : RationalFst(fst, safe) {} + + // Gets a copy of this ClosureFst. See Fst<>::Copy() for further doc. + ClosureFst *Copy(bool safe = false) const override { + return new ClosureFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for ClosureFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const ClosureFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for ClosureFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ClosureFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdClosureFst = ClosureFst; + +} // namespace fst + +#endif // FST_CLOSURE_H_ diff --git a/projects/llm_framework/include/fst/compact-fst.h b/projects/llm_framework/include/fst/compact-fst.h new file mode 100644 index 00000000..402c87b7 --- /dev/null +++ b/projects/llm_framework/include/fst/compact-fst.h @@ -0,0 +1,1564 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST Class for memory-efficient representation of common types of +// FSTs: linear automata, acceptors, unweighted FSTs, ... + +#ifndef FST_COMPACT_FST_H_ +#define FST_COMPACT_FST_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include // For optional argument declarations +#include +#include +#include +#include + + +namespace fst { + +struct CompactFstOptions : public CacheOptions { + // The default caching behaviour is to do no caching. Most compactors are + // cheap and therefore we save memory by not doing caching. + CompactFstOptions() : CacheOptions(true, 0) {} + + explicit CompactFstOptions(const CacheOptions &opts) : CacheOptions(opts) {} +}; + +// New upcoming (Fst) Compactor interface - currently used internally +// by CompactFstImpl. +// +// class Compactor { +// public: +// // Constructor from the Fst to be compacted. +// Compactor(const Fst &fst, ...); +// // Copy constructor +// Compactor(const Compactor &compactor, bool safe = false) +// // Default constructor (optional, see comment below). +// Compactor(); +// +// // Returns the start state, number of states, and total number of arcs +// // of the compacted Fst +// StateId Start() const; +// StateId NumStates() const; +// size_t NumArcs() const; +// +// // Accessor class for state attributes. +// class State { +// public: +// State(); // Required, corresponds to kNoStateId. +// State(const Compactor *c, StateId); // Accessor for StateId 's'. +// StateId GetStateId() const; +// Weight Final() const; +// size_t NumArcs() const; +// Arc GetArc(size_t i, uint32 f) const; +// }; +// +// // Modifies 'state' accessor to provide access to state id 's'. +// void SetState(StateId s, State *state); +// // Tests whether 'fst' can be compacted by this compactor. +// bool IsCompatible(const Fst &fst) const; +// // Return the properties that are always true for an fst +// // compacted using this compactor +// uint64 Properties() const; +// // Return a string identifying the type of compactor. +// static const string &Type(); +// // Return true if an error has occured. +// bool Error() const; +// // Writes a compactor to a file. +// bool Write(std::ostream &strm, const FstWriteOptions &opts) const; +// // Reads a compactor from a file. +// static Compactor*Read(std::istream &strm, const FstReadOptions &opts, +// const FstHeader &hdr); +// }; +// + +// Old (Arc) Compactor Interface: +// +// The ArcCompactor class determines how arcs and final weights are compacted +// and expanded. +// +// Final weights are treated as transitions to the superfinal state, i.e., +// ilabel = olabel = kNoLabel and nextstate = kNoStateId. +// +// There are two types of compactors: +// +// * Fixed out-degree compactors: 'compactor.Size()' returns a positive integer +// 's'. An FST can be compacted by this compactor only if each state has +// exactly 's' outgoing transitions (counting a non-Zero() final weight as a +// transition). A typical example is a compactor for string FSTs, i.e., +// 's == 1'. +// +// * Variable out-degree compactors: 'compactor.Size() == -1'. There are no +// out-degree restrictions for these compactors. +// +// Interface: +// +// class ArcCompactor { +// public: +// // Element is the type of the compacted transitions. +// using Element = ... +// +// // Returns the compacted representation of a transition 'arc' +// // at a state 's'. +// Element Compact(StateId s, const Arc &arc); +// +// // Returns the transition at state 's' represented by the compacted +// // transition 'e'. +// Arc Expand(StateId s, const Element &e) const; +// +// // Returns -1 for variable out-degree compactors, and the mandatory +// // out-degree otherwise. +// ssize_t Size() const; +// +// // Tests whether an FST can be compacted by this compactor. +// bool Compatible(const Fst &fst) const; +// +// // Returns the properties that are always true for an FST compacted using +// // this compactor +// uint64 Properties() const; +// +// // Returns a string identifying the type of compactor. +// static const string &Type(); +// +// // Writes a compactor to a file. +// bool Write(std::ostream &strm) const; +// +// // Reads a compactor from a file. +// static ArcCompactor *Read(std::istream &strm); +// +// // Default constructor (optional, see comment below). +// ArcCompactor(); +// }; +// +// The default constructor is only required for FST_REGISTER to work (i.e., +// enabling Convert() and the command-line utilities to work with this new +// compactor). However, a default constructor always needs to be specified for +// this code to compile, but one can have it simply raise an error when called, +// like so: +// +// Compactor::Compactor() { +// FSTERROR() << "Compactor: No default constructor"; +// } + +// Default implementation data for CompactFst, which can shared between +// otherwise independent copies. +// +// The implementation contains two arrays: 'states_' and 'compacts_'. +// +// For fixed out-degree compactors, the 'states_' array is unallocated. The +// 'compacts_' contains the compacted transitions. Its size is 'ncompacts_'. +// The outgoing transitions at a given state are stored consecutively. For a +// given state 's', its 'compactor.Size()' outgoing transitions (including +// superfinal transition when 's' is final), are stored in position +// ['s*compactor.Size()', '(s+1)*compactor.Size()'). +// +// For variable out-degree compactors, the states_ array has size +// 'nstates_ + 1' and contains pointers to positions into 'compacts_'. For a +// given state 's', the compacted transitions of 's' are stored in positions +// ['states_[s]', 'states_[s + 1]') in 'compacts_'. By convention, +// 'states_[nstates_] == ncompacts_'. +// +// In both cases, the superfinal transitions (when 's' is final, i.e., +// 'Final(s) != Weight::Zero()') are stored first. +// +// The unsigned type U is used to represent indices into the compacts_ array. +template +class DefaultCompactStore { + public: + DefaultCompactStore() + : states_(nullptr), + compacts_(nullptr), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) {} + + template + DefaultCompactStore(const Fst &fst, const Compactor &compactor); + + template + DefaultCompactStore(const Iterator &begin, const Iterator &end, + const Compactor &compactor); + + ~DefaultCompactStore() { + if (!states_region_) delete[] states_; + if (!compacts_region_) delete[] compacts_; + } + + template + static DefaultCompactStore *Read( + std::istream &strm, const FstReadOptions &opts, const FstHeader &hdr, + const Compactor &compactor); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const; + + Unsigned States(ssize_t i) const { return states_[i]; } + + const Element &Compacts(size_t i) const { return compacts_[i]; } + + size_t NumStates() const { return nstates_; } + + size_t NumCompacts() const { return ncompacts_; } + + size_t NumArcs() const { return narcs_; } + + ssize_t Start() const { return start_; } + + bool Error() const { return error_; } + + // Returns a string identifying the type of data storage container. + static const string &Type(); + + private: + std::unique_ptr states_region_; + std::unique_ptr compacts_region_; + Unsigned *states_; + Element *compacts_; + size_t nstates_; + size_t ncompacts_; + size_t narcs_; + ssize_t start_; + bool error_; +}; + +template +template +DefaultCompactStore::DefaultCompactStore( + const Fst &fst, const Compactor &compactor) + : states_(nullptr), + compacts_(nullptr), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + start_ = fst.Start(); + // Counts # of states and arcs. + StateId nfinals = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates_; + const auto s = siter.Value(); + narcs_ += fst.NumArcs(s); + if (fst.Final(s) != Weight::Zero()) ++nfinals; + } + if (compactor.Size() == -1) { + states_ = new Unsigned[nstates_ + 1]; + ncompacts_ = narcs_ + nfinals; + compacts_ = new Element[ncompacts_]; + states_[nstates_] = ncompacts_; + } else { + states_ = nullptr; + ncompacts_ = nstates_ * compactor.Size(); + if ((narcs_ + nfinals) != ncompacts_) { + FSTERROR() << "DefaultCompactStore: Compactor incompatible with FST"; + error_ = true; + return; + } + compacts_ = new Element[ncompacts_]; + } + size_t pos = 0; + size_t fpos = 0; + for (size_t s = 0; s < nstates_; ++s) { + fpos = pos; + if (compactor.Size() == -1) states_[s] = pos; + if (fst.Final(s) != Weight::Zero()) { + compacts_[pos++] = compactor.Compact( + s, Arc(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + } + for (ArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { + compacts_[pos++] = compactor.Compact(s, aiter.Value()); + } + if ((compactor.Size() != -1) && (pos != fpos + compactor.Size())) { + FSTERROR() << "DefaultCompactStore: Compactor incompatible with FST"; + error_ = true; + return; + } + } + if (pos != ncompacts_) { + FSTERROR() << "DefaultCompactStore: Compactor incompatible with FST"; + error_ = true; + return; + } +} + +template +template +DefaultCompactStore::DefaultCompactStore( + const Iterator &begin, const Iterator &end, const Compactor &compactor) + : states_(nullptr), + compacts_(nullptr), + nstates_(0), + ncompacts_(0), + narcs_(0), + start_(kNoStateId), + error_(false) { + using Arc = typename Compactor::Arc; + using Weight = typename Arc::Weight; + if (compactor.Size() != -1) { + ncompacts_ = std::distance(begin, end); + if (compactor.Size() == 1) { + // For strings, allows implicit final weight. Empty input is the empty + // string. + if (ncompacts_ == 0) { + ++ncompacts_; + } else { + const auto arc = + compactor.Expand(ncompacts_ - 1, *(begin + (ncompacts_ - 1))); + if (arc.ilabel != kNoLabel) ++ncompacts_; + } + } + if (ncompacts_ % compactor.Size()) { + FSTERROR() << "DefaultCompactStore: Size of input container incompatible" + << " with compactor"; + error_ = true; + return; + } + if (ncompacts_ == 0) return; + start_ = 0; + nstates_ = ncompacts_ / compactor.Size(); + compacts_ = new Element[ncompacts_]; + size_t i = 0; + Iterator it = begin; + for (; it != end; ++it, ++i) { + compacts_[i] = *it; + if (compactor.Expand(i, *it).ilabel != kNoLabel) ++narcs_; + } + if (i < ncompacts_) { + compacts_[i] = compactor.Compact( + i, Arc(kNoLabel, kNoLabel, Weight::One(), kNoStateId)); + } + } else { + if (std::distance(begin, end) == 0) return; + // Count # of states, arcs and compacts. + auto it = begin; + for (size_t i = 0; it != end; ++it, ++i) { + const auto arc = compactor.Expand(i, *it); + if (arc.ilabel != kNoLabel) { + ++narcs_; + ++ncompacts_; + } else { + ++nstates_; + if (arc.weight != Weight::Zero()) ++ncompacts_; + } + } + start_ = 0; + compacts_ = new Element[ncompacts_]; + states_ = new Unsigned[nstates_ + 1]; + states_[nstates_] = ncompacts_; + size_t i = 0; + size_t s = 0; + for (it = begin; it != end; ++it) { + const auto arc = compactor.Expand(i, *it); + if (arc.ilabel != kNoLabel) { + compacts_[i++] = *it; + } else { + states_[s++] = i; + if (arc.weight != Weight::Zero()) compacts_[i++] = *it; + } + } + if ((s != nstates_) || (i != ncompacts_)) { + FSTERROR() << "DefaultCompactStore: Ill-formed input container"; + error_ = true; + return; + } + } +} + +template +template +DefaultCompactStore + *DefaultCompactStore::Read(std::istream &strm, + const FstReadOptions &opts, + const FstHeader &hdr, + const Compactor &compactor) { + std::unique_ptr> data( + new DefaultCompactStore()); + data->start_ = hdr.Start(); + data->nstates_ = hdr.NumStates(); + data->narcs_ = hdr.NumArcs(); + if (compactor.Size() == -1) { + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Read: Alignment failed: " + << opts.source; + return nullptr; + } + auto b = (data->nstates_ + 1) * sizeof(Unsigned); + data->states_region_.reset(MappedFile::Map( + &strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !data->states_region_) { + LOG(ERROR) << "DefaultCompactStore::Read: Read failed: " << opts.source; + return nullptr; + } + data->states_ = + static_cast(data->states_region_->mutable_data()); + } else { + data->states_ = nullptr; + } + data->ncompacts_ = compactor.Size() == -1 ? data->states_[data->nstates_] + : data->nstates_ * compactor.Size(); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Read: Alignment failed: " + << opts.source; + return nullptr; + } + size_t b = data->ncompacts_ * sizeof(Element); + data->compacts_region_.reset( + MappedFile::Map(&strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !data->compacts_region_) { + LOG(ERROR) << "DefaultCompactStore::Read: Read failed: " << opts.source; + return nullptr; + } + data->compacts_ = + static_cast(data->compacts_region_->mutable_data()); + return data.release(); +} + +template +bool DefaultCompactStore::Write( + std::ostream &strm, const FstWriteOptions &opts) const { + if (states_) { + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Write: Alignment failed: " + << opts.source; + return false; + } + strm.write(reinterpret_cast(states_), + (nstates_ + 1) * sizeof(Unsigned)); + } + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "DefaultCompactStore::Write: Alignment failed: " + << opts.source; + return false; + } + strm.write(reinterpret_cast(compacts_), ncompacts_ * sizeof(Element)); + strm.flush(); + if (!strm) { + LOG(ERROR) << "DefaultCompactStore::Write: Write failed: " << opts.source; + return false; + } + return true; +} + +template +const string &DefaultCompactStore::Type() { + static const string *const type = new string("compact"); + return *type; +} + +template class DefaultCompactState; + +// Wraps an arc compactor and a compact store as a new Fst compactor. +template > +class DefaultCompactor { + public: + using ArcCompactor = C; + using Unsigned = U; + using CompactStore = S; + using Element = typename C::Element; + using Arc = typename C::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using State = DefaultCompactState; + friend State; + + DefaultCompactor() + : arc_compactor_(nullptr), compact_store_(nullptr) {} + + // Constructs from Fst. + DefaultCompactor(const Fst &fst, + std::shared_ptr arc_compactor) + : arc_compactor_(std::move(arc_compactor)), + compact_store_(std::make_shared(fst, *arc_compactor_)) {} + + DefaultCompactor(const Fst &fst, + std::shared_ptr> compactor) + : arc_compactor_(compactor->arc_compactor_), + compact_store_(compactor->compact_store_ == nullptr ? + std::make_shared(fst, *arc_compactor_) : + compactor->compact_store_) {} + + // Constructs from CompactStore. + DefaultCompactor(std::shared_ptr arc_compactor, + std::shared_ptr compact_store) + : arc_compactor_(std::move(arc_compactor)), + compact_store_(std::move(compact_store)) {} + + // Constructs from set of compact elements (when arc_compactor.Size() != -1). + template + DefaultCompactor(const Iterator &b, const Iterator &e, + std::shared_ptr arc_compactor) + : arc_compactor_(std::move(arc_compactor)), + compact_store_(std::make_shared(b, e, *arc_compactor_)) {} + + // Copy constructor. + DefaultCompactor(const DefaultCompactor &compactor) + : arc_compactor_(std::make_shared(*compactor.GetArcCompactor())), + compact_store_(compactor.SharedCompactStore()) {} + + template + explicit DefaultCompactor(const DefaultCompactor &compactor) + : arc_compactor_(std::make_shared(*compactor.GetArcCompactor())), + compact_store_(compactor.SharedCompactStore()) {} + + StateId Start() const { return compact_store_->Start(); } + StateId NumStates() const { return compact_store_->NumStates(); } + size_t NumArcs() const { return compact_store_->NumArcs(); } + + void SetState(StateId s, State *state) const { + if (state->GetStateId() != s) state->Set(this, s); + } + + static DefaultCompactor *Read(std::istream &strm, + const FstReadOptions &opts, + const FstHeader &hdr) { + std::shared_ptr arc_compactor(C::Read(strm)); + if (arc_compactor == nullptr) return nullptr; + std::shared_ptr compact_store(S::Read(strm, opts, hdr, *arc_compactor)); + if (compact_store == nullptr) return nullptr; + return new DefaultCompactor(arc_compactor, compact_store); + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + return arc_compactor_->Write(strm) && compact_store_->Write(strm, opts); + } + + uint64 Properties() const { return arc_compactor_->Properties(); } + + bool IsCompatible(const Fst &fst) const { + return arc_compactor_->Compatible(fst); + } + + bool Error() const { return compact_store_->Error(); } + + bool HasFixedOutdegree() const { return arc_compactor_->Size() != -1; } + + static const string &Type() { + static const string *const type = [] { + string type = "compact"; + if (sizeof(U) != sizeof(uint32)) type += std::to_string(8 * sizeof(U)); + type += "_"; + type += C::Type(); + if (CompactStore::Type() != "compact") { + type += "_"; + type += CompactStore::Type(); + } + return new string(type); + }(); + return *type; + } + + const ArcCompactor *GetArcCompactor() const { return arc_compactor_.get(); } + CompactStore *GetCompactStore() const { return compact_store_.get(); } + + std::shared_ptr SharedArcCompactor() const { + return arc_compactor_; + } + + std::shared_ptr SharedCompactStore() const { + return compact_store_; + } + + // TODO(allauzen): remove dependencies on this method and make private. + Arc ComputeArc(StateId s, Unsigned i, uint32 f) const { + return arc_compactor_->Expand(s, compact_store_->Compacts(i), f); + } + + private: + std::pair CompactsRange(StateId s) const { + std::pair range; + if (HasFixedOutdegree()) { + range.first = s * arc_compactor_->Size(); + range.second = arc_compactor_->Size(); + } else { + range.first = compact_store_->States(s); + range.second = compact_store_->States(s + 1) - range.first; + } + return range; + } + + private: + std::shared_ptr arc_compactor_; + std::shared_ptr compact_store_; +}; + +// Default implementation of state attributes accessor class for +// DefaultCompactor. Use of efficient specialization strongly encouraged. +template +class DefaultCompactState { + public: + using Arc = typename C::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + DefaultCompactState() = default; + + DefaultCompactState(const DefaultCompactor *compactor, StateId s) + : compactor_(compactor), + s_(s), + range_(compactor->CompactsRange(s)), + has_final_( + range_.second != 0 && + compactor->ComputeArc(s, range_.first, + kArcILabelValue).ilabel == kNoLabel) { + if (has_final_) { + ++range_.first; + --range_.second; + } + } + + void Set(const DefaultCompactor *compactor, StateId s) { + compactor_ = compactor; + s_ = s; + range_ = compactor->CompactsRange(s); + if (range_.second != 0 && + compactor->ComputeArc(s, range_.first, kArcILabelValue).ilabel + == kNoLabel) { + has_final_ = true; + ++range_.first; + --range_.second; + } else { + has_final_ = false; + } + } + + StateId GetStateId() const { return s_; } + + Weight Final() const { + if (!has_final_) return Weight::Zero(); + return compactor_->ComputeArc(s_, range_.first - 1, kArcWeightValue).weight; + } + + size_t NumArcs() const { return range_.second; } + + Arc GetArc(size_t i, uint32 f) const { + return compactor_->ComputeArc(s_, range_.first + i, f); + } + + private: + const DefaultCompactor *compactor_ = nullptr; // borrowed ref. + StateId s_ = kNoStateId; + std::pair range_ = {0, 0}; + bool has_final_ = false; +}; + +// Specialization for DefaultCompactStore. +template +class DefaultCompactState> { + public: + using Arc = typename C::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using CompactStore = DefaultCompactStore; + + DefaultCompactState() = default; + + DefaultCompactState( + const DefaultCompactor *compactor, StateId s) + : arc_compactor_(compactor->GetArcCompactor()), s_(s) { + Init(compactor); + } + + void Set(const DefaultCompactor *compactor, StateId s) { + arc_compactor_ = compactor->GetArcCompactor(); + s_ = s; + has_final_ = false; + Init(compactor); + } + + StateId GetStateId() const { return s_; } + + Weight Final() const { + if (!has_final_) return Weight::Zero(); + return arc_compactor_->Expand(s_, *(compacts_ - 1), kArcWeightValue).weight; + } + + size_t NumArcs() const { return num_arcs_; } + + Arc GetArc(size_t i, uint32 f) const { + return arc_compactor_->Expand(s_, compacts_[i], f); + } + + private: + void Init(const DefaultCompactor *compactor) { + const auto *store = compactor->GetCompactStore(); + U offset; + if (!compactor->HasFixedOutdegree()) { // Variable out-degree compactor. + offset = store->States(s_); + num_arcs_ = store->States(s_ + 1) - offset; + } else { // Fixed out-degree compactor. + offset = s_ * arc_compactor_->Size(); + num_arcs_ = arc_compactor_->Size(); + } + if (num_arcs_ > 0) { + compacts_ = &(store->Compacts(offset)); + if (arc_compactor_->Expand(s_, *compacts_, kArcILabelValue).ilabel + == kNoStateId) { + ++compacts_; + --num_arcs_; + has_final_ = true; + } + } + } + + private: + const C *arc_compactor_ = nullptr; // Borrowed reference. + const typename C::Element *compacts_ = nullptr; // Borrowed reference. + StateId s_ = kNoStateId; + U num_arcs_ = 0; + bool has_final_ = false; +}; + +template +class CompactFst; + +template +void Cast(const F &, G *); + +namespace internal { + +// Implementation class for CompactFst, which contains parametrizeable +// Fst data storage (DefaultCompactStore by default) and Fst cache. +template > +class CompactFstImpl + : public CacheBaseImpl { + public: + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + using Compactor = C; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + using ImplBase = CacheBaseImpl; + using ImplBase::PushArc; + using ImplBase::HasArcs; + using ImplBase::HasFinal; + using ImplBase::HasStart; + using ImplBase::SetArcs; + using ImplBase::SetFinal; + using ImplBase::SetStart; + + CompactFstImpl() + : ImplBase(CompactFstOptions()), + compactor_() { + SetType(Compactor::Type()); + SetProperties(kNullProperties | kStaticProperties); + } + + CompactFstImpl(const Fst &fst, std::shared_ptr compactor, + const CompactFstOptions &opts) + : ImplBase(opts), + compactor_(std::make_shared(fst, compactor)) { + SetType(Compactor::Type()); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + if (compactor_->Error()) SetProperties(kError, kError); + uint64 copy_properties = fst.Properties(kMutable, false) ? + fst.Properties(kCopyProperties, true): + CheckProperties(fst, + kCopyProperties & ~kWeightedCycles & ~kUnweightedCycles, + kCopyProperties); + if ((copy_properties & kError) || !compactor_->IsCompatible(fst)) { + FSTERROR() << "CompactFstImpl: Input Fst incompatible with compactor"; + SetProperties(kError, kError); + return; + } + SetProperties(copy_properties | kStaticProperties); + } + + CompactFstImpl(std::shared_ptr compactor, + const CompactFstOptions &opts) + : ImplBase(opts), + compactor_(compactor) { + SetType(Compactor::Type()); + SetProperties(kStaticProperties | compactor_->Properties()); + if (compactor_->Error()) SetProperties(kError, kError); + } + + CompactFstImpl(const CompactFstImpl &impl) + : ImplBase(impl), + compactor_(impl.compactor_ == nullptr ? + std::make_shared() : + std::make_shared(*impl.compactor_)) { + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + // Allows to change the cache store from OtherI to I. + template + CompactFstImpl(const CompactFstImpl &impl) + : ImplBase(CacheOptions(impl.GetCacheGc(), impl.GetCacheLimit())), + compactor_(impl.compactor_ == nullptr ? + std::make_shared() : + std::make_shared(*impl.compactor_)) { + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) SetStart(compactor_->Start()); + return ImplBase::Start(); + } + + Weight Final(StateId s) { + if (HasFinal(s)) return ImplBase::Final(s); + compactor_->SetState(s, &state_); + return state_.Final(); + } + + StateId NumStates() const { + if (Properties(kError)) return 0; + return compactor_->NumStates(); + } + + size_t NumArcs(StateId s) { + if (HasArcs(s)) return ImplBase::NumArcs(s); + compactor_->SetState(s, &state_); + return state_.NumArcs(); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s) && !Properties(kILabelSorted)) Expand(s); + if (HasArcs(s)) return ImplBase::NumInputEpsilons(s); + return CountEpsilons(s, false); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s) && !Properties(kOLabelSorted)) Expand(s); + if (HasArcs(s)) return ImplBase::NumOutputEpsilons(s); + return CountEpsilons(s, true); + } + + size_t CountEpsilons(StateId s, bool output_epsilons) { + compactor_->SetState(s, &state_); + const uint32 f = output_epsilons ? kArcOLabelValue : kArcILabelValue; + size_t num_eps = 0; + for (size_t i = 0; i < state_.NumArcs(); ++i) { + const auto& arc = state_.GetArc(i, f); + const auto label = output_epsilons ? arc.olabel : arc.ilabel; + if (label == 0) + ++num_eps; + else if (label > 0) + break; + } + return num_eps; + } + + static CompactFstImpl *Read( + std::istream &strm, const FstReadOptions &opts) { + std::unique_ptr> impl( + new CompactFstImpl()); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) { + return nullptr; + } + // Ensures compatibility. + if (hdr.Version() == kAlignedFileVersion) { + hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); + } + impl->compactor_ = std::shared_ptr( + Compactor::Read(strm, opts, hdr)); + if (!impl->compactor_) { + return nullptr; + } + return impl.release(); + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(compactor_->Start()); + hdr.SetNumStates(compactor_->NumStates()); + hdr.SetNumArcs(compactor_->NumArcs()); + // Ensures compatibility. + const auto file_version = opts.align ? kAlignedFileVersion : kFileVersion; + WriteHeader(strm, opts, file_version, &hdr); + return compactor_->Write(strm, opts); + } + + // Provides information needed for generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = compactor_->NumStates(); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + ImplBase::InitArcIterator(s, data); + } + + void Expand(StateId s) { + compactor_->SetState(s, &state_); + for (size_t i = 0; i < state_.NumArcs(); ++i) + PushArc(s, state_.GetArc(i, kArcValueFlags)); + SetArcs(s); + if (!HasFinal(s)) SetFinal(s, state_.Final()); + } + + const Compactor *GetCompactor() const { return compactor_.get(); } + std::shared_ptr SharedCompactor() const { return compactor_; } + void SetCompactor(std::shared_ptr compactor) { + // TODO(allauzen): is this correct? is this needed? + // TODO(allauzen): consider removing and forcing this through direct calls + // to compactor. + compactor_ = compactor; + } + + // Properties always true of this FST class. + static constexpr uint64 kStaticProperties = kExpanded; + + protected: + template + explicit CompactFstImpl( + const CompactFstImpl &impl) + : compactor_(std::make_shared(*impl.GetCompactor())) { + SetType(impl.Type()); + SetProperties(impl.Properties()); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + private: + // Allows access during write. + template + friend class ::fst::CompactFst; // allow access during write. + + // Current unaligned file format version. + static constexpr int kFileVersion = 2; + // Current aligned file format version. + static constexpr int kAlignedFileVersion = 1; + // Minimum file format version supported. + static constexpr int kMinFileVersion = 1; + + std::shared_ptr compactor_; + typename Compactor::State state_; +}; + +template +constexpr uint64 CompactFstImpl::kStaticProperties; + +template +constexpr int CompactFstImpl::kFileVersion; + +template +constexpr int CompactFstImpl::kAlignedFileVersion; + +template +constexpr int CompactFstImpl::kMinFileVersion; + +} // namespace internal + +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToExpandedFst. The Unsigned type +// is used to represent indices into the compact arc array. (Template +// argument defaults are declared in fst-decl.h.) +template +class CompactFst + : public ImplToExpandedFst, + CacheStore>> { + public: + template + void friend Cast(const F &, G *); + + using Arc = A; + using StateId = typename A::StateId; + using Compactor = DefaultCompactor; + using Impl = internal::CompactFstImpl; + using Store = CacheStore; // for CacheArcIterator + + friend class StateIterator< + CompactFst>; + friend class ArcIterator< + CompactFst>; + + CompactFst() : ImplToExpandedFst(std::make_shared()) {} + + // If data is not nullptr, it is assumed to be already initialized. + explicit CompactFst( + const Fst &fst, + const ArcCompactor &compactor = ArcCompactor(), + const CompactFstOptions &opts = CompactFstOptions(), + std::shared_ptr data = std::shared_ptr()) + : ImplToExpandedFst( + std::make_shared( + fst, + std::make_shared( + std::make_shared(compactor), data), + opts)) {} + + // If data is not nullptr, it is assumed to be already initialized. + CompactFst( + const Fst &fst, + std::shared_ptr compactor, + const CompactFstOptions &opts = CompactFstOptions(), + std::shared_ptr data = std::shared_ptr()) + : ImplToExpandedFst( + std::make_shared(fst, + std::make_shared(compactor, data), + opts)) {} + + // The following 2 constructors take as input two iterators delimiting a set + // of (already) compacted transitions, starting with the transitions out of + // the initial state. The format of the input differs for fixed out-degree + // and variable out-degree compactors. + // + // - For fixed out-degree compactors, the final weight (encoded as a + // compacted transition) needs to be given only for final states. All strings + // (compactor of size 1) will be assume to be terminated by a final state + // even when the final state is not implicitely given. + // + // - For variable out-degree compactors, the final weight (encoded as a + // compacted transition) needs to be given for all states and must appeared + // first in the list (for state s, final weight of s, followed by outgoing + // transitons in s). + // + // These 2 constructors allows the direct construction of a CompactFst + // without first creating a more memory-hungry regular FST. This is useful + // when memory usage is severely constrained. + template + explicit CompactFst(const Iterator &begin, const Iterator &end, + const ArcCompactor &compactor = ArcCompactor(), + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst( + std::make_shared( + std::make_shared( + begin, end, std::make_shared(compactor)), + opts)) {} + + template + CompactFst(const Iterator &begin, const Iterator &end, + std::shared_ptr compactor, + const CompactFstOptions &opts = CompactFstOptions()) + : ImplToExpandedFst( + std::make_shared( + std::make_shared(begin, end, compactor), opts)) {} + + // See Fst<>::Copy() for doc. + CompactFst( + const CompactFst + &fst, + bool safe = false) + : ImplToExpandedFst(fst, safe) {} + + // Get a copy of this CompactFst. See Fst<>::Copy() for further doc. + CompactFst *Copy( + bool safe = false) const override { + return new CompactFst( + *this, safe); + } + + // Read a CompactFst from an input stream; return nullptr on error + static CompactFst *Read( + std::istream &strm, const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new CompactFst(std::shared_ptr(impl)) + : nullptr; + } + + // Read a CompactFst from a file; return nullptr on error + // Empty filename reads from standard input + static CompactFst *Read( + const string &filename) { + auto *impl = ImplToExpandedFst::Read(filename); + return impl ? new CompactFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + template + static bool WriteFst(const FST &fst, const ArcCompactor &compactor, + std::ostream &strm, const FstWriteOptions &opts); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new SortedMatcher< + CompactFst>( + *this, match_type); + } + + template + void SetCompactElements(const Iterator &b, const Iterator &e) { + GetMutableImpl()->SetCompactor(std::make_shared( + b, e, std::make_shared())); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; + + explicit CompactFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + // Use overloading to extract the type of the argument. + static Impl *GetImplIfCompactFst( + const CompactFst + &compact_fst) { + return compact_fst.GetImpl(); + } + + // This does not give privileged treatment to subclasses of CompactFst. + template + static Impl *GetImplIfCompactFst(const NonCompactFst &fst) { + return nullptr; + } + + CompactFst &operator=(const CompactFst &fst) = delete; +}; + +// Writes FST in Compact format, with a possible pass over the machine before +// writing to compute the number of states and arcs. +template +template +bool CompactFst::WriteFst( + const FST &fst, const ArcCompactor &compactor, std::ostream &strm, + const FstWriteOptions &opts) { + using Arc = A; + using Weight = typename A::Weight; + using Element = typename ArcCompactor::Element; + const auto file_version = + opts.align ? Impl::kAlignedFileVersion : Impl::kFileVersion; + size_t num_arcs = -1; + size_t num_states = -1; + auto first_pass_compactor = compactor; + if (auto *impl = GetImplIfCompactFst(fst)) { + num_arcs = impl->GetCompactor()->GetCompactStore()->NumArcs(); + num_states = impl->GetCompactor()->GetCompactStore()->NumStates(); + first_pass_compactor = *impl->GetCompactor()->GetArcCompactor(); + } else { + // A first pass is needed to compute the state of the compactor, which + // is saved ahead of the rest of the data structures. This unfortunately + // means forcing a complete double compaction when writing in this format. + // TODO(allauzen): eliminate mutable state from compactors. + num_arcs = 0; + num_states = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + ++num_states; + if (fst.Final(s) != Weight::Zero()) { + first_pass_compactor.Compact( + s, Arc(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + } + for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { + ++num_arcs; + first_pass_compactor.Compact(s, aiter.Value()); + } + } + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "compact"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + type += "_"; + type += ArcCompactor::Type(); + if (CompactStore::Type() != "compact") { + type += "_"; + type += CompactStore::Type(); + } + const auto copy_properties = fst.Properties(kCopyProperties, true); + if ((copy_properties & kError) || !compactor.Compatible(fst)) { + FSTERROR() << "Fst incompatible with compactor"; + return false; + } + uint64 properties = copy_properties | Impl::kStaticProperties; + internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, type, + properties, &hdr); + first_pass_compactor.Write(strm); + if (first_pass_compactor.Size() == -1) { + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "CompactFst::Write: Alignment failed: " << opts.source; + return false; + } + Unsigned compacts = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + strm.write(reinterpret_cast(&compacts), sizeof(compacts)); + if (fst.Final(s) != Weight::Zero()) { + ++compacts; + } + compacts += fst.NumArcs(s); + } + strm.write(reinterpret_cast(&compacts), sizeof(compacts)); + } + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + const auto &second_pass_compactor = compactor; + Element element; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) { + element = second_pass_compactor.Compact( + s, A(kNoLabel, kNoLabel, fst.Final(s), kNoStateId)); + strm.write(reinterpret_cast(&element), sizeof(element)); + } + for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { + element = second_pass_compactor.Compact(s, aiter.Value()); + strm.write(reinterpret_cast(&element), sizeof(element)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "CompactFst write failed: " << opts.source; + return false; + } + return true; +} + +// Specialization for CompactFst; see generic version in fst.h for sample +// usage (but use the CompactFst type!). This version should inline. +template +class StateIterator< + CompactFst> { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator( + const CompactFst &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + StateId nstates_; + StateId s_; +}; + +// Specialization for CompactFst. Never caches, +// always iterates over the underlying compact elements. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + using Element = typename ArcCompactor::Element; + using Compactor = DefaultCompactor; + using State = typename Compactor::State; + + ArcIterator(const CompactFst &fst, + StateId s) + : state_(fst.GetImpl()->GetCompactor(), s), + pos_(0), + flags_(kArcValueFlags) {} + + bool Done() const { return pos_ >= state_.NumArcs(); } + + const Arc &Value() const { + arc_ = state_.GetArc(pos_, flags_); + return arc_; + } + + void Next() { ++pos_; } + + size_t Position() const { return pos_; } + + void Reset() { pos_ = 0; } + + void Seek(size_t pos) { pos_ = pos; } + + uint32 Flags() const { return flags_; } + + void SetFlags(uint32 f, uint32 m) { + flags_ &= ~m; + flags_ |= (f & kArcValueFlags); + } + + private: + State state_; + size_t pos_; + mutable Arc arc_; + uint32 flags_; +}; + +// ArcCompactor for unweighted string FSTs. +template +class StringCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = Label; + + Element Compact(StateId s, const Arc &arc) const { return arc.ilabel; } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p, p, Weight::One(), p != kNoLabel ? s + 1 : kNoStateId); + } + + constexpr ssize_t Size() const { return 1; } + + constexpr uint64 Properties() const { + return kString | kAcceptor | kUnweighted; + } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("string"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static StringCompactor *Read(std::istream &strm) { + return new StringCompactor; + } +}; + +// ArcCompactor for weighted string FSTs. +template +class WeightedStringCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(arc.ilabel, arc.weight); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first, p.first, p.second, + p.first != kNoLabel ? s + 1 : kNoStateId); + } + + constexpr ssize_t Size() const { return 1; } + + constexpr uint64 Properties() const { return kString | kAcceptor; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("weighted_string"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static WeightedStringCompactor *Read(std::istream &strm) { + return new WeightedStringCompactor; + } +}; + +// ArcCompactor for unweighted acceptor FSTs. +template +class UnweightedAcceptorCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(arc.ilabel, arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first, p.first, Weight::One(), p.second); + } + + constexpr ssize_t Size() const { return -1; } + + constexpr uint64 Properties() const { return kAcceptor | kUnweighted; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("unweighted_acceptor"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static UnweightedAcceptorCompactor *Read(std::istream &istrm) { + return new UnweightedAcceptorCompactor; + } +}; + +// ArcCompactor for weighted acceptor FSTs. +template +class AcceptorCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair, StateId>; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(std::make_pair(arc.ilabel, arc.weight), + arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first.first, p.first.first, p.first.second, p.second); + } + + constexpr ssize_t Size() const { return -1; } + + constexpr uint64 Properties() const { return kAcceptor; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("acceptor"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static AcceptorCompactor *Read(std::istream &strm) { + return new AcceptorCompactor; + } +}; + +// ArcCompactor for unweighted FSTs. +template +class UnweightedCompactor { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Element = std::pair, StateId>; + + Element Compact(StateId s, const Arc &arc) const { + return std::make_pair(std::make_pair(arc.ilabel, arc.olabel), + arc.nextstate); + } + + Arc Expand(StateId s, const Element &p, uint32 f = kArcValueFlags) const { + return Arc(p.first.first, p.first.second, Weight::One(), p.second); + } + + constexpr ssize_t Size() const { return -1; } + + constexpr uint64 Properties() const { return kUnweighted; } + + bool Compatible(const Fst &fst) const { + const auto props = Properties(); + return fst.Properties(props, true) == props; + } + + static const string &Type() { + static const string *const type = new string("unweighted"); + return *type; + } + + bool Write(std::ostream &strm) const { return true; } + + static UnweightedCompactor *Read(std::istream &strm) { + return new UnweightedCompactor; + } +}; + +template +using CompactStringFst = CompactFst, Unsigned>; + +template +using CompactWeightedStringFst = + CompactFst, Unsigned>; + +template +using CompactAcceptorFst = CompactFst, Unsigned>; + +template +using CompactUnweightedFst = + CompactFst, Unsigned>; + +template +using CompactUnweightedAcceptorFst = + CompactFst, Unsigned>; + +using StdCompactStringFst = CompactStringFst; + +using StdCompactWeightedStringFst = CompactWeightedStringFst; + +using StdCompactAcceptorFst = CompactAcceptorFst; + +using StdCompactUnweightedFst = CompactUnweightedFst; + +using StdCompactUnweightedAcceptorFst = + CompactUnweightedAcceptorFst; + +} // namespace fst + +#endif // FST_COMPACT_FST_H_ diff --git a/projects/llm_framework/include/fst/compat.h b/projects/llm_framework/include/fst/compat.h new file mode 100644 index 00000000..73ed5737 --- /dev/null +++ b/projects/llm_framework/include/fst/compat.h @@ -0,0 +1,130 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_LIB_COMPAT_H_ +#define FST_LIB_COMPAT_H_ + +#include +#include +#include +#include +#include +#include + +// Makes copy constructor and operator= private +// Deprecated: now just use =delete. +#define DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator=(const type&) + +#if defined(__GNUC__) || defined(__clang__) +#define OPENFST_DEPRECATED(message) __attribute__((deprecated(message))) +#elif defined(_MSC_VER) +#define OPENFST_DEPRECATED(message) __declspec(deprecated(message)) +#else +#define OPENFST_DEPRECATED(message) +#endif + +#include +#include +#include +#include +#include +#include + +using std::string; + +void FailedNewHandler(); + +#ifdef _MSC_VER +#include +const char* basename(const char* path); +#define __builtin_popcount __popcnt + +#ifdef _M_X64 +// Using 64-bit MSVC intrinsics. +#define __builtin_popcountll __popcnt64 +inline unsigned int __builtin_ctzll(std::uint64_t w) { + unsigned long v; + return _BitScanForward64(&v, std::uint32_t(w)) ? v : 0; +} +#else +// Using 32-bit MSVC intrinsics. +inline unsigned int __builtin_popcountll(std::uint64_t w) { + return __popcnt(std::uint32_t(w)) + __popcnt(std::uint32_t(w >> 32)); +} +inline unsigned int __builtin_ctzll(std::uint64_t w) { + unsigned long v; + return (_BitScanForward(&v, std::uint32_t(w)) ? v : + _BitScanForward(&v, std::uint32_t(w >> 32)) ? v + 32 : 0); +} +#endif // _M_X64 +#endif // _MSC_VER + +namespace fst { + +// Downcasting. +template +inline To down_cast(From* f) { return static_cast(f); } + +// Bitcasting. +template +inline Dest bit_cast(const Source &source) { + static_assert(sizeof(Dest) == sizeof(Source), + "Bitcasting unsafe for specified types"); + Dest dest; + memcpy(&dest, &source, sizeof(dest)); + return dest; +} + +// Check sums +class CheckSummer { + public: + CheckSummer() : count_(0) { + check_sum_.resize(kCheckSumLength, '\0'); + } + + void Reset() { + count_ = 0; + for (int i = 0; i < kCheckSumLength; ++i) check_sum_[i] = '\0'; + } + + void Update(void const *data, int size) { + const char *p = reinterpret_cast(data); + for (int i = 0; i < size; ++i) { + check_sum_[(count_++) % kCheckSumLength] ^= p[i]; + } + } + + void Update(string const &data) { + for (int i = 0; i < data.size(); ++i) { + check_sum_[(count_++) % kCheckSumLength] ^= data[i]; + } + } + + string Digest() { return check_sum_; } + + private: + static const int kCheckSumLength = 32; + int count_; + string check_sum_; + + CheckSummer(const CheckSummer &) = delete; + CheckSummer &operator=(const CheckSummer &) = delete; +}; + +} // namespace fst + +#endif // FST_LIB_COMPAT_H_ diff --git a/projects/llm_framework/include/fst/complement.h b/projects/llm_framework/include/fst/complement.h new file mode 100644 index 00000000..64eebc03 --- /dev/null +++ b/projects/llm_framework/include/fst/complement.h @@ -0,0 +1,277 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to complement an FST. + +#ifndef FST_COMPLEMENT_H_ +#define FST_COMPLEMENT_H_ + +#include +#include +#include +#include + +#include +#include + + +namespace fst { + +template +class ComplementFst; + +namespace internal { + +// Implementation of delayed ComplementFst. The algorithm used completes the +// (deterministic) FSA and then exchanges final and non-final states. +// Completion, i.e. ensuring that all labels can be read from every state, is +// accomplished by using ρ-labels, which match all labels that are otherwise +// not found leaving a state. The first state in the output is reserved to be a +// new state that is the destination of all ρ-labels. Each remaining output +// state s corresponds to input state s - 1. The first arc in the output at +// these states is the ρ-label, the remaining arcs correspond to the input +// arcs. +template +class ComplementFstImpl : public FstImpl { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + friend class StateIterator>; + friend class ArcIterator>; + + explicit ComplementFstImpl(const Fst &fst) : fst_(fst.Copy()) { + SetType("complement"); + uint64 props = fst.Properties(kILabelSorted, false); + SetProperties(ComplementProperties(props), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + ComplementFstImpl(const ComplementFstImpl &impl) + : fst_(impl.fst_->Copy()) { + SetType("complement"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() const { + if (Properties(kError)) return kNoStateId; + auto start = fst_->Start(); + return start != kNoStateId ? start + 1 : 0; + } + + // Exchange final and non-final states; makes ρ-destination state final. + Weight Final(StateId s) const { + if (s == 0 || fst_->Final(s - 1) == Weight::Zero()) { + return Weight::One(); + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId s) const { + return s == 0 ? 1 : fst_->NumArcs(s - 1) + 1; + } + + size_t NumInputEpsilons(StateId s) const { + return s == 0 ? 0 : fst_->NumInputEpsilons(s - 1); + } + + size_t NumOutputEpsilons(StateId s) const { + return s == 0 ? 0 : fst_->NumOutputEpsilons(s - 1); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && fst_->Properties(kError, false)) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + private: + std::unique_ptr> fst_; +}; + +} // namespace internal + +// Complements an automaton. This is a library-internal operation that +// introduces a (negative) ρ-label; use Difference/DifferenceFst in user code, +// which will not see this label. This version is a delayed FST. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class ComplementFst : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Impl = internal::ComplementFstImpl; + + friend class StateIterator>; + friend class ArcIterator>; + + explicit ComplementFst(const Fst &fst) + : ImplToFst(std::make_shared(fst)) { + static constexpr auto props = + kUnweighted | kNoEpsilons | kIDeterministic | kAcceptor; + if (fst.Properties(props, true) != props) { + FSTERROR() << "ComplementFst: Argument not an unweighted " + << "epsilon-free deterministic acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + ComplementFst(const ComplementFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this FST. See Fst<>::Copy() for further doc. + ComplementFst *Copy(bool safe = false) const override { + return new ComplementFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + inline void InitArcIterator(StateId s, + ArcIteratorData *data) const override; + + // Label that represents the ρ-transition; we use a negative value private to + // the library and which will preserve FST label sort order. + static const Label kRhoLabel = -2; + + private: + using ImplToFst::GetImpl; + + ComplementFst &operator=(const ComplementFst &) = delete; +}; + +template +const typename Arc::Label ComplementFst::kRhoLabel; + +// Specialization for ComplementFst. +template +class StateIterator> : public StateIteratorBase { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const ComplementFst &fst) + : siter_(*fst.GetImpl()->fst_), s_(0) {} + + bool Done() const final { return s_ > 0 && siter_.Done(); } + + StateId Value() const final { return s_; } + + void Next() final { + if (s_ != 0) siter_.Next(); + ++s_; + } + + void Reset() final { + siter_.Reset(); + s_ = 0; + } + + private: + StateIterator> siter_; + StateId s_; +}; + +// Specialization for ComplementFst. +template +class ArcIterator> : public ArcIteratorBase { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + ArcIterator(const ComplementFst &fst, StateId s) : s_(s), pos_(0) { + if (s_ != 0) { + aiter_.reset(new ArcIterator>(*fst.GetImpl()->fst_, s - 1)); + } + } + + bool Done() const final { + if (s_ != 0) { + return pos_ > 0 && aiter_->Done(); + } else { + return pos_ > 0; + } + } + + // Adds the ρ-label to the ρ destination state. + const Arc &Value() const final { + if (pos_ == 0) { + arc_.ilabel = arc_.olabel = ComplementFst::kRhoLabel; + arc_.weight = Weight::One(); + arc_.nextstate = 0; + } else { + arc_ = aiter_->Value(); + ++arc_.nextstate; + } + return arc_; + } + + void Next() final { + if (s_ != 0 && pos_ > 0) aiter_->Next(); + ++pos_; + } + + size_t Position() const final { return pos_; } + + void Reset() final { + if (s_ != 0) aiter_->Reset(); + pos_ = 0; + } + + void Seek(size_t a) final { + if (s_ != 0) { + if (a == 0) { + aiter_->Reset(); + } else { + aiter_->Seek(a - 1); + } + } + pos_ = a; + } + + uint32 Flags() const final { return kArcValueFlags; } + + void SetFlags(uint32, uint32) final {} + + private: + std::unique_ptr>> aiter_; + StateId s_; + size_t pos_; + mutable Arc arc_; +}; + +template +inline void ComplementFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +template +inline void ComplementFst::InitArcIterator(StateId s, + ArcIteratorData *data) const { + data->base = new ArcIterator>(*this, s); +} + +// Useful alias when using StdArc. +using StdComplementFst = ComplementFst; + +} // namespace fst + +#endif // FST_COMPLEMENT_H_ diff --git a/projects/llm_framework/include/fst/compose-filter.h b/projects/llm_framework/include/fst/compose-filter.h new file mode 100644 index 00000000..7251e273 --- /dev/null +++ b/projects/llm_framework/include/fst/compose-filter.h @@ -0,0 +1,571 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for filtering the composition matches, e.g. for correct epsilon +// handling. + +#ifndef FST_COMPOSE_FILTER_H_ +#define FST_COMPOSE_FILTER_H_ + +#include +#include // For optional argument declarations +#include +#include + + +namespace fst { + +// Composition filters determine which matches are allowed to proceed. The +// filter's state is represeted by the type ComposeFilter::FilterState. +// The basic filters handle correct epsilon matching. Their interface is: +// +// template +// class ComposeFilter { +// public: +// using Matcher1 = ...; +// using Matcher2 = ...; +// using FST1 = typename M1::FST; +// using FST2 = typename M2::FST; +// using FilterState = ...; +// +// using Arc = typename FST1::Arc; +// using StateId = typename Arc::StateId; +// using Weight = typename Arc::Weight; +// +// // Required constructor. +// ComposeFilter(const FST1 &fst1, const FST2 &fst2, +// M1 *matcher1 = nullptr, M2 *matcher2 = nullptr); +// +// // If safe=true, the copy is thread-safe. See Fst<>::Copy() +// // for further doc. +// ComposeFilter(const ComposeFilter &filter, +// bool safe = false); +// +// // Return start state of filter. +// FilterState Start() const; +// +// // Specifies current composition state. +// void SetState(StateId s1, StateId s2, const FilterState &fs); +// +// // Apply filter at current composition state to these transitions. If an +// // arc label to be matched is kNolabel, then that side does not consume a +// // symbol. Returns the new filter state or, if disallowed, +// // FilterState::NoState(). The filter is permitted to modify its inputs +// // (e.g. for optimization reasons). +// FilterState FilterArc(Arc *arc1, Arc *arc2) const; + +// // Apply filter at current composition state to these final weights +// // (cf. superfinal transitions). The filter may modify its inputs +// // (e.g. for optimization reasons). +// void FilterFinal(Weight *w1, Weight *w2) const; +// +// // Return the respective matchers. Ownership stays with filter. These +// // methods allow the filter to access and possibly modify the compositio +// // matchers (useful, e.g., with lookahead). +// +// Matcher1 *GetMatcher1(); +// +// Matcher2 *GetMatcher2(); +// +// // This specifies how the filter affects the composition result properties. +// It takes as argument the properties that would apply with a trivial +// // composition filter. +// uint64 Properties(uint64 props) const; +// }; +// +// This filter allows only exact matching of symbols from FST1 with on FST2; +// e.g., no special interpretation of epsilons. +template +class NullComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = TrivialFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + NullComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + NullComposeFilter(const NullComposeFilter &filter, bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + FilterState Start() const { return FilterState(true); } + + void SetState(StateId, StateId, const FilterState &) {} + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + return (arc1->olabel == kNoLabel || arc2->ilabel == kNoLabel) + ? FilterState::NoState() + : FilterState(true); + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; +}; + +// This filter allows all epsilon matches, potentially resulting in redundant +// epsilon paths. The use of this filter gives correct results iff one of the +// following conditions hold: +// +// (1) The semiring is idempotent, +// (2) the first FST is output-epsilon free, or +// (3) the second FST is input-epsilon free. +// +// For (1), redundant epsilon paths may be created but won't hurt correctness. +// For (2) and (3), no redundant paths are created. +template +class TrivialComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = TrivialFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + TrivialComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + TrivialComposeFilter(const TrivialComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + FilterState Start() const { return FilterState(true); } + + void SetState(StateId, StateId, const FilterState &) {} + + FilterState FilterArc(Arc *, Arc *) const { return FilterState(true); } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; +}; + +// This filter requires epsilons on FST1 to be read before epsilons on FST2. +template +class SequenceComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = CharFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + SequenceComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + SequenceComposeFilter(const SequenceComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + if (s1_ == s1 && s2_ == s2 && fs == fs_) return; + s1_ = s1; + s2_ = s2; + fs_ = fs; + const auto na1 = internal::NumArcs(fst1_, s1); + const auto ne1 = internal::NumOutputEpsilons(fst1_, s1); + const bool fin1 = internal::Final(fst1_, s1) != Weight::Zero(); + alleps1_ = na1 == ne1 && !fin1; + noeps1_ = ne1 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc1->olabel == kNoLabel) { + return alleps1_ ? FilterState::NoState() : noeps1_ ? FilterState(0) + : FilterState(1); + } else if (arc2->ilabel == kNoLabel) { + return fs_ != FilterState(0) ? FilterState::NoState() : FilterState(0); + } else { + return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0); + } + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + StateId s1_; // Current fst1_ state. + StateId s2_; // Current fst2_ state. + FilterState fs_; // Current filter state. + bool alleps1_; // Only epsilons (and non-final) leaving s1_? + bool noeps1_; // No epsilons leaving s1_? +}; + +// This filter requires epsilons on FST2 to be read before epsilons on FST1. +template +class AltSequenceComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = CharFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + AltSequenceComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + AltSequenceComposeFilter( + const AltSequenceComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + if (s1_ == s1 && s2_ == s2 && fs == fs_) return; + s1_ = s1; + s2_ = s2; + fs_ = fs; + const auto na2 = internal::NumArcs(fst2_, s2); + const auto ne2 = internal::NumInputEpsilons(fst2_, s2); + const bool fin2 = internal::Final(fst2_, s2) != Weight::Zero(); + alleps2_ = na2 == ne2 && !fin2; + noeps2_ = ne2 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc2->ilabel == kNoLabel) { + return alleps2_ ? FilterState::NoState() : noeps2_ ? FilterState(0) + : FilterState(1); + } else if (arc1->olabel == kNoLabel) { + return fs_ == FilterState(1) ? FilterState::NoState() : FilterState(0); + } else { + return arc1->olabel == 0 ? FilterState::NoState() : FilterState(0); + } + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST2 &fst2_; + StateId s1_; // Current fst1_ state. + StateId s2_; // Current fst2_ state. + FilterState fs_; // Current filter state. + bool alleps2_; // Only epsilons (and non-final) leaving s2_? + bool noeps2_; // No epsilons leaving s2_? +}; + +// This filter requires epsilons on FST1 to be matched with epsilons on FST2 +// whenever possible. (Template arg default declared in fst-decl.h.) +template +class MatchComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = CharFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MatchComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + MatchComposeFilter(const MatchComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + s1_(kNoStateId), + s2_(kNoStateId), + fs_(kNoStateId) {} + + FilterState Start() const { return FilterState(0); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + if (s1_ == s1 && s2_ == s2 && fs == fs_) return; + s1_ = s1; + s2_ = s2; + fs_ = fs; + size_t na1 = internal::NumArcs(fst1_, s1); + size_t ne1 = internal::NumOutputEpsilons(fst1_, s1); + bool f1 = internal::Final(fst1_, s1) != Weight::Zero(); + alleps1_ = na1 == ne1 && !f1; + noeps1_ = ne1 == 0; + size_t na2 = internal::NumArcs(fst2_, s2); + size_t ne2 = internal::NumInputEpsilons(fst2_, s2); + bool f2 = internal::Final(fst2_, s2) != Weight::Zero(); + alleps2_ = na2 == ne2 && !f2; + noeps2_ = ne2 == 0; + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + if (arc2->ilabel == kNoLabel) { // Epsilon in FST1. + return fs_ == FilterState(0) + ? (noeps2_ + ? FilterState(0) + : (alleps2_ ? FilterState::NoState() : FilterState(1))) + : (fs_ == FilterState(1) ? FilterState(1) + : FilterState::NoState()); + } else if (arc1->olabel == kNoLabel) { // Epsilon in FST2. + return fs_ == FilterState(0) + ? (noeps1_ + ? FilterState(0) + : (alleps1_ ? FilterState::NoState() : FilterState(2))) + : (fs_ == FilterState(2) ? FilterState(2) + : FilterState::NoState()); + } else if (arc1->olabel == 0) { // Epsilon in both. + return fs_ == FilterState(0) ? FilterState(0) : FilterState::NoState(); + } else { // Both are non-epsilons. + return FilterState(0); + } + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; + StateId s1_; // Current fst1_ state. + StateId s2_; // Current fst2_ state. + FilterState fs_; // Current filter state ID. + bool alleps1_; // Only epsilson (and non-final) leaving s1? + bool alleps2_; // Only epsilons (and non-final) leaving s2? + bool noeps1_; // No epsilons leaving s1? + bool noeps2_; // No epsilons leaving s2? +}; + +// This filter disallows matching epsilons on FST1 with epsilons on FST2, +// but allows all other matches, potentially resulting in redundant +// epsilon paths. The use of this filter gives correct results iff one of the +// following conditions hold: +// +// (1) The semiring is idempotent, +// (2) the first FST is output-epsilon free, or +// (3) the second FST is input-epsilon free. +// +// For (1), redundant epsilon paths may be created but won't hurt correctness. +// For (2) and (3), no redundant paths are created. +template +class NoMatchComposeFilter { + public: + using Matcher1 = M1; + using Matcher2 = M2; + using FST1 = typename M1::FST; + using FST2 = typename M2::FST; + using FilterState = TrivialFilterState; + + using Arc = typename FST1::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + NoMatchComposeFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr) + : matcher1_(matcher1 ? matcher1 : new Matcher1(fst1, MATCH_OUTPUT)), + matcher2_(matcher2 ? matcher2 : new Matcher2(fst2, MATCH_INPUT)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + NoMatchComposeFilter(const NoMatchComposeFilter &filter, + bool safe = false) + : matcher1_(filter.matcher1_->Copy(safe)), + matcher2_(filter.matcher2_->Copy(safe)), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()) {} + + FilterState Start() const { return FilterState(true); } + + void SetState(StateId, StateId, const FilterState &) {} + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + return FilterState(arc1->olabel != 0 || arc2->ilabel != 0); + } + + void FilterFinal(Weight *, Weight *) const {} + + Matcher1 *GetMatcher1() { return matcher1_.get(); } + + Matcher2 *GetMatcher2() { return matcher2_.get(); } + + uint64 Properties(uint64 props) const { return props; } + + private: + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + const FST1 &fst1_; + const FST2 &fst2_; +}; + +// This filter works with the MultiEpsMatcher to determine if multi-epsilons are +// preserved in the composition output (rather than rewritten as 0) and +// ensures correct properties. +template +class MultiEpsFilter { + public: + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + using FST1 = typename Filter::FST1; + using FST2 = typename Filter::FST2; + using FilterState = typename Filter::FilterState; + + using Arc = typename Filter::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MultiEpsFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr, + bool keep_multi_eps = false) + : filter_(fst1, fst2, matcher1, matcher2), + keep_multi_eps_(keep_multi_eps) {} + + MultiEpsFilter(const MultiEpsFilter &filter, bool safe = false) + : filter_(filter.filter_, safe), + keep_multi_eps_(filter.keep_multi_eps_) {} + + FilterState Start() const { return filter_.Start(); } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + return filter_.SetState(s1, s2, fs); + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + const auto fs = filter_.FilterArc(arc1, arc2); + if (keep_multi_eps_) { + if (arc1->olabel == kNoLabel) arc1->ilabel = arc2->ilabel; + if (arc2->ilabel == kNoLabel) arc2->olabel = arc1->olabel; + } + return fs; + } + + void FilterFinal(Weight *w1, Weight *w2) const { + return filter_.FilterFinal(w1, w2); + } + + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + const auto oprops = filter_.Properties(iprops); + return oprops & kILabelInvariantProperties & kOLabelInvariantProperties; + } + + private: + Filter filter_; + bool keep_multi_eps_; +}; + +} // namespace fst + +#endif // FST_COMPOSE_FILTER_H_ diff --git a/projects/llm_framework/include/fst/compose.h b/projects/llm_framework/include/fst/compose.h new file mode 100644 index 00000000..1066d097 --- /dev/null +++ b/projects/llm_framework/include/fst/compose.h @@ -0,0 +1,1035 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to compute the composition of two FSTs. + +#ifndef FST_COMPOSE_H_ +#define FST_COMPOSE_H_ + +#include +#include +#include + +#include + +#include +#include +#include // For optional argument declarations +#include +#include +#include +#include + + +namespace fst { + +// Delayed composition options templated on the arc type, the matcher, +// the composition filter, and the composition state table. By +// default, the matchers, filter, and state table are constructed by +// composition. If set below, the user can instead pass in these +// objects; in that case, ComposeFst takes their ownership. This +// version controls composition implemented between generic Fst +// types and a shared matcher type M for Fst. This should be +// adequate for most applications, giving a reasonable tradeoff +// between efficiency and code sharing (but see ComposeFstImplOptions). +template >, + class Filter = SequenceComposeFilter, + class StateTable = + GenericComposeStateTable> +struct ComposeFstOptions : public CacheOptions { + M *matcher1; // FST1 matcher. + M *matcher2; // FST2 matcher. + Filter *filter; // Composition filter. + StateTable *state_table; // Composition state table. + + explicit ComposeFstOptions(const CacheOptions &opts = CacheOptions(), + M *matcher1 = nullptr, M *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheOptions(opts), + matcher1(matcher1), + matcher2(matcher2), + filter(filter), + state_table(state_table) {} +}; + +// Forward declaration of ComposeFstMatcher. +template +class ComposeFstMatcher; + +// Delayed composition options templated on the two matcher types, the +// composition filter, the composition state table and the cache store. By +// default, the matchers, filter, state table and cache store are constructed +// by composition. If set below, the user can instead pass in these objects; in +// that case, ComposeFst takes their ownership. This version controls +// composition implemented using arbitrary matchers (of the same arc type but +// otherwise arbitrary FST type). The user must ensure the matchers are +// compatible. These options permit the most efficient use, but shares the +// least code. This is for advanced use only in the most demanding or +// specialized applications that can benefit from it; otherwise, prefer +// ComposeFstOptions). +template , + class StateTable = GenericComposeStateTable< + typename M1::Arc, typename Filter::FilterState>, + class CacheStore = DefaultCacheStore> +struct ComposeFstImplOptions : public CacheImplOptions { + M1 *matcher1; // FST1 matcher (see matcher.h).... + M2 *matcher2; // FST2 matcher. + Filter *filter; // Composition filter (see compose-filter.h). + StateTable + *state_table; // Composition state table (see compose-state-table.h). + bool own_state_table; // ComposeFstImpl takes ownership of 'state_table'? + bool allow_noncommute; // Allow non-commutative weights + + explicit ComposeFstImplOptions(const CacheOptions &opts, + M1 *matcher1 = nullptr, M2 *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheImplOptions(opts), + matcher1(matcher1), + matcher2(matcher2), + filter(filter), + state_table(state_table), + own_state_table(true), + allow_noncommute(false) {} + + explicit ComposeFstImplOptions(const CacheImplOptions &opts, + M1 *matcher1 = nullptr, M2 *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheImplOptions(opts), + matcher1(matcher1), + matcher2(matcher2), + filter(filter), + state_table(state_table), + own_state_table(true), + allow_noncommute(false) {} + + ComposeFstImplOptions() + : matcher1(nullptr), + matcher2(nullptr), + filter(nullptr), + state_table(nullptr), + own_state_table(true), + allow_noncommute(false) {} +}; + +namespace internal { + +// Implementation of delayed composition. This base class is common to the +// variants with different matchers, composition filters and state tables. +template , + class F = ComposeFst> +class ComposeFstImplBase + : public CacheBaseImpl { + public: + using FST = F; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using State = typename CacheStore::State; + using CacheImpl = CacheBaseImpl; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::HasStart; + using CacheImpl::HasFinal; + using CacheImpl::HasArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + ComposeFstImplBase(const CacheImplOptions &opts) + : CacheImpl(opts) {} + + ComposeFstImplBase(const CacheOptions &opts) : CacheImpl(opts) {} + + ComposeFstImplBase(const ComposeFstImplBase &impl) : CacheImpl(impl, true) { + SetType(impl.Type()); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + virtual ComposeFstImplBase *Copy() const = 0; + + ~ComposeFstImplBase() override {} + + StateId Start() { + if (!HasStart()) { + const auto start = ComputeStart(); + if (start != kNoStateId) SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) SetFinal(s, ComputeFinal(s)); + return CacheImpl::Final(s); + } + + virtual void Expand(StateId s) = 0; + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + virtual MatcherBase *InitMatcher(const F &fst, + MatchType match_type) const { + // Use the default matcher if no override is provided. + return nullptr; + } + + protected: + virtual StateId ComputeStart() = 0; + virtual Weight ComputeFinal(StateId s) = 0; +}; + +// Implementation of delayed composition templated on the matchers (see +// matcher.h), composition filter (see compose-filter.h) and the composition +// state table (see compose-state-table.h). +template +class ComposeFstImpl + : public ComposeFstImplBase { + public: + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + + using FST1 = typename Matcher1::FST; + using FST2 = typename Matcher2::FST; + + using Arc = typename CacheStore::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = typename Filter::FilterState; + using State = typename CacheStore::State; + + using CacheImpl = CacheBaseImpl; + + using StateTuple = typename StateTable::StateTuple; + + friend class ComposeFstMatcher; + + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::SetProperties; + + template + ComposeFstImpl(const FST1 &fst1, const FST2 &fst2, + const ComposeFstImplOptions &opts); + + ComposeFstImpl(const ComposeFstImpl &impl) + : ComposeFstImplBase(impl), + filter_(new Filter(*impl.filter_, true)), + matcher1_(filter_->GetMatcher1()), + matcher2_(filter_->GetMatcher2()), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + state_table_(new StateTable(*impl.state_table_)), + own_state_table_(true), + match_type_(impl.match_type_) {} + + ~ComposeFstImpl() override { + if (own_state_table_) delete state_table_; + } + + ComposeFstImpl *Copy() const override { return new ComposeFstImpl(*this); } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && + (fst1_.Properties(kError, false) || fst2_.Properties(kError, false) || + (matcher1_->Properties(0) & kError) || + (matcher2_->Properties(0) & kError) | + (filter_->Properties(0) & kError) || + state_table_->Error())) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + // Arranges it so that the first arg to OrderedExpand is the Fst + // that will be matched on. + void Expand(StateId s) override { + const auto &tuple = state_table_->Tuple(s); + const auto s1 = tuple.StateId1(); + const auto s2 = tuple.StateId2(); + filter_->SetState(s1, s2, tuple.GetFilterState()); + if (MatchInput(s1, s2)) { + OrderedExpand(s, fst2_, s2, fst1_, s1, matcher2_, true); + } else { + OrderedExpand(s, fst1_, s1, fst2_, s2, matcher1_, false); + } + } + + const FST1 &GetFst1() const { return fst1_; } + + const FST2 &GetFst2() const { return fst2_; } + + const Matcher1 *GetMatcher1() const { return matcher1_; } + + Matcher1 *GetMatcher1() { return matcher1_; } + + const Matcher2 *GetMatcher2() const { return matcher2_; } + + Matcher2 *GetMatcher2() { return matcher2_; } + + const Filter *GetFilter() const { return filter_.get(); } + + Filter *GetFilter() { return filter_.get(); } + + const StateTable *GetStateTable() const { return state_table_; } + + StateTable *GetStateTable() { return state_table_; } + + MatcherBase *InitMatcher(const ComposeFst &fst, + MatchType match_type) const override { + const auto test_props = match_type == MATCH_INPUT + ? kFstProperties & ~kILabelInvariantProperties + : kFstProperties & ~kOLabelInvariantProperties; + // If both matchers support 'match_type' and we have a guarantee that a + // call to 'filter_->FilterArc(arc1, arc2)' will not modify the ilabel of + // arc1 when MATCH_INPUT or the olabel or arc2 when MATCH_OUTPUT, then + // ComposeFstMatcher can be used. + if ((matcher1_->Type(false) == match_type) && + (matcher2_->Type(false) == match_type) && + (filter_->Properties(test_props) == test_props)) { + return new ComposeFstMatcher< + CacheStore, Filter, StateTable>(&fst, match_type); + } + return nullptr; + } + + private: + // This does that actual matching of labels in the composition. The + // arguments are ordered so matching is called on state 'sa' of + // 'fsta' for each arc leaving state 'sb' of 'fstb'. The 'match_input' arg + // determines whether the input or output label of arcs at 'sb' is + // the one to match on. + template + void OrderedExpand(StateId s, const Fst &, StateId sa, const FST &fstb, + StateId sb, Matcher *matchera, bool match_input) { + matchera->SetState(sa); + // First processes non-consuming symbols (e.g., epsilons) on FSTA. + const Arc loop(match_input ? 0 : kNoLabel, match_input ? kNoLabel : 0, + Weight::One(), sb); + MatchArc(s, matchera, loop, match_input); + // Then processes matches on FSTB. + for (ArcIterator iterb(fstb, sb); !iterb.Done(); iterb.Next()) { + MatchArc(s, matchera, iterb.Value(), match_input); + } + CacheImpl::SetArcs(s); + } + + // Matches a single transition from 'fstb' against 'fata' at 's'. + template + void MatchArc(StateId s, Matcher *matchera, const Arc &arc, + bool match_input) { + if (matchera->Find(match_input ? arc.olabel : arc.ilabel)) { + for (; !matchera->Done(); matchera->Next()) { + auto arca = matchera->Value(); + auto arcb = arc; + if (match_input) { + const auto &fs = filter_->FilterArc(&arcb, &arca); + if (fs != FilterState::NoState()) AddArc(s, arcb, arca, fs); + } else { + const auto &fs = filter_->FilterArc(&arca, &arcb); + if (fs != FilterState::NoState()) AddArc(s, arca, arcb, fs); + } + } + } + } + + // Add a matching transition at 's'. + void AddArc(StateId s, const Arc &arc1, const Arc &arc2, + const FilterState &f) { + const StateTuple tuple(arc1.nextstate, arc2.nextstate, f); + CacheImpl::EmplaceArc( + s, arc1.ilabel, arc2.olabel, Times(arc1.weight, arc2.weight), + state_table_->FindState(tuple)); + } + + StateId ComputeStart() override { + const auto s1 = fst1_.Start(); + if (s1 == kNoStateId) return kNoStateId; + const auto s2 = fst2_.Start(); + if (s2 == kNoStateId) return kNoStateId; + const auto &fs = filter_->Start(); + const StateTuple tuple(s1, s2, fs); + return state_table_->FindState(tuple); + } + + Weight ComputeFinal(StateId s) override { + const auto &tuple = state_table_->Tuple(s); + const auto s1 = tuple.StateId1(); + auto final1 = matcher1_->Final(s1); + if (final1 == Weight::Zero()) return final1; + const auto s2 = tuple.StateId2(); + auto final2 = matcher2_->Final(s2); + if (final2 == Weight::Zero()) return final2; + filter_->SetState(s1, s2, tuple.GetFilterState()); + filter_->FilterFinal(&final1, &final2); + return Times(final1, final2); + } + + // Determines which side to match on per composition state. + bool MatchInput(StateId s1, StateId s2) { + switch (match_type_) { + case MATCH_INPUT: + return true; + case MATCH_OUTPUT: + return false; + default: // MATCH_BOTH + const auto priority1 = matcher1_->Priority(s1); + const auto priority2 = matcher2_->Priority(s2); + if (priority1 == kRequirePriority && priority2 == kRequirePriority) { + FSTERROR() << "ComposeFst: Both sides can't require match"; + SetProperties(kError, kError); + return true; + } + if (priority1 == kRequirePriority) return false; + if (priority2 == kRequirePriority) { + return true; + } + return priority1 <= priority2; + } + } + + // Identifies and verifies the capabilities of the matcher to be used for + // composition. + void SetMatchType(); + + std::unique_ptr filter_; + Matcher1 *matcher1_; // Borrowed reference. + Matcher2 *matcher2_; // Borrowed reference. + const FST1 &fst1_; + const FST2 &fst2_; + StateTable *state_table_; + bool own_state_table_; + + MatchType match_type_; +}; + +template +template +ComposeFstImpl::ComposeFstImpl( + const FST1 &fst1, const FST2 &fst2, + const ComposeFstImplOptions &opts) + : ComposeFstImplBase(opts), + filter_(opts.filter + ? opts.filter + : new Filter(fst1, fst2, opts.matcher1, opts.matcher2)), + matcher1_(filter_->GetMatcher1()), + matcher2_(filter_->GetMatcher2()), + fst1_(matcher1_->GetFst()), + fst2_(matcher2_->GetFst()), + state_table_(opts.state_table ? opts.state_table + : new StateTable(fst1_, fst2_)), + own_state_table_(opts.state_table ? opts.own_state_table : true) { + SetType("compose"); + if (!CompatSymbols(fst2.InputSymbols(), fst1.OutputSymbols())) { + FSTERROR() << "ComposeFst: Output symbol table of 1st argument " + << "does not match input symbol table of 2nd argument"; + SetProperties(kError, kError); + } + SetInputSymbols(fst1_.InputSymbols()); + SetOutputSymbols(fst2_.OutputSymbols()); + SetMatchType(); + VLOG(2) << "ComposeFstImpl: Match type: " << match_type_; + if (match_type_ == MATCH_NONE) SetProperties(kError, kError); + const auto fprops1 = fst1.Properties(kFstProperties, false); + const auto fprops2 = fst2.Properties(kFstProperties, false); + const auto mprops1 = matcher1_->Properties(fprops1); + const auto mprops2 = matcher2_->Properties(fprops2); + const auto cprops = ComposeProperties(mprops1, mprops2); + SetProperties(filter_->Properties(cprops), kCopyProperties); + if (state_table_->Error()) SetProperties(kError, kError); +} + +template +void ComposeFstImpl::SetMatchType() { + // Ensures any required matching is possible and known. + if ((matcher1_->Flags() & kRequireMatch) && + matcher1_->Type(true) != MATCH_OUTPUT) { + FSTERROR() << "ComposeFst: 1st argument cannot perform required matching " + << "(sort?)."; + match_type_ = MATCH_NONE; + return; + } + if ((matcher2_->Flags() & kRequireMatch) && + matcher2_->Type(true) != MATCH_INPUT) { + FSTERROR() << "ComposeFst: 2nd argument cannot perform required matching " + << "(sort?)."; + match_type_ = MATCH_NONE; + return; + } + // Finds which sides to match on (favoring minimal testing of capabilities). + const auto type1 = matcher1_->Type(false); + const auto type2 = matcher2_->Type(false); + if (type1 == MATCH_OUTPUT && type2 == MATCH_INPUT) { + match_type_ = MATCH_BOTH; + } else if (type1 == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (type2 == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else if (matcher1_->Type(true) == MATCH_OUTPUT) { + match_type_ = MATCH_OUTPUT; + } else if (matcher2_->Type(true) == MATCH_INPUT) { + match_type_ = MATCH_INPUT; + } else { + FSTERROR() << "ComposeFst: 1st argument cannot match on output labels " + << "and 2nd argument cannot match on input labels (sort?)."; + match_type_ = MATCH_NONE; + } +} + +} // namespace internal + +// Computes the composition of two transducers. This version is a delayed FST. +// If FST1 transduces string x to y with weight a and FST2 transduces y to z +// with weight b, then their composition transduces string x to z with weight +// Times(x, z). +// +// The output labels of the first transducer or the input labels of the second +// transducer must be sorted (with the default matcher). The weights need to +// form a commutative semiring (valid for TropicalWeight and LogWeight). +// +// Complexity: +// +// Assuming the first FST is unsorted and the second is sorted, +// +// Time: O(v1 v2 d1 (log d2 + m2)), +// Space: O(v1 v2) +// +// where vi = # of states visited, di = maximum out-degree, and mi the +// maximum multiplicity of the states visited, for the ith FST. Constant time +// and space to visit an input state or arc is assumed and exclusive of caching. +// +// Caveats: +// - ComposeFst does not trim its output (since it is a delayed operation). +// - The efficiency of composition can be strongly affected by several factors: +// - the choice of which transducer is sorted - prefer sorting the FST +// that has the greater average out-degree. +// - the amount of non-determinism +// - the presence and location of epsilon transitions - avoid epsilon +// transitions on the output side of the first transducer or +// the input side of the second transducer or prefer placing +// them later in a path since they delay matching and can +// introduce non-coaccessible states and transitions. +// +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToFst. The CacheStore specifies the +// cache store (default declared in fst-decl.h). +template */> +class ComposeFst + : public ImplToFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = CacheStore; + using State = typename CacheStore::State; + + using Impl = internal::ComposeFstImplBase; + + friend class ArcIterator>; + friend class StateIterator>; + template friend class ComposeFstMatcher; + + // Compose specifying only caching options. + ComposeFst(const Fst &fst1, const Fst &fst2, + const CacheOptions &opts = CacheOptions()) + : ImplToFst(CreateBase(fst1, fst2, opts)) {} + + // Compose specifying one shared matcher type M. Requires that the input FSTs + // and matcher FST types be Fst. Recommended for best code-sharing and + // matcher compatiblity. + template + ComposeFst(const Fst &fst1, const Fst &fst2, + const ComposeFstOptions &opts) + : ImplToFst(CreateBase1(fst1, fst2, opts)) {} + + // Compose specifying two matcher types Matcher1 and Matcher2. Requires input + // FST (of the same Arc type, but o.w. arbitrary) match the corresponding + // matcher FST types). Recommended only for advanced use in demanding or + // specialized applications due to potential code bloat and matcher + // incompatibilities. + template + ComposeFst(const typename Matcher1::FST &fst1, + const typename Matcher2::FST &fst2, + const ComposeFstImplOptions &opts) + : ImplToFst(CreateBase2(fst1, fst2, opts)) {} + + // See Fst<>::Copy() for doc. + ComposeFst(const ComposeFst &fst, bool safe = false) + : ImplToFst(safe ? std::shared_ptr(fst.GetImpl()->Copy()) + : fst.GetSharedImpl()) {} + + // Get a copy of this ComposeFst. See Fst<>::Copy() for further doc. + ComposeFst *Copy(bool safe = false) const override { + return new ComposeFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return GetImpl()->InitMatcher(*this, match_type); + } + + protected: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + explicit ComposeFst(std::shared_ptr impl) : ImplToFst(impl) {} + + // Create compose implementation specifying two matcher types. + template + static std::shared_ptr CreateBase2( + const typename Matcher1::FST &fst1, const typename Matcher2::FST &fst2, + const ComposeFstImplOptions &opts) { + auto impl = std::make_shared< + internal::ComposeFstImpl>(fst1, fst2, + opts); + if (!(Weight::Properties() & kCommutative) && !opts.allow_noncommute) { + const auto props1 = fst1.Properties(kUnweighted, true); + const auto props2 = fst2.Properties(kUnweighted, true); + if (!(props1 & kUnweighted) && !(props2 & kUnweighted)) { + FSTERROR() << "ComposeFst: Weights must be a commutative semiring: " + << Weight::Type(); + impl->SetProperties(kError, kError); + } + } + return impl; + } + + // Create compose implementation specifying one matcher type; requires that + // input and matcher FST types be Fst. + template + static std::shared_ptr CreateBase1( + const Fst &fst1, const Fst &fst2, + const ComposeFstOptions &opts) { + ComposeFstImplOptions + nopts(opts, opts.matcher1, opts.matcher2, opts.filter, + opts.state_table); + return CreateBase2(fst1, fst2, nopts); + } + + // Create compose implementation specifying no matcher type. + static std::shared_ptr CreateBase(const Fst &fst1, + const Fst &fst2, + const CacheOptions &opts) { + switch (LookAheadMatchType(fst1, fst2)) { // Check for lookahead matchers + default: + case MATCH_NONE: { // Default composition (no look-ahead). + ComposeFstOptions nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + case MATCH_OUTPUT: { // Lookahead on fst1. + using M = typename DefaultLookAhead::FstMatcher; + using F = typename DefaultLookAhead::ComposeFilter; + ComposeFstOptions nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + case MATCH_INPUT: { // Lookahead on fst2 + using M = typename DefaultLookAhead::FstMatcher; + using F = typename DefaultLookAhead::ComposeFilter; + ComposeFstOptions nopts(opts); + return CreateBase1(fst1, fst2, nopts); + } + } + } + + private: + ComposeFst &operator=(const ComposeFst &fst) = delete; +}; + +// Specialization for ComposeFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const ComposeFst &fst) + : CacheStateIterator>(fst, + fst.GetMutableImpl()) {} +}; + +// Specialization for ComposeFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ComposeFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void ComposeFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Specialized matcher for ComposeFst. Supports MATCH_INPUT or MATCH_OUTPUT, +// iff the underlying matchers for the two FSTS being composed support +// MATCH_INPUT or MATCH_OUTPUT, respectively. +template +class ComposeFstMatcher : public MatcherBase { + public: + using Arc = typename CacheStore::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + using FilterState = typename Filter::FilterState; + + using StateTuple = typename StateTable::StateTuple; + using Impl = internal::ComposeFstImpl; + + // The compose FST arg must match the filter and state table types. + // This makes a copy of the FST. + ComposeFstMatcher(const ComposeFst &fst, + MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + impl_(static_cast(fst_.GetImpl())), + s_(kNoStateId), + match_type_(match_type), + matcher1_(impl_->matcher1_->Copy()), + matcher2_(impl_->matcher2_->Copy()), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel); + } + + // The compose FST arg must match the filter and state table types. + // This doesn't copy the FST (although it may copy components). + ComposeFstMatcher(const ComposeFst *fst, + MatchType match_type) + : fst_(*fst), + impl_(static_cast(fst_.GetImpl())), + s_(kNoStateId), + match_type_(match_type), + matcher1_(impl_->matcher1_->Copy()), + matcher2_(impl_->matcher2_->Copy()), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel); + } + + // This makes a copy of the FST. + ComposeFstMatcher( + const ComposeFstMatcher &matcher, + bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + impl_(static_cast(fst_.GetImpl())), + s_(kNoStateId), + match_type_(matcher.match_type_), + matcher1_(matcher.matcher1_->Copy(safe)), + matcher2_(matcher.matcher2_->Copy(safe)), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) std::swap(loop_.ilabel, loop_.olabel); + } + + ComposeFstMatcher *Copy( + bool safe = false) const override { + return new ComposeFstMatcher(*this, safe); + } + + MatchType Type(bool test) const override { + if ((matcher1_->Type(test) == MATCH_NONE) || + (matcher2_->Type(test) == MATCH_NONE)) { + return MATCH_NONE; + } + if (((matcher1_->Type(test) == MATCH_UNKNOWN) && + (matcher2_->Type(test) == MATCH_UNKNOWN)) || + ((matcher1_->Type(test) == MATCH_UNKNOWN) && + (matcher2_->Type(test) == match_type_)) || + ((matcher1_->Type(test) == match_type_) && + (matcher2_->Type(test) == MATCH_UNKNOWN))) { + return MATCH_UNKNOWN; + } + if ((matcher1_->Type(test) == match_type_) && + (matcher2_->Type(test) == match_type_)) { + return match_type_; + } + return MATCH_NONE; + } + + const Fst &GetFst() const override { return fst_; } + + uint64 Properties(uint64 inprops) const override { + return inprops; + } + + void SetState(StateId s) final { + if (s_ == s) return; + s_ = s; + const auto &tuple = impl_->state_table_->Tuple(s); + matcher1_->SetState(tuple.StateId1()); + matcher2_->SetState(tuple.StateId2()); + loop_.nextstate = s_; + } + + bool Find(Label label) final { + bool found = false; + current_loop_ = false; + if (label == 0) { + current_loop_ = true; + found = true; + } + if (match_type_ == MATCH_INPUT) { + found = found || FindLabel(label, matcher1_.get(), matcher2_.get()); + } else { // match_type_ == MATCH_OUTPUT + found = found || FindLabel(label, matcher2_.get(), matcher1_.get()); + } + return found; + } + + bool Done() const final { + return !current_loop_ && matcher1_->Done() && matcher2_->Done(); + } + + const Arc &Value() const final { return current_loop_ ? loop_ : arc_; } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else if (match_type_ == MATCH_INPUT) { + FindNext(matcher1_.get(), matcher2_.get()); + } else { // match_type_ == MATCH_OUTPUT + FindNext(matcher2_.get(), matcher1_.get()); + } + } + + ssize_t Priority(StateId s) final { return fst_.NumArcs(s); } + + private: + // Processes a match with the filter and creates resulting arc. + bool MatchArc(StateId s, Arc arc1, + Arc arc2) { // FIXME(kbg): copy but not assignment. + const auto &fs = impl_->filter_->FilterArc(&arc1, &arc2); + if (fs == FilterState::NoState()) return false; + const StateTuple tuple(arc1.nextstate, arc2.nextstate, fs); + arc_.ilabel = arc1.ilabel; + arc_.olabel = arc2.olabel; + arc_.weight = Times(arc1.weight, arc2.weight); + arc_.nextstate = impl_->state_table_->FindState(tuple); + return true; + } + + // Finds the first match allowed by the filter. + template + bool FindLabel(Label label, MatcherA *matchera, MatcherB *matcherb) { + if (matchera->Find(label)) { + matcherb->Find(match_type_ == MATCH_INPUT ? matchera->Value().olabel + : matchera->Value().ilabel); + return FindNext(matchera, matcherb); + } + return false; + } + + // Finds the next match allowed by the filter, returning true iff such a + // match is found. + template + bool FindNext(MatcherA *matchera, MatcherB *matcherb) { + // State when entering this function: + // 'matchera' is pointed to a match x, y for label x, and a match for y was + // requested on 'matcherb'. + while (!matchera->Done() || !matcherb->Done()) { + if (matcherb->Done()) { + // If no more matches for y on 'matcherb', moves forward on 'matchera' + // until a match x, y' is found such that there is a match for y' on + // 'matcherb'. + matchera->Next(); + while (!matchera->Done() && + !matcherb->Find(match_type_ == MATCH_INPUT + ? matchera->Value().olabel + : matchera->Value().ilabel)) { + matchera->Next(); + } + } + while (!matcherb->Done()) { + // 'matchera' is pointing to a match x, y' ('arca') and 'matcherb' is + // pointing to a match y', z' ('arcb'). If combining these two arcs is + // allowed by the filter (hence resulting in an arc x, z') return true. + // Position 'matcherb' on the next potential match for y' before + // returning. + const auto &arca = matchera->Value(); + const auto &arcb = matcherb->Value(); + // Position 'matcherb' on the next potential match for y'. + matcherb->Next(); + // Returns true If combining these two arcs is allowed by the filter + // (hence resulting in an arc x, z'); otherwise consider next match + // for y' on 'matcherb'. + if (MatchArc(s_, match_type_ == MATCH_INPUT ? arca : arcb, + match_type_ == MATCH_INPUT ? arcb : arca)) { + return true; + } + } + } + // Both 'matchera' and 'matcherb' are done, no more match to analyse. + return false; + } + + std::unique_ptr> owned_fst_; + const ComposeFst &fst_; + const Impl *impl_; + StateId s_; + MatchType match_type_; + std::unique_ptr matcher1_; + std::unique_ptr matcher2_; + bool current_loop_; + Arc loop_; + Arc arc_; +}; + +// Useful alias when using StdArc. +using StdComposeFst = ComposeFst; + +enum ComposeFilter { + AUTO_FILTER, + NULL_FILTER, + TRIVIAL_FILTER, + SEQUENCE_FILTER, + ALT_SEQUENCE_FILTER, + MATCH_FILTER, + NO_MATCH_FILTER +}; + +struct ComposeOptions { + bool connect; // Connect output? + ComposeFilter filter_type; // Pre-defined filter to use. + + explicit ComposeOptions(bool connect = true, + ComposeFilter filter_type = AUTO_FILTER) + : connect(connect), filter_type(filter_type) {} +}; + +// Computes the composition of two transducers. This version writes +// the composed FST into a MutableFst. If FST1 transduces string x to +// y with weight a and FST2 transduces y to z with weight b, then +// their composition transduces string x to z with weight +// Times(x, z). +// +// The output labels of the first transducer or the input labels of +// the second transducer must be sorted. The weights need to form a +// commutative semiring (valid for TropicalWeight and LogWeight). +// +// Complexity: +// +// Assuming the first FST is unsorted and the second is sorted: +// +// Time: O(V1 V2 D1 (log D2 + M2)), +// Space: O(V1 V2 D1 M2) +// +// where Vi = # of states, Di = maximum out-degree, and Mi is the maximum +// multiplicity, for the ith FST. +// +// Caveats: +// +// - Compose trims its output. +// - The efficiency of composition can be strongly affected by several factors: +// - the choice of which transducer is sorted - prefer sorting the FST +// that has the greater average out-degree. +// - the amount of non-determinism +// - the presence and location of epsilon transitions - avoid epsilon +// transitions on the output side of the first transducer or +// the input side of the second transducer or prefer placing +// them later in a path since they delay matching and can +// introduce non-coaccessible states and transitions. +template +void Compose(const Fst &ifst1, const Fst &ifst2, + MutableFst *ofst, + const ComposeOptions &opts = ComposeOptions()) { + using M = Matcher>; + // In each case, we cache only the last state for fastest copy. + switch (opts.filter_type) { + case AUTO_FILTER: { + CacheOptions nopts; + nopts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, nopts); + break; + } + case NULL_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case SEQUENCE_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case ALT_SEQUENCE_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case MATCH_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case NO_MATCH_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + case TRIVIAL_FILTER: { + ComposeFstOptions> copts; + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + break; + } + } + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_COMPOSE_H_ diff --git a/projects/llm_framework/include/fst/concat.h b/projects/llm_framework/include/fst/concat.h new file mode 100644 index 00000000..74d22c22 --- /dev/null +++ b/projects/llm_framework/include/fst/concat.h @@ -0,0 +1,220 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to compute the concatenation of two FSTs. + +#ifndef FST_CONCAT_H_ +#define FST_CONCAT_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Computes the concatenation (product) of two FSTs. If FST1 transduces string +// x to y with weight a and FST2 transduces string w to v with weight b, then +// their concatenation transduces string xw to yv with weight Times(a, b). +// +// This version modifies its MutableFst argument (in first position). +// +// Complexity: +// +// Time: O(V1 + V2 + E2) +// Space: O(V1 + V2 + E2) +// +// where Vi is the number of states, and Ei is the number of arcs, of the ith +// FST. +template +void Concat(MutableFst *fst1, const Fst &fst2) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + // Checks that the symbol table are compatible. + if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Concat: Input/output symbol tables of 1st argument " + << "does not match input/output symbol tables of 2nd argument"; + fst1->SetProperties(kError, kError); + return; + } + const auto props1 = fst1->Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + const auto start1 = fst1->Start(); + if (start1 == kNoStateId) { + if (props2 & kError) fst1->SetProperties(kError, kError); + return; + } + const auto numstates1 = fst1->NumStates(); + if (fst2.Properties(kExpanded, false)) { + fst1->ReserveStates(numstates1 + CountStates(fst2)); + } + for (StateIterator> siter2(fst2); !siter2.Done(); siter2.Next()) { + const auto s1 = fst1->AddState(); + const auto s2 = siter2.Value(); + fst1->SetFinal(s1, fst2.Final(s2)); + fst1->ReserveArcs(s1, fst2.NumArcs(s2)); + for (ArcIterator> aiter(fst2, s2); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + arc.nextstate += numstates1; + fst1->AddArc(s1, arc); + } + } + const auto start2 = fst2.Start(); + for (StateId s1 = 0; s1 < numstates1; ++s1) { + const auto weight = fst1->Final(s1); + if (weight != Weight::Zero()) { + fst1->SetFinal(s1, Weight::Zero()); + if (start2 != kNoStateId) { + fst1->AddArc(s1, Arc(0, 0, weight, start2 + numstates1)); + } + } + } + if (start2 != kNoStateId) { + fst1->SetProperties(ConcatProperties(props1, props2), kFstProperties); + } +} + +// Computes the concatentation of two FSTs. This version modifies its +// MutableFst argument (in second position). +// +// Complexity: +// +// Time: O(V1 + E1) +// Space: O(V1 + E1) +// +// where Vi is the number of states, and Ei is the number of arcs, of the ith +// FST. +template +void Concat(const Fst &fst1, MutableFst *fst2) { + using Weight = typename Arc::Weight; + // Checks that the symbol table are compatible. + if (!CompatSymbols(fst1.InputSymbols(), fst2->InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2->OutputSymbols())) { + FSTERROR() << "Concat: Input/output symbol tables of 1st argument " + << "does not match input/output symbol tables of 2nd argument"; + fst2->SetProperties(kError, kError); + return; + } + const auto props1 = fst1.Properties(kFstProperties, false); + const auto props2 = fst2->Properties(kFstProperties, false); + const auto start2 = fst2->Start(); + if (start2 == kNoStateId) { + if (props1 & kError) fst2->SetProperties(kError, kError); + return; + } + const auto numstates2 = fst2->NumStates(); + if (fst1.Properties(kExpanded, false)) { + fst2->ReserveStates(numstates2 + CountStates(fst1)); + } + for (StateIterator> siter(fst1); !siter.Done(); siter.Next()) { + const auto s1 = siter.Value(); + const auto s2 = fst2->AddState(); + const auto weight = fst1.Final(s1); + if (weight != Weight::Zero()) { + fst2->ReserveArcs(s2, fst1.NumArcs(s1) + 1); + fst2->AddArc(s2, Arc(0, 0, weight, start2)); + } else { + fst2->ReserveArcs(s2, fst1.NumArcs(s1)); + } + for (ArcIterator> aiter(fst1, s1); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + arc.nextstate += numstates2; + fst2->AddArc(s2, arc); + } + } + const auto start1 = fst1.Start(); + if (start1 != kNoStateId) { + fst2->SetStart(start1 + numstates2); + fst2->SetProperties(ConcatProperties(props1, props2), kFstProperties); + } else { + fst2->SetStart(fst2->AddState()); + } +} + +// Computes the concatentation of two FSTs. This version modifies its +// RationalFst input (in first position). +template +void Concat(RationalFst *fst1, const Fst &fst2) { + fst1->GetMutableImpl()->AddConcat(fst2, true); +} + +// Computes the concatentation of two FSTs. This version modifies its +// RationalFst input (in second position). +template +void Concat(const Fst &fst1, RationalFst *fst2) { + fst2->GetMutableImpl()->AddConcat(fst1, false); +} + +using ConcatFstOptions = RationalFstOptions; + +// Computes the concatenation (product) of two FSTs; this version is a delayed +// FST. If FST1 transduces string x to y with weight a and FST2 transduces +// string w to v with weight b, then their concatenation transduces string xw +// to yv with Times(a, b). +// +// Complexity: +// +// Time: O(v1 + e1 + v2 + e2), +// Space: O(v1 + v2) +// +// where vi is the number of states visited, and ei is the number of arcs +// visited, of the ith FST. Constant time and space to visit an input state or +// arc is assumed and exclusive of caching. +template +class ConcatFst : public RationalFst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + ConcatFst(const Fst &fst1, const Fst &fst2) { + GetMutableImpl()->InitConcat(fst1, fst2); + } + + ConcatFst(const Fst &fst1, const Fst &fst2, + const ConcatFstOptions &opts) + : RationalFst(opts) { + GetMutableImpl()->InitConcat(fst1, fst2); + } + + // See Fst<>::Copy() for doc. + ConcatFst(const ConcatFst &fst, bool safe = false) + : RationalFst(fst, safe) {} + + // Get a copy of this ConcatFst. See Fst<>::Copy() for further doc. + ConcatFst *Copy(bool safe = false) const override { + return new ConcatFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for ConcatFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const ConcatFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for ConcatFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ConcatFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdConcatFst = ConcatFst; + +} // namespace fst + +#endif // FST_CONCAT_H_ diff --git a/projects/llm_framework/include/fst/config.h b/projects/llm_framework/include/fst/config.h new file mode 100644 index 00000000..32f2a653 --- /dev/null +++ b/projects/llm_framework/include/fst/config.h @@ -0,0 +1,3 @@ +// Windows-specific OpenFst config file +// No dynamic registration. +#define FST_NO_DYNAMIC_LINKING 1 diff --git a/projects/llm_framework/include/fst/config.h.in b/projects/llm_framework/include/fst/config.h.in new file mode 100644 index 00000000..7815dfcd --- /dev/null +++ b/projects/llm_framework/include/fst/config.h.in @@ -0,0 +1,11 @@ +// OpenFst config file + +/* Define to 1 if you have the ICU library. */ +#undef HAVE_ICU + +/* Define to 1 if the system has the type `std::tr1::hash'. */ +#define HAVE_STD__TR1__HASH_LONG_LONG_UNSIGNED_ 1 + +/* Define to 1 if the system has the type `__gnu_cxx::slist'. */ +#define HAVE___GNU_CXX__SLIST_INT_ 1 diff --git a/projects/llm_framework/include/fst/connect.h b/projects/llm_framework/include/fst/connect.h new file mode 100644 index 00000000..4c989292 --- /dev/null +++ b/projects/llm_framework/include/fst/connect.h @@ -0,0 +1,323 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions to remove unsuccessful paths from an FST. + +#ifndef FST_CONNECT_H_ +#define FST_CONNECT_H_ + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +// Finds and returns connected components. Use with Visit(). +template +class CcVisitor { + public: + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + + // cc[i]: connected component number for state i. + explicit CcVisitor(std::vector *cc) + : comps_(new UnionFind(0, kNoStateId)), cc_(cc), nstates_(0) {} + + // comps: connected components equiv classes. + explicit CcVisitor(UnionFind *comps) + : comps_(comps), cc_(nullptr), nstates_(0) {} + + ~CcVisitor() { + if (cc_) delete comps_; + } + + void InitVisit(const Fst &fst) {} + + bool InitState(StateId s, StateId root) { + ++nstates_; + if (comps_->FindSet(s) == kNoStateId) comps_->MakeSet(s); + return true; + } + + bool WhiteArc(StateId s, const Arc &arc) { + comps_->MakeSet(arc.nextstate); + comps_->Union(s, arc.nextstate); + return true; + } + + bool GreyArc(StateId s, const Arc &arc) { + comps_->Union(s, arc.nextstate); + return true; + } + + bool BlackArc(StateId s, const Arc &arc) { + comps_->Union(s, arc.nextstate); + return true; + } + + void FinishState(StateId s) {} + + void FinishVisit() { + if (cc_) GetCcVector(cc_); + } + + // Returns number of components. + // cc[i]: connected component number for state i. + int GetCcVector(std::vector *cc) { + cc->clear(); + cc->resize(nstates_, kNoStateId); + StateId ncomp = 0; + for (StateId s = 0; s < nstates_; ++s) { + const auto rep = comps_->FindSet(s); + auto &comp = (*cc)[rep]; + if (comp == kNoStateId) { + comp = ncomp; + ++ncomp; + } + (*cc)[s] = comp; + } + return ncomp; + } + + private: + UnionFind *comps_; // Components. + std::vector *cc_; // State's cc number. + StateId nstates_; // State count. +}; + +// Finds and returns strongly-connected components, accessible and +// coaccessible states and related properties. Uses Tarjan's single +// DFS SCC algorithm (see Aho, et al, "Design and Analysis of Computer +// Algorithms", 189pp). Use with DfsVisit(); +template +class SccVisitor { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // scc[i]: strongly-connected component number for state i. + // SCC numbers will be in topological order for acyclic input. + // access[i]: accessibility of state i. + // coaccess[i]: coaccessibility of state i. + // Any of above can be NULL. + // props: related property bits (cyclicity, initial cyclicity, + // accessibility, coaccessibility) set/cleared (o.w. unchanged). + SccVisitor(std::vector *scc, std::vector *access, + std::vector *coaccess, uint64 *props) + : scc_(scc), access_(access), coaccess_(coaccess), props_(props) {} + explicit SccVisitor(uint64 *props) + : scc_(nullptr), access_(nullptr), coaccess_(nullptr), props_(props) {} + + void InitVisit(const Fst &fst); + + bool InitState(StateId s, StateId root); + + bool TreeArc(StateId s, const Arc &arc) { return true; } + + bool BackArc(StateId s, const Arc &arc) { + const auto t = arc.nextstate; + if ((*dfnumber_)[t] < (*lowlink_)[s]) (*lowlink_)[s] = (*dfnumber_)[t]; + if ((*coaccess_)[t]) (*coaccess_)[s] = true; + *props_ |= kCyclic; + *props_ &= ~kAcyclic; + if (t == start_) { + *props_ |= kInitialCyclic; + *props_ &= ~kInitialAcyclic; + } + return true; + } + + bool ForwardOrCrossArc(StateId s, const Arc &arc) { + const auto t = arc.nextstate; + if ((*dfnumber_)[t] < (*dfnumber_)[s] /* cross edge */ && (*onstack_)[t] && + (*dfnumber_)[t] < (*lowlink_)[s]) { + (*lowlink_)[s] = (*dfnumber_)[t]; + } + if ((*coaccess_)[t]) (*coaccess_)[s] = true; + return true; + } + + // Last argument always ignored, but required by the interface. + void FinishState(StateId state, StateId p, const Arc *); + + void FinishVisit() { + // Numbers SCCs in topological order when acyclic. + if (scc_) { + for (size_t s = 0; s < scc_->size(); ++s) { + (*scc_)[s] = nscc_ - 1 - (*scc_)[s]; + } + } + if (coaccess_internal_) delete coaccess_; + dfnumber_.reset(); + lowlink_.reset(); + onstack_.reset(); + scc_stack_.reset(); + } + + private: + std::vector *scc_; // State's scc number. + std::vector *access_; // State's accessibility. + std::vector *coaccess_; // State's coaccessibility. + uint64 *props_; + const Fst *fst_; + StateId start_; + StateId nstates_; // State count. + StateId nscc_; // SCC count. + bool coaccess_internal_; + std::unique_ptr> dfnumber_; // State discovery times. + std::unique_ptr> + lowlink_; // lowlink[state] == dfnumber[state] => SCC root + std::unique_ptr> onstack_; // Is a state on the SCC stack? + std::unique_ptr> + scc_stack_; // SCC stack, with random access. +}; + +template +inline void SccVisitor::InitVisit(const Fst &fst) { + if (scc_) scc_->clear(); + if (access_) access_->clear(); + if (coaccess_) { + coaccess_->clear(); + coaccess_internal_ = false; + } else { + coaccess_ = new std::vector; + coaccess_internal_ = true; + } + *props_ |= kAcyclic | kInitialAcyclic | kAccessible | kCoAccessible; + *props_ &= ~(kCyclic | kInitialCyclic | kNotAccessible | kNotCoAccessible); + fst_ = &fst; + start_ = fst.Start(); + nstates_ = 0; + nscc_ = 0; + dfnumber_.reset(new std::vector()); + lowlink_.reset(new std::vector()); + onstack_.reset(new std::vector()); + scc_stack_.reset(new std::vector()); +} + +template +inline bool SccVisitor::InitState(StateId s, StateId root) { + scc_stack_->push_back(s); + if (static_cast(dfnumber_->size()) <= s) { + if (scc_) scc_->resize(s + 1, -1); + if (access_) access_->resize(s + 1, false); + coaccess_->resize(s + 1, false); + dfnumber_->resize(s + 1, -1); + lowlink_->resize(s + 1, -1); + onstack_->resize(s + 1, false); + } + (*dfnumber_)[s] = nstates_; + (*lowlink_)[s] = nstates_; + (*onstack_)[s] = true; + if (root == start_) { + if (access_) (*access_)[s] = true; + } else { + if (access_) (*access_)[s] = false; + *props_ |= kNotAccessible; + *props_ &= ~kAccessible; + } + ++nstates_; + return true; +} + +template +inline void SccVisitor::FinishState(StateId s, StateId p, const Arc *) { + if (fst_->Final(s) != Weight::Zero()) (*coaccess_)[s] = true; + if ((*dfnumber_)[s] == (*lowlink_)[s]) { // Root of new SCC. + bool scc_coaccess = false; + auto i = scc_stack_->size(); + StateId t; + do { + t = (*scc_stack_)[--i]; + if ((*coaccess_)[t]) scc_coaccess = true; + } while (s != t); + do { + t = scc_stack_->back(); + if (scc_) (*scc_)[t] = nscc_; + if (scc_coaccess) (*coaccess_)[t] = true; + (*onstack_)[t] = false; + scc_stack_->pop_back(); + } while (s != t); + if (!scc_coaccess) { + *props_ |= kNotCoAccessible; + *props_ &= ~kCoAccessible; + } + ++nscc_; + } + if (p != kNoStateId) { + if ((*coaccess_)[s]) (*coaccess_)[p] = true; + if ((*lowlink_)[s] < (*lowlink_)[p]) (*lowlink_)[p] = (*lowlink_)[s]; + } +} + +// Trims an FST, removing states and arcs that are not on successful paths. +// This version modifies its input. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(V + E) +// +// where V = # of states and E = # of arcs. +template +void Connect(MutableFst *fst) { + using StateId = typename Arc::StateId; + std::vector access; + std::vector coaccess; + uint64 props = 0; + SccVisitor scc_visitor(nullptr, &access, &coaccess, &props); + DfsVisit(*fst, &scc_visitor); + std::vector dstates; + dstates.reserve(access.size()); + for (StateId s = 0; s < access.size(); ++s) { + if (!access[s] || !coaccess[s]) dstates.push_back(s); + } + fst->DeleteStates(dstates); + fst->SetProperties(kAccessible | kCoAccessible, kAccessible | kCoAccessible); +} + +// Returns an acyclic FST where each SCC in the input FST has been condensed to +// a single state with transitions between SCCs retained and within SCCs +// dropped. Also populates 'scc' with a mapping from input to output states. +template +void Condense(const Fst &ifst, MutableFst *ofst, + std::vector *scc) { + using StateId = typename Arc::StateId; + ofst->DeleteStates(); + uint64 props = 0; + SccVisitor scc_visitor(scc, nullptr, nullptr, &props); + DfsVisit(ifst, &scc_visitor); + const auto iter = std::max_element(scc->cbegin(), scc->cend()); + if (iter == scc->cend()) return; + const StateId num_condensed_states = 1 + *iter; + ofst->ReserveStates(num_condensed_states); + for (StateId c = 0; c < num_condensed_states; ++c) { + ofst->AddState(); + } + for (StateId s = 0; s < scc->size(); ++s) { + const auto c = (*scc)[s]; + if (s == ifst.Start()) ofst->SetStart(c); + const auto weight = ifst.Final(s); + if (weight != Arc::Weight::Zero()) + ofst->SetFinal(c, Plus(ofst->Final(c), weight)); + for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + const auto nextc = (*scc)[arc.nextstate]; + if (nextc != c) { + Arc condensed_arc = arc; + condensed_arc.nextstate = nextc; + ofst->AddArc(c, std::move(condensed_arc)); + } + } + } + ofst->SetProperties(kAcyclic | kInitialAcyclic, kAcyclic | kInitialAcyclic); +} + +} // namespace fst + +#endif // FST_CONNECT_H_ diff --git a/projects/llm_framework/include/fst/const-fst.h b/projects/llm_framework/include/fst/const-fst.h new file mode 100644 index 00000000..09c81c7b --- /dev/null +++ b/projects/llm_framework/include/fst/const-fst.h @@ -0,0 +1,485 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Simple concrete immutable FST whose states and arcs are each stored in +// single arrays. + +#ifndef FST_CONST_FST_H_ +#define FST_CONST_FST_H_ + +#include +#include +#include + +// Google-only... +// ...Google-only +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +template +class ConstFst; + +template +void Cast(const F &, G *); + +namespace internal { + +// States and arcs each implemented by single arrays, templated on the +// Arc definition. Unsigned is used to represent indices into the arc array. +template +class ConstFstImpl : public FstImpl { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + + ConstFstImpl() + : states_(nullptr), + arcs_(nullptr), + narcs_(0), + nstates_(0), + start_(kNoStateId) { + string type = "const"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + SetType(type); + SetProperties(kNullProperties | kStaticProperties); + } + + explicit ConstFstImpl(const Fst &fst); + + StateId Start() const { return start_; } + + Weight Final(StateId s) const { return states_[s].weight; } + + StateId NumStates() const { return nstates_; } + + size_t NumArcs(StateId s) const { return states_[s].narcs; } + + size_t NumInputEpsilons(StateId s) const { return states_[s].niepsilons; } + + size_t NumOutputEpsilons(StateId s) const { return states_[s].noepsilons; } + + static ConstFstImpl *Read(std::istream &strm, + const FstReadOptions &opts); + + const Arc *Arcs(StateId s) const { return arcs_ + states_[s].pos; } + + // Provide information needed for generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = nstates_; + } + + // Provide information needed for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + data->base = nullptr; + data->arcs = arcs_ + states_[s].pos; + data->narcs = states_[s].narcs; + data->ref_count = nullptr; + } + + private: + // Used to find narcs_ and nstates_ in Write. + friend class ConstFst; + + // States implemented by array *states_ below, arcs by (single) *arcs_. + struct ConstState { + Weight weight; // Final weight. + Unsigned pos; // Start of state's arcs in *arcs_. + Unsigned narcs; // Number of arcs (per state). + Unsigned niepsilons; // Number of input epsilons. + Unsigned noepsilons; // Number of output epsilons. + + ConstState() : weight(Weight::Zero()) {} + }; + + // Properties always true of this FST class. + static constexpr uint64 kStaticProperties = kExpanded; + // Current unaligned file format version. The unaligned version was added and + // made the default since the aligned version does not work on pipes. + static constexpr int kFileVersion = 2; + // Current aligned file format version. + static constexpr int kAlignedFileVersion = 1; + // Minimum file format version supported. + static constexpr int kMinFileVersion = 1; + + std::unique_ptr states_region_; // Mapped file for states. + std::unique_ptr arcs_region_; // Mapped file for arcs. + ConstState *states_; // States representation. + Arc *arcs_; // Arcs representation. + size_t narcs_; // Number of arcs. + StateId nstates_; // Number of states. + StateId start_; // Initial state. + + ConstFstImpl(const ConstFstImpl &) = delete; + ConstFstImpl &operator=(const ConstFstImpl &) = delete; +}; + +template +constexpr uint64 ConstFstImpl::kStaticProperties; + +template +constexpr int ConstFstImpl::kFileVersion; + +template +constexpr int ConstFstImpl::kAlignedFileVersion; + +template +constexpr int ConstFstImpl::kMinFileVersion; + +template +ConstFstImpl::ConstFstImpl(const Fst &fst) + : narcs_(0), nstates_(0) { + string type = "const"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + SetType(type); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + start_ = fst.Start(); + // Counts states and arcs. + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates_; + narcs_ += fst.NumArcs(siter.Value()); + } + states_region_.reset(MappedFile::Allocate(nstates_ * sizeof(*states_))); + arcs_region_.reset(MappedFile::Allocate(narcs_ * sizeof(*arcs_))); + states_ = reinterpret_cast(states_region_->mutable_data()); + arcs_ = reinterpret_cast(arcs_region_->mutable_data()); + size_t pos = 0; + for (StateId s = 0; s < nstates_; ++s) { + states_[s].weight = fst.Final(s); + states_[s].pos = pos; + states_[s].narcs = 0; + states_[s].niepsilons = 0; + states_[s].noepsilons = 0; + for (ArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + ++states_[s].narcs; + if (arc.ilabel == 0) ++states_[s].niepsilons; + if (arc.olabel == 0) ++states_[s].noepsilons; + arcs_[pos] = arc; + ++pos; + } + } + const auto props = + fst.Properties(kMutable, false) + ? fst.Properties(kCopyProperties, true) + : CheckProperties( + fst, kCopyProperties & ~kWeightedCycles & ~kUnweightedCycles, + kCopyProperties); + SetProperties(props | kStaticProperties); +} + +template +ConstFstImpl *ConstFstImpl::Read( + std::istream &strm, const FstReadOptions &opts) { + using ConstState = typename ConstFstImpl::ConstState; + std::unique_ptr> impl( + new ConstFstImpl()); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; + impl->start_ = hdr.Start(); + impl->nstates_ = hdr.NumStates(); + impl->narcs_ = hdr.NumArcs(); + // Ensures compatibility. + if (hdr.Version() == kAlignedFileVersion) { + hdr.SetFlags(hdr.GetFlags() | FstHeader::IS_ALIGNED); + } + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; + return nullptr; + } + size_t b = impl->nstates_ * sizeof(ConstState); + impl->states_region_.reset( + MappedFile::Map(&strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !impl->states_region_) { + LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->states_ = + reinterpret_cast(impl->states_region_->mutable_data()); + if ((hdr.GetFlags() & FstHeader::IS_ALIGNED) && !AlignInput(strm)) { + LOG(ERROR) << "ConstFst::Read: Alignment failed: " << opts.source; + return nullptr; + } + b = impl->narcs_ * sizeof(Arc); + impl->arcs_region_.reset( + MappedFile::Map(&strm, opts.mode == FstReadOptions::MAP, opts.source, b)); + if (!strm || !impl->arcs_region_) { + LOG(ERROR) << "ConstFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->arcs_ = reinterpret_cast(impl->arcs_region_->mutable_data()); + return impl.release(); +} + +} // namespace internal + +// Simple concrete immutable FST. This class attaches interface to +// implementation and handles reference counting, delegating most methods to +// ImplToExpandedFst. The unsigned type U is used to represent indices into the +// arc array (default declared in fst-decl.h). +template +class ConstFst : public ImplToExpandedFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Impl = internal::ConstFstImpl; + using ConstState = typename Impl::ConstState; + + friend class StateIterator>; + friend class ArcIterator>; + + template + void friend Cast(const F &, G *); + + ConstFst() : ImplToExpandedFst(std::make_shared()) {} + + explicit ConstFst(const Fst &fst) + : ImplToExpandedFst(std::make_shared(fst)) {} + + ConstFst(const ConstFst &fst, bool safe = false) + : ImplToExpandedFst(fst) {} + + // Gets a copy of this ConstFst. See Fst<>::Copy() for further doc. + ConstFst *Copy(bool safe = false) const override { + return new ConstFst(*this, safe); + } + + // Reads a ConstFst from an input stream, returning nullptr on error. + static ConstFst *Read(std::istream &strm, + const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new ConstFst(std::shared_ptr(impl)) + : nullptr; + } + + // Read a ConstFst from a file; return nullptr on error; empty filename reads + // from standard input. + static ConstFst *Read(const string &filename) { + auto *impl = ImplToExpandedFst::Read(filename); + return impl ? new ConstFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return WriteFst(*this, strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + template + static bool WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->InitArcIterator(s, data); + } + + private: + explicit ConstFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + using ImplToFst>::GetImpl; + + // Uses overloading to extract the type of the argument. + static const Impl *GetImplIfConstFst(const ConstFst &const_fst) { + return const_fst.GetImpl(); + } + + // NB: this does not give privileged treatment to subtypes of ConstFst. + template + static Impl *GetImplIfConstFst(const FST &fst) { + return nullptr; + } + + ConstFst &operator=(const ConstFst &) = delete; +}; + +// Writes FST in Const format, potentially with a pass over the machine before +// writing to compute number of states and arcs. +template +template +bool ConstFst::WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts) { + const auto file_version = + opts.align ? internal::ConstFstImpl::kAlignedFileVersion + : internal::ConstFstImpl::kFileVersion; + size_t num_arcs = 0; // To silence -Wsometimes-uninitialized warnings. + size_t num_states = 0; // Ditto. + std::streamoff start_offset = 0; + bool update_header = true; + if (const auto *impl = GetImplIfConstFst(fst)) { + num_arcs = impl->narcs_; + num_states = impl->nstates_; + update_header = false; + } else if (opts.stream_write || (start_offset = strm.tellp()) == -1) { + // precompute values needed for header when we cannot seek to rewrite it. + num_arcs = 0; + num_states = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + num_arcs += fst.NumArcs(siter.Value()); + ++num_states; + } + update_header = false; + } + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(num_states); + hdr.SetNumArcs(num_arcs); + string type = "const"; + if (sizeof(Unsigned) != sizeof(uint32)) { + type += std::to_string(CHAR_BIT * sizeof(Unsigned)); + } + const auto properties = + fst.Properties(kCopyProperties, true) | + internal::ConstFstImpl::kStaticProperties; + internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, type, + properties, &hdr); + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after header"; + return false; + } + size_t pos = 0; + size_t states = 0; + typename ConstFst::ConstState state; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + state.weight = fst.Final(s); + state.pos = pos; + state.narcs = fst.NumArcs(s); + state.niepsilons = fst.NumInputEpsilons(s); + state.noepsilons = fst.NumOutputEpsilons(s); + strm.write(reinterpret_cast(&state), sizeof(state)); + pos += state.narcs; + ++states; + } + hdr.SetNumStates(states); + hdr.SetNumArcs(pos); + if (opts.align && !AlignOutput(strm)) { + LOG(ERROR) << "Could not align file during write after writing states"; + } + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + for (ArcIterator aiter(fst, siter.Value()); !aiter.Done(); + aiter.Next()) { + const auto &arc = aiter.Value(); +// Google-only... +#ifdef MEMORY_SANITIZER + // arc may contain padding which has unspecified contents. Tell MSAN to + // not complain about it when writing it to a file. + ANNOTATE_MEMORY_IS_INITIALIZED(reinterpret_cast(&arc), + sizeof(arc)); +#endif + // ...Google-only + strm.write(reinterpret_cast(&arc), sizeof(arc)); + } + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "ConstFst::WriteFst: write failed: " << opts.source; + return false; + } + if (update_header) { + return internal::FstImpl::UpdateFstHeader( + fst, strm, opts, file_version, type, properties, &hdr, start_offset); + } else { + if (hdr.NumStates() != num_states) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + if (hdr.NumArcs() != num_arcs) { + LOG(ERROR) << "Inconsistent number of arcs observed during write"; + return false; + } + } + return true; +} + +// Specialization for ConstFst; see generic version in fst.h for sample usage +// (but use the ConstFst type instead). This version should inline. +template +class StateIterator> { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const ConstFst &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + const StateId nstates_; + StateId s_; +}; + +// Specialization for ConstFst; see generic version in fst.h for sample usage +// (but use the ConstFst type instead). This version should inline. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const ConstFst &fst, StateId s) + : arcs_(fst.GetImpl()->Arcs(s)), + narcs_(fst.GetImpl()->NumArcs(s)), + i_(0) {} + + bool Done() const { return i_ >= narcs_; } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + size_t Position() const { return i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + constexpr uint32 Flags() const { return kArcValueFlags; } + + void SetFlags(uint32, uint32) {} + + private: + const Arc *arcs_; + size_t narcs_; + size_t i_; +}; + +// A useful alias when using StdArc. +using StdConstFst = ConstFst; + +} // namespace fst + +#endif // FST_CONST_FST_H_ diff --git a/projects/llm_framework/include/fst/determinize.h b/projects/llm_framework/include/fst/determinize.h new file mode 100644 index 00000000..736a140e --- /dev/null +++ b/projects/llm_framework/include/fst/determinize.h @@ -0,0 +1,1093 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to determinize an FST. + +#ifndef FST_DETERMINIZE_H_ +#define FST_DETERMINIZE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +// Common divisors are used in determinization to compute transition weights. +// In the simplest case, it is the same as semiring Plus, but other choices +// permit more efficient determinization when the output contains strings. + +// The default common divisor uses the semiring Plus. +template +struct DefaultCommonDivisor { + public: + using Weight = W; + + Weight operator()(const Weight &w1, const Weight &w2) const { + return Plus(w1, w2); + } +}; + +// The label common divisor for a (left) string semiring selects a single +// letter common prefix or the empty string. This is used in the +// determinization of output strings so that at most a single letter will +// appear in the output of a transtion. +template +struct LabelCommonDivisor { + public: + using Weight = StringWeight; + + Weight operator()(const Weight &w1, const Weight &w2) const { + typename Weight::Iterator iter1(w1); + typename Weight::Iterator iter2(w2); + if (!(StringWeight::Properties() & kLeftSemiring)) { + FSTERROR() << "LabelCommonDivisor: Weight needs to be left semiring"; + return Weight::NoWeight(); + } else if (w1.Size() == 0 || w2.Size() == 0) { + return Weight::One(); + } else if (w1 == Weight::Zero()) { + return Weight(iter2.Value()); + } else if (w2 == Weight::Zero()) { + return Weight(iter1.Value()); + } else if (iter1.Value() == iter2.Value()) { + return Weight(iter1.Value()); + } else { + return Weight::One(); + } + } +}; + +// The gallic common divisor uses the label common divisor on the string +// component and the common divisor on the weight component, which defaults to +// the default common divisor. +template > +class GallicCommonDivisor { + public: + using Weight = GallicWeight; + + Weight operator()(const Weight &w1, const Weight &w2) const { + return Weight(label_common_divisor_(w1.Value1(), w2.Value1()), + weight_common_divisor_(w1.Value2(), w2.Value2())); + } + + private: + LabelCommonDivisor label_common_divisor_; + CommonDivisor weight_common_divisor_; +}; + +// Specialization for general GALLIC weight. +template +class GallicCommonDivisor { + public: + using Weight = GallicWeight; + using GRWeight = GallicWeight; + using Iterator = + UnionWeightIterator>; + + Weight operator()(const Weight &w1, const Weight &w2) const { + auto weight = GRWeight::Zero(); + for (Iterator iter(w1); !iter.Done(); iter.Next()) { + weight = common_divisor_(weight, iter.Value()); + } + for (Iterator iter(w2); !iter.Done(); iter.Next()) { + weight = common_divisor_(weight, iter.Value()); + } + return weight == GRWeight::Zero() ? Weight::Zero() : Weight(weight); + } + + private: + GallicCommonDivisor common_divisor_; +}; + +namespace internal { + +// Represents an element in a subset +template +struct DeterminizeElement { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + DeterminizeElement(StateId s, Weight weight) + : state_id(s), weight(std::move(weight)) {} + + inline bool operator==(const DeterminizeElement &element) const { + return state_id == element.state_id && weight == element.weight; + } + + inline bool operator!=(const DeterminizeElement &element) const { + return !(*this == element); + } + + inline bool operator<(const DeterminizeElement &element) const { + return state_id < element.state_id; + } + + StateId state_id; // Input state ID. + Weight weight; // Residual weight. +}; + +// Represents a weighted subset and determinization filter state +template +struct DeterminizeStateTuple { + using Arc = A; + using Element = DeterminizeElement; + using Subset = std::forward_list; + + DeterminizeStateTuple() : filter_state(FilterState::NoState()) {} + + inline bool operator==( + const DeterminizeStateTuple &tuple) const { + return (tuple.filter_state == filter_state) && (tuple.subset == subset); + } + + inline bool operator!=( + const DeterminizeStateTuple &tuple) const { + return (tuple.filter_state != filter_state) || (tuple.subset != subset); + } + + Subset subset; + FilterState filter_state; +}; + +// Proto-transition for determinization. +template +struct DeterminizeArc { + using Arc = typename StateTuple::Arc; + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + DeterminizeArc() + : label(kNoLabel), weight(Weight::Zero()), dest_tuple(nullptr) {} + + explicit DeterminizeArc(const Arc &arc) + : label(arc.ilabel), weight(Weight::Zero()), dest_tuple(new StateTuple) {} + + Label label; // Arc label. + Weight weight; // Arc weight. + StateTuple *dest_tuple; // Destination subset and filter state. +}; + +} // namespace internal + +// Determinization filters are used to compute destination state tuples based +// on the source tuple, transition, and destination element or on similar +// super-final transition information. The filter operates on a map between a +// label and the corresponding destination state tuples. It must define the map +// type LabelMap. The default filter is used for weighted determinization. +// A determinize filter for implementing weighted determinization. +template +class DefaultDeterminizeFilter { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = CharFilterState; + using Element = internal::DeterminizeElement; + using StateTuple = internal::DeterminizeStateTuple; + using LabelMap = std::map>; + + // This is needed e.g. to go into the gallic domain for transducers. + template + struct rebind { + using Other = DefaultDeterminizeFilter; + }; + + explicit DefaultDeterminizeFilter(const Fst &fst) : fst_(fst.Copy()) {} + + // This is needed (e.g.) to go into the gallic domain for transducers. + // Ownership of the templated filter argument is given to this class. + template + DefaultDeterminizeFilter(const Fst &fst, Filter *filter) + : fst_(fst.Copy()) { + delete filter; + } + + // Copy constructor; the FST can be passed if it has been deep-copied. + DefaultDeterminizeFilter(const DefaultDeterminizeFilter &filter, + const Fst *fst = nullptr) + : fst_(fst ? fst->Copy() : filter.fst_->Copy()) {} + + FilterState Start() const { return FilterState(0); } + + // Does no work. + void SetState(StateId s, const StateTuple &tuple) {} + + // Filters transition, possibly modifying label map. Returns true if arc is + // added to the label map. + bool FilterArc(const Arc &arc, const Element &src_element, + Element &&dest_element, LabelMap *label_map) const { + // Adds element to unique state tuple for arc label. + auto &det_arc = (*label_map)[arc.ilabel]; + if (det_arc.label == kNoLabel) { + det_arc = internal::DeterminizeArc(arc); + det_arc.dest_tuple->filter_state = FilterState(0); + } + det_arc.dest_tuple->subset.push_front(std::move(dest_element)); + return true; + } + + // Filters super-final transition, returning new final weight. + Weight FilterFinal(Weight weight, const Element &element) { return weight; } + + static uint64 Properties(uint64 props) { return props; } + + private: + std::unique_ptr> fst_; +}; + +// Determinization state table interface: +// +// template +// class DeterminizeStateTable { +// public: +// using StateId = typename Arc::StateId; +// using StateTuple = internal::DeterminizeStateTuple; +// +// // Required sub-class. This is needed (e.g.) to go into the gallic domain. +// template +// struct rebind { +// using Other = DeterminizeStateTable; +// } +// +// // Required constuctor. +// DeterminizeStateTable(); +// +// // Required copy constructor that does not copy state. +// DeterminizeStateTable(const DeterminizeStateTable +// &table); +// +// // Looks up state ID by state tuple; if it doesn't exist, then adds it. +// // FindState takes ownership of the state tuple argument so that it +// // doesn't have to copy it if it creates a new state. +// StateId FindState(StateTuple *tuple); +// +// // Looks up state tuple by ID. +// const StateTuple *Tuple(StateId id) const; +// }; + +// The default determinization state table based on the compact hash bi-table. +template +class DefaultDeterminizeStateTable { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StateTuple = internal::DeterminizeStateTuple; + using Element = typename StateTuple::Element; + using Subset = typename StateTuple::Subset; + + template + struct rebind { + using Other = DefaultDeterminizeStateTable; + }; + + explicit DefaultDeterminizeStateTable(size_t table_size = 0) + : table_size_(table_size), tuples_(table_size_) {} + + DefaultDeterminizeStateTable( + const DefaultDeterminizeStateTable &table) + : table_size_(table.table_size_), tuples_(table_size_) {} + + ~DefaultDeterminizeStateTable() { + for (StateId s = 0; s < tuples_.Size(); ++s) delete tuples_.FindEntry(s); + } + + // Finds the state corresponding to a state tuple. Only creates a new state if + // the tuple is not found. FindState takes ownership of the tuple argument so + // that it doesn't have to copy it if it creates a new state. + StateId FindState(StateTuple *tuple) { + const StateId ns = tuples_.Size(); + const auto s = tuples_.FindId(tuple); + if (s != ns) delete tuple; // Tuple found. + return s; + } + + const StateTuple *Tuple(StateId s) { return tuples_.FindEntry(s); } + + private: + // Comparison object for StateTuples. + class StateTupleEqual { + public: + bool operator()(const StateTuple *tuple1, const StateTuple *tuple2) const { + return *tuple1 == *tuple2; + } + }; + + // Hash function for StateTuples. + class StateTupleKey { + public: + size_t operator()(const StateTuple *tuple) const { + size_t h = tuple->filter_state.Hash(); + for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) { + const size_t h1 = it->state_id; + static constexpr auto lshift = 5; + static constexpr auto rshift = CHAR_BIT * sizeof(size_t) - 5; + h ^= h << 1 ^ h1 << lshift ^ h1 >> rshift ^ it->weight.Hash(); + } + return h; + } + }; + + size_t table_size_; + CompactHashBiTable + tuples_; + + DefaultDeterminizeStateTable &operator=( + const DefaultDeterminizeStateTable &) = delete; +}; + +// Determinization type. +enum DeterminizeType { + // Input transducer is known to be functional (or error). + DETERMINIZE_FUNCTIONAL, // Input transducer is functional (error if not). + // Input transducer is not known to be functional. + DETERMINIZE_NONFUNCTIONAL, + // Input transducer is not known to be functional but only keep the min of + // of ambiguous outputs. + DETERMINIZE_DISAMBIGUATE +}; + +// Options for finite-state transducer determinization templated on the arc +// type, common divisor, the determinization filter and the state table. +// DeterminizeFst takes ownership of the determinization filter and state table, +// if provided. +template , + class Filter = DefaultDeterminizeFilter, + class StateTable = + DefaultDeterminizeStateTable> +struct DeterminizeFstOptions : public CacheOptions { + using Label = typename Arc::Label; + + float delta; // Quantization delta for subset weights. + Label subsequential_label; // Label used for residual final output + // when producing subsequential transducers. + DeterminizeType type; // Determinization type. + bool increment_subsequential_label; // When creating several subsequential + // arcs at a given state, make their + // label distinct by incrementing. + Filter *filter; // Determinization filter; + // DeterminizeFst takes ownership. + StateTable *state_table; // Determinization state table; + // DeterminizeFst takes ownership. + + explicit DeterminizeFstOptions(const CacheOptions &opts, float delta = kDelta, + Label subsequential_label = 0, + DeterminizeType type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : CacheOptions(opts), + delta(delta), + subsequential_label(subsequential_label), + type(type), + increment_subsequential_label(increment_subsequential_label), + filter(filter), + state_table(state_table) {} + + explicit DeterminizeFstOptions(float delta = kDelta, + Label subsequential_label = 0, + DeterminizeType type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : delta(delta), + subsequential_label(subsequential_label), + type(type), + increment_subsequential_label(increment_subsequential_label), + filter(filter), + state_table(state_table) {} +}; + +namespace internal { + +// Implementation of delayed DeterminizeFst. This base class is +// common to the variants that implement acceptor and transducer +// determinization. +template +class DeterminizeFstImplBase : public CacheImpl { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + template + DeterminizeFstImplBase( + const Fst &fst, + const DeterminizeFstOptions &opts) + : CacheImpl(opts), fst_(fst.Copy()) { + SetType("determinize"); + const auto iprops = fst.Properties(kFstProperties, false); + const auto dprops = + DeterminizeProperties(iprops, opts.subsequential_label != 0, + opts.type == DETERMINIZE_NONFUNCTIONAL + ? opts.increment_subsequential_label + : true); + SetProperties(Filter::Properties(dprops), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + DeterminizeFstImplBase(const DeterminizeFstImplBase &impl) + : CacheImpl(impl), fst_(impl.fst_->Copy(true)) { + SetType("determinize"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + virtual DeterminizeFstImplBase *Copy() const = 0; + + StateId Start() { + if (!HasStart()) { + const auto start = ComputeStart(); + if (start != kNoStateId) SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) SetFinal(s, ComputeFinal(s)); + return CacheImpl::Final(s); + } + + virtual void Expand(StateId s) = 0; + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + virtual StateId ComputeStart() = 0; + + virtual Weight ComputeFinal(StateId s) = 0; + + const Fst &GetFst() const { return *fst_; } + + private: + std::unique_ptr> fst_; // Input FST. +}; + +// Implementation of delayed determinization for weighted acceptors. +template +class DeterminizeFsaImpl : public DeterminizeFstImplBase { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = typename Filter::FilterState; + using StateTuple = internal::DeterminizeStateTuple; + using Element = typename StateTuple::Element; + using Subset = typename StateTuple::Subset; + using LabelMap = typename Filter::LabelMap; + + using FstImpl::SetProperties; + using DeterminizeFstImplBase::GetFst; + using DeterminizeFstImplBase::SetArcs; + + DeterminizeFsaImpl( + const Fst &fst, const std::vector *in_dist, + std::vector *out_dist, + const DeterminizeFstOptions &opts) + : DeterminizeFstImplBase(fst, opts), + delta_(opts.delta), + in_dist_(in_dist), + out_dist_(out_dist), + filter_(opts.filter ? opts.filter : new Filter(fst)), + state_table_(opts.state_table ? opts.state_table : new StateTable()) { + if (!fst.Properties(kAcceptor, true)) { + FSTERROR() << "DeterminizeFst: Argument not an acceptor"; + SetProperties(kError, kError); + } + if (!(Weight::Properties() & kLeftSemiring)) { + FSTERROR() << "DeterminizeFst: Weight must be left distributive: " + << Weight::Type(); + SetProperties(kError, kError); + } + if (out_dist_) out_dist_->clear(); + } + + DeterminizeFsaImpl( + const DeterminizeFsaImpl &impl) + : DeterminizeFstImplBase(impl), + delta_(impl.delta_), + in_dist_(nullptr), + out_dist_(nullptr), + filter_(new Filter(*impl.filter_, &GetFst())), + state_table_(new StateTable(*impl.state_table_)) { + if (impl.out_dist_) { + FSTERROR() << "DeterminizeFsaImpl: Cannot copy with out_dist vector"; + SetProperties(kError, kError); + } + } + + DeterminizeFsaImpl *Copy() + const override { + return new DeterminizeFsaImpl( + *this); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (GetFst().Properties(kError, false))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + StateId ComputeStart() override { + const auto s = GetFst().Start(); + if (s == kNoStateId) return kNoStateId; + auto *tuple = new StateTuple; + tuple->subset.emplace_front(s, Weight::One()); + tuple->filter_state = filter_->Start(); + return FindState(tuple); + } + + Weight ComputeFinal(StateId s) override { + const auto *tuple = state_table_->Tuple(s); + filter_->SetState(s, *tuple); + auto final_weight = Weight::Zero(); + for (auto it = tuple->subset.begin(); it != tuple->subset.end(); ++it) { + const auto &element = *it; + final_weight = + Plus(final_weight, + Times(element.weight, GetFst().Final(element.state_id))); + final_weight = filter_->FilterFinal(final_weight, element); + if (!final_weight.Member()) SetProperties(kError, kError); + } + return final_weight; + } + + StateId FindState(StateTuple *tuple) { + const auto s = state_table_->FindState(tuple); + if (in_dist_ && out_dist_->size() <= s) { + out_dist_->push_back(ComputeDistance(tuple->subset)); + } + return s; + } + + // Computes distance from a state to the final states in the DFA given the + // distances in the NFA. + Weight ComputeDistance(const Subset &subset) { + auto outd = Weight::Zero(); + for (auto it = subset.begin(); it != subset.end(); ++it) { + const auto &element = *it; + const auto ind = + (element.state_id < in_dist_->size() ? (*in_dist_)[element.state_id] + : Weight::Zero()); + outd = Plus(outd, Times(element.weight, ind)); + } + return outd; + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) override { + LabelMap label_map; + GetLabelMap(s, &label_map); + for (auto it = label_map.begin(); it != label_map.end(); ++it) { + AddArc(s, std::move(it->second)); + } + SetArcs(s); + } + + private: + using DetArc = internal::DeterminizeArc; + + // Constructs proto-determinization transition, including destination subset, + // per label. + void GetLabelMap(StateId s, LabelMap *label_map) { + const auto *src_tuple = state_table_->Tuple(s); + filter_->SetState(s, *src_tuple); + for (auto it = src_tuple->subset.begin(); it != src_tuple->subset.end(); + ++it) { + const auto &src_element = *it; + for (ArcIterator> aiter(GetFst(), src_element.state_id); + !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + Element dest_element(arc.nextstate, + Times(src_element.weight, arc.weight)); + filter_->FilterArc(arc, src_element, std::move(dest_element), + label_map); + } + } + for (auto it = label_map->begin(); it != label_map->end(); ++it) { + NormArc(&it->second); + } + } + + // Sorts subsets and removes duplicate elements, normalizing transition and + // subset weights. + void NormArc(DetArc *det_arc) { + auto *dest_tuple = det_arc->dest_tuple; + dest_tuple->subset.sort(); + auto piter = dest_tuple->subset.begin(); + for (auto diter = dest_tuple->subset.begin(); + diter != dest_tuple->subset.end();) { + auto &dest_element = *diter; + auto &prev_element = *piter; + // Computes arc weight. + det_arc->weight = common_divisor_(det_arc->weight, dest_element.weight); + if (piter != diter && dest_element.state_id == prev_element.state_id) { + // Found duplicate state: sums state weight and deletes duplicate. + prev_element.weight = Plus(prev_element.weight, dest_element.weight); + if (!prev_element.weight.Member()) SetProperties(kError, kError); + ++diter; + dest_tuple->subset.erase_after(piter); + } else { + piter = diter; + ++diter; + } + } + // Divides out label weight from destination subset elements, quantizing to + // ensure comparisons are effective. + for (auto diter = dest_tuple->subset.begin(); + diter != dest_tuple->subset.end(); ++diter) { + auto &dest_element = *diter; + dest_element.weight = + Divide(dest_element.weight, det_arc->weight, DIVIDE_LEFT); + dest_element.weight = dest_element.weight.Quantize(delta_); + } + } + + // Adds an arc from state S to the destination state associated with state + // tuple in det_arc as created by GetLabelMap. + void AddArc(StateId s, DetArc &&det_arc) { + CacheImpl::EmplaceArc( + s, det_arc.label, det_arc.label, std::move(det_arc.weight), + FindState(det_arc.dest_tuple)); + } + + float delta_; // Quantization delta for weights. + const std::vector *in_dist_; // Distance to final NFA states. + std::vector *out_dist_; // Distance to final DFA states. + + // FIXME(kbg): Ought to be static const? + CommonDivisor common_divisor_; + std::unique_ptr filter_; + std::unique_ptr state_table_; +}; + +// Implementation of delayed determinization for transducers. Transducer +// determinization is implemented by mapping the input to the Gallic semiring as +// an acceptor whose weights contain the output strings and using acceptor +// determinization above to determinize that acceptor. +template +class DeterminizeFstImpl : public DeterminizeFstImplBase { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ToMapper = ToGallicMapper; + using ToArc = typename ToMapper::ToArc; + using ToFst = ArcMapFst; + using FromMapper = FromGallicMapper; + using FromFst = ArcMapFst; + + using ToCommonDivisor = GallicCommonDivisor; + using ToFilter = typename Filter::template rebind::Other; + using ToFilterState = typename ToFilter::FilterState; + using ToStateTable = + typename StateTable::template rebind::Other; + using FactorIterator = GallicFactor; + + using FstImpl::SetProperties; + using DeterminizeFstImplBase::GetFst; + using CacheBaseImpl>::GetCacheGc; + using CacheBaseImpl>::GetCacheLimit; + + DeterminizeFstImpl( + const Fst &fst, + const DeterminizeFstOptions &opts) + : DeterminizeFstImplBase(fst, opts), + delta_(opts.delta), + subsequential_label_(opts.subsequential_label), + increment_subsequential_label_(opts.increment_subsequential_label) { + if (opts.state_table) { + FSTERROR() << "DeterminizeFst: " + << "A state table can not be passed with transducer input"; + SetProperties(kError, kError); + return; + } + Init(GetFst(), opts.filter); + } + + DeterminizeFstImpl( + const DeterminizeFstImpl &impl) + : DeterminizeFstImplBase(impl), + delta_(impl.delta_), + subsequential_label_(impl.subsequential_label_), + increment_subsequential_label_(impl.increment_subsequential_label_) { + Init(GetFst(), nullptr); + } + + DeterminizeFstImpl *Copy() + const override { + return new DeterminizeFstImpl( + *this); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (GetFst().Properties(kError, false) || + from_fst_->Properties(kError, false))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + StateId ComputeStart() override { return from_fst_->Start(); } + + Weight ComputeFinal(StateId s) override { return from_fst_->Final(s); } + + void Expand(StateId s) override { + for (ArcIterator aiter(*from_fst_, s); !aiter.Done(); + aiter.Next()) { + CacheImpl::PushArc(s, aiter.Value()); + } + CacheImpl::SetArcs(s); + } + + private: + // Initialization of transducer determinization implementation, which is + // defined after DeterminizeFst since it calls it. + void Init(const Fst &fst, Filter *filter); + + float delta_; + Label subsequential_label_; + bool increment_subsequential_label_; + std::unique_ptr from_fst_; +}; + +} // namespace internal + +// Determinizes a weighted transducer. This version is a delayed +// FST. The result will be an equivalent FST that has the property +// that no state has two transitions with the same input label. +// For this algorithm, epsilon transitions are treated as regular +// symbols (cf. RmEpsilon). +// +// The transducer must be functional. The weights must be (weakly) left +// divisible (valid for TropicalWeight and LogWeight for instance) and be +// zero-sum-free if for all a, b: (Plus(a, b) == 0) => a = b = 0. +// +// Complexity: +// +// Determinizable: exponential (polynomial in the size of the output). +// Non-determinizable: does not terminate. +// +// The determinizable automata include all unweighted and all acyclic input. +// +// For more information, see: +// +// Mohri, M. 1997. Finite-state transducers in language and speech processing. +// Computational Linguistics 23(2): 269-311. +// +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToFst. +template +class DeterminizeFst : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::DeterminizeFstImplBase; + + friend class ArcIterator>; + friend class StateIterator>; + + template + friend class DeterminizeFstImpl; + + explicit DeterminizeFst(const Fst &fst) + : ImplToFst(CreateImpl(fst)) {} + + template + DeterminizeFst( + const Fst &fst, + const DeterminizeFstOptions + &opts = + DeterminizeFstOptions()) + : ImplToFst(CreateImpl(fst, opts)) {} + + // This acceptor-only version additionally computes the distance to final + // states in the output if provided with those distances for the input; this + // is useful for e.g., computing the k-shortest unique paths. + template + DeterminizeFst( + const Fst &fst, const std::vector *in_dist, + std::vector *out_dist, + const DeterminizeFstOptions + &opts = + DeterminizeFstOptions()) + : ImplToFst( + std::make_shared>( + fst, in_dist, out_dist, opts)) { + if (!fst.Properties(kAcceptor, true)) { + FSTERROR() << "DeterminizeFst: " + << "Distance to final states computed for acceptors only"; + GetMutableImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + DeterminizeFst(const DeterminizeFst &fst, bool safe = false) + : ImplToFst(safe ? std::shared_ptr(fst.GetImpl()->Copy()) + : fst.GetSharedImpl()) {} + + // Get a copy of this DeterminizeFst. See Fst<>::Copy() for further doc. + DeterminizeFst *Copy(bool safe = false) const override { + return new DeterminizeFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + static std::shared_ptr CreateImpl(const Fst &fst) { + using D = DefaultCommonDivisor; + using F = DefaultDeterminizeFilter; + using T = DefaultDeterminizeStateTable; + const DeterminizeFstOptions opts; + return CreateImpl(fst, opts); + } + + template + static std::shared_ptr CreateImpl( + const Fst &fst, + const DeterminizeFstOptions + &opts) { + if (fst.Properties(kAcceptor, true)) { + // Calls implementation for acceptors. + return std::make_shared< + internal::DeterminizeFsaImpl>( + fst, nullptr, nullptr, opts); + } else if (opts.type == DETERMINIZE_DISAMBIGUATE) { + auto rv = std::make_shared>(fst, opts); + if (!(Weight::Properties() & kPath)) { + FSTERROR() << "DeterminizeFst: Weight needs to have the " + << "path property to disambiguate output: " + << Weight::Type(); + rv->SetProperties(kError, kError); + } + // Calls disambiguating implementation for non-functional transducers. + return rv; + } else if (opts.type == DETERMINIZE_FUNCTIONAL) { + // Calls implementation for functional transducers. + return std::make_shared>(fst, opts); + } else { // opts.type == DETERMINIZE_NONFUNCTIONAL + // Calls implementation for non functional transducers; + return std::make_shared>(fst, opts); + } + } + + DeterminizeFst &operator=(const DeterminizeFst &) = delete; +}; + +namespace internal { + +// Initialization of transducer determinization implementation, which is defined +// after DeterminizeFst since it calls it. +template +void DeterminizeFstImpl::Init(const Fst &fst, Filter *filter) { + // Mapper to an acceptor. + const ToFst to_fst(fst, ToMapper()); + auto *to_filter = filter ? new ToFilter(to_fst, filter) : nullptr; + // This recursive call terminates since it is to a (non-recursive) + // different constructor. + const CacheOptions copts(GetCacheGc(), GetCacheLimit()); + const DeterminizeFstOptions + dopts(copts, delta_, 0, DETERMINIZE_FUNCTIONAL, false, to_filter); + // Uses acceptor-only constructor to avoid template recursion. + const DeterminizeFst det_fsa(to_fst, nullptr, nullptr, dopts); + // Mapper back to transducer. + const FactorWeightOptions fopts( + CacheOptions(true, 0), delta_, kFactorFinalWeights, subsequential_label_, + subsequential_label_, increment_subsequential_label_, + increment_subsequential_label_); + const FactorWeightFst factored_fst(det_fsa, fopts); + from_fst_.reset(new FromFst(factored_fst, FromMapper(subsequential_label_))); +} + +} // namespace internal + +// Specialization for DeterminizeFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const DeterminizeFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for DeterminizeFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const DeterminizeFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void DeterminizeFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Useful aliases when using StdArc. +using StdDeterminizeFst = DeterminizeFst; + +template +struct DeterminizeOptions { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + float delta; // Quantization delta for subset weights. + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + Label subsequential_label; // Label used for residual final output. + DeterminizeType type; + bool increment_subsequential_label; // When creating several subsequential + // arcs at a given state, make their + // label distinct by incrementation? + + explicit DeterminizeOptions(float delta = kDelta, + Weight weight_threshold = Weight::Zero(), + StateId state_threshold = kNoStateId, + Label subsequential_label = 0, + DeterminizeType type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false) + : delta(delta), + weight_threshold(std::move(weight_threshold)), + state_threshold(state_threshold), + subsequential_label(subsequential_label), + type(type), + increment_subsequential_label(increment_subsequential_label) {} +}; + +// Determinizes a weighted transducer. This version writes the +// determinized Fst to an output MutableFst. The result will be an +// equivalent FST that has the property that no state has two +// transitions with the same input label. For this algorithm, epsilon +// transitions are treated as regular symbols (cf. RmEpsilon). +// +// The transducer must be functional. The weights must be (weakly) +// left divisible (valid for TropicalWeight and LogWeight). +// +// Complexity: +// +// Determinizable: exponential (polynomial in the size of the output) +// Non-determinizable: does not terminate +// +// The determinizable automata include all unweighted and all acyclic input. +template +void Determinize( + const Fst &ifst, MutableFst *ofst, + const DeterminizeOptions &opts = DeterminizeOptions()) { + using Weight = typename Arc::Weight; + DeterminizeFstOptions nopts; + nopts.delta = opts.delta; + nopts.subsequential_label = opts.subsequential_label; + nopts.type = opts.type; + nopts.increment_subsequential_label = opts.increment_subsequential_label; + nopts.gc_limit = 0; // Caches only the last state for fastest copy. + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + if (ifst.Properties(kAcceptor, false)) { + std::vector idistance; + std::vector odistance; + ShortestDistance(ifst, &idistance, true); + DeterminizeFst dfst(ifst, &idistance, &odistance, nopts); + PruneOptions> popts( + opts.weight_threshold, opts.state_threshold, AnyArcFilter(), + &odistance); + Prune(dfst, ofst, popts); + } else { + *ofst = DeterminizeFst(ifst, nopts); + Prune(ofst, opts.weight_threshold, opts.state_threshold); + } + } else { + *ofst = DeterminizeFst(ifst, nopts); + } +} + +} // namespace fst + +#endif // FST_DETERMINIZE_H_ diff --git a/projects/llm_framework/include/fst/dfs-visit.h b/projects/llm_framework/include/fst/dfs-visit.h new file mode 100644 index 00000000..a7b18a6c --- /dev/null +++ b/projects/llm_framework/include/fst/dfs-visit.h @@ -0,0 +1,202 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Depth-first search visitation. See visit.h for more general search queue +// disciplines. + +#ifndef FST_DFS_VISIT_H_ +#define FST_DFS_VISIT_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Visitor Interface: class determining actions taken during a depth-first +// search-style visit. If any of the boolean member functions return false, the +// DFS is aborted by first calling FinishState() on all currently grey states +// and then calling FinishVisit(). +// +// This is similar to the more general visitor interface in visit.h, except +// that FinishState returns additional information appropriate only for a DFS +// and some methods names here are better suited to a DFS. +// +// template +// class Visitor { +// public: +// using StateId = typename Arc::StateId; +// +// Visitor(T *return_data); +// +// // Invoked before DFS visit. +// void InitVisit(const Fst &fst); +// +// // Invoked when state discovered (2nd arg is DFS tree root). +// bool InitState(StateId s, StateId root); +// +// // Invoked when tree arc to white/undiscovered state examined. +// bool TreeArc(StateId s, const Arc &arc); +// +// // Invoked when back arc to grey/unfinished state examined. +// bool BackArc(StateId s, const Arc &arc); +// +// // Invoked when forward or cross arc to black/finished state examined. +// bool ForwardOrCrossArc(StateId s, const Arc &arc); +// +// // Invoked when state finished ('s' is tree root, 'parent' is kNoStateId, +// // and 'arc' is nullptr). +// void FinishState(StateId s, StateId parent, const Arc *arc); +// +// // Invoked after DFS visit. +// void FinishVisit(); +// }; + +namespace internal { + +// An FST state's DFS stack state. +template +struct DfsState { + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + DfsState(const FST &fst, StateId s) : state_id(s), arc_iter(fst, s) {} + + void *operator new(size_t size, MemoryPool> *pool) { + return pool->Allocate(); + } + + static void Destroy(DfsState *dfs_state, + MemoryPool> *pool) { + if (dfs_state) { + dfs_state->~DfsState(); + pool->Free(dfs_state); + } + } + + StateId state_id; // FST state. + ArcIterator arc_iter; // The corresponding arcs. +}; + +} // namespace internal + +// Performs depth-first visitation. Visitor class argument determines actions +// and contains any return data. ArcFilter determines arcs that are considered. +// If 'access_only' is true, performs visitation only to states accessible from +// the initial state. +// +// Note this is similar to Visit() in visit.h called with a LIFO queue, except +// this version has a Visitor class specialized and augmented for a DFS. +template +void DfsVisit(const FST &fst, Visitor *visitor, ArcFilter filter, + bool access_only = false) { + visitor->InitVisit(fst); + const auto start = fst.Start(); + if (start == kNoStateId) { + visitor->FinishVisit(); + return; + } + // An FST state's DFS status + static constexpr uint8 kDfsWhite = 0; // Undiscovered. + static constexpr uint8 kDfsGrey = 1; // Discovered but unfinished. + static constexpr uint8 kDfsBlack = 2; // Finished. + std::vector state_color; + std::stack *> state_stack; // DFS execution stack. + MemoryPool> state_pool; // Pool for DFSStates. + auto nstates = start + 1; // Number of known states in general case. + bool expanded = false; + if (fst.Properties(kExpanded, false)) { // Tests if expanded case, then + nstates = CountStates(fst); // uses ExpandedFst::NumStates(). + expanded = true; + } + state_color.resize(nstates, kDfsWhite); + StateIterator siter(fst); + // Continue DFS while true. + bool dfs = true; + // Iterate over trees in DFS forest. + for (auto root = start; dfs && root < nstates;) { + state_color[root] = kDfsGrey; + state_stack.push(new (&state_pool) internal::DfsState(fst, root)); + dfs = visitor->InitState(root, root); + while (!state_stack.empty()) { + auto *dfs_state = state_stack.top(); + const auto s = dfs_state->state_id; + if (s >= static_cast(state_color.size())) { + nstates = s + 1; + state_color.resize(nstates, kDfsWhite); + } + ArcIterator &aiter = dfs_state->arc_iter; + if (!dfs || aiter.Done()) { + state_color[s] = kDfsBlack; + internal::DfsState::Destroy(dfs_state, &state_pool); + state_stack.pop(); + if (!state_stack.empty()) { + auto *parent_state = state_stack.top(); + auto &piter = parent_state->arc_iter; + visitor->FinishState(s, parent_state->state_id, &piter.Value()); + piter.Next(); + } else { + visitor->FinishState(s, kNoStateId, nullptr); + } + continue; + } + const auto &arc = aiter.Value(); + if (arc.nextstate >= state_color.size()) { + nstates = arc.nextstate + 1; + state_color.resize(nstates, kDfsWhite); + } + if (!filter(arc)) { + aiter.Next(); + continue; + } + const auto next_color = state_color[arc.nextstate]; + switch (next_color) { + default: + case kDfsWhite: + dfs = visitor->TreeArc(s, arc); + if (!dfs) break; + state_color[arc.nextstate] = kDfsGrey; + state_stack.push(new (&state_pool) + internal::DfsState(fst, arc.nextstate)); + dfs = visitor->InitState(arc.nextstate, root); + break; + case kDfsGrey: + dfs = visitor->BackArc(s, arc); + aiter.Next(); + break; + case kDfsBlack: + dfs = visitor->ForwardOrCrossArc(s, arc); + aiter.Next(); + break; + } + } + if (access_only) break; + // Finds next tree root. + for (root = root == start ? 0 : root + 1; + root < nstates && state_color[root] != kDfsWhite; ++root) { + } + // Checks for a state beyond the largest known state. + if (!expanded && root == nstates) { + for (; !siter.Done(); siter.Next()) { + if (siter.Value() == nstates) { + ++nstates; + state_color.push_back(kDfsWhite); + break; + } + } + } + } + visitor->FinishVisit(); +} + +template +void DfsVisit(const Fst &fst, Visitor *visitor) { + DfsVisit(fst, visitor, AnyArcFilter()); +} + +} // namespace fst + +#endif // FST_DFS_VISIT_H_ diff --git a/projects/llm_framework/include/fst/difference.h b/projects/llm_framework/include/fst/difference.h new file mode 100644 index 00000000..f073b0e1 --- /dev/null +++ b/projects/llm_framework/include/fst/difference.h @@ -0,0 +1,205 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to compute the difference between two FSAs. + +#ifndef FST_DIFFERENCE_H_ +#define FST_DIFFERENCE_H_ + +#include + + +#include +#include +#include + + +namespace fst { + +template >, + class Filter = SequenceComposeFilter, + class StateTable = + GenericComposeStateTable> +struct DifferenceFstOptions + : public ComposeFstOptions { + explicit DifferenceFstOptions(const CacheOptions &opts = CacheOptions(), + M *matcher1 = nullptr, M *matcher2 = nullptr, + Filter *filter = nullptr, + StateTable *state_table = nullptr) + : ComposeFstOptions(opts, matcher1, matcher2, + filter, state_table) {} +}; + +// Computes the difference between two FSAs. This version is a delayed FST. +// Only strings that are in the first automaton but not in second are retained +// in the result. +// +// The first argument must be an acceptor; the second argument must be an +// unweighted, epsilon-free, deterministic acceptor. One of the arguments must +// be label-sorted. +// +// Complexity: same as ComposeFst. +// +// Caveats: same as ComposeFst. +template +class DifferenceFst : public ComposeFst { + public: + using Arc = A; + using Weight = typename Arc::Weight; + using StateId = typename Arc::StateId; + + using ComposeFst::CreateBase1; + + // A - B = A ^ B'. + DifferenceFst(const Fst &fst1, const Fst &fst2, + const CacheOptions &opts = CacheOptions()) + : ComposeFst(CreateDifferenceImplWithCacheOpts(fst1, fst2, opts)) { + if (!fst1.Properties(kAcceptor, true)) { + FSTERROR() << "DifferenceFst: 1st argument not an acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + template + DifferenceFst( + const Fst &fst1, const Fst &fst2, + const DifferenceFstOptions &opts) + : ComposeFst( + CreateDifferenceImplWithDifferenceOpts(fst1, fst2, opts)) { + if (!fst1.Properties(kAcceptor, true)) { + FSTERROR() << "DifferenceFst: 1st argument not an acceptor"; + GetImpl()->SetProperties(kError, kError); + } + } + + // See Fst<>::Copy() for doc. + DifferenceFst(const DifferenceFst &fst, bool safe = false) + : ComposeFst(fst, safe) {} + + // Get a copy of this DifferenceFst. See Fst<>::Copy() for further doc. + DifferenceFst *Copy(bool safe = false) const override { + return new DifferenceFst(*this, safe); + } + + private: + using Impl = internal::ComposeFstImplBase; + using ImplToFst::GetImpl; + + static std::shared_ptr CreateDifferenceImplWithCacheOpts( + const Fst &fst1, const Fst &fst2, const CacheOptions &opts) { + using RM = RhoMatcher>>; + ComplementFst cfst(fst2); + ComposeFstOptions copts( + CacheOptions(), new RM(fst1, MATCH_NONE), + new RM(cfst, MATCH_INPUT, ComplementFst::kRhoLabel)); + return CreateBase1(fst1, cfst, copts); + } + + template + static std::shared_ptr CreateDifferenceImplWithDifferenceOpts( + const Fst &fst1, const Fst &fst2, + const DifferenceFstOptions &opts) { + using RM = RhoMatcher; + ComplementFst cfst(fst2); + ComposeFstOptions copts(opts); + copts.matcher1 = new RM(fst1, MATCH_NONE, kNoLabel, MATCHER_REWRITE_ALWAYS, + opts.matcher1); + copts.matcher2 = new RM(cfst, MATCH_INPUT, ComplementFst::kRhoLabel, + MATCHER_REWRITE_ALWAYS, opts.matcher2); + return CreateBase1(fst1, cfst, copts); + } +}; + +// Specialization for DifferenceFst. +template +class StateIterator> + : public StateIterator> { + public: + explicit StateIterator(const DifferenceFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for DifferenceFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const DifferenceFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +using DifferenceOptions = ComposeOptions; + +// Useful alias when using StdArc. +using StdDifferenceFst = DifferenceFst; + +using DifferenceOptions = ComposeOptions; + +// Computes the difference between two FSAs. This version writes the difference +// to an output MutableFst. Only strings that are in the first automaton but not +// in the second are retained in the result. +// +// The first argument must be an acceptor; the second argument must be an +// unweighted, epsilon-free, deterministic acceptor. One of the arguments must +// be label-sorted. +// +// Complexity: same as Compose. +// +// Caveats: same as Compose. +template +void Difference(const Fst &ifst1, const Fst &ifst2, + MutableFst *ofst, + const DifferenceOptions &opts = DifferenceOptions()) { + using M = Matcher>; + // In each case, we cache only the last state for fastest copy. + switch (opts.filter_type) { + case AUTO_FILTER: { + CacheOptions nopts; + nopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, nopts); + break; + } + case SEQUENCE_FILTER: { + DifferenceFstOptions dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case ALT_SEQUENCE_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case MATCH_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case NO_MATCH_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case NULL_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + case TRIVIAL_FILTER: { + DifferenceFstOptions> dopts; + dopts.gc_limit = 0; + *ofst = DifferenceFst(ifst1, ifst2, dopts); + break; + } + } + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_DIFFERENCE_H_ diff --git a/projects/llm_framework/include/fst/disambiguate.h b/projects/llm_framework/include/fst/disambiguate.h new file mode 100644 index 00000000..7e3ba37c --- /dev/null +++ b/projects/llm_framework/include/fst/disambiguate.h @@ -0,0 +1,564 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to disambiguate an FST. + +#ifndef FST_DISAMBIGUATE_H_ +#define FST_DISAMBIGUATE_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +template +struct DisambiguateOptions : public DeterminizeOptions { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit DisambiguateOptions(float delta = kDelta, + Weight weight = Weight::Zero(), + StateId n = kNoStateId, Label label = 0) + : DeterminizeOptions(delta, std::move(weight), n, label, + DETERMINIZE_FUNCTIONAL) {} +}; + +namespace internal { + +// A determinization filter based on a subset element relation. The relation is +// assumed to be reflexive and symmetric. +template +class RelationDeterminizeFilter { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FilterState = IntegerFilterState; + using StateTuple = DeterminizeStateTuple; + using Subset = typename StateTuple::Subset; + using Element = typename StateTuple::Element; + using LabelMap = std::multimap>; + + // This is needed (e.g.) to go into the gallic domain for transducers; there + // is no need to rebind the relation since its use here only depends on the + // state IDs. + template + struct rebind { + using Other = RelationDeterminizeFilter; + }; + + explicit RelationDeterminizeFilter(const Fst &fst) + : fst_(fst.Copy()), r_(new Relation()), s_(kNoStateId), head_(nullptr) {} + + // Ownership of the relation is given to this class. + RelationDeterminizeFilter(const Fst &fst, Relation *r) + : fst_(fst.Copy()), r_(r), s_(kNoStateId), head_(0) {} + + // Ownership of the relation is given to this class. + RelationDeterminizeFilter(const Fst &fst, Relation *r, + std::vector *head) + : fst_(fst.Copy()), r_(r), s_(kNoStateId), head_(head) {} + + // This is needed, e.g., to go into the gallic domain for transducers. + // Ownership of the templated filter argument is given to this class. + template + RelationDeterminizeFilter(const Fst &fst, Filter *filter) + : fst_(fst.Copy()), + r_(new Relation(filter->GetRelation())), + s_(kNoStateId), + head_(filter->GetHeadStates()) { + delete filter; + } + + // Copy constructor; the FST can be passed if it has been deep-copied. + RelationDeterminizeFilter(const RelationDeterminizeFilter &filter, + const Fst *fst = nullptr) + : fst_(fst ? fst->Copy() : filter.fst_->Copy()), + r_(new Relation(*filter.r_)), + s_(kNoStateId), + head_() {} + + FilterState Start() const { return FilterState(fst_->Start()); } + + void SetState(StateId s, const StateTuple &tuple) { + if (s_ != s) { + s_ = s; + tuple_ = &tuple; + const auto head = tuple.filter_state.GetState(); + is_final_ = fst_->Final(head) != Weight::Zero(); + if (head_) { + if (head_->size() <= s) head_->resize(s + 1, kNoStateId); + (*head_)[s] = head; + } + } + } + + // Filters transition, possibly modifying label map. Returns true if arc is + // added to label map. + bool FilterArc(const Arc &arc, const Element &src_element, + const Element &dest_element, LabelMap *label_map) const; + + // Filters super-final transition, returning new final weight. + Weight FilterFinal(const Weight final_weight, const Element &element) const { + return is_final_ ? final_weight : Weight::Zero(); + } + + static uint64 Properties(uint64 props) { + return props & ~(kIDeterministic | kODeterministic); + } + + const Relation &GetRelation() { return *r_; } + + std::vector *GetHeadStates() { return head_; } + + private: + // Pairs arc labels with state tuples with possible heads and empty subsets. + void InitLabelMap(LabelMap *label_map) const; + + std::unique_ptr> fst_; // Input FST. + std::unique_ptr r_; // Relation compatible with inv. trans. fnc. + StateId s_; // Current state. + const StateTuple *tuple_; // Current tuple. + bool is_final_; // Is the current head state final? + std::vector *head_; // Head state for a given state, + // owned by the Disambiguator. +}; + +template +bool RelationDeterminizeFilter::FilterArc( + const Arc &arc, const Element &src_element, const Element &dest_element, + LabelMap *label_map) const { + bool added = false; + if (label_map->empty()) InitLabelMap(label_map); + // Adds element to state tuple if element state is related to tuple head. + for (auto liter = label_map->lower_bound(arc.ilabel); + liter != label_map->end() && liter->first == arc.ilabel; ++liter) { + auto *dest_tuple = liter->second.dest_tuple; + const auto dest_head = dest_tuple->filter_state.GetState(); + if ((*r_)(dest_element.state_id, dest_head)) { + dest_tuple->subset.push_front(dest_element); + added = true; + } + } + return added; +} + +template +void RelationDeterminizeFilter::InitLabelMap( + LabelMap *label_map) const { + const auto src_head = tuple_->filter_state.GetState(); + Label label = kNoLabel; + StateId nextstate = kNoStateId; + for (ArcIterator> aiter(*fst_, src_head); !aiter.Done(); + aiter.Next()) { + const auto &arc = aiter.Value(); + // Continues if multiarc. + if (arc.ilabel == label && arc.nextstate == nextstate) continue; + DeterminizeArc det_arc(arc); + det_arc.dest_tuple->filter_state = FilterState(arc.nextstate); + label_map->insert(std::make_pair(arc.ilabel, det_arc)); + label = arc.ilabel; + nextstate = arc.nextstate; + } +} + +// Helper class to disambiguate an FST via Disambiguate(). +template +class Disambiguator { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // IDs arcs with state ID and arc position. Arc position -1 indicates final + // (super-final transition). + using ArcId = std::pair; + + Disambiguator() : error_(false) {} + + void Disambiguate( + const Fst &ifst, MutableFst *ofst, + const DisambiguateOptions &opts = DisambiguateOptions()) { + VectorFst sfst(ifst); + Connect(&sfst); + ArcSort(&sfst, ArcCompare()); + PreDisambiguate(sfst, ofst, opts); + ArcSort(ofst, ArcCompare()); + FindAmbiguities(*ofst); + RemoveSplits(ofst); + MarkAmbiguities(); + RemoveAmbiguities(ofst); + if (error_) ofst->SetProperties(kError, kError); + } + + private: + // Comparison functor for comparing input labels and next states of arcs. This + // sort order facilitates the predisambiguation. + class ArcCompare { + public: + bool operator()(const Arc &arc1, const Arc &arc2) const { + return arc1.ilabel < arc2.ilabel || + (arc1.ilabel == arc2.ilabel && arc1.nextstate < arc2.nextstate); + } + + uint64 Properties(uint64 props) const { + return (props & kArcSortProperties) | kILabelSorted | + (props & kAcceptor ? kOLabelSorted : 0); + } + }; + + // Comparison functor for comparing transitions represented by their arc ID. + // This sort order facilitates ambiguity detection. + class ArcIdCompare { + public: + explicit ArcIdCompare(const std::vector &head) : head_(head) {} + + bool operator()(const ArcId &a1, const ArcId &a2) const { + // Sort first by source head state... + const auto src1 = a1.first; + const auto src2 = a2.first; + const auto head1 = head_[src1]; + const auto head2 = head_[src2]; + if (head1 < head2) return true; + if (head2 < head1) return false; + // ...then by source state... + if (src1 < src2) return true; + if (src2 < src1) return false; + // ...then by position. + return a1.second < a2.second; + } + + private: + const std::vector &head_; + }; + + // A relation that determines if two states share a common future. + class CommonFuture { + public: + using StateTable = GenericComposeStateTable; + using StateTuple = typename StateTable::StateTuple; + + // Needed for compilation with DeterminizeRelationFilter. + CommonFuture() { + FSTERROR() << "Disambiguate::CommonFuture: FST not provided"; + } + + explicit CommonFuture(const Fst &ifst) { + using M = Matcher>; + ComposeFstOptions> opts; + // Ensures composition is between acceptors. + const bool trans = ifst.Properties(kNotAcceptor, true); + const auto *fsa = + trans ? new ProjectFst(ifst, PROJECT_INPUT) : &ifst; + opts.state_table = new StateTable(*fsa, *fsa); + const ComposeFst cfst(*fsa, *fsa, opts); + std::vector coaccess; + uint64 props = 0; + SccVisitor scc_visitor(nullptr, nullptr, &coaccess, &props); + DfsVisit(cfst, &scc_visitor); + for (StateId s = 0; s < coaccess.size(); ++s) { + if (coaccess[s]) { + related_.insert(opts.state_table->Tuple(s).StatePair()); + } + } + if (trans) delete fsa; + } + + bool operator()(const StateId s1, StateId s2) const { + return related_.count(std::make_pair(s1, s2)) > 0; + } + + private: + // States s1 and s2 resp. are in this relation iff they there is a + // path from s1 to a final state that has the same label as some + // path from s2 to a final state. + std::set> related_; + }; + + using ArcIdMap = std::multimap; + + // Inserts candidate into the arc ID map. + inline void InsertCandidate(StateId s1, StateId s2, const ArcId &a1, + const ArcId &a2) { + candidates_->insert(head_[s1] > head_[s2] ? std::make_pair(a1, a2) + : std::make_pair(a2, a1)); + } + + // Returns the arc corresponding to ArcId a. + static Arc GetArc(const Fst &fst, ArcId aid) { + if (aid.second == -1) { // Returns super-final transition. + return Arc(kNoLabel, kNoLabel, fst.Final(aid.first), kNoStateId); + } else { + ArcIterator> aiter(fst, aid.first); + aiter.Seek(aid.second); + return aiter.Value(); + } + } + + // Outputs an equivalent FST whose states are subsets of states that have a + // future path in common. + void PreDisambiguate(const ExpandedFst &ifst, MutableFst *ofst, + const DisambiguateOptions &opts); + + // Finds transitions that are ambiguous candidates in the result of + // PreDisambiguate. + void FindAmbiguities(const ExpandedFst &fst); + + // Finds transition pairs that are ambiguous candidates from two specified + // source states. + void FindAmbiguousPairs(const ExpandedFst &fst, StateId s1, StateId s2); + + // Marks ambiguous transitions to be removed. + void MarkAmbiguities(); + + // Deletes spurious ambiguous transitions (due to quantization). + void RemoveSplits(MutableFst *ofst); + + // Deletes actual ambiguous transitions. + void RemoveAmbiguities(MutableFst *ofst); + + // States s1 and s2 are in this relation iff there is a path from the initial + // state to s1 that has the same label as some path from the initial state to + // s2. We store only state pairs s1, s2 such that s1 <= s2. + std::set> coreachable_; + + // Queue of disambiguation-related states to be processed. We store only + // state pairs s1, s2 such that s1 <= s2. + std::list> queue_; + + // Head state in the pre-disambiguation for a given state. + std::vector head_; + + // Maps from a candidate ambiguous arc A to each ambiguous candidate arc B + // with the same label and destination state as A, whose source state s' is + // coreachable with the source state s of A, and for which head(s') < head(s). + std::unique_ptr candidates_; + + // Set of ambiguous transitions to be removed. + std::set ambiguous_; + + // States to merge due to quantization issues. + std::unique_ptr> merge_; + // Marks error condition. + bool error_; + + Disambiguator(const Disambiguator &) = delete; + Disambiguator &operator=(const Disambiguator &) = delete; +}; + +template +void Disambiguator::PreDisambiguate(const ExpandedFst &ifst, + MutableFst *ofst, + const DisambiguateOptions &opts) { + using CommonDivisor = DefaultCommonDivisor; + using Filter = RelationDeterminizeFilter; + // Subset elements with states s1 and s2 (resp.) are in this relation iff they + // there is a path from s1 to a final state that has the same label as some + // path from s2 to a final state. + auto *common_future = new CommonFuture(ifst); + DeterminizeFstOptions nopts; + nopts.delta = opts.delta; + nopts.subsequential_label = opts.subsequential_label; + nopts.filter = new Filter(ifst, common_future, &head_); + // The filter takes ownership of 'common_future', and determinization takes + // ownership of the filter itself. + nopts.gc_limit = 0; // Cache only the last state for fastest copy. + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + /* TODO(riley): fails regression test; understand why + if (ifst.Properties(kAcceptor, true)) { + std::vector idistance, odistance; + ShortestDistance(ifst, &idistance, true); + DeterminizeFst dfst(ifst, &idistance, &odistance, nopts); + PruneOptions< Arc, AnyArcFilter> popts(opts.weight_threshold, + opts.state_threshold, + AnyArcFilter(), + &odistance); + Prune(dfst, ofst, popts); + } else */ { + *ofst = DeterminizeFst(ifst, nopts); + Prune(ofst, opts.weight_threshold, opts.state_threshold); + } + } else { + *ofst = DeterminizeFst(ifst, nopts); + } + head_.resize(ofst->NumStates(), kNoStateId); +} + +template +void Disambiguator::FindAmbiguities(const ExpandedFst &fst) { + if (fst.Start() == kNoStateId) return; + candidates_.reset(new ArcIdMap(ArcIdCompare(head_))); + const auto start_pr = std::make_pair(fst.Start(), fst.Start()); + coreachable_.insert(start_pr); + queue_.push_back(start_pr); + while (!queue_.empty()) { + const auto &pr = queue_.front(); + const auto s1 = pr.first; + const auto s2 = pr.second; + queue_.pop_front(); + FindAmbiguousPairs(fst, s1, s2); + } +} + +template +void Disambiguator::FindAmbiguousPairs(const ExpandedFst &fst, + StateId s1, StateId s2) { + if (fst.NumArcs(s2) > fst.NumArcs(s1)) FindAmbiguousPairs(fst, s2, s1); + SortedMatcher> matcher(fst, MATCH_INPUT); + matcher.SetState(s2); + for (ArcIterator> aiter(fst, s1); !aiter.Done(); aiter.Next()) { + const auto &arc1 = aiter.Value(); + const ArcId a1(s1, aiter.Position()); + if (matcher.Find(arc1.ilabel)) { + for (; !matcher.Done(); matcher.Next()) { + const auto &arc2 = matcher.Value(); + // Continues on implicit epsilon match. + if (arc2.ilabel == kNoLabel) continue; + const ArcId a2(s2, matcher.Position()); + // Actual transition is ambiguous. + if (s1 != s2 && arc1.nextstate == arc2.nextstate) { + InsertCandidate(s1, s2, a1, a2); + } + const auto spr = arc1.nextstate <= arc2.nextstate + ? std::make_pair(arc1.nextstate, arc2.nextstate) + : std::make_pair(arc2.nextstate, arc1.nextstate); + // Not already marked as coreachable? + if (coreachable_.insert(spr).second) { + // Only possible if state split by quantization issues. + if (spr.first != spr.second && + head_[spr.first] == head_[spr.second]) { + if (!merge_) { + merge_.reset(new UnionFind(fst.NumStates(), kNoStateId)); + merge_->MakeAllSet(fst.NumStates()); + } + merge_->Union(spr.first, spr.second); + } else { + queue_.push_back(spr); + } + } + } + } + } + // Super-final transition is ambiguous. + if (s1 != s2 && fst.Final(s1) != Weight::Zero() && + fst.Final(s2) != Weight::Zero()) { + const ArcId a1(s1, -1); + const ArcId a2(s2, -1); + InsertCandidate(s1, s2, a1, a2); + } +} + +template +void Disambiguator::MarkAmbiguities() { + if (!candidates_) return; + for (auto it = candidates_->begin(); it != candidates_->end(); ++it) { + const auto a = it->first; + const auto b = it->second; + // If b is not to be removed, then a is. + if (ambiguous_.count(b) == 0) ambiguous_.insert(a); + } + coreachable_.clear(); + candidates_.reset(); +} + +template +void Disambiguator::RemoveSplits(MutableFst *ofst) { + if (!merge_) return; + // Merges split states to remove spurious ambiguities. + for (StateIterator> siter(*ofst); !siter.Done(); + siter.Next()) { + for (MutableArcIterator> aiter(ofst, siter.Value()); + !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + const auto nextstate = merge_->FindSet(arc.nextstate); + if (nextstate != arc.nextstate) { + arc.nextstate = nextstate; + aiter.SetValue(arc); + } + } + } + // Repeats search for actual ambiguities on modified FST. + coreachable_.clear(); + merge_.reset(); + candidates_.reset(); + FindAmbiguities(*ofst); + if (merge_) { // Shouldn't get here; sanity test. + FSTERROR() << "Disambiguate: Unable to remove spurious ambiguities"; + error_ = true; + return; + } +} + +template +void Disambiguator::RemoveAmbiguities(MutableFst *ofst) { + if (ambiguous_.empty()) return; + // Adds dead state to redirect ambiguous transitions to be removed. + const auto dead = ofst->AddState(); + for (auto it = ambiguous_.begin(); it != ambiguous_.end(); ++it) { + const auto pos = it->second; + if (pos >= 0) { // Actual transition. + MutableArcIterator> aiter(ofst, it->first); + aiter.Seek(pos); + auto arc = aiter.Value(); + arc.nextstate = dead; + aiter.SetValue(arc); + } else { // Super-final transition. + ofst->SetFinal(it->first, Weight::Zero()); + } + } + Connect(ofst); + ambiguous_.clear(); +} + +} // namespace internal + +// Disambiguates a weighted FST. This version writes the disambiguated FST to an +// output MutableFst. The result will be an equivalent FST that has the +// property that there are not two distinct paths from the initial state to a +// final state with the same input labeling. +// +// The weights must be (weakly) left divisible (valid for Tropical and +// LogWeight). +// +// Complexity: +// +// Disambiguable: exponential (polynomial in the size of the output). +// Non-disambiguable: does not terminate. +// +// The disambiguable transducers include all automata and functional transducers +// that are unweighted or that are acyclic or that are unambiguous. +// +// For more information, see: +// +// Mohri, M. and Riley, M. 2015. On the disambiguation of weighted automata. +// In CIAA, pages 263-278. +template +void Disambiguate( + const Fst &ifst, MutableFst *ofst, + const DisambiguateOptions &opts = DisambiguateOptions()) { + internal::Disambiguator disambiguator; + disambiguator.Disambiguate(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_DISAMBIGUATE_H_ diff --git a/projects/llm_framework/include/fst/edit-fst.h b/projects/llm_framework/include/fst/edit-fst.h new file mode 100644 index 00000000..e7bdfd61 --- /dev/null +++ b/projects/llm_framework/include/fst/edit-fst.h @@ -0,0 +1,702 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// An FST implementation that allows non-destructive edit operations on an +// existing FST. +// +// The EditFst class enables non-destructive edit operations on a wrapped +// ExpandedFst. The implementation uses copy-on-write semantics at the node +// level: if a user has an underlying fst on which he or she wants to perform a +// relatively small number of edits (read: mutations), then this implementation +// will copy the edited node to an internal MutableFst and perform any edits in +// situ on that copied node. This class supports all the methods of MutableFst +// except for DeleteStates(const std::vector &); thus, new nodes may +// also be +// added, and one may add transitions from existing nodes of the wrapped fst to +// new nodes. +// +// N.B.: The documentation for Fst::Copy(true) says that its behavior is +// undefined if invoked on an fst that has already been accessed. This class +// requires that the Fst implementation it wraps provides consistent, reliable +// behavior when its Copy(true) method is invoked, where consistent means +// the graph structure, graph properties and state numbering and do not change. +// VectorFst and CompactFst, for example, are both well-behaved in this regard. + +#ifndef FST_EDIT_FST_H_ +#define FST_EDIT_FST_H_ + +#include +#include +#include + +#include + +#include + + +namespace fst { +namespace internal { + +// The EditFstData class is a container for all mutable data for EditFstImpl; +// also, this class provides most of the actual implementation of what EditFst +// does (that is, most of EditFstImpl's methods delegate to methods in this, the +// EditFstData class). Instances of this class are reference-counted and can be +// shared between otherwise independent EditFstImpl instances. This scheme +// allows EditFstImpl to implement the thread-safe, copy-on-write semantics +// required by Fst::Copy(true). +// +// template parameters: +// A the type of arc to use +// WrappedFstT the type of fst wrapped by the EditFst instance that +// this EditFstData instance is backing +// MutableFstT the type of mutable fst to use internally for edited states; +// crucially, MutableFstT::Copy(false) *must* yield an fst that is +// thread-safe for reading (VectorFst, for example, has this property) +template , + typename MutableFstT = VectorFst> +class EditFstData { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + EditFstData() : num_new_states_(0) {} + + EditFstData(const EditFstData &other) + : edits_(other.edits_), + external_to_internal_ids_(other.external_to_internal_ids_), + edited_final_weights_(other.edited_final_weights_), + num_new_states_(other.num_new_states_) {} + + ~EditFstData() {} + + static EditFstData *Read( + std::istream &strm, const FstReadOptions &opts); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + // Serialize all private data members of this class. + FstWriteOptions edits_opts(opts); + edits_opts.write_header = true; // Force writing contained header. + edits_.Write(strm, edits_opts); + WriteType(strm, external_to_internal_ids_); + WriteType(strm, edited_final_weights_); + WriteType(strm, num_new_states_); + if (!strm) { + LOG(ERROR) << "EditFstData::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + StateId NumNewStates() const { return num_new_states_; } + + // accessor methods for the fst holding edited states + StateId EditedStart() const { return edits_.Start(); } + + Weight Final(StateId s, const WrappedFstT *wrapped) const { + auto final_weight_it = GetFinalWeightIterator(s); + if (final_weight_it == NotInFinalWeightMap()) { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->Final(s) + : edits_.Final(it->second); + } else { + return final_weight_it->second; + } + } + + size_t NumArcs(StateId s, const WrappedFstT *wrapped) const { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->NumArcs(s) + : edits_.NumArcs(it->second); + } + + size_t NumInputEpsilons(StateId s, const WrappedFstT *wrapped) const { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->NumInputEpsilons(s) + : edits_.NumInputEpsilons(it->second); + } + + size_t NumOutputEpsilons(StateId s, const WrappedFstT *wrapped) const { + auto it = GetEditedIdMapIterator(s); + return it == NotInEditedMap() ? wrapped->NumOutputEpsilons(s) + : edits_.NumOutputEpsilons(it->second); + } + + void SetEditedProperties(uint64 props, uint64 mask) { + edits_.SetProperties(props, mask); + } + + // Non-const MutableFst operations. + + // Sets the start state for this FST. + void SetStart(StateId s) { edits_.SetStart(s); } + + // Sets the final state for this FST. + Weight SetFinal(StateId s, Weight w, const WrappedFstT *wrapped) { + Weight old_weight = Final(s, wrapped); + auto it = GetEditedIdMapIterator(s); + // If we haven't already edited state s, don't add it to edited_ (which can + // be expensive if s has many transitions); just use the + // edited_final_weights_ map. + if (it == NotInEditedMap()) { + edited_final_weights_[s] = w; + } else { + edits_.SetFinal(GetEditableInternalId(s, wrapped), w); + } + return old_weight; + } + + // Adds a new state to this FST, initially with no arcs. + StateId AddState(StateId curr_num_states) { + StateId internal_state_id = edits_.AddState(); + StateId external_state_id = curr_num_states; + external_to_internal_ids_[external_state_id] = internal_state_id; + num_new_states_++; + return external_state_id; + } + + // Adds the specified arc to the specified state of this FST. + const Arc *AddArc(StateId s, const Arc &arc, const WrappedFstT *wrapped) { + const auto internal_id = GetEditableInternalId(s, wrapped); + const auto num_arcs = edits_.NumArcs(internal_id); + ArcIterator arc_it(edits_, internal_id); + const Arc *prev_arc = nullptr; + if (num_arcs > 0) { + // grab the final arc associated with this state in edits_ + arc_it.Seek(num_arcs - 1); + prev_arc = &(arc_it.Value()); + } + edits_.AddArc(internal_id, arc); + return prev_arc; + } + + void DeleteStates() { + edits_.DeleteStates(); + num_new_states_ = 0; + external_to_internal_ids_.clear(); + edited_final_weights_.clear(); + } + + // Removes all but the first n outgoing arcs of the specified state. + void DeleteArcs(StateId s, size_t n, const WrappedFstT *wrapped) { + edits_.DeleteArcs(GetEditableInternalId(s, wrapped), n); + } + + // Removes all outgoing arcs from the specified state. + void DeleteArcs(StateId s, const WrappedFstT *wrapped) { + edits_.DeleteArcs(GetEditableInternalId(s, wrapped)); + } + + // End methods for non-const MutableFst operations. + + // Provides information for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data, + const WrappedFstT *wrapped) const { + auto id_map_it = GetEditedIdMapIterator(s); + if (id_map_it == NotInEditedMap()) { + VLOG(3) << "EditFstData::InitArcIterator: iterating on state " << s + << " of original fst"; + wrapped->InitArcIterator(s, data); + } else { + VLOG(2) << "EditFstData::InitArcIterator: iterating on edited state " << s + << " (internal state id: " << id_map_it->second << ")"; + edits_.InitArcIterator(id_map_it->second, data); + } + } + + // Provides information for the generic mutable arc iterator. + void InitMutableArcIterator(StateId s, MutableArcIteratorData *data, + const WrappedFstT *wrapped) { + data->base = new MutableArcIterator( + &edits_, GetEditableInternalId(s, wrapped)); + } + + // Prints out the map from external to internal state id's (for debugging + // purposes). + void PrintMap() { + for (auto map_it = external_to_internal_ids_.begin(); + map_it != NotInEditedMap(); ++map_it) { + LOG(INFO) << "(external,internal)=(" << map_it->first << "," + << map_it->second << ")"; + } + } + + private: + // Returns the iterator of the map from external to internal state id's + // of edits_ for the specified external state id. + typename std::unordered_map::const_iterator + GetEditedIdMapIterator(StateId s) const { + return external_to_internal_ids_.find(s); + } + + typename std::unordered_map::const_iterator + NotInEditedMap() const { + return external_to_internal_ids_.end(); + } + + typename std::unordered_map::const_iterator + GetFinalWeightIterator(StateId s) const { + return edited_final_weights_.find(s); + } + + typename std::unordered_map::const_iterator + NotInFinalWeightMap() const { + return edited_final_weights_.end(); + } + + // Returns the internal state ID of the specified external ID if the state has + // already been made editable, or else copies the state from wrapped_ to + // edits_ and returns the state id of the newly editable state in edits_. + StateId GetEditableInternalId(StateId s, const WrappedFstT *wrapped) { + auto id_map_it = GetEditedIdMapIterator(s); + if (id_map_it == NotInEditedMap()) { + StateId new_internal_id = edits_.AddState(); + VLOG(2) << "EditFstData::GetEditableInternalId: editing state " << s + << " of original fst; new internal state id:" << new_internal_id; + external_to_internal_ids_[s] = new_internal_id; + for (ArcIterator> arc_iterator(*wrapped, s); + !arc_iterator.Done(); arc_iterator.Next()) { + edits_.AddArc(new_internal_id, arc_iterator.Value()); + } + // Copies the final weight. + auto final_weight_it = GetFinalWeightIterator(s); + if (final_weight_it == NotInFinalWeightMap()) { + edits_.SetFinal(new_internal_id, wrapped->Final(s)); + } else { + edits_.SetFinal(new_internal_id, final_weight_it->second); + edited_final_weights_.erase(s); + } + return new_internal_id; + } else { + return id_map_it->second; + } + } + + // A mutable FST (by default, a VectorFst) to contain new states, and/or + // copies of states from a wrapped ExpandedFst that have been modified in + // some way. + MutableFstT edits_; + // A mapping from external state IDs to the internal IDs of states that + // appear in edits_. + std::unordered_map external_to_internal_ids_; + // A mapping from external state IDs to final state weights assigned to + // those states. The states in this map are *only* those whose final weight + // has been modified; if any other part of the state has been modified, + // the entire state is copied to edits_, and all modifications reside there. + std::unordered_map edited_final_weights_; + // The number of new states added to this mutable fst impl, which is <= the + // number of states in edits_ (since edits_ contains both edited *and* new + // states). + StateId num_new_states_; +}; + +// EditFstData method implementations: just the Read method. +template +EditFstData * +EditFstData::Read(std::istream &strm, + const FstReadOptions &opts) { + auto *data = new EditFstData(); + // next read in MutabelFstT machine that stores edits + FstReadOptions edits_opts(opts); + // Contained header was written out, so read it in. + edits_opts.header = nullptr; + + // Because our internal representation of edited states is a solid object + // of type MutableFstT (defaults to VectorFst) and not a pointer, + // and because the static Read method allocates a new object on the heap, + // we need to call Read, check if there was a failure, use + // MutableFstT::operator= to assign the object (not the pointer) to the + // edits_ data member (which will increase the ref count by 1 on the impl) + // and, finally, delete the heap-allocated object. + std::unique_ptr edits(MutableFstT::Read(strm, edits_opts)); + if (!edits) return nullptr; + data->edits_ = *edits; + edits.reset(); + // Finally, reads in rest of private data members. + ReadType(strm, &data->external_to_internal_ids_); + ReadType(strm, &data->edited_final_weights_); + ReadType(strm, &data->num_new_states_); + if (!strm) { + LOG(ERROR) << "EditFst::Read: read failed: " << opts.source; + return nullptr; + } + return data; +} + +// This class enables non-destructive edit operations on a wrapped ExpandedFst. +// The implementation uses copy-on-write semantics at the node level: if a user +// has an underlying fst on which he or she wants to perform a relatively small +// number of edits (read: mutations), then this implementation will copy the +// edited node to an internal MutableFst and perform any edits in situ on that +// copied node. This class supports all the methods of MutableFst except for +// DeleteStates(const std::vector &); thus, new nodes may also be +// added, and +// one may add transitions from existing nodes of the wrapped fst to new nodes. +// +// template parameters: +// A the type of arc to use +// WrappedFstT the type of fst wrapped by the EditFst instance that +// this EditFstImpl instance is backing +// MutableFstT the type of mutable fst to use internally for edited states; +// crucially, MutableFstT::Copy(false) *must* yield an fst that is +// thread-safe for reading (VectorFst, for example, has this property) +template , + typename MutableFstT = VectorFst> +class EditFstImpl : public FstImpl { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + // Constructs an editable FST implementation with no states. Effectively, this + // initially-empty fst will in every way mimic the behavior of a + // VectorFst---more precisely, a VectorFstImpl instance---but with slightly + // slower performance (by a constant factor), due to the fact that + // this class maintains a mapping between external state id's and + // their internal equivalents. + EditFstImpl() : wrapped_(new MutableFstT()) { + FstImpl::SetType("edit"); + InheritPropertiesFromWrapped(); + data_ = std::make_shared>(); + } + + // Wraps the specified ExpandedFst. This constructor requires that the + // specified Fst is an ExpandedFst instance. This requirement is only enforced + // at runtime. (See below for the reason.) + // + // This library uses the pointer-to-implementation or "PIMPL" design pattern. + // In particular, to make it convenient to bind an implementation class to its + // interface, there are a pair of template "binder" classes, one for immutable + // and one for mutable fst's (ImplToFst and ImplToMutableFst, respectively). + // As it happens, the API for the ImplToMutableFst class requires that + // the implementation class--the template parameter "I"--have a constructor + // taking a const Fst reference. Accordingly, the constructor here must + // perform a static_cast to the WrappedFstT type required by EditFst and + // therefore EditFstImpl. + explicit EditFstImpl(const Fst &wrapped) + : wrapped_(static_cast(wrapped.Copy())) { + FstImpl::SetType("edit"); + data_ = std::make_shared>(); + // have edits_ inherit all properties from wrapped_ + data_->SetEditedProperties(wrapped_->Properties(kFstProperties, false), + kFstProperties); + InheritPropertiesFromWrapped(); + } + + // A copy constructor for this implementation class, used to implement + // the Copy() method of the Fst interface. + EditFstImpl(const EditFstImpl &impl) + : FstImpl(), + wrapped_(static_cast(impl.wrapped_->Copy(true))), + data_(impl.data_) { + SetProperties(impl.Properties()); + } + + // const Fst/ExpandedFst operations, declared in the Fst and ExpandedFst + // interfaces + StateId Start() const { + const auto edited_start = data_->EditedStart(); + return edited_start == kNoStateId ? wrapped_->Start() : edited_start; + } + + Weight Final(StateId s) const { return data_->Final(s, wrapped_.get()); } + + size_t NumArcs(StateId s) const { return data_->NumArcs(s, wrapped_.get()); } + + size_t NumInputEpsilons(StateId s) const { + return data_->NumInputEpsilons(s, wrapped_.get()); + } + + size_t NumOutputEpsilons(StateId s) const { + return data_->NumOutputEpsilons(s, wrapped_.get()); + } + + StateId NumStates() const { + return wrapped_->NumStates() + data_->NumNewStates(); + } + + static EditFstImpl *Read( + std::istream &strm, const FstReadOptions &opts); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(NumStates()); + FstWriteOptions header_opts(opts); + // Allows the contained FST to hold any symbols. + header_opts.write_isymbols = false; + header_opts.write_osymbols = false; + WriteHeader(strm, header_opts, kFileVersion, &hdr); + // First, serializes the wrapped FST to stream. + FstWriteOptions wrapped_opts(opts); + // Forcse writing the contained header. + wrapped_opts.write_header = true; + wrapped_->Write(strm, wrapped_opts); + data_->Write(strm, opts); + strm.flush(); + if (!strm) { + LOG(ERROR) << "EditFst::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + // Sets the start state for this FST. + void SetStart(StateId s) { + MutateCheck(); + data_->SetStart(s); + SetProperties(SetStartProperties(FstImpl::Properties())); + } + + // Sets the final state for this fst. + void SetFinal(StateId s, Weight weight) { + MutateCheck(); + Weight old_weight = data_->SetFinal(s, weight, wrapped_.get()); + SetProperties( + SetFinalProperties(FstImpl::Properties(), old_weight, weight)); + } + + // Adds a new state to this fst, initially with no arcs. + StateId AddState() { + MutateCheck(); + SetProperties(AddStateProperties(FstImpl::Properties())); + return data_->AddState(NumStates()); + } + + // Adds the specified arc to the specified state of this fst. + void AddArc(StateId s, const Arc &arc) { + MutateCheck(); + const auto *prev_arc = data_->AddArc(s, arc, wrapped_.get()); + SetProperties( + AddArcProperties(FstImpl::Properties(), s, arc, prev_arc)); + } + + void DeleteStates(const std::vector &dstates) { + FSTERROR() << ": EditFstImpl::DeleteStates(const std::vector&): " + << " not implemented"; + SetProperties(kError, kError); + } + + // Deletes all states in this fst. + void DeleteStates(); + + // Removes all but the first n outgoing arcs of the specified state. + void DeleteArcs(StateId s, size_t n) { + MutateCheck(); + data_->DeleteArcs(s, n, wrapped_.get()); + SetProperties(DeleteArcsProperties(FstImpl::Properties())); + } + + // Removes all outgoing arcs from the specified state. + void DeleteArcs(StateId s) { + MutateCheck(); + data_->DeleteArcs(s, wrapped_.get()); + SetProperties(DeleteArcsProperties(FstImpl::Properties())); + } + + void ReserveStates(StateId s) {} + + void ReserveArcs(StateId s, size_t n) {} + + // Ends non-const MutableFst operations. + + // Provides information for the generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = NumStates(); + } + + // Provides information for the generic arc iterator. + void InitArcIterator(StateId s, ArcIteratorData *data) const { + data_->InitArcIterator(s, data, wrapped_.get()); + } + + // Provides information for the generic mutable arc iterator. + void InitMutableArcIterator(StateId s, MutableArcIteratorData *data) { + MutateCheck(); + data_->InitMutableArcIterator(s, data, wrapped_.get()); + } + + private: + // Properties always true of this FST class. + static constexpr uint64 kStaticProperties = kExpanded | kMutable; + // Current file format version. + static constexpr int kFileVersion = 2; + // Minimum file format version supported + static constexpr int kMinFileVersion = 2; + + // Causes this FST to inherit all the properties from its wrapped FST, except + // for the two properties that always apply to EditFst instances: kExpanded + // and kMutable. + void InheritPropertiesFromWrapped() { + SetProperties(wrapped_->Properties(kCopyProperties, false) | + kStaticProperties); + SetInputSymbols(wrapped_->InputSymbols()); + SetOutputSymbols(wrapped_->OutputSymbols()); + } + + // This method ensures that any operations that alter the mutable data + // portion of this EditFstImpl cause the data_ member to be copied when its + // reference count is greater than 1. Note that this method is distinct from + // MutableFst::Mutate, which gets invoked whenever one of the basic mutation + // methods defined in MutableFst is invoked, such as SetInputSymbols. + // The MutateCheck here in EditFstImpl is invoked whenever one of the + // mutating methods specifically related to the types of edits provided + // by EditFst is performed, such as changing an arc of an existing state + // of the wrapped fst via a MutableArcIterator, or adding a new state via + // AddState(). + void MutateCheck() { + if (!data_.unique()) { + data_ = + std::make_shared>(*data_); + } + } + + // The FST that this FST wraps. The purpose of this class is to enable + // non-destructive edits on this wrapped FST. + std::unique_ptr wrapped_; + // The mutable data for this EditFst instance, with delegates for all the + // methods that can mutate data. + std::shared_ptr> data_; +}; + +template +constexpr uint64 EditFstImpl::kStaticProperties; + +template +constexpr int EditFstImpl::kFileVersion; + +template +constexpr int EditFstImpl::kMinFileVersion; + +template +inline void EditFstImpl::DeleteStates() { + data_->DeleteStates(); + // we are deleting all states, so just forget about pointer to wrapped_ + // and do what default constructor does: set wrapped_ to a new VectorFst + wrapped_.reset(new MutableFstT()); + const auto new_props = + DeleteAllStatesProperties(FstImpl::Properties(), kStaticProperties); + FstImpl::SetProperties(new_props); +} + +template +EditFstImpl * +EditFstImpl::Read(std::istream &strm, + const FstReadOptions &opts) { + auto *impl = new EditFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; + impl->SetStart(hdr.Start()); + // Reads in wrapped FST. + FstReadOptions wrapped_opts(opts); + // Contained header was written out, so reads it in too. + wrapped_opts.header = nullptr; + std::unique_ptr> wrapped_fst(Fst::Read(strm, wrapped_opts)); + if (!wrapped_fst) return nullptr; + impl->wrapped_.reset(static_cast(wrapped_fst.release())); + impl->data_ = std::shared_ptr>( + EditFstData::Read(strm, opts)); + if (!impl->data_) return nullptr; + return impl; +} + +} // namespace internal + +// Concrete, editable FST. This class attaches interface to implementation. +template , + typename MutableFstT = VectorFst> +class EditFst : public ImplToMutableFst< + internal::EditFstImpl> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Impl = internal::EditFstImpl; + + friend class MutableArcIterator>; + + EditFst() : ImplToMutableFst(std::make_shared()) {} + + explicit EditFst(const Fst &fst) + : ImplToMutableFst(std::make_shared(fst)) {} + + explicit EditFst(const WrappedFstT &fst) + : ImplToMutableFst(std::make_shared(fst)) {} + + // See Fst<>::Copy() for doc. + EditFst(const EditFst &fst, bool safe = false) + : ImplToMutableFst(fst, safe) {} + + ~EditFst() override {} + + // Gets a copy of this EditFst. See Fst<>::Copy() for further doc. + EditFst *Copy( + bool safe = false) const override { + return new EditFst(*this, safe); + } + + EditFst &operator=( + const EditFst &fst) { + SetImpl(fst.GetSharedImpl()); + return *this; + } + + EditFst &operator=( + const Fst &fst) override { + SetImpl(std::make_shared(fst)); + return *this; + } + + // Reads an EditFst from an input stream, returning nullptr on error. + static EditFst *Read( + std::istream &strm, const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new EditFst(std::shared_ptr(impl)) : nullptr; + } + + // Reads an EditFst from a file, returning nullptr on error. If the filename + // argument is an empty string, it reads from standard input. + static EditFst *Read(const string &filename) { + auto *impl = ImplToExpandedFst>::Read(filename); + return impl ? new EditFst( + std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->InitArcIterator(s, data); + } + + void InitMutableArcIterator(StateId s, + MutableArcIteratorData *data) override { + GetMutableImpl()->InitMutableArcIterator(s, data); + } + + private: + explicit EditFst(std::shared_ptr impl) : ImplToMutableFst(impl) {} + + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; + using ImplToFst>::SetImpl; +}; + +} // namespace fst + +#endif // FST_EDIT_FST_H_ diff --git a/projects/llm_framework/include/fst/encode.h b/projects/llm_framework/include/fst/encode.h new file mode 100644 index 00000000..f251bbfc --- /dev/null +++ b/projects/llm_framework/include/fst/encode.h @@ -0,0 +1,556 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to encode and decode an FST. + +#ifndef FST_ENCODE_H_ +#define FST_ENCODE_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + + +namespace fst { + +enum EncodeType { ENCODE = 1, DECODE = 2 }; + +static constexpr uint32 kEncodeLabels = 0x0001; +static constexpr uint32 kEncodeWeights = 0x0002; +static constexpr uint32 kEncodeFlags = 0x0003; + +namespace internal { + +static constexpr uint32 kEncodeHasISymbols = 0x0004; +static constexpr uint32 kEncodeHasOSymbols = 0x0008; + +// Identifies stream data as an encode table (and its endianity) +static const int32 kEncodeMagicNumber = 2129983209; + +// The following class encapsulates implementation details for the encoding and +// decoding of label/weight tuples used for encoding and decoding of FSTs. The +// EncodeTable is bidirectional. I.e, it stores both the Tuple of encode labels +// and weights to a unique label, and the reverse. +template +class EncodeTable { + public: + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + // Encoded data consists of arc input/output labels and arc weight. + struct Tuple { + Tuple() {} + + Tuple(Label ilabel_, Label olabel_, Weight weight_) + : ilabel(ilabel_), olabel(olabel_), weight(std::move(weight_)) {} + + Tuple(const Tuple &tuple) + : ilabel(tuple.ilabel), + olabel(tuple.olabel), + weight(std::move(tuple.weight)) {} + + Label ilabel; + Label olabel; + Weight weight; + }; + + // Comparison object for hashing EncodeTable Tuple(s). + class TupleEqual { + public: + bool operator()(const Tuple *x, const Tuple *y) const { + return (x->ilabel == y->ilabel && x->olabel == y->olabel && + x->weight == y->weight); + } + }; + + // Hash function for EncodeTabe Tuples. Based on the encode flags + // we either hash the labels, weights or combination of them. + class TupleKey { + public: + TupleKey() : encode_flags_(kEncodeLabels | kEncodeWeights) {} + + TupleKey(const TupleKey &key) : encode_flags_(key.encode_flags_) {} + + explicit TupleKey(uint32 encode_flags) : encode_flags_(encode_flags) {} + + size_t operator()(const Tuple *x) const { + size_t hash = x->ilabel; + static constexpr int lshift = 5; + static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5; + if (encode_flags_ & kEncodeLabels) { + hash = hash << lshift ^ hash >> rshift ^ x->olabel; + } + if (encode_flags_ & kEncodeWeights) { + hash = hash << lshift ^ hash >> rshift ^ x->weight.Hash(); + } + return hash; + } + + private: + int32 encode_flags_; + }; + + explicit EncodeTable(uint32 encode_flags) + : flags_(encode_flags), encode_hash_(1024, TupleKey(encode_flags)) {} + + using EncodeHash = std::unordered_map; + + // Given an arc, encodes either input/output labels or input/costs or both. + Label Encode(const Arc &arc) { + std::unique_ptr tuple( + new Tuple(arc.ilabel, flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One())); + auto insert_result = encode_hash_.insert( + std::make_pair(tuple.get(), encode_tuples_.size() + 1)); + if (insert_result.second) encode_tuples_.push_back(std::move(tuple)); + return insert_result.first->second; + } + + // Given an arc, looks up its encoded label or returns kNoLabel if not found. + Label GetLabel(const Arc &arc) const { + const Tuple tuple(arc.ilabel, flags_ & kEncodeLabels ? arc.olabel : 0, + flags_ & kEncodeWeights ? arc.weight : Weight::One()); + auto it = encode_hash_.find(&tuple); + return (it == encode_hash_.end()) ? kNoLabel : it->second; + } + + // Given an encoded arc label, decodes back to input/output labels and costs. + const Tuple *Decode(Label key) const { + if (key < 1 || key > encode_tuples_.size()) { + LOG(ERROR) << "EncodeTable::Decode: Unknown decode key: " << key; + return nullptr; + } + return encode_tuples_[key - 1].get(); + } + + size_t Size() const { return encode_tuples_.size(); } + + bool Write(std::ostream &strm, const string &source) const; + + static EncodeTable *Read(std::istream &strm, const string &source); + + uint32 Flags() const { return flags_ & kEncodeFlags; } + + const SymbolTable *InputSymbols() const { return isymbols_.get(); } + + const SymbolTable *OutputSymbols() const { return osymbols_.get(); } + + void SetInputSymbols(const SymbolTable *syms) { + if (syms) { + isymbols_.reset(syms->Copy()); + flags_ |= kEncodeHasISymbols; + } else { + isymbols_.reset(); + flags_ &= ~kEncodeHasISymbols; + } + } + + void SetOutputSymbols(const SymbolTable *syms) { + if (syms) { + osymbols_.reset(syms->Copy()); + flags_ |= kEncodeHasOSymbols; + } else { + osymbols_.reset(); + flags_ &= ~kEncodeHasOSymbols; + } + } + + private: + uint32 flags_; + std::vector> encode_tuples_; + EncodeHash encode_hash_; + std::unique_ptr isymbols_; // Pre-encoded input symbol table. + std::unique_ptr osymbols_; // Pre-encoded output symbol table. + + EncodeTable(const EncodeTable &) = delete; + EncodeTable &operator=(const EncodeTable &) = delete; +}; + +template +bool EncodeTable::Write(std::ostream &strm, + const string &source) const { + WriteType(strm, kEncodeMagicNumber); + WriteType(strm, flags_); + const int64 size = encode_tuples_.size(); + WriteType(strm, size); + for (const auto &tuple : encode_tuples_) { + WriteType(strm, tuple->ilabel); + WriteType(strm, tuple->olabel); + tuple->weight.Write(strm); + } + if (flags_ & kEncodeHasISymbols) isymbols_->Write(strm); + if (flags_ & kEncodeHasOSymbols) osymbols_->Write(strm); + strm.flush(); + if (!strm) { + LOG(ERROR) << "EncodeTable::Write: Write failed: " << source; + return false; + } + return true; +} + +template +EncodeTable *EncodeTable::Read(std::istream &strm, + const string &source) { + int32 magic_number = 0; + ReadType(strm, &magic_number); + if (magic_number != kEncodeMagicNumber) { + LOG(ERROR) << "EncodeTable::Read: Bad encode table header: " << source; + return nullptr; + } + uint32 flags; + ReadType(strm, &flags); + int64 size; + ReadType(strm, &size); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: Read failed: " << source; + return nullptr; + } + std::unique_ptr> table(new EncodeTable(flags)); + for (int64 i = 0; i < size; ++i) { + std::unique_ptr tuple(new Tuple()); + ReadType(strm, &tuple->ilabel); + ReadType(strm, &tuple->olabel); + tuple->weight.Read(strm); + if (!strm) { + LOG(ERROR) << "EncodeTable::Read: Read failed: " << source; + return nullptr; + } + table->encode_tuples_.push_back(std::move(tuple)); + table->encode_hash_[table->encode_tuples_.back().get()] = + table->encode_tuples_.size(); + } + if (flags & kEncodeHasISymbols) { + table->isymbols_.reset(SymbolTable::Read(strm, source)); + } + if (flags & kEncodeHasOSymbols) { + table->osymbols_.reset(SymbolTable::Read(strm, source)); + } + return table.release(); +} + +} // namespace internal + +// A mapper to encode/decode weighted transducers. Encoding of an FST is used +// for performing classical determinization or minimization on a weighted +// transducer viewing it as an unweighted acceptor over encoded labels. +// +// The mapper stores the encoding in a local hash table (EncodeTable). This +// table is shared (and reference-counted) between the encoder and decoder. +// A decoder has read-only access to the EncodeTable. +// +// The EncodeMapper allows on the fly encoding of the machine. As the +// EncodeTable is generated the same table may by used to decode the machine +// on the fly. For example in the following sequence of operations +// +// Encode -> Determinize -> Decode +// +// we will use the encoding table generated during the encode step in the +// decode, even though the encoding is not complete. +template +class EncodeMapper { + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + public: + EncodeMapper(uint32 flags, EncodeType type) + : flags_(flags), + type_(type), + table_(std::make_shared>(flags)), + error_(false) {} + + EncodeMapper(const EncodeMapper &mapper) + : flags_(mapper.flags_), + type_(mapper.type_), + table_(mapper.table_), + error_(false) {} + + // Copy constructor but setting the type, typically to DECODE. + EncodeMapper(const EncodeMapper &mapper, EncodeType type) + : flags_(mapper.flags_), + type_(type), + table_(mapper.table_), + error_(mapper.error_) {} + + Arc operator()(const Arc &arc); + + MapFinalAction FinalAction() const { + return (type_ == ENCODE && (flags_ & kEncodeWeights)) + ? MAP_REQUIRE_SUPERFINAL + : MAP_NO_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 inprops) { + uint64 outprops = inprops; + if (error_) outprops |= kError; + uint64 mask = kFstProperties; + if (flags_ & kEncodeLabels) { + mask &= kILabelInvariantProperties & kOLabelInvariantProperties; + } + if (flags_ & kEncodeWeights) { + mask &= kILabelInvariantProperties & kWeightInvariantProperties & + (type_ == ENCODE ? kAddSuperFinalProperties + : kRmSuperFinalProperties); + } + return outprops & mask; + } + + uint32 Flags() const { return flags_; } + + EncodeType Type() const { return type_; } + + bool Write(std::ostream &strm, const string &source) const { + return table_->Write(strm, source); + } + + bool Write(const string &filename) const { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return false; + } + return Write(strm, filename); + } + + static EncodeMapper *Read(std::istream &strm, const string &source, + EncodeType type = ENCODE) { + auto *table = internal::EncodeTable::Read(strm, source); + return table ? new EncodeMapper(table->Flags(), type, table) : nullptr; + } + + static EncodeMapper *Read(const string &filename, + EncodeType type = ENCODE) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "EncodeMap: Can't open file: " << filename; + return nullptr; + } + return Read(strm, filename, type); + } + + const SymbolTable *InputSymbols() const { return table_->InputSymbols(); } + + const SymbolTable *OutputSymbols() const { return table_->OutputSymbols(); } + + void SetInputSymbols(const SymbolTable *syms) { + table_->SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable *syms) { + table_->SetOutputSymbols(syms); + } + + private: + uint32 flags_; + EncodeType type_; + std::shared_ptr> table_; + bool error_; + + explicit EncodeMapper(uint32 flags, EncodeType type, + internal::EncodeTable *table) + : flags_(flags), type_(type), table_(table), error_(false) {} + + EncodeMapper &operator=(const EncodeMapper &) = delete; +}; + +template +Arc EncodeMapper::operator()(const Arc &arc) { + if (type_ == ENCODE) { + if ((arc.nextstate == kNoStateId && !(flags_ & kEncodeWeights)) || + (arc.nextstate == kNoStateId && (flags_ & kEncodeWeights) && + arc.weight == Weight::Zero())) { + return arc; + } else { + const auto label = table_->Encode(arc); + return Arc(label, flags_ & kEncodeLabels ? label : arc.olabel, + flags_ & kEncodeWeights ? Weight::One() : arc.weight, + arc.nextstate); + } + } else { // type_ == DECODE + if (arc.nextstate == kNoStateId) { + return arc; + } else { + if (arc.ilabel == 0) return arc; + if (flags_ & kEncodeLabels && arc.ilabel != arc.olabel) { + FSTERROR() << "EncodeMapper: Label-encoded arc has different " + "input and output labels"; + error_ = true; + } + if (flags_ & kEncodeWeights && arc.weight != Weight::One()) { + FSTERROR() << "EncodeMapper: Weight-encoded arc has non-trivial weight"; + error_ = true; + } + const auto tuple = table_->Decode(arc.ilabel); + if (!tuple) { + FSTERROR() << "EncodeMapper: Decode failed"; + error_ = true; + return Arc(kNoLabel, kNoLabel, Weight::NoWeight(), arc.nextstate); + } else { + return Arc(tuple->ilabel, + flags_ & kEncodeLabels ? tuple->olabel : arc.olabel, + flags_ & kEncodeWeights ? tuple->weight : arc.weight, + arc.nextstate); + } + } + } +} + +// Complexity: O(E + V). +template +inline void Encode(MutableFst *fst, EncodeMapper *mapper) { + mapper->SetInputSymbols(fst->InputSymbols()); + mapper->SetOutputSymbols(fst->OutputSymbols()); + ArcMap(fst, mapper); +} + +template +inline void Decode(MutableFst *fst, const EncodeMapper &mapper) { + ArcMap(fst, EncodeMapper(mapper, DECODE)); + RmFinalEpsilon(fst); + fst->SetInputSymbols(mapper.InputSymbols()); + fst->SetOutputSymbols(mapper.OutputSymbols()); +} + +// On-the-fly encoding of an input FST. +// +// Complexity: +// +// Construction: O(1) +// Traversal: O(e + v) +// +// where e is the number of arcs visited and v is the number of states visited. +// Constant time and space to visit an input state or arc is assumed and +// exclusive of caching. +template +class EncodeFst : public ArcMapFst> { + public: + using Mapper = EncodeMapper; + using Impl = internal::ArcMapFstImpl; + + EncodeFst(const Fst &fst, Mapper *encoder) + : ArcMapFst(fst, encoder, ArcMapFstOptions()) { + encoder->SetInputSymbols(fst.InputSymbols()); + encoder->SetOutputSymbols(fst.OutputSymbols()); + } + + EncodeFst(const Fst &fst, const Mapper &encoder) + : ArcMapFst(fst, encoder, ArcMapFstOptions()) {} + + // See Fst<>::Copy() for doc. + EncodeFst(const EncodeFst &fst, bool copy = false) + : ArcMapFst(fst, copy) {} + + // Makes a copy of this EncodeFst. See Fst<>::Copy() for further doc. + EncodeFst *Copy(bool safe = false) const override { + if (safe) { + FSTERROR() << "EncodeFst::Copy(true): Not allowed"; + GetImpl()->SetProperties(kError, kError); + } + return new EncodeFst(*this); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; +}; + +// On-the-fly decoding of an input FST. +// +// Complexity: +// +// Construction: O(1). +// Traversal: O(e + v) +// +// Constant time and space to visit an input state or arc is assumed and +// exclusive of caching. +template +class DecodeFst : public ArcMapFst> { + public: + using Mapper = EncodeMapper; + using Impl = internal::ArcMapFstImpl; + using ImplToFst::GetImpl; + + DecodeFst(const Fst &fst, const Mapper &encoder) + : ArcMapFst(fst, Mapper(encoder, DECODE), + ArcMapFstOptions()) { + GetMutableImpl()->SetInputSymbols(encoder.InputSymbols()); + GetMutableImpl()->SetOutputSymbols(encoder.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + DecodeFst(const DecodeFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Makes a copy of this DecodeFst. See Fst<>::Copy() for further doc. + DecodeFst *Copy(bool safe = false) const override { + return new DecodeFst(*this, safe); + } + + private: + using ImplToFst::GetMutableImpl; +}; + +// Specialization for EncodeFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const EncodeFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for EncodeFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + ArcIterator(const EncodeFst &fst, typename Arc::StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Specialization for DecodeFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const DecodeFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for DecodeFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + ArcIterator(const DecodeFst &fst, typename Arc::StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Useful aliases when using StdArc. + +using StdEncodeFst = EncodeFst; + +using StdDecodeFst = DecodeFst; + +} // namespace fst + +#endif // FST_ENCODE_H_ diff --git a/projects/llm_framework/include/fst/epsnormalize.h b/projects/llm_framework/include/fst/epsnormalize.h new file mode 100644 index 00000000..18105fb1 --- /dev/null +++ b/projects/llm_framework/include/fst/epsnormalize.h @@ -0,0 +1,61 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function that implements epsilon-normalization. + +#ifndef FST_EPSNORMALIZE_H_ +#define FST_EPSNORMALIZE_H_ + + +#include +#include +#include +#include + + +namespace fst { + +enum EpsNormalizeType { EPS_NORM_INPUT, EPS_NORM_OUTPUT }; + +// Returns an equivalent FST that is epsilon-normalized. An acceptor is +// epsilon-normalized if it is epsilon-removed. A transducer is input +// epsilon-normalized if additionally if on each path any epsilon input +// label follows all non-epsilon input labels. Output epsilon-normalized +// is defined similarly. +// +// For more information, see: +// +// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers. International Journal of Computer +// Science, 13(1): 129-143, 2002. +template +void EpsNormalize(const Fst &ifst, MutableFst *ofst, + EpsNormalizeType type = EPS_NORM_INPUT) { + EpsNormalize(ifst, ofst, type); +} + +// Same as above, except allows specifying explicitly the gallic weight type. +template +void EpsNormalize(const Fst &ifst, MutableFst *ofst, + EpsNormalizeType type) { + VectorFst> gfst; + std::unique_ptr symbols; + if (type == EPS_NORM_INPUT) { + ArcMap(ifst, &gfst, ToGallicMapper()); + if (ifst.OutputSymbols()) symbols.reset(ifst.OutputSymbols()->Copy()); + } else { // type == EPS_NORM_OUTPUT + ArcMap(InvertFst(ifst), &gfst, ToGallicMapper()); + if (ifst.InputSymbols()) symbols.reset(ifst.InputSymbols()->Copy()); + } + RmEpsilon(&gfst); + FactorWeightFst, + GallicFactor> + fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper()); + ofst->SetOutputSymbols(symbols.get()); + if (type == EPS_NORM_OUTPUT) Invert(ofst); +} + +} // namespace fst + +#endif // FST_EPSNORMALIZE_H_ diff --git a/projects/llm_framework/include/fst/equal.h b/projects/llm_framework/include/fst/equal.h new file mode 100644 index 00000000..ed89c6ce --- /dev/null +++ b/projects/llm_framework/include/fst/equal.h @@ -0,0 +1,169 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to test equality of two FSTs. + +#ifndef FST_EQUAL_H_ +#define FST_EQUAL_H_ + +#include + +#include +#include + + +namespace fst { + +constexpr uint32 kEqualFsts = 0x0001; +constexpr uint32 kEqualFstTypes = 0x0002; +constexpr uint32 kEqualCompatProperties = 0x0004; +constexpr uint32 kEqualCompatSymbols = 0x0008; +constexpr uint32 kEqualAll = + kEqualFsts | kEqualFstTypes | kEqualCompatProperties | kEqualCompatSymbols; + +class WeightApproxEqual { + public: + explicit WeightApproxEqual(float delta) : delta_(delta) {} + + template + bool operator()(const Weight &w1, const Weight &w2) const { + return ApproxEqual(w1, w2, delta_); + } + + private: + float delta_; +}; + +// Tests if two Fsts have the same states and arcs in the same order (when +// etype & kEqualFst). +// Also optional checks equality of Fst types (etype & kEqualFstTypes) and +// compatibility of stored properties (etype & kEqualCompatProperties) and +// of symbol tables (etype & kEqualCompatSymbols). +template +bool Equal(const Fst &fst1, const Fst &fst2, + WeightEqual weight_equal, uint32 etype = kEqualFsts) { + if ((etype & kEqualFstTypes) && (fst1.Type() != fst2.Type())) { + VLOG(1) << "Equal: Mismatched FST types (" << fst1.Type() << " != " + << fst2.Type() << ")"; + return false; + } + if ((etype & kEqualCompatProperties) && + !CompatProperties(fst1.Properties(kCopyProperties, false), + fst2.Properties(kCopyProperties, false))) { + VLOG(1) << "Equal: Properties not compatible"; + return false; + } + if (etype & kEqualCompatSymbols) { + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols(), false)) { + VLOG(1) << "Equal: Input symbols not compatible"; + return false; + } + if (!CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols(), false)) { + VLOG(1) << "Equal: Output symbols not compatible"; + return false; + } + } + if (!(etype & kEqualFsts)) return true; + if (fst1.Start() != fst2.Start()) { + VLOG(1) << "Equal: Mismatched start states (" << fst1.Start() << " != " + << fst2.Start() << ")"; + return false; + } + StateIterator> siter1(fst1); + StateIterator> siter2(fst2); + while (!siter1.Done() || !siter2.Done()) { + if (siter1.Done() || siter2.Done()) { + VLOG(1) << "Equal: Mismatched number of states"; + return false; + } + const auto s1 = siter1.Value(); + const auto s2 = siter2.Value(); + if (s1 != s2) { + VLOG(1) << "Equal: Mismatched states (" << s1 << "!= " + << s2 << ")"; + return false; + } + const auto &final1 = fst1.Final(s1); + const auto &final2 = fst2.Final(s2); + if (!weight_equal(final1, final2)) { + VLOG(1) << "Equal: Mismatched final weights at state " << s1 + << " (" << final1 << " != " << final2 << ")"; + return false; + } + ArcIterator> aiter1(fst1, s1); + ArcIterator> aiter2(fst2, s2); + for (auto a = 0; !aiter1.Done() || !aiter2.Done(); ++a) { + if (aiter1.Done() || aiter2.Done()) { + VLOG(1) << "Equal: Mismatched number of arcs at state " << s1; + return false; + } + const auto &arc1 = aiter1.Value(); + const auto &arc2 = aiter2.Value(); + if (arc1.ilabel != arc2.ilabel) { + VLOG(1) << "Equal: Mismatched arc input labels at state " << s1 + << ", arc " << a << " (" << arc1.ilabel << " != " + << arc2.ilabel << ")"; + return false; + } else if (arc1.olabel != arc2.olabel) { + VLOG(1) << "Equal: Mismatched arc output labels at state " << s1 + << ", arc " << a << " (" << arc1.olabel << " != " + << arc2.olabel << ")"; + return false; + } else if (!weight_equal(arc1.weight, arc2.weight)) { + VLOG(1) << "Equal: Mismatched arc weights at state " << s1 + << ", arc " << a << " (" << arc1.weight << " != " + << arc2.weight << ")"; + return false; + } else if (arc1.nextstate != arc2.nextstate) { + VLOG(1) << "Equal: Mismatched next state at state " << s1 + << ", arc " << a << " (" << arc1.nextstate << " != " + << arc2.nextstate << ")"; + return false; + } + aiter1.Next(); + aiter2.Next(); + } + // Sanity checks: should never fail. + if (fst1.NumArcs(s1) != fst2.NumArcs(s2)) { + FSTERROR() << "Equal: Inconsistent arc counts at state " << s1 + << " (" << fst1.NumArcs(s1) << " != " + << fst2.NumArcs(s2) << ")"; + return false; + } + if (fst1.NumInputEpsilons(s1) != fst2.NumInputEpsilons(s2)) { + FSTERROR() << "Equal: Inconsistent input epsilon counts at state " << s1 + << " (" << fst1.NumInputEpsilons(s1) << " != " + << fst2.NumInputEpsilons(s2) << ")"; + return false; + } + if (fst1.NumOutputEpsilons(s1) != fst2.NumOutputEpsilons(s2)) { + FSTERROR() << "Equal: Inconsistent output epsilon counts at state " << s1 + << " (" << fst1.NumOutputEpsilons(s1) << " != " + << fst2.NumOutputEpsilons(s2) << ")"; + } + siter1.Next(); + siter2.Next(); + } + return true; +} + +template +bool Equal(const Fst &fst1, const Fst &fst2, + float delta = kDelta, uint32 etype = kEqualFsts) { + return Equal(fst1, fst2, WeightApproxEqual(delta), etype); +} + +// Support double deltas without forcing all clients to cast to float. +// Without this overload, Equal will be chosen, +// since it is a better match than double -> float narrowing, but +// the instantiation will fail. +template +bool Equal(const Fst &fst1, const Fst &fst2, + double delta, uint32 etype = kEqualFsts) { + return Equal(fst1, fst2, WeightApproxEqual(static_cast(delta)), etype); +} + + +} // namespace fst + +#endif // FST_EQUAL_H_ diff --git a/projects/llm_framework/include/fst/equivalent.h b/projects/llm_framework/include/fst/equivalent.h new file mode 100644 index 00000000..cf3fdb61 --- /dev/null +++ b/projects/llm_framework/include/fst/equivalent.h @@ -0,0 +1,230 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to determine the equivalence of two FSTs. + +#ifndef FST_EQUIVALENT_H_ +#define FST_EQUIVALENT_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + + +namespace fst { +namespace internal { + +// Traits-like struct holding utility functions/typedefs/constants for +// the equivalence algorithm. +// +// Encoding device: in order to make the statesets of the two acceptors +// disjoint, we map Arc::StateId on the type MappedId. The states of +// the first acceptor are mapped on odd numbers (s -> 2s + 1), and +// those of the second one on even numbers (s -> 2s + 2). The number 0 +// is reserved for an implicit (non-final) dead state (required for +// the correct treatment of non-coaccessible states; kNoStateId is mapped to +// kDeadState for both acceptors). The union-find algorithm operates on the +// mapped IDs. +template +struct EquivalenceUtil { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using MappedId = StateId; // ID for an equivalence class. + + // MappedId for an implicit dead state. + static constexpr MappedId kDeadState = 0; + + // MappedId for lookup failure. + static constexpr MappedId kInvalidId = -1; + + // Maps state ID to the representative of the corresponding + // equivalence class. The parameter 'which_fst' takes the values 1 + // and 2, identifying the input FST. + static MappedId MapState(StateId s, int32 which_fst) { + return (kNoStateId == s) ? kDeadState + : (static_cast(s) << 1) + which_fst; + } + + // Maps set ID to State ID. + static StateId UnMapState(MappedId id) { + return static_cast((--id) >> 1); + } + + // Convenience function: checks if state with MappedId s is final in + // acceptor fa. + static bool IsFinal(const Fst &fa, MappedId s) { + return (kDeadState == s) ? false + : (fa.Final(UnMapState(s)) != Weight::Zero()); + } + // Convenience function: returns the representative of ID in sets, + // creating a new set if needed. + static MappedId FindSet(UnionFind *sets, MappedId id) { + const auto repr = sets->FindSet(id); + if (repr != kInvalidId) { + return repr; + } else { + sets->MakeSet(id); + return id; + } + } +}; + +template +constexpr + typename EquivalenceUtil::MappedId EquivalenceUtil::kDeadState; + +template +constexpr + typename EquivalenceUtil::MappedId EquivalenceUtil::kInvalidId; + +} // namespace internal + +// Equivalence checking algorithm: determines if the two FSTs fst1 and fst2 +// are equivalent. The input FSTs must be deterministic input-side epsilon-free +// acceptors, unweighted or with weights over a left semiring. Two acceptors are +// considered equivalent if they accept exactly the same set of strings (with +// the same weights). +// +// The algorithm (cf. Aho, Hopcroft and Ullman, "The Design and Analysis of +// Computer Programs") successively constructs sets of states that can be +// reached by the same prefixes, starting with a set containing the start states +// of both acceptors. A disjoint tree forest (the union-find algorithm) is used +// to represent the sets of states. The algorithm returns false if one of the +// constructed sets contains both final and non-final states. Returns an +// optional error value (useful when FLAGS_error_fatal = false). +// +// Complexity: +// +// Quasi-linear, i.e., O(n G(n)), where +// +// n = |S1| + |S2| is the number of states in both acceptors +// +// G(n) is a very slowly growing function that can be approximated +// by 4 by all practical purposes. +template +bool Equivalent(const Fst &fst1, const Fst &fst2, + float delta = kDelta, bool *error = nullptr) { + using Weight = typename Arc::Weight; + if (error) *error = false; + // Check that the symbol table are compatible. + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Equivalent: Input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + if (error) *error = true; + return false; + } + // Check properties first. + static constexpr auto props = kNoEpsilons | kIDeterministic | kAcceptor; + if (fst1.Properties(props, true) != props) { + FSTERROR() << "Equivalent: 1st argument not an" + << " epsilon-free deterministic acceptor"; + if (error) *error = true; + return false; + } + if (fst2.Properties(props, true) != props) { + FSTERROR() << "Equivalent: 2nd argument not an" + << " epsilon-free deterministic acceptor"; + if (error) *error = true; + return false; + } + if ((fst1.Properties(kUnweighted, true) != kUnweighted) || + (fst2.Properties(kUnweighted, true) != kUnweighted)) { + VectorFst efst1(fst1); + VectorFst efst2(fst2); + Push(&efst1, REWEIGHT_TO_INITIAL, delta); + Push(&efst2, REWEIGHT_TO_INITIAL, delta); + ArcMap(&efst1, QuantizeMapper(delta)); + ArcMap(&efst2, QuantizeMapper(delta)); + EncodeMapper mapper(kEncodeWeights | kEncodeLabels, ENCODE); + ArcMap(&efst1, &mapper); + ArcMap(&efst2, &mapper); + return Equivalent(efst1, efst2); + } + using Util = internal::EquivalenceUtil; + using MappedId = typename Util::MappedId; + enum { FST1 = 1, FST2 = 2 }; // Required by Util::MapState(...) + auto s1 = Util::MapState(fst1.Start(), FST1); + auto s2 = Util::MapState(fst2.Start(), FST2); + // The union-find structure. + UnionFind eq_classes(1000, Util::kInvalidId); + // Initializes the union-find structure. + eq_classes.MakeSet(s1); + eq_classes.MakeSet(s2); + // Data structure for the (partial) acceptor transition function of fst1 and + // fst2: input labels mapped to pairs of MappedIds representing destination + // states of the corresponding arcs in fst1 and fst2, respectively. + using Label2StatePairMap = + std::unordered_map>; + Label2StatePairMap arc_pairs; + // Pairs of MappedId's to be processed, organized in a queue. + std::deque> q; + bool ret = true; + // Returns early if the start states differ w.r.t. finality. + if (Util::IsFinal(fst1, s1) != Util::IsFinal(fst2, s2)) ret = false; + // Main loop: explores the two acceptors in a breadth-first manner, updating + // the equivalence relation on the statesets. Loop invariant: each block of + // the states contains either final states only or non-final states only. + for (q.push_back(std::make_pair(s1, s2)); ret && !q.empty(); q.pop_front()) { + s1 = q.front().first; + s2 = q.front().second; + // Representatives of the equivalence classes of s1/s2. + const auto rep1 = Util::FindSet(&eq_classes, s1); + const auto rep2 = Util::FindSet(&eq_classes, s2); + if (rep1 != rep2) { + eq_classes.Union(rep1, rep2); + arc_pairs.clear(); + // Copies outgoing arcs starting at s1 into the hash-table. + if (Util::kDeadState != s1) { + ArcIterator> arc_iter(fst1, Util::UnMapState(s1)); + for (; !arc_iter.Done(); arc_iter.Next()) { + const auto &arc = arc_iter.Value(); + // Zero-weight arcs are treated as if they did not exist. + if (arc.weight != Weight::Zero()) { + arc_pairs[arc.ilabel].first = Util::MapState(arc.nextstate, FST1); + } + } + } + // Copies outgoing arcs starting at s2 into the hashtable. + if (Util::kDeadState != s2) { + ArcIterator> arc_iter(fst2, Util::UnMapState(s2)); + for (; !arc_iter.Done(); arc_iter.Next()) { + const auto &arc = arc_iter.Value(); + // Zero-weight arcs are treated as if they did not exist. + if (arc.weight != Weight::Zero()) { + arc_pairs[arc.ilabel].second = Util::MapState(arc.nextstate, FST2); + } + } + } + // Iterates through the hashtable and process pairs of target states. + for (const auto &arc_iter : arc_pairs) { + const auto &pair = arc_iter.second; + if (Util::IsFinal(fst1, pair.first) != + Util::IsFinal(fst2, pair.second)) { + // Detected inconsistency: return false. + ret = false; + break; + } + q.push_back(pair); + } + } + } + if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { + if (error) *error = true; + return false; + } + return ret; +} + +} // namespace fst + +#endif // FST_EQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/expanded-fst.h b/projects/llm_framework/include/fst/expanded-fst.h new file mode 100644 index 00000000..2c7d514c --- /dev/null +++ b/projects/llm_framework/include/fst/expanded-fst.h @@ -0,0 +1,179 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Generic FST augmented with state count-interface class definition. + +#ifndef FST_EXPANDED_FST_H_ +#define FST_EXPANDED_FST_H_ + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +// A generic FST plus state count. +template +class ExpandedFst : public Fst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + virtual StateId NumStates() const = 0; // State count + + // Get a copy of this ExpandedFst. See Fst<>::Copy() for further doc. + ExpandedFst *Copy(bool safe = false) const override = 0; + + // Read an ExpandedFst from an input stream; return NULL on error. + static ExpandedFst *Read(std::istream &strm, + const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) { + hdr = *opts.header; + } else { + if (!hdr.Read(strm, opts.source)) return nullptr; + ropts.header = &hdr; + } + if (!(hdr.Properties() & kExpanded)) { + LOG(ERROR) << "ExpandedFst::Read: Not an ExpandedFst: " << ropts.source; + return nullptr; + } + const auto reader = + FstRegister::GetRegister()->GetReader(hdr.FstType()); + if (!reader) { + LOG(ERROR) << "ExpandedFst::Read: Unknown FST type \"" << hdr.FstType() + << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source; + return nullptr; + } + auto *fst = reader(strm, ropts); + if (!fst) return nullptr; + return static_cast *>(fst); + } + + // Read an ExpandedFst from a file; return NULL on error. + // Empty filename reads from standard input. + static ExpandedFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } +}; + +namespace internal { + +// ExpandedFst case - abstract methods. +template +inline typename Arc::Weight Final(const ExpandedFst &fst, + typename Arc::StateId s) { + return fst.Final(s); +} + +template +inline ssize_t NumArcs(const ExpandedFst &fst, typename Arc::StateId s) { + return fst.NumArcs(s); +} + +template +inline ssize_t NumInputEpsilons(const ExpandedFst &fst, + typename Arc::StateId s) { + return fst.NumInputEpsilons(s); +} + +template +inline ssize_t NumOutputEpsilons(const ExpandedFst &fst, + typename Arc::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + +// A useful alias when using StdArc. +using StdExpandedFst = ExpandedFst; + +// This is a helper class template useful for attaching an ExpandedFst +// interface to its implementation, handling reference counting. It +// delegates to ImplToFst the handling of the Fst interface methods. +template > +class ImplToExpandedFst : public ImplToFst { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + StateId NumStates() const override { return GetImpl()->NumStates(); } + + protected: + using ImplToFst::GetImpl; + + explicit ImplToExpandedFst(std::shared_ptr impl) + : ImplToFst(impl) {} + + ImplToExpandedFst(const ImplToExpandedFst &fst, bool safe) + : ImplToFst(fst, safe) {} + + static Impl *Read(std::istream &strm, const FstReadOptions &opts) { + return Impl::Read(strm, opts); + } + + // Read FST implementation from a file; return NULL on error. + // Empty filename reads from standard input. + static Impl *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ExpandedFst::Read: Can't open file: " << filename; + return nullptr; + } + return Impl::Read(strm, FstReadOptions(filename)); + } else { + return Impl::Read(std::cin, FstReadOptions("standard input")); + } + } +}; + +// Function to return the number of states in an FST, counting them +// if necessary. +template +typename Arc::StateId CountStates(const Fst &fst) { + if (fst.Properties(kExpanded, false)) { + const auto *efst = static_cast *>(&fst); + return efst->NumStates(); + } else { + typename Arc::StateId nstates = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates; + } + return nstates; + } +} + +// Function to return the number of arcs in an FST. +template +typename Arc::StateId CountArcs(const Fst &fst) { + size_t narcs = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + narcs += fst.NumArcs(siter.Value()); + } + return narcs; +} + +} // namespace fst + +#endif // FST_EXPANDED_FST_H_ diff --git a/projects/llm_framework/include/fst/expectation-weight.h b/projects/llm_framework/include/fst/expectation-weight.h new file mode 100644 index 00000000..f996cbc6 --- /dev/null +++ b/projects/llm_framework/include/fst/expectation-weight.h @@ -0,0 +1,134 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Expectation semiring as described by Jason Eisner: +// See: doi=10.1.1.22.9398 +// Multiplex semiring operations and identities: +// One: +// Zero: +// Plus: + = < (a1 + a2) , (b1 + b2) > +// Times: * = < (a1 * a2) , [(a1 * b2) + (a2 * b1)] > +// Division: Undefined (currently) +// +// Usually used to store the pair so that +// ShortestDistance[Fst>>] +// == < PosteriorProbability, Expected_Value[V] > + +#ifndef FST_EXPECTATION_WEIGHT_H_ +#define FST_EXPECTATION_WEIGHT_H_ + +#include + +#include + +#include +#include + + +namespace fst { + +// X1 is usually a probability weight like LogWeight. +// X2 is usually a random variable or vector (see SignedLogWeight or +// SparsePowerWeight). +// +// If X1 is distinct from X2, it is required that there is an external product +// between X1 and X2 and if both semriring are commutative, or left or right +// semirings, then result must have those properties. +template +class ExpectationWeight : public PairWeight { + public: + using PairWeight::Value1; + using PairWeight::Value2; + + using PairWeight::Reverse; + using PairWeight::Quantize; + using PairWeight::Member; + + using ReverseWeight = + ExpectationWeight; + + ExpectationWeight() : PairWeight(Zero()) {} + + explicit ExpectationWeight(const PairWeight &weight) + : PairWeight(weight) {} + + ExpectationWeight(const X1 &x1, const X2 &x2) : PairWeight(x1, x2) {} + + static const ExpectationWeight &Zero() { + static const ExpectationWeight zero(X1::Zero(), X2::Zero()); + return zero; + } + + static const ExpectationWeight &One() { + static const ExpectationWeight one(X1::One(), X2::Zero()); + return one; + } + + static const ExpectationWeight &NoWeight() { + static const ExpectationWeight no_weight(X1::NoWeight(), X2::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = + new string("expectation_" + X1::Type() + "_" + X2::Type()); + return *type; + } + + PairWeight Quantize(float delta = kDelta) const { + return ExpectationWeight(PairWeight::Quantize()); + } + + ReverseWeight Reverse() const { + return ReverseWeight(PairWeight::Reverse()); + } + + bool Member() const { return PairWeight::Member(); } + + static constexpr uint64 Properties() { + return X1::Properties() & X2::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } +}; + +template +inline ExpectationWeight Plus(const ExpectationWeight &w1, + const ExpectationWeight &w2) { + return ExpectationWeight(Plus(w1.Value1(), w2.Value1()), + Plus(w1.Value2(), w2.Value2())); +} + +template +inline ExpectationWeight Times(const ExpectationWeight &w1, + const ExpectationWeight &w2) { + return ExpectationWeight( + Times(w1.Value1(), w2.Value1()), + Plus(Times(w1.Value1(), w2.Value2()), Times(w1.Value2(), w2.Value1()))); +} + +template +inline ExpectationWeight Divide(const ExpectationWeight &w1, + const ExpectationWeight &w2, + DivideType typ = DIVIDE_ANY) { + FSTERROR() << "ExpectationWeight::Divide: Not implemented"; + return ExpectationWeight::NoWeight(); +} + +// This function object generates weights by calling the underlying generators +// for the template weight types, like all other pair weight types. This is +// intended primarily for testing. +template +class WeightGenerate> + : public WeightGenerate> { + public: + using Weight = ExpectationWeight; + using Generate = WeightGenerate>; + + explicit WeightGenerate(bool allow_zero = true) : Generate(allow_zero) {} + + Weight operator()() const { return Weight(Generate::operator()()); } +}; + +} // namespace fst + +#endif // FST_EXPECTATION_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/extensions/compress/compress-script.h b/projects/llm_framework/include/fst/extensions/compress/compress-script.h new file mode 100644 index 00000000..bad238aa --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/compress/compress-script.h @@ -0,0 +1,53 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Declarations of 'scriptable' versions of compression operations, that is, +// those that can be called with FstClass-type arguments. + +#ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_SCRIPT_H_ +#define FST_EXTENSIONS_COMPRESS_COMPRESS_SCRIPT_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +typedef std::tuple CompressArgs; + +template +void Compress(CompressArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + const string &filename = std::get<1>(*args); + const bool gzip = std::get<2>(*args); + + if (!fst::Compress(fst, filename, gzip)) FSTERROR() << "Compress: failed"; +} + +void Compress(const FstClass &fst, const string &filename, const bool gzip); + +typedef std::tuple + DecompressArgs; + +template +void Decompress(DecompressArgs *args) { + const string &filename = std::get<0>(*args); + MutableFst *fst = std::get<1>(*args)->GetMutableFst(); + const bool gzip = std::get<2>(*args); + + if (!fst::Decompress(filename, fst, gzip)) + FSTERROR() << "Decompress: failed"; +} + +void Decompress(const string &filename, MutableFstClass *fst, const bool gzip); + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_COMPRESS_COMPRESS_SCRIPT_H_ diff --git a/projects/llm_framework/include/fst/extensions/compress/compress.h b/projects/llm_framework/include/fst/extensions/compress/compress.h new file mode 100644 index 00000000..aa94848f --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/compress/compress.h @@ -0,0 +1,906 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Compresses and decompresses unweighted FSTs. + +#ifndef FST_EXTENSIONS_COMPRESS_COMPRESS_H_ +#define FST_EXTENSIONS_COMPRESS_COMPRESS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// Identifies stream data as a vanilla compressed FST. +static const int32 kCompressMagicNumber = 1858869554; +// Identifies stream data as (probably) a Gzip file accidentally read from +// a vanilla stream, without gzip support. +static const int32 kGzipMagicNumber = 0x8b1f; +// Selects the two most significant bytes. +constexpr uint32 kGzipMask = 0xffffffff >> 16; + +namespace internal { + +// Expands a Lempel Ziv code and returns the set of code words. expanded_code[i] +// is the i^th Lempel Ziv codeword. +template +bool ExpandLZCode(const std::vector> &code, + std::vector> *expanded_code) { + expanded_code->resize(code.size()); + for (int i = 0; i < code.size(); ++i) { + if (code[i].first > i) { + LOG(ERROR) << "ExpandLZCode: Not a valid code"; + return false; + } + if (code[i].first == 0) { + (*expanded_code)[i].resize(1, code[i].second); + } else { + (*expanded_code)[i].resize((*expanded_code)[code[i].first - 1].size() + + 1); + std::copy((*expanded_code)[code[i].first - 1].begin(), + (*expanded_code)[code[i].first - 1].end(), + (*expanded_code)[i].begin()); + (*expanded_code)[i][(*expanded_code)[code[i].first - 1].size()] = + code[i].second; + } + } + return true; +} + +} // namespace internal + +// Lempel Ziv on data structure Edge, with a less than operator +// EdgeLessThan and an equals operator EdgeEquals. +// Edge has a value defaultedge which it never takes and +// Edge is defined, it is initialized to defaultedge +template +class LempelZiv { + public: + LempelZiv() : dict_number_(0), default_edge_() { + root_.current_number = dict_number_++; + root_.current_edge = default_edge_; + decode_vector_.push_back(std::make_pair(0, default_edge_)); + } + // Encodes a vector input into output + void BatchEncode(const std::vector &input, + std::vector> *output); + + // Decodes codedvector to output. Returns false if + // the index exceeds the size. + bool BatchDecode(const std::vector> &input, + std::vector *output); + + // Decodes a single dictionary element. Returns false + // if the index exceeds the size. + bool SingleDecode(const Var &index, Edge *output) { + if (index >= decode_vector_.size()) { + LOG(ERROR) << "LempelZiv::SingleDecode: " + << "Index exceeded the dictionary size"; + return false; + } else { + *output = decode_vector_[index].second; + return true; + } + } + + ~LempelZiv() { + for (auto it = (root_.next_number).begin(); it != (root_.next_number).end(); + ++it) { + CleanUp(it->second); + } + } + // Adds a single dictionary element while decoding + // void AddDictElement(const std::pair &newdict) { + // EdgeEquals InstEdgeEquals; + // if (InstEdgeEquals(newdict.second, default_edge_) != 1) + // decode_vector_.push_back(newdict); + // } + + private: + // Node datastructure is used for encoding + + struct Node { + Var current_number; + Edge current_edge; + std::map next_number; + }; + + void CleanUp(Node *temp) { + for (auto it = (temp->next_number).begin(); it != (temp->next_number).end(); + ++it) { + CleanUp(it->second); + } + delete temp; + } + Node root_; + Var dict_number_; + // decode_vector_ is used for decoding + std::vector> decode_vector_; + Edge default_edge_; +}; + +template +void LempelZiv::BatchEncode( + const std::vector &input, std::vector> *output) { + for (typename std::vector::const_iterator it = input.begin(); + it != input.end(); ++it) { + Node *temp_node = &root_; + while (it != input.end()) { + auto next = (temp_node->next_number).find(*it); + if (next != (temp_node->next_number).end()) { + temp_node = next->second; + ++it; + } else { + break; + } + } + if (it == input.end() && temp_node->current_number != 0) { + output->push_back( + std::make_pair(temp_node->current_number, default_edge_)); + } else if (it != input.end()) { + output->push_back(std::make_pair(temp_node->current_number, *it)); + Node *new_node = new (Node); + new_node->current_number = dict_number_++; + new_node->current_edge = *it; + (temp_node->next_number)[*it] = new_node; + } + if (it == input.end()) break; + } +} + +template +bool LempelZiv::BatchDecode( + const std::vector> &input, std::vector *output) { + for (typename std::vector>::const_iterator it = + input.begin(); + it != input.end(); ++it) { + std::vector temp_output; + EdgeEquals InstEdgeEquals; + if (InstEdgeEquals(it->second, default_edge_) != 1) { + decode_vector_.push_back(*it); + temp_output.push_back(it->second); + } + Var temp_integer = it->first; + if (temp_integer >= decode_vector_.size()) { + LOG(ERROR) << "LempelZiv::BatchDecode: " + << "Index exceeded the dictionary size"; + return false; + } else { + while (temp_integer != 0) { + temp_output.push_back(decode_vector_[temp_integer].second); + temp_integer = decode_vector_[temp_integer].first; + } + std::reverse(temp_output.begin(), temp_output.end()); + output->insert(output->end(), temp_output.begin(), temp_output.end()); + } + } + return true; +} + +// The main Compressor class +template +class Compressor { + public: + typedef typename Arc::StateId StateId; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + + Compressor() {} + + // Compresses fst into a boolean vector code. Returns true on sucesss. + bool Compress(const Fst &fst, std::ostream &strm); + + // Decompresses the boolean vector into Fst. Returns true on sucesss. + bool Decompress(std::istream &strm, const string &source, + MutableFst *fst); + + // Finds the BFS order of a fst + void BfsOrder(const ExpandedFst &fst, std::vector *order); + + // Preprocessing step to convert fst to a isomorphic fst + // Returns a preproccess fst and a dictionary + void Preprocess(const Fst &fst, MutableFst *preprocessedfst, + EncodeMapper *encoder); + + // Performs Lempel Ziv and outputs a stream of integers + // and sends it to a stream + void EncodeProcessedFst(const ExpandedFst &fst, std::ostream &strm); + + // Decodes fst from the stream + void DecodeProcessedFst(const std::vector &input, + MutableFst *fst, bool unweighted); + + // Converts buffer_code_ to uint8 and writes to a stream. + + // Writes the boolean file to the stream + void WriteToStream(std::ostream &strm); + + // Writes the weights to the stream + void WriteWeight(const std::vector &input, std::ostream &strm); + + void ReadWeight(std::istream &strm, std::vector *output); + + // Same as fst::Decode without the line RmFinalEpsilon(fst) + void DecodeForCompress(MutableFst *fst, const EncodeMapper &mapper); + + // Updates the buffer_code_ + template + void WriteToBuffer(CVar input) { + std::vector current_code; + Elias::DeltaEncode(input, ¤t_code); + if (!buffer_code_.empty()) { + buffer_code_.insert(buffer_code_.end(), current_code.begin(), + current_code.end()); + } else { + buffer_code_.assign(current_code.begin(), current_code.end()); + } + } + + private: + struct LZLabel { + LZLabel() : label(0) {} + Label label; + }; + + struct LabelLessThan { + bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { + return labelone.label < labeltwo.label; + } + }; + + struct LabelEquals { + bool operator()(const LZLabel &labelone, const LZLabel &labeltwo) const { + return labelone.label == labeltwo.label; + } + }; + + struct Transition { + Transition() : nextstate(0), label(0), weight(Weight::Zero()) {} + + StateId nextstate; + Label label; + Weight weight; + }; + + struct TransitionLessThan { + bool operator()(const Transition &transition_one, + const Transition &transition_two) const { + if (transition_one.nextstate == transition_two.nextstate) + return transition_one.label < transition_two.label; + else + return transition_one.nextstate < transition_two.nextstate; + } + } transition_less_than; + + struct TransitionEquals { + bool operator()(const Transition &transition_one, + const Transition &transition_two) const { + return transition_one.nextstate == transition_two.nextstate && + transition_one.label == transition_two.label; + } + } transition_equals; + + struct OldDictCompare { + bool operator()(const std::pair &pair_one, + const std::pair &pair_two) const { + if ((pair_one.second).nextstate == (pair_two.second).nextstate) + return (pair_one.second).label < (pair_two.second).label; + else + return (pair_one.second).nextstate < (pair_two.second).nextstate; + } + } old_dict_compare; + + std::vector buffer_code_; + std::vector arc_weight_; + std::vector final_weight_; +}; + +template +inline void Compressor::DecodeForCompress( + MutableFst *fst, const EncodeMapper &mapper) { + ArcMap(fst, EncodeMapper(mapper, DECODE)); + fst->SetInputSymbols(mapper.InputSymbols()); + fst->SetOutputSymbols(mapper.OutputSymbols()); +} + +// Compressor::BfsOrder +template +void Compressor::BfsOrder(const ExpandedFst &fst, + std::vector *order) { + Arc arc; + StateId bfs_visit_number = 0; + std::queue states_queue; + order->assign(fst.NumStates(), kNoStateId); + states_queue.push(fst.Start()); + (*order)[fst.Start()] = bfs_visit_number++; + while (!states_queue.empty()) { + for (ArcIterator> aiter(fst, states_queue.front()); !aiter.Done(); + aiter.Next()) { + arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if ((*order)[nextstate] == kNoStateId) { + (*order)[nextstate] = bfs_visit_number++; + states_queue.push(nextstate); + } + } + states_queue.pop(); + } + + // If the FST is unconnected, then the following + // code finds them + while (bfs_visit_number < fst.NumStates()) { + int unseen_state = 0; + for (unseen_state = 0; unseen_state < fst.NumStates(); ++unseen_state) { + if ((*order)[unseen_state] == kNoStateId) break; + } + states_queue.push(unseen_state); + (*order)[unseen_state] = bfs_visit_number++; + while (!states_queue.empty()) { + for (ArcIterator> aiter(fst, states_queue.front()); + !aiter.Done(); aiter.Next()) { + arc = aiter.Value(); + StateId nextstate = arc.nextstate; + if ((*order)[nextstate] == kNoStateId) { + (*order)[nextstate] = bfs_visit_number++; + states_queue.push(nextstate); + } + } + states_queue.pop(); + } + } +} + +template +void Compressor::Preprocess(const Fst &fst, + MutableFst *preprocessedfst, + EncodeMapper *encoder) { + *preprocessedfst = fst; + if (!preprocessedfst->NumStates()) { + return; + } + // Relabels the edges and develops a dictionary + Encode(preprocessedfst, encoder); + std::vector order; + // Finds the BFS sorting order of the fst + BfsOrder(*preprocessedfst, &order); + // Reorders the states according to the BFS order + StateSort(preprocessedfst, order); +} + +template +void Compressor::EncodeProcessedFst(const ExpandedFst &fst, + std::ostream &strm) { + std::vector output; + LempelZiv dict_new; + LempelZiv dict_old; + std::vector current_new_input; + std::vector current_old_input; + std::vector> current_new_output; + std::vector> current_old_output; + std::vector final_states; + + StateId number_of_states = fst.NumStates(); + + StateId seen_states = 0; + // Adding the number of states + WriteToBuffer(number_of_states); + + for (StateId state = 0; state < number_of_states; ++state) { + current_new_input.clear(); + current_old_input.clear(); + current_new_output.clear(); + current_old_output.clear(); + if (state > seen_states) ++seen_states; + + // Collecting the final states + if (fst.Final(state) != Weight::Zero()) { + final_states.push_back(state); + final_weight_.push_back(fst.Final(state)); + } + + // Reading the states + for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { + Arc arc = aiter.Value(); + if (arc.nextstate > seen_states) { // RILEY: > or >= ? + ++seen_states; + LZLabel temp_label; + temp_label.label = arc.ilabel; + arc_weight_.push_back(arc.weight); + current_new_input.push_back(temp_label); + } else { + Transition temp_transition; + temp_transition.nextstate = arc.nextstate; + temp_transition.label = arc.ilabel; + temp_transition.weight = arc.weight; + current_old_input.push_back(temp_transition); + } + } + // Adding new states + dict_new.BatchEncode(current_new_input, ¤t_new_output); + WriteToBuffer(current_new_output.size()); + + for (auto it = current_new_output.begin(); it != current_new_output.end(); + ++it) { + WriteToBuffer(it->first); + WriteToBuffer. +// See the FarReader interface in far.h for the exact semantics. +class FarReaderImplBase { + public: + virtual const string &ArcType() const = 0; + virtual bool Done() const = 0; + virtual bool Error() const = 0; + virtual const string &GetKey() const = 0; + virtual const FstClass *GetFstClass() const = 0; + virtual bool Find(const string &key) = 0; + virtual void Next() = 0; + virtual void Reset() = 0; + virtual FarType Type() const = 0; + virtual ~FarReaderImplBase() {} +}; + +// Templated implementation. +template +class FarReaderClassImpl : public FarReaderImplBase { + public: + explicit FarReaderClassImpl(const string &filename) + : impl_(FarReader::Open(filename)) {} + + explicit FarReaderClassImpl(const std::vector &filenames) + : impl_(FarReader::Open(filenames)) {} + + const string &ArcType() const final { return Arc::Type(); } + + bool Done() const final { return impl_->Done(); } + + bool Error() const final { return impl_->Error(); } + + bool Find(const string &key) final { return impl_->Find(key); } + + const FstClass *GetFstClass() const final { + fstc_.reset(new FstClass(*impl_->GetFst())); + return fstc_.get(); + } + + const string &GetKey() const final { return impl_->GetKey(); } + + void Next() final { return impl_->Next(); } + + void Reset() final { impl_->Reset(); } + + FarType Type() const final { return impl_->Type(); } + + const FarReader *GetImpl() const { return impl_.get(); } + + FarReader *GetImpl() { return impl_.get(); } + + private: + std::unique_ptr> impl_; + mutable std::unique_ptr fstc_; +}; + + +class FarReaderClass; + +using OpenFarReaderClassArgs = + WithReturnValue &>; + +// Untemplated user-facing class holding a templated pimpl. +class FarReaderClass { + public: + const string &ArcType() const { return impl_->ArcType(); } + + bool Done() const { return impl_->Done(); } + + // Returns True if the impl is null (i.e., due to read failure). + // Attempting to call any other function will result in null dereference. + bool Error() const { return (impl_) ? impl_->Error() : true; } + + bool Find(const string &key) { return impl_->Find(key); } + + const FstClass *GetFstClass() const { return impl_->GetFstClass(); } + + const string &GetKey() const { return impl_->GetKey(); } + + void Next() { impl_->Next(); } + + void Reset() { impl_->Reset(); } + + FarType Type() const { return impl_->Type(); } + + template + const FarReader *GetFarReader() const { + if (Arc::Type() != ArcType()) return nullptr; + const FarReaderClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + FarReader *GetFarReader() { + if (Arc::Type() != ArcType()) return nullptr; + FarReaderClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + friend void OpenFarReaderClass(OpenFarReaderClassArgs *args); + + // Defined in the CC. + + static FarReaderClass *Open(const string &filename); + + static FarReaderClass *Open(const std::vector &filenames); + + private: + template + explicit FarReaderClass(FarReaderClassImpl *impl) : impl_(impl) {} + + std::unique_ptr impl_; +}; + +// These exist solely for registration purposes; users should call the +// static method FarReaderClass::Open instead. + +template +void OpenFarReaderClass(OpenFarReaderClassArgs *args) { + args->retval = new FarReaderClass(new FarReaderClassImpl(args->args)); +} + +// FarWriter API. + +// Virtual interface implemented by each concrete FarWriterImpl. +class FarWriterImplBase { + public: + // Unlike the lower-level library, this returns a boolean to signal failure + // due to non-conformant arc types. + virtual bool Add(const string &key, const FstClass &fst) = 0; + virtual const string &ArcType() const = 0; + virtual bool Error() const = 0; + virtual FarType Type() const = 0; + virtual ~FarWriterImplBase() {} +}; + + +// Templated implementation. +template +class FarWriterClassImpl : public FarWriterImplBase { + public: + explicit FarWriterClassImpl(const string &filename, + FarType type = FAR_DEFAULT) + : impl_(FarWriter::Create(filename, type)) {} + + bool Add(const string &key, const FstClass &fst) final { + if (ArcType() != fst.ArcType()) { + FSTERROR() << "Cannot write FST with " << fst.ArcType() << " arcs to " + << "FAR with " << ArcType() << " arcs"; + return false; + } + impl_->Add(key, *(fst.GetFst())); + return true; + } + + const string &ArcType() const final { return Arc::Type(); } + + bool Error() const final { return impl_->Error(); } + + FarType Type() const final { return impl_->Type(); } + + const FarWriter *GetImpl() const { return impl_.get(); } + + FarWriter *GetImpl() { return impl_.get(); } + + private: + std::unique_ptr> impl_; +}; + + +class FarWriterClass; + +using CreateFarWriterClassInnerArgs = std::pair; + +using CreateFarWriterClassArgs = + WithReturnValue; + +// Untemplated user-facing class holding a templated pimpl. +class FarWriterClass { + public: + static FarWriterClass *Create(const string &filename, const string &arc_type, + FarType type = FAR_DEFAULT); + + bool Add(const string &key, const FstClass &fst) { + return impl_->Add(key, fst); + } + + // Returns True if the impl is null (i.e., due to construction failure). + // Attempting to call any other function will result in null dereference. + bool Error() const { return (impl_) ? impl_->Error() : true; } + + const string &ArcType() const { return impl_->ArcType(); } + + FarType Type() const { return impl_->Type(); } + + template + const FarWriter *GetFarWriter() const { + if (Arc::Type() != ArcType()) return nullptr; + const FarWriterClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + FarWriter *GetFarWriter() { + if (Arc::Type() != ArcType()) return nullptr; + FarWriterClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + + template + friend void CreateFarWriterClass(CreateFarWriterClassArgs *args); + + private: + template + explicit FarWriterClass(FarWriterClassImpl *impl) : impl_(impl) {} + + std::unique_ptr impl_; +}; + +// This exists solely for registration purposes; users should call the +// static method FarWriterClass::Create instead. +template +void CreateFarWriterClass(CreateFarWriterClassArgs *args) { + args->retval = new FarWriterClass(new FarWriterClassImpl( + std::get<0>(args->args), std::get<1>(args->args))); +} + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_CLASS_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/far.h b/projects/llm_framework/include/fst/extensions/far/far.h new file mode 100644 index 00000000..c24c7dab --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/far.h @@ -0,0 +1,481 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Finite-State Transducer (FST) archive classes. + +#ifndef FST_EXTENSIONS_FAR_FAR_H_ +#define FST_EXTENSIONS_FAR_FAR_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace fst { + +enum FarEntryType { FET_LINE, FET_FILE }; + +enum FarTokenType { FTT_SYMBOL, FTT_BYTE, FTT_UTF8 }; + +inline bool IsFst(const string &filename) { + std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); + if (!strm) return false; + return IsFstHeader(strm, filename); +} + +// FST archive header class +class FarHeader { + public: + const string &ArcType() const { return arctype_; } + + const string &FarType() const { return fartype_; } + + bool Read(const string &filename) { + FstHeader fsthdr; + if (filename.empty()) { + // Header reading unsupported on stdin. Assumes STList and StdArc. + fartype_ = "stlist"; + arctype_ = "standard"; + return true; + } else if (IsSTTable(filename)) { // Checks if STTable. + ReadSTTableHeader(filename, &fsthdr); + fartype_ = "sttable"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsSTList(filename)) { // Checks if STList. + ReadSTListHeader(filename, &fsthdr); + fartype_ = "stlist"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } else if (IsFst(filename)) { // Checks if FST. + std::ifstream istrm(filename, + std::ios_base::in | std::ios_base::binary); + fsthdr.Read(istrm, filename); + fartype_ = "fst"; + arctype_ = fsthdr.ArcType().empty() ? "unknown" : fsthdr.ArcType(); + return true; + } + return false; + } + + private: + string fartype_; + string arctype_; +}; + +enum FarType { + FAR_DEFAULT = 0, + FAR_STTABLE = 1, + FAR_STLIST = 2, + FAR_FST = 3, +}; + +// This class creates an archive of FSTs. +template +class FarWriter { + public: + using Arc = A; + + // Creates a new (empty) FST archive; returns null on error. + static FarWriter *Create(const string &filename, FarType type = FAR_DEFAULT); + + // Adds an FST to the end of an archive. Keys must be non-empty and + // in lexicographic order. FSTs must have a suitable write method. + virtual void Add(const string &key, const Fst &fst) = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarWriter() {} + + protected: + FarWriter() {} +}; + +// This class iterates through an existing archive of FSTs. +template +class FarReader { + public: + using Arc = A; + + // Opens an existing FST archive in a single file; returns null on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const string &filename); + + // Opens an existing FST archive in multiple files; returns null on error. + // Sets current position to the beginning of the achive. + static FarReader *Open(const std::vector &filenames); + + // Resets current position to beginning of archive. + virtual void Reset() = 0; + + // Sets current position to first entry >= key. Returns true if a match. + virtual bool Find(const string &key) = 0; + + // Current position at end of archive? + virtual bool Done() const = 0; + + // Move current position to next FST. + virtual void Next() = 0; + + // Returns key at the current position. This reference is invalidated if + // the current position in the archive is changed. + virtual const string &GetKey() const = 0; + + // Returns pointer to FST at the current position. This is invalidated if + // the current position in the archive is changed. + virtual const Fst *GetFst() const = 0; + + virtual FarType Type() const = 0; + + virtual bool Error() const = 0; + + virtual ~FarReader() {} + + protected: + FarReader() {} +}; + +template +class FstWriter { + public: + void operator()(std::ostream &strm, const Fst &fst) const { + fst.Write(strm, FstWriteOptions()); + } +}; + +template +class STTableFarWriter : public FarWriter { + public: + using Arc = A; + + static STTableFarWriter *Create(const string &filename) { + auto *writer = STTableWriter, FstWriter>::Create(filename); + return new STTableFarWriter(writer); + } + + void Add(const string &key, const Fst &fst) final { + writer_->Add(key, fst); + } + + FarType Type() const final { return FAR_STTABLE; } + + bool Error() const final { return writer_->Error(); } + + private: + explicit STTableFarWriter(STTableWriter, FstWriter> *writer) + : writer_(writer) {} + + std::unique_ptr, FstWriter>> writer_; +}; + +template +class STListFarWriter : public FarWriter { + public: + using Arc = A; + + static STListFarWriter *Create(const string &filename) { + auto *writer = STListWriter, FstWriter>::Create(filename); + return new STListFarWriter(writer); + } + + void Add(const string &key, const Fst &fst) final { + writer_->Add(key, fst); + } + + constexpr FarType Type() const final { return FAR_STLIST; } + + bool Error() const final { return writer_->Error(); } + + private: + explicit STListFarWriter(STListWriter, FstWriter> *writer) + : writer_(writer) {} + + std::unique_ptr, FstWriter>> writer_; +}; + +template +class FstFarWriter : public FarWriter { + public: + using Arc = A; + + explicit FstFarWriter(const string &filename) + : filename_(filename), error_(false), written_(false) {} + + static FstFarWriter *Create(const string &filename) { + return new FstFarWriter(filename); + } + + void Add(const string &key, const Fst &fst) final { + if (written_) { + LOG(WARNING) << "FstFarWriter::Add: only one FST supported," + << " subsequent entries discarded."; + } else { + error_ = !fst.Write(filename_); + written_ = true; + } + } + + constexpr FarType Type() const final { return FAR_FST; } + + bool Error() const final { return error_; } + + ~FstFarWriter() final {} + + private: + string filename_; + bool error_; + bool written_; +}; + +template +FarWriter *FarWriter::Create(const string &filename, FarType type) { + switch (type) { + case FAR_DEFAULT: + if (filename.empty()) return STListFarWriter::Create(filename); + case FAR_STTABLE: + return STTableFarWriter::Create(filename); + case FAR_STLIST: + return STListFarWriter::Create(filename); + case FAR_FST: + return FstFarWriter::Create(filename); + default: + LOG(ERROR) << "FarWriter::Create: Unknown FAR type"; + return nullptr; + } +} + +template +class FstReader { + public: + Fst *operator()(std::istream &strm) const { + return Fst::Read(strm, FstReadOptions()); + } +}; + +template +class STTableFarReader : public FarReader { + public: + using Arc = A; + + static STTableFarReader *Open(const string &filename) { + auto *reader = STTableReader, FstReader>::Open(filename); + if (!reader || reader->Error()) return nullptr; + return new STTableFarReader(reader); + } + + static STTableFarReader *Open(const std::vector &filenames) { + auto *reader = STTableReader, FstReader>::Open(filenames); + if (!reader || reader->Error()) return nullptr; + return new STTableFarReader(reader); + } + + void Reset() final { reader_->Reset(); } + + bool Find(const string &key) final { return reader_->Find(key); } + + bool Done() const final { return reader_->Done(); } + + void Next() final { return reader_->Next(); } + + const string &GetKey() const final { return reader_->GetKey(); } + + const Fst *GetFst() const final { return reader_->GetEntry(); } + + constexpr FarType Type() const final { return FAR_STTABLE; } + + bool Error() const final { return reader_->Error(); } + + private: + explicit STTableFarReader(STTableReader, FstReader> *reader) + : reader_(reader) {} + + std::unique_ptr, FstReader>> reader_; +}; + +template +class STListFarReader : public FarReader { + public: + using Arc = A; + + static STListFarReader *Open(const string &filename) { + auto *reader = STListReader, FstReader>::Open(filename); + if (!reader || reader->Error()) return nullptr; + return new STListFarReader(reader); + } + + static STListFarReader *Open(const std::vector &filenames) { + auto *reader = STListReader, FstReader>::Open(filenames); + if (!reader || reader->Error()) return nullptr; + return new STListFarReader(reader); + } + + void Reset() final { reader_->Reset(); } + + bool Find(const string &key) final { return reader_->Find(key); } + + bool Done() const final { return reader_->Done(); } + + void Next() final { return reader_->Next(); } + + const string &GetKey() const final { return reader_->GetKey(); } + + const Fst *GetFst() const final { return reader_->GetEntry(); } + + constexpr FarType Type() const final { return FAR_STLIST; } + + bool Error() const final { return reader_->Error(); } + + private: + explicit STListFarReader(STListReader, FstReader> *reader) + : reader_(reader) {} + + std::unique_ptr, FstReader>> reader_; +}; + +template +class FstFarReader : public FarReader { + public: + using Arc = A; + + static FstFarReader *Open(const string &filename) { + std::vector filenames; + filenames.push_back(filename); + return new FstFarReader(filenames); + } + + static FstFarReader *Open(const std::vector &filenames) { + return new FstFarReader(filenames); + } + + explicit FstFarReader(const std::vector &filenames) + : keys_(filenames), has_stdin_(false), pos_(0), error_(false) { + std::sort(keys_.begin(), keys_.end()); + streams_.resize(keys_.size(), 0); + for (size_t i = 0; i < keys_.size(); ++i) { + if (keys_[i].empty()) { + if (!has_stdin_) { + streams_[i] = &std::cin; + has_stdin_ = true; + } else { + FSTERROR() << "FstFarReader::FstFarReader: standard input should " + "only appear once in the input file list"; + error_ = true; + return; + } + } else { + streams_[i] = new std::ifstream( + keys_[i], std::ios_base::in | std::ios_base::binary); + if (streams_[i]->fail()) { + FSTERROR() << "FstFarReader::FstFarReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + } + } + if (pos_ >= keys_.size()) return; + ReadFst(); + } + + void Reset() final { + if (has_stdin_) { + FSTERROR() + << "FstFarReader::Reset: Operation not supported on standard input"; + error_ = true; + return; + } + pos_ = 0; + ReadFst(); + } + + bool Find(const string &key) final { + if (has_stdin_) { + FSTERROR() + << "FstFarReader::Find: Operation not supported on standard input"; + error_ = true; + return false; + } + pos_ = 0; // TODO + ReadFst(); + return true; + } + + bool Done() const final { return error_ || pos_ >= keys_.size(); } + + void Next() final { + ++pos_; + ReadFst(); + } + + const string &GetKey() const final { return keys_[pos_]; } + + const Fst *GetFst() const final { return fst_.get(); } + + constexpr FarType Type() const final { return FAR_FST; } + + bool Error() const final { return error_; } + + ~FstFarReader() final { + for (size_t i = 0; i < keys_.size(); ++i) { + if (streams_[i] != &std::cin) { + delete streams_[i]; + } + } + } + + private: + void ReadFst() { + fst_.reset(); + if (pos_ >= keys_.size()) return; + streams_[pos_]->seekg(0); + fst_.reset(Fst::Read(*streams_[pos_], FstReadOptions())); + if (!fst_) { + FSTERROR() << "FstFarReader: Error reading Fst from: " << keys_[pos_]; + error_ = true; + } + } + + std::vector keys_; + std::vector streams_; + bool has_stdin_; + size_t pos_; + mutable std::unique_ptr> fst_; + mutable bool error_; +}; + +template +FarReader *FarReader::Open(const string &filename) { + if (filename.empty()) + return STListFarReader::Open(filename); + else if (IsSTTable(filename)) + return STTableFarReader::Open(filename); + else if (IsSTList(filename)) + return STListFarReader::Open(filename); + else if (IsFst(filename)) + return FstFarReader::Open(filename); + return nullptr; +} + +template +FarReader *FarReader::Open(const std::vector &filenames) { + if (!filenames.empty() && filenames[0].empty()) + return STListFarReader::Open(filenames); + else if (!filenames.empty() && IsSTTable(filenames[0])) + return STTableFarReader::Open(filenames); + else if (!filenames.empty() && IsSTList(filenames[0])) + return STListFarReader::Open(filenames); + else if (!filenames.empty() && IsFst(filenames[0])) + return FstFarReader::Open(filenames); + return nullptr; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_FAR_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/farlib.h b/projects/llm_framework/include/fst/extensions/far/farlib.h new file mode 100644 index 00000000..c9bb1710 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/farlib.h @@ -0,0 +1,19 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// A finite-state archive (FAR) is used to store an indexable collection of +// FSTs in a single file. Utilities are provided to create FARs from FSTs, +// to iterate over FARs, and to extract specific FSTs from FARs. + +#ifndef FST_EXTENSIONS_FAR_FARLIB_H_ +#define FST_EXTENSIONS_FAR_FARLIB_H_ + +#include +#include +#include +#include +#include +#include +#include + +#endif // FST_EXTENSIONS_FAR_FARLIB_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/farscript.h b/projects/llm_framework/include/fst/extensions/far/farscript.h new file mode 100644 index 00000000..4bd11a94 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/farscript.h @@ -0,0 +1,269 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Convenience file for including all of the FAR operations, or registering +// them for new arc types. + +#ifndef FST_EXTENSIONS_FAR_FARSCRIPT_H_ +#define FST_EXTENSIONS_FAR_FARSCRIPT_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because this struct is +// only used to pass them deeper in the call graph. Be sure you understand why +// this is so before using this struct for anything else! +struct FarCompileStringsArgs { + const std::vector &in_fnames; + const string &out_fname; + const string &fst_type; + const FarType &far_type; + const int32 generate_keys; + const FarEntryType fet; + const FarTokenType tt; + const string &symbols_fname; + const string &unknown_symbol; + const bool keep_symbols; + const bool initial_symbols; + const bool allow_negative_labels; + const string &key_prefix; + const string &key_suffix; + + FarCompileStringsArgs(const std::vector &in_fnames, + const string &out_fname, const string &fst_type, + const FarType &far_type, int32 generate_keys, + FarEntryType fet, FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, + const string &key_prefix, const string &key_suffix) + : in_fnames(in_fnames), + out_fname(out_fname), + fst_type(fst_type), + far_type(far_type), + generate_keys(generate_keys), + fet(fet), + tt(tt), + symbols_fname(symbols_fname), + unknown_symbol(unknown_symbol), + keep_symbols(keep_symbols), + initial_symbols(initial_symbols), + allow_negative_labels(allow_negative_labels), + key_prefix(key_prefix), + key_suffix(key_suffix) {} +}; + +template +void FarCompileStrings(FarCompileStringsArgs *args) { + FarCompileStrings( + args->in_fnames, args->out_fname, args->fst_type, args->far_type, + args->generate_keys, args->fet, args->tt, args->symbols_fname, + args->unknown_symbol, args->keep_symbols, args->initial_symbols, + args->allow_negative_labels, args->key_prefix, args->key_suffix); +} + +void FarCompileStrings(const std::vector &in_fnames, + const string &out_fname, const string &arc_type, + const string &fst_type, const FarType &far_type, + int32 generate_keys, FarEntryType fet, FarTokenType tt, + const string &symbols_fname, + const string &unknown_symbol, bool keep_symbols, + bool initial_symbols, bool allow_negative_labels, + const string &key_prefix, const string &key_suffix); + +// Note: it is safe to pass these strings as references because this struct is +// only used to pass them deeper in the call graph. Be sure you understand why +// this is so before using this struct for anything else! +struct FarCreateArgs { + const std::vector &in_fnames; + const string &out_fname; + const int32 generate_keys; + const FarType &far_type; + const string &key_prefix; + const string &key_suffix; + + FarCreateArgs(const std::vector &in_fnames, const string &out_fname, + const int32 generate_keys, const FarType &far_type, + const string &key_prefix, const string &key_suffix) + : in_fnames(in_fnames), + out_fname(out_fname), + generate_keys(generate_keys), + far_type(far_type), + key_prefix(key_prefix), + key_suffix(key_suffix) {} +}; + +template +void FarCreate(FarCreateArgs *args) { + FarCreate(args->in_fnames, args->out_fname, args->generate_keys, + args->far_type, args->key_prefix, args->key_suffix); +} + +void FarCreate(const std::vector &in_fnames, const string &out_fname, + const string &arc_type, const int32 generate_keys, + const FarType &far_type, const string &key_prefix, + const string &key_suffix); + +using FarEqualInnerArgs = std::tuple; + +using FarEqualArgs = WithReturnValue; + +template +void FarEqual(FarEqualArgs *args) { + args->retval = fst::FarEqual( + std::get<0>(args->args), std::get<1>(args->args), std::get<2>(args->args), + std::get<3>(args->args), std::get<4>(args->args)); +} + +bool FarEqual(const string &filename1, const string &filename2, + const string &arc_type, float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + +using FarExtractArgs = + std::tuple &, int32, const string &, + const string &, const string &, const string &, const string &>; + +template +void FarExtract(FarExtractArgs *args) { + fst::FarExtract(std::get<0>(*args), std::get<1>(*args), + std::get<2>(*args), std::get<3>(*args), + std::get<4>(*args), std::get<5>(*args), + std::get<6>(*args)); +} + +void FarExtract(const std::vector &ifilenames, const string &arc_type, + int32 generate_filenames, const string &keys, + const string &key_separator, const string &range_delimiter, + const string &filename_prefix, const string &filename_suffix); + +using FarInfoArgs = std::tuple &, const string &, + const string &, const bool>; + +template +void FarInfo(FarInfoArgs *args) { + fst::FarInfo(std::get<0>(*args), std::get<1>(*args), + std::get<2>(*args), std::get<3>(*args)); +} + +void FarInfo(const std::vector &filenames, const string &arc_type, + const string &begin_key, const string &end_key, + const bool list_fsts); + +using GetFarInfoArgs = std::tuple &, const string &, + const string &, const bool, FarInfoData *>; + +template +void GetFarInfo(GetFarInfoArgs *args) { + fst::GetFarInfo(std::get<0>(*args), std::get<1>(*args), + std::get<2>(*args), std::get<3>(*args), + std::get<4>(*args)); +} + +void GetFarInfo(const std::vector &filenames, const string &arc_type, + const string &begin_key, const string &end_key, + const bool list_fsts, FarInfoData *); + +using FarIsomorphicInnerArgs = std::tuple; + +using FarIsomorphicArgs = WithReturnValue; + +template +void FarIsomorphic(FarIsomorphicArgs *args) { + args->retval = fst::FarIsomorphic( + std::get<0>(args->args), std::get<1>(args->args), std::get<2>(args->args), + std::get<3>(args->args), std::get<4>(args->args)); +} + +bool FarIsomorphic(const string &filename1, const string &filename2, + const string &arc_type, float delta = kDelta, + const string &begin_key = string(), + const string &end_key = string()); + +struct FarPrintStringsArgs { + const std::vector &ifilenames; + const FarEntryType entry_type; + const FarTokenType token_type; + const string &begin_key; + const string &end_key; + const bool print_key; + const bool print_weight; + const string &symbols_fname; + const bool initial_symbols; + const int32 generate_filenames; + const string &filename_prefix; + const string &filename_suffix; + + FarPrintStringsArgs(const std::vector &ifilenames, + const FarEntryType entry_type, + const FarTokenType token_type, const string &begin_key, + const string &end_key, const bool print_key, + const bool print_weight, const string &symbols_fname, + const bool initial_symbols, + const int32 generate_filenames, + const string &filename_prefix, + const string &filename_suffix) + : ifilenames(ifilenames), + entry_type(entry_type), + token_type(token_type), + begin_key(begin_key), + end_key(end_key), + print_key(print_key), + print_weight(print_weight), + symbols_fname(symbols_fname), + initial_symbols(initial_symbols), + generate_filenames(generate_filenames), + filename_prefix(filename_prefix), + filename_suffix(filename_suffix) {} +}; + +template +void FarPrintStrings(FarPrintStringsArgs *args) { + fst::FarPrintStrings( + args->ifilenames, args->entry_type, args->token_type, args->begin_key, + args->end_key, args->print_key, args->print_weight, args->symbols_fname, + args->initial_symbols, args->generate_filenames, args->filename_prefix, + args->filename_suffix); +} + +void FarPrintStrings(const std::vector &ifilenames, + const string &arc_type, const FarEntryType entry_type, + const FarTokenType token_type, const string &begin_key, + const string &end_key, const bool print_key, + const bool print_weight, const string &symbols_fname, + const bool initial_symbols, const int32 generate_filenames, + const string &filename_prefix, + const string &filename_suffix); + +} // namespace script +} // namespace fst + +#define REGISTER_FST_FAR_OPERATIONS(ArcType) \ + REGISTER_FST_OPERATION(FarCompileStrings, ArcType, FarCompileStringsArgs); \ + REGISTER_FST_OPERATION(FarCreate, ArcType, FarCreateArgs); \ + REGISTER_FST_OPERATION(FarEqual, ArcType, FarEqualArgs); \ + REGISTER_FST_OPERATION(FarExtract, ArcType, FarExtractArgs); \ + REGISTER_FST_OPERATION(FarInfo, ArcType, FarInfoArgs); \ + REGISTER_FST_OPERATION(FarIsomorphic, ArcType, FarIsomorphicArgs); \ + REGISTER_FST_OPERATION(FarPrintStrings, ArcType, FarPrintStringsArgs); \ + REGISTER_FST_OPERATION(GetFarInfo, ArcType, GetFarInfoArgs) + +#endif // FST_EXTENSIONS_FAR_FARSCRIPT_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/getters.h b/projects/llm_framework/include/fst/extensions/far/getters.h new file mode 100644 index 00000000..3dde4194 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/getters.h @@ -0,0 +1,30 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions for registering and invoking FAR main +// functions that support multiple and extensible arc types. + +#ifndef FST_EXTENSIONS_FAR_GETTERS_H_ +#define FST_EXTENSIONS_FAR_GETTERS_H_ + +#include +#include + +namespace fst { +namespace script { + +FarType GetFarType(const string &str); + +bool GetFarEntryType(const string &str, FarEntryType *entry_type); + +bool GetFarTokenType(const string &str, FarTokenType *token_type); + +void ExpandArgs(int argc, char **argv, int *argcp, char ***argvp); + +} // namespace script + +string GetFarTypeString(FarType type); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_GETTERS_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/info.h b/projects/llm_framework/include/fst/extensions/far/info.h new file mode 100644 index 00000000..0391c1f4 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/info.h @@ -0,0 +1,147 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_FAR_INFO_H_ +#define FST_EXTENSIONS_FAR_INFO_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fst { + +template +void AccumulateStatesAndArcs(const Fst &fst, size_t *nstate, size_t *narc, + size_t *nfinal) { + for (StateIterator> siter(fst); !siter.Done(); + siter.Next(), ++(*nstate)) { + ArcIterator> aiter(fst, siter.Value()); + for (; !aiter.Done(); aiter.Next(), ++(*narc)) { + } + if (fst.Final(siter.Value()) != Arc::Weight::Zero()) ++(*nfinal); + } +} + +struct KeyInfo { + string key; + string type; + size_t nstate = 0; + size_t narc = 0; + size_t nfinal = 0; +}; + +struct FarInfoData { + std::vector key_infos; + string far_type; + string arc_type; + size_t nfst = 0; + size_t nstate = 0; + size_t narc = 0; + size_t nfinal = 0; + std::set fst_types; +}; + +template +void GetFarInfo(const std::vector &filenames, const string &begin_key, + const string &end_key, const bool list_fsts, + FarInfoData *far_info) { + *far_info = FarInfoData(); + std::unique_ptr> reader(FarReader::Open(filenames)); + if (!reader) { + LOG(ERROR) << "GetFarInfo: failed to create far reader."; + return; + } + if (!begin_key.empty()) reader->Find(begin_key); + + for (; !reader->Done(); reader->Next()) { + const auto &key = reader->GetKey(); + if (!end_key.empty() && end_key < key) break; + ++far_info->nfst; + const auto *fst = reader->GetFst(); + far_info->fst_types.insert(fst->Type()); + if (list_fsts) { + KeyInfo info; + info.key = key; + info.type = fst->Type(); + AccumulateStatesAndArcs(*fst, &info.nstate, &info.narc, &info.nfinal); + far_info->nstate += info.nstate; + far_info->narc += info.narc; + far_info->nfinal += info.nfinal; + far_info->key_infos.push_back(info); + } else { + AccumulateStatesAndArcs(*fst, &far_info->nstate, &far_info->narc, + &far_info->nfinal); + } + } + far_info->far_type = GetFarTypeString(reader->Type()); + far_info->arc_type = Arc::Type(); +} + +template +void FarInfo(const std::vector &filenames, const string &begin_key, + const string &end_key, const bool list_fsts) { + FarInfoData info; + GetFarInfo(filenames, begin_key, end_key, list_fsts, &info); + if (!list_fsts) { + std::cout << std::left << std::setw(50) << "far type" << info.far_type + << std::endl; + std::cout << std::left << std::setw(50) << "arc type" << Arc::Type() + << std::endl; + std::cout << std::left << std::setw(50) << "fst type"; + for (auto iter = info.fst_types.begin(); iter != info.fst_types.end(); + ++iter) { + if (iter != info.fst_types.begin()) std::cout << ","; + std::cout << *iter; + } + std::cout << std::endl; + std::cout << std::left << std::setw(50) << "# of FSTs" << info.nfst + << std::endl; + std::cout << std::left << std::setw(50) << "total # of states" + << info.nstate << std::endl; + std::cout << std::left << std::setw(50) << "total # of arcs" << info.narc + << std::endl; + std::cout << std::left << std::setw(50) << "total # of final states" + << info.nfinal << std::endl; + } else { + // FIXME(kbg): Grok, then document this. + int wkey = 10; + int wtype = 10; + int wnstate = 14; + int wnarc = 12; + int wnfinal = 20; + for (const auto &key_info : info.key_infos) { + if (key_info.key.size() + 2 > wkey) wkey = key_info.key.size() + 2; + if (key_info.type.size() + 2 > wtype) wtype = key_info.type.size() + 2; + if (ceil(log10(key_info.nstate)) + 2 > wnstate) { + wnstate = ceil(log10(key_info.nstate)) + 2; + } + if (ceil(log10(key_info.narc)) + 2 > wnarc) { + wnarc = ceil(log10(key_info.narc)) + 2; + } + if (ceil(log10(key_info.nfinal)) + 2 > wnfinal) { + wnfinal = ceil(log10(key_info.nfinal)) + 2; + } + } + std::cout << std::left << std::setw(wkey) << "key" << std::setw(wtype) + << "type" << std::right << std::setw(wnstate) << "# of states" + << std::setw(wnarc) << "# of arcs" << std::setw(wnfinal) + << "# of final states" << std::endl; + for (const auto &key_info : info.key_infos) { + std::cout << std::left << std::setw(wkey) << key_info.key + << std::setw(wtype) << key_info.type << std::right + << std::setw(wnstate) << key_info.nstate << std::setw(wnarc) + << key_info.narc << std::setw(wnfinal) << key_info.nfinal + << std::endl; + } + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_INFO_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/isomorphic.h b/projects/llm_framework/include/fst/extensions/far/isomorphic.h new file mode 100644 index 00000000..1e6e9cb3 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/isomorphic.h @@ -0,0 +1,69 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_FAR_ISOMORPHIC_H_ +#define FST_EXTENSIONS_FAR_ISOMORPHIC_H_ + +#include +#include + +#include +#include + +namespace fst { + +template +bool FarIsomorphic(const string &filename1, const string &filename2, + float delta = kDelta, const string &begin_key = string(), + const string &end_key = string()) { + std::unique_ptr> reader1(FarReader::Open(filename1)); + if (!reader1) { + LOG(ERROR) << "FarIsomorphic: Cannot open FAR file " << filename1; + return false; + } + std::unique_ptr> reader2(FarReader::Open(filename2)); + if (!reader2) { + LOG(ERROR) << "FarIsomorphic: Cannot open FAR file " << filename2; + return false; + } + if (!begin_key.empty()) { + bool find_begin1 = reader1->Find(begin_key); + bool find_begin2 = reader2->Find(begin_key); + if (!find_begin1 || !find_begin2) { + bool ret = !find_begin1 && !find_begin2; + if (!ret) { + VLOG(1) << "FarIsomorphic: Key " << begin_key << " missing from " + << (find_begin1 ? "second" : "first") << " archive."; + } + return ret; + } + } + for (; !reader1->Done() && !reader2->Done(); + reader1->Next(), reader2->Next()) { + const auto &key1 = reader1->GetKey(); + const auto &key2 = reader2->GetKey(); + if (!end_key.empty() && end_key < key1 && end_key < key2) return true; + if (key1 != key2) { + LOG(ERROR) << "FarIsomorphic: Mismatched keys " << key1 << " and " + << key2; + return false; + } + if (!Isomorphic(*(reader1->GetFst()), *(reader2->GetFst()), delta)) { + LOG(ERROR) << "FarIsomorphic: FSTs for key " << key1 + << " are not isomorphic"; + return false; + } + } + if (!reader1->Done() || !reader2->Done()) { + LOG(ERROR) << "FarIsomorphic: Key " + << (reader1->Done() ? reader2->GetKey() : reader1->GetKey()) + << " missing form " << (reader2->Done() ? "first" : "second") + << " archive"; + return false; + } + return true; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_ISOMORPHIC_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/print-strings.h b/projects/llm_framework/include/fst/extensions/far/print-strings.h new file mode 100644 index 00000000..dc428401 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/print-strings.h @@ -0,0 +1,105 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Outputs as strings the string FSTs in a finite-state archive. + +#ifndef FST_EXTENSIONS_FAR_PRINT_STRINGS_H_ +#define FST_EXTENSIONS_FAR_PRINT_STRINGS_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include + +DECLARE_string(far_field_separator); + +namespace fst { + +template +void FarPrintStrings(const std::vector &ifilenames, + FarEntryType entry_type, FarTokenType far_token_type, + const string &begin_key, const string &end_key, + bool print_key, bool print_weight, + const string &symbols_fname, bool initial_symbols, + int32 generate_filenames, const string &filename_prefix, + const string &filename_suffix) { + StringTokenType token_type; + if (far_token_type == FTT_SYMBOL) { + token_type = StringTokenType::SYMBOL; + } else if (far_token_type == FTT_BYTE) { + token_type = StringTokenType::BYTE; + } else if (far_token_type == FTT_UTF8) { + token_type = StringTokenType::UTF8; + } else { + FSTERROR() << "FarPrintStrings: Unknown token type"; + return; + } + std::unique_ptr syms; + if (!symbols_fname.empty()) { + // TODO(kbg): Allow negative flag? + const SymbolTableTextOptions opts(true); + syms.reset(SymbolTable::ReadText(symbols_fname, opts)); + if (!syms) { + LOG(ERROR) << "FarPrintStrings: Error reading symbol table " + << symbols_fname; + return; + } + } + std::unique_ptr> far_reader(FarReader::Open(ifilenames)); + if (!far_reader) return; + if (!begin_key.empty()) far_reader->Find(begin_key); + string okey; + int nrep = 0; + for (int i = 1; !far_reader->Done(); far_reader->Next(), ++i) { + const auto &key = far_reader->GetKey(); + if (!end_key.empty() && end_key < key) break; + if (okey == key) { + ++nrep; + } else { + nrep = 0; + } + okey = key; + const auto *fst = far_reader->GetFst(); + if (i == 1 && initial_symbols && !syms && fst->InputSymbols()) + syms.reset(fst->InputSymbols()->Copy()); + string str; + VLOG(2) << "Handling key: " << key; + StringPrinter string_printer(token_type, + syms ? syms.get() : fst->InputSymbols()); + string_printer(*fst, &str); + if (entry_type == FET_LINE) { + if (print_key) std::cout << key << FLAGS_far_field_separator[0]; + std::cout << str; + if (print_weight) + std::cout << FLAGS_far_field_separator[0] << ShortestDistance(*fst); + std::cout << std::endl; + } else if (entry_type == FET_FILE) { + std::stringstream sstrm; + if (generate_filenames) { + sstrm.fill('0'); + sstrm << std::right << std::setw(generate_filenames) << i; + } else { + sstrm << key; + if (nrep > 0) sstrm << "." << nrep; + } + string filename; + filename = filename_prefix + sstrm.str() + filename_suffix; + std::ofstream ostrm(filename); + if (!ostrm) { + LOG(ERROR) << "FarPrintStrings: Can't open file: " << filename; + return; + } + ostrm << str; + if (token_type == StringTokenType::SYMBOL) ostrm << "\n"; + } + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_PRINT_STRINGS_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/script-impl.h b/projects/llm_framework/include/fst/extensions/far/script-impl.h new file mode 100644 index 00000000..a0586cc3 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/script-impl.h @@ -0,0 +1,23 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions for registering and invoking Far main +// functions that support multiple and extensible arc types. + +#ifndef FST_EXTENSIONS_FAR_SCRIPT_IMPL_H_ +#define FST_EXTENSIONS_FAR_SCRIPT_IMPL_H_ + +#include + +#include +namespace fst { +namespace script { + +string LoadArcTypeFromFar(const string &far_fname); + +string LoadArcTypeFromFst(const string &fst_fname); + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_SCRIPT_IMPL_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/stlist.h b/projects/llm_framework/include/fst/extensions/far/stlist.h new file mode 100644 index 00000000..b155e17e --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/stlist.h @@ -0,0 +1,273 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// A generic (string,type) list file format. +// +// This is a stripped-down version of STTable that does not support the Find() +// operation but that does support reading/writting from standard in/out. + +#ifndef FST_EXTENSIONS_FAR_STLIST_H_ +#define FST_EXTENSIONS_FAR_STLIST_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fst { + +static constexpr int32 kSTListMagicNumber = 5656924; +static constexpr int32 kSTListFileVersion = 1; + +// String-type list writing class for object of type T using a functor Writer. +// The Writer functor must provide at least the following interface: +// +// struct Writer { +// void operator()(std::ostream &, const T &) const; +// }; +template +class STListWriter { + public: + explicit STListWriter(const string &filename) + : stream_(filename.empty() ? &std::cout : new std::ofstream( + filename, + std::ios_base::out | + std::ios_base::binary)), + error_(false) { + WriteType(*stream_, kSTListMagicNumber); + WriteType(*stream_, kSTListFileVersion); + if (!stream_) { + FSTERROR() << "STListWriter::STListWriter: Error writing to file: " + << filename; + error_ = true; + } + } + + static STListWriter *Create(const string &filename) { + return new STListWriter(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STListWriter::Add: Key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STListWriter::Add: Key out of order: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + WriteType(*stream_, key); + entry_writer_(*stream_, t); + } + + bool Error() const { return error_; } + + ~STListWriter() { + WriteType(*stream_, string()); + if (stream_ != &std::cout) delete stream_; + } + + private: + Writer entry_writer_; + std::ostream *stream_; // Output stream. + string last_key_; // Last key. + bool error_; + + STListWriter(const STListWriter &) = delete; + STListWriter &operator=(const STListWriter &) = delete; +}; + +// String-type list reading class for object of type T using a functor Reader. +// Reader must provide at least the following interface: +// +// struct Reader { +// T *operator()(std::istream &) const; +// }; +template +class STListReader { + public: + explicit STListReader(const std::vector &filenames) + : sources_(filenames), error_(false) { + streams_.resize(filenames.size(), 0); + bool has_stdin = false; + for (size_t i = 0; i < filenames.size(); ++i) { + if (filenames[i].empty()) { + if (!has_stdin) { + streams_[i] = &std::cin; + sources_[i] = "stdin"; + has_stdin = true; + } else { + FSTERROR() << "STListReader::STListReader: Cannot read multiple " + << "inputs from standard input"; + error_ = true; + return; + } + } else { + streams_[i] = new std::ifstream( + filenames[i], std::ios_base::in | std::ios_base::binary); + if (streams_[i]->fail()) { + FSTERROR() << "STListReader::STListReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + } + int32 magic_number = 0; + ReadType(*streams_[i], &magic_number); + int32 file_version = 0; + ReadType(*streams_[i], &file_version); + if (magic_number != kSTListMagicNumber) { + FSTERROR() << "STListReader::STListReader: Wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTListFileVersion) { + FSTERROR() << "STListReader::STListReader: Wrong file version: " + << filenames[i]; + error_ = true; + return; + } + string key; + ReadType(*streams_[i], &key); + if (!key.empty()) heap_.push(std::make_pair(key, i)); + if (!*streams_[i]) { + FSTERROR() << "STListReader: Error reading file: " << sources_[i]; + error_ = true; + return; + } + } + if (heap_.empty()) return; + const auto current = heap_.top().second; + entry_.reset(entry_reader_(*streams_[current])); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: Error reading entry for key " + << heap_.top().first << ", file " << sources_[current]; + error_ = true; + } + } + + ~STListReader() { + for (auto &stream : streams_) { + if (stream != &std::cin) delete stream; + } + } + + static STListReader *Open(const string &filename) { + std::vector filenames; + filenames.push_back(filename); + return new STListReader(filenames); + } + + static STListReader *Open(const std::vector &filenames) { + return new STListReader(filenames); + } + + void Reset() { + FSTERROR() << "STListReader::Reset: Operation not supported"; + error_ = true; + } + + bool Find(const string &key) { + FSTERROR() << "STListReader::Find: Operation not supported"; + error_ = true; + return false; + } + + bool Done() const { return error_ || heap_.empty(); } + + void Next() { + if (error_) return; + auto current = heap_.top().second; + string key; + heap_.pop(); + ReadType(*(streams_[current]), &key); + if (!*streams_[current]) { + FSTERROR() << "STListReader: Error reading file: " << sources_[current]; + error_ = true; + return; + } + if (!key.empty()) heap_.push(std::make_pair(key, current)); + if (!heap_.empty()) { + current = heap_.top().second; + entry_.reset(entry_reader_(*streams_[current])); + if (!entry_ || !*streams_[current]) { + FSTERROR() << "STListReader: Error reading entry for key: " + << heap_.top().first << ", file: " << sources_[current]; + error_ = true; + } + } + } + + const string &GetKey() const { return heap_.top().first; } + + const T *GetEntry() const { return entry_.get(); } + + bool Error() const { return error_; } + + private: + Reader entry_reader_; // Read functor. + std::vector streams_; // Input streams. + std::vector sources_; // Corresponding filenames. + std::priority_queue< + std::pair, std::vector>, + std::greater>> heap_; // (Key, stream id) heap + mutable std::unique_ptr entry_; // The currently read entry. + bool error_; + + STListReader(const STListReader &) = delete; + STListReader &operator=(const STListReader &) = delete; +}; + +// String-type list header reading function, templated on the entry header type. +// The Header type must provide at least the following interface: +// +// struct Header { +// void Read(std::istream &strm, const string &filename); +// }; +template +bool ReadSTListHeader(const string &filename, Header *header) { + if (filename.empty()) { + LOG(ERROR) << "ReadSTListHeader: Can't read header from standard input"; + return false; + } + std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ReadSTListHeader: Could not open file: " << filename; + return false; + } + int32 magic_number = 0; + ReadType(strm, &magic_number); + int32 file_version = 0; + ReadType(strm, &file_version); + if (magic_number != kSTListMagicNumber) { + LOG(ERROR) << "ReadSTListHeader: Wrong file type: " << filename; + return false; + } + if (file_version != kSTListFileVersion) { + LOG(ERROR) << "ReadSTListHeader: Wrong file version: " << filename; + return false; + } + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (!strm) { + LOG(ERROR) << "ReadSTListHeader: Error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTList(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STLIST_H_ diff --git a/projects/llm_framework/include/fst/extensions/far/sttable.h b/projects/llm_framework/include/fst/extensions/far/sttable.h new file mode 100644 index 00000000..2a01bb16 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/far/sttable.h @@ -0,0 +1,353 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// A generic string-to-type table file format. +// +// This is not meant as a generalization of SSTable. This is more of a simple +// replacement for SSTable in order to provide an open-source implementation +// of the FAR format for the external version of the FST library. + +#ifndef FST_EXTENSIONS_FAR_STTABLE_H_ +#define FST_EXTENSIONS_FAR_STTABLE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +static constexpr int32 kSTTableMagicNumber = 2125656924; +static constexpr int32 kSTTableFileVersion = 1; + +// String-type table writing class for an object of type T using a functor +// Writer. The Writer functor must provide at least the following interface: +// +// struct Writer { +// void operator()(std::ostream &, const T &) const; +// }; +template +class STTableWriter { + public: + explicit STTableWriter(const string &filename) + : stream_(filename, std::ios_base::out | std::ios_base::binary), + error_(false) { + WriteType(stream_, kSTTableMagicNumber); + WriteType(stream_, kSTTableFileVersion); + if (stream_.fail()) { + FSTERROR() << "STTableWriter::STTableWriter: Error writing to file: " + << filename; + error_ = true; + } + } + + static STTableWriter *Create(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableWriter: Writing to standard out unsupported."; + return nullptr; + } + return new STTableWriter(filename); + } + + void Add(const string &key, const T &t) { + if (key == "") { + FSTERROR() << "STTableWriter::Add: Key empty: " << key; + error_ = true; + } else if (key < last_key_) { + FSTERROR() << "STTableWriter::Add: Key out of order: " << key; + error_ = true; + } + if (error_) return; + last_key_ = key; + positions_.push_back(stream_.tellp()); + WriteType(stream_, key); + entry_writer_(stream_, t); + } + + bool Error() const { return error_; } + + ~STTableWriter() { + WriteType(stream_, positions_); + WriteType(stream_, static_cast(positions_.size())); + } + + private: + Writer entry_writer_; + std::ofstream stream_; + std::vector positions_; // Position in file of each key-entry pair. + string last_key_; // Last key. + bool error_; + + STTableWriter(const STTableWriter &) = delete; + STTableWriter &operator=(const STTableWriter &) = delete; +}; + +// String-type table reading class for object of type T using a functor Reader. +// Reader must provide at least the following interface: +// +// struct Reader { +// T *operator()(std::istream &) const; +// }; +// +template +class STTableReader { + public: + explicit STTableReader(const std::vector &filenames) + : sources_(filenames), error_(false) { + compare_.reset(new Compare(&keys_)); + keys_.resize(filenames.size()); + streams_.resize(filenames.size(), 0); + positions_.resize(filenames.size()); + for (size_t i = 0; i < filenames.size(); ++i) { + streams_[i] = new std::ifstream( + filenames[i], std::ios_base::in | std::ios_base::binary); + if (streams_[i]->fail()) { + FSTERROR() << "STTableReader::STTableReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + int32 magic_number = 0; + ReadType(*streams_[i], &magic_number); + int32 file_version = 0; + ReadType(*streams_[i], &file_version); + if (magic_number != kSTTableMagicNumber) { + FSTERROR() << "STTableReader::STTableReader: Wrong file type: " + << filenames[i]; + error_ = true; + return; + } + if (file_version != kSTTableFileVersion) { + FSTERROR() << "STTableReader::STTableReader: Wrong file version: " + << filenames[i]; + error_ = true; + return; + } + int64 num_entries; + streams_[i]->seekg(-static_cast(sizeof(int64)), std::ios_base::end); + ReadType(*streams_[i], &num_entries); + if (num_entries > 0) { + streams_[i]->seekg(-static_cast(sizeof(int64)) * (num_entries + 1), + std::ios_base::end); + positions_[i].resize(num_entries); + for (size_t j = 0; (j < num_entries) && (!streams_[i]->fail()); ++j) { + ReadType(*streams_[i], &(positions_[i][j])); + } + streams_[i]->seekg(positions_[i][0]); + if (streams_[i]->fail()) { + FSTERROR() << "STTableReader::STTableReader: Error reading file: " + << filenames[i]; + error_ = true; + return; + } + } + } + MakeHeap(); + } + + ~STTableReader() { + for (auto &stream : streams_) delete stream; + } + + static STTableReader *Open(const string &filename) { + if (filename.empty()) { + LOG(ERROR) << "STTableReader: Operation not supported on standard input"; + return nullptr; + } + std::vector filenames; + filenames.push_back(filename); + return new STTableReader(filenames); + } + + static STTableReader *Open(const std::vector &filenames) { + return new STTableReader(filenames); + } + + void Reset() { + if (error_) return; + for (size_t i = 0; i < streams_.size(); ++i) + streams_[i]->seekg(positions_[i].front()); + MakeHeap(); + } + + bool Find(const string &key) { + if (error_) return false; + for (size_t i = 0; i < streams_.size(); ++i) LowerBound(i, key); + MakeHeap(); + if (heap_.empty()) return false; + return keys_[current_] == key; + } + + bool Done() const { return error_ || heap_.empty(); } + + void Next() { + if (error_) return; + if (streams_[current_]->tellg() <= positions_[current_].back()) { + ReadType(*(streams_[current_]), &(keys_[current_])); + if (streams_[current_]->fail()) { + FSTERROR() << "STTableReader: Error reading file: " + << sources_[current_]; + error_ = true; + return; + } + std::push_heap(heap_.begin(), heap_.end(), *compare_); + } else { + heap_.pop_back(); + } + if (!heap_.empty()) PopHeap(); + } + + const string &GetKey() const { return keys_[current_]; } + + const T *GetEntry() const { return entry_.get(); } + + bool Error() const { return error_; } + + private: + // Comparison functor used to compare stream IDs in the heap. + struct Compare { + explicit Compare(const std::vector *keys) : keys(keys) {} + + bool operator()(size_t i, size_t j) const { + return (*keys)[i] > (*keys)[j]; + }; + + private: + const std::vector *keys; + }; + + // Positions the stream at the position corresponding to the lower bound for + // the specified key. + void LowerBound(size_t id, const string &find_key) { + auto *strm = streams_[id]; + const auto &positions = positions_[id]; + if (positions.empty()) return; + size_t low = 0; + size_t high = positions.size() - 1; + while (low < high) { + size_t mid = (low + high) / 2; + strm->seekg(positions[mid]); + string key; + ReadType(*strm, &key); + if (key > find_key) { + high = mid; + } else if (key < find_key) { + low = mid + 1; + } else { + for (size_t i = mid; i > low; --i) { + strm->seekg(positions[i - 1]); + ReadType(*strm, &key); + if (key != find_key) { + strm->seekg(positions[i]); + return; + } + } + strm->seekg(positions[low]); + return; + } + } + strm->seekg(positions[low]); + } + + // Adds all streams to the heap. + void MakeHeap() { + heap_.clear(); + for (size_t i = 0; i < streams_.size(); ++i) { + if (positions_[i].empty()) continue; + ReadType(*streams_[i], &(keys_[i])); + if (streams_[i]->fail()) { + FSTERROR() << "STTableReader: Error reading file: " << sources_[i]; + error_ = true; + return; + } + heap_.push_back(i); + } + if (heap_.empty()) return; + std::make_heap(heap_.begin(), heap_.end(), *compare_); + PopHeap(); + } + + // Positions the stream with the lowest key at the top of the heap, sets + // current_ to the ID of that stream, and reads the current entry from that + // stream. + void PopHeap() { + std::pop_heap(heap_.begin(), heap_.end(), *compare_); + current_ = heap_.back(); + entry_.reset(entry_reader_(*streams_[current_])); + if (!entry_) error_ = true; + if (streams_[current_]->fail()) { + FSTERROR() << "STTableReader: Error reading entry for key: " + << keys_[current_] << ", file: " << sources_[current_]; + error_ = true; + } + } + + Reader entry_reader_; + std::vector streams_; // Input streams. + std::vector sources_; // Corresponding file names. + std::vector> positions_; // Index of positions. + std::vector keys_; // Lowest unread key for each stream. + std::vector heap_; // Heap containing ID of streams with unread keys. + int64 current_; // ID of current stream to be read. + std::unique_ptr compare_; // Functor comparing stream IDs. + mutable std::unique_ptr entry_; // The currently read entry. + bool error_; +}; + +// String-type table header reading function template on the entry header type. +// The Header type must provide at least the following interface: +// +// struct Header { +// void Read(std::istream &istrm, const string &filename); +// }; +template +bool ReadSTTableHeader(const string &filename, Header *header) { + if (filename.empty()) { + LOG(ERROR) << "ReadSTTable: Can't read header from standard input"; + return false; + } + std::ifstream strm(filename, std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ReadSTTableHeader: Could not open file: " << filename; + return false; + } + int32 magic_number = 0; + ReadType(strm, &magic_number); + int32 file_version = 0; + ReadType(strm, &file_version); + if (magic_number != kSTTableMagicNumber) { + LOG(ERROR) << "ReadSTTableHeader: Wrong file type: " << filename; + return false; + } + if (file_version != kSTTableFileVersion) { + LOG(ERROR) << "ReadSTTableHeader: Wrong file version: " << filename; + return false; + } + int64 i = -1; + strm.seekg(-static_cast(sizeof(int64)), std::ios_base::end); + ReadType(strm, &i); // Reads number of entries + if (strm.fail()) { + LOG(ERROR) << "ReadSTTableHeader: Error reading file: " << filename; + return false; + } + if (i == 0) return true; // No entry header to read. + strm.seekg(-2 * static_cast(sizeof(int64)), std::ios_base::end); + ReadType(strm, &i); // Reads position for last entry in file. + strm.seekg(i); + string key; + ReadType(strm, &key); + header->Read(strm, filename + ":" + key); + if (strm.fail()) { + LOG(ERROR) << "ReadSTTableHeader: Error reading file: " << filename; + return false; + } + return true; +} + +bool IsSTTable(const string &filename); + +} // namespace fst + +#endif // FST_EXTENSIONS_FAR_STTABLE_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h new file mode 100644 index 00000000..a6ac7279 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data-builder.h @@ -0,0 +1,1074 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ +#define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include + +namespace fst { + +// Forward declaration +template +class FeatureGroupBuilder; + +// For logging purposes +inline string TranslateLabel(int64 label, const SymbolTable *syms); +template +string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms); +template +string JoinLabels(const std::vector *Dump(); + + private: + bool error_; + CompactSet all_output_labels_; + std::map> word_output_map_, word_feat_map_; + std::map> feat_groups_; + std::vector>> groups_; + size_t max_future_size_; + Label max_input_label_; + const SymbolTable *isyms_, *fsyms_, *osyms_; + + LinearFstDataBuilder(const LinearFstDataBuilder &) = delete; + LinearFstDataBuilder &operator=(const LinearFstDataBuilder &) = delete; +}; + +// Builds a LinearFstData tailored for a LinearClassifierFst. The +// major difference between an ordinary LinearFstData that works on +// taggers and a LinearFstData that works on classifiers is that +// feature groups are divided into sections by the prediction class +// label. For a prediction label `pred` and a logical group id +// `group`, the actual group id is `group * num_classes + pred - +// 1`. +// +// This layout saves us from recording output labels in each single +// FeatureGroup. Because there is no need for any delaying, stripping +// the output allows features with different shapes but using the same +// set of feature label mapping to reside in a single FeatureGroup. +template +class LinearClassifierFstDataBuilder { + public: + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Constructs a builder for a `num_classes`-class classifier, + // optinally with associated symbol tables for diagnostic + // output. The output labels (i.e. prediction) must be in the range + // of [1, num_classes]. + explicit LinearClassifierFstDataBuilder(size_t num_classes, + const SymbolTable *isyms = nullptr, + const SymbolTable *fsyms = nullptr, + const SymbolTable *osyms = nullptr) + : error_(false), + num_classes_(num_classes), + num_groups_(0), + builder_(isyms, fsyms, osyms) {} + + // Tests whether the builder has encountered any error. Similar to + // LinearFstDataBuilder<>::Error(). + bool Error() const { return error_; } + + // Same as LinearFstDataBuilder<>::AddWord(). + bool AddWord(Label word, const std::vector *Dump(); + + private: + std::vector builder_; +}; + +// Builds a single feature group. Usually used in +// `LinearFstDataBuilder::AddWeight()`. See that method for the +// constraints on grouping features. +template +class FeatureGroupBuilder { + public: + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Constructs a builder with the given future size. All features + // added to the group will have look-ahead windows of this size. + FeatureGroupBuilder(size_t future_size, const SymbolTable *fsyms, + const SymbolTable *osyms) + : error_(false), future_size_(future_size), fsyms_(fsyms), osyms_(osyms) { + // This edge is special; see doc of class `FeatureGroup` on the + // details. + start_ = trie_.Insert(trie_.Root(), InputOutputLabel(kNoLabel, kNoLabel)); + } + + // Tests whether the builder has encountered any error. No operation + // is valid if the builder is already at error state. All other + // public methods should check this before any actual operations. + bool Error() const { return error_; } + + // Adds a feature weight with the given context. Returns true iff + // the weight is added. A weight is not added if it has ill-formed + // context involving start-, end-of-sentence marks. + // + // Note: `input` is the sequence of input + // features, instead of input labels themselves. `input` must be at + // least as long as `future_size`; `output` may be empty, but + // usually should be non-empty because an empty output context is + // useless in discriminative modelling. All labels in both `input` + // and `output` must be > 0 (this is checked in + // `LinearFstDataBuilder::AddWeight()`). See + // LinearFstDataBuilder<>::AddWeight for more details. + // + // This may fail if the input is smaller than the look-ahead window. + bool AddWeight(const std::vector *Dump(size_t max_future_size); + + private: + typedef typename FeatureGroup::InputOutputLabel InputOutputLabel; + typedef typename FeatureGroup::InputOutputLabelHash InputOutputLabelHash; + typedef typename FeatureGroup::WeightBackLink WeightBackLink; + // Nested trie topology uses more memory but we can traverse a + // node's children easily, which is required in `BuildBackLinks()`. + typedef NestedTrieTopology Topology; + typedef MutableTrie Trie; + + // Finds the first node with an arc with `label` following the + // back-off chain of `parent`. Returns the node index or + // `kNoTrieNodeId` when not found. The number of hops is stored in + // `hop` when it is not `nullptr`. + // + // This does not fail. + int FindFirstMatch(InputOutputLabel label, int parent, int *hop) const; + + // Links each node to its immediate back-off. root is linked to -1. + // + // This may fail when the unique immediate back-off constraint is + // violated. + void BuildBackLinks(); + + // Traces back on the back-chain for each node to multiply the + // weights from back-offs to the node itself. + // + // This does not fail. + void PreAccumulateWeights(); + + // Reconstruct the path from trie root to given node for logging. + bool TrieDfs(const Topology &topology, int cur, int target, + std::vector *path) const; + string TriePath(int node, const Topology &topology) const; + + bool error_; + size_t future_size_; + Trie trie_; + int start_; + const SymbolTable *fsyms_, *osyms_; + + FeatureGroupBuilder(const FeatureGroupBuilder &) = delete; + FeatureGroupBuilder &operator=(const FeatureGroupBuilder &) = delete; +}; + +// +// Implementation of methods in `LinearFstDataBuilder` +// +template +bool LinearFstDataBuilder::AddWord(Label word, + const std::vector::kStartOfSentence || + word == LinearFstData::kEndOfSentence) { + LOG(WARNING) << "Ignored: adding boundary label: " + << TranslateLabel(word, isyms_) + << "(start-of-sentence=" << LinearFstData::kStartOfSentence + << ", end-of-sentence=" << LinearFstData::kEndOfSentence + << ")"; + return false; + } + if (word <= 0) { + error_ = true; + FSTERROR() << "Word label must be > 0; got " << word; + return false; + } + if (word > max_input_label_) max_input_label_ = word; + // Make sure the word hasn't been added before + if (word_feat_map_.find(word) != word_feat_map_.end()) { + error_ = true; + FSTERROR() << "Input word " << TranslateLabel(word, isyms_) + << " is added twice"; + return false; + } + // Store features + std::set::AddWord( + Label word, const std::vector::kStartOfSentence || + output == LinearFstData::kEndOfSentence) { + LOG(WARNING) << "Ignored: word = " << TranslateLabel(word, isyms_) + << ": adding boundary label as possible output: " << output + << "(start-of-sentence=" + << LinearFstData::kStartOfSentence + << ", end-of-sentence=" << LinearFstData::kEndOfSentence + << ")"; + continue; + } + if (output <= 0) { + error_ = true; + FSTERROR() << "Output label must be > 0; got " << output; + return false; + } + outputs->insert(output); + all_output_labels_.Insert(output); + } + return true; +} + +template +inline int LinearFstDataBuilder::AddGroup(size_t future_size) { + if (error_) { + FSTERROR() << "Calling LinearFstDataBuilder<>::AddGroup() at error state"; + return -1; + } + size_t ret = groups_.size(); + groups_.emplace_back(new FeatureGroupBuilder(future_size, fsyms_, osyms_)); + if (future_size > max_future_size_) max_future_size_ = future_size; + return ret; +} + +template +bool LinearFstDataBuilder::AddWeight(size_t group, + const std::vector::kStartOfSentence && + input[i - 1] != LinearFstData::kStartOfSentence) + start_in_middle = true; + if (input[i - 1] == LinearFstData::kEndOfSentence && + input[i] != LinearFstData::kEndOfSentence) + end_in_middle = true; + } + if (start_in_middle) { + LOG(WARNING) << "Ignored: start-of-sentence in the middle of the input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (end_in_middle) { + LOG(WARNING) << "Ignored: end-of-sentence in the middle of the input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + } + // Check well-formedness of boundary marks on the output. + { + bool non_first_start = false, non_last_end = false; + for (int i = 1; i < output.size(); ++i) { + if (output[i] == LinearFstData::kStartOfSentence) + non_first_start = true; + if (output[i - 1] == LinearFstData::kEndOfSentence) + non_last_end = true; + } + if (non_first_start) { + LOG(WARNING) << "Ignored: start-of-sentence not appearing " + << "as the first label in the output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (non_last_end) { + LOG(WARNING) << "Ignored: end-of-sentence not appearing " + << "as the last label in the output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + } + + for (size_t i = 0; i < input.size(); ++i) { + Label feat = input[i]; + if (feat != LinearFstData::kStartOfSentence && + feat != LinearFstData::kEndOfSentence && feat <= 0) { + error_ = true; + FSTERROR() << "Feature label must be > 0; got " << feat; + return false; + } + feat_groups_[feat].insert(group); + } + for (size_t i = 0; i < output.size(); ++i) { + Label label = output[i]; + if (label != LinearFstData::kStartOfSentence && + label != LinearFstData::kEndOfSentence && label <= 0) { + error_ = true; + FSTERROR() << "Output label must be > 0; got " << label; + return false; + } + if (label != LinearFstData::kStartOfSentence && + label != LinearFstData::kEndOfSentence) + all_output_labels_.Insert(label); + } + + // Everything looks good at this point (more checks on the way in + // the feature group). Add this feature weight. + bool added = groups_[group]->AddWeight(input, output, weight); + if (groups_[group]->Error()) { + error_ = true; + FSTERROR() << "FeatureGroupBuilder<>::AddWeight() failed"; + return false; + } + return added; +} + +template +LinearFstData *LinearFstDataBuilder::Dump() { + if (error_) { + FSTERROR() << "Calling LinearFstDataBuilder<>::Dump() at error state"; + return nullptr; + } + + std::unique_ptr> data(new LinearFstData()); + data->max_future_size_ = max_future_size_; + data->max_input_label_ = max_input_label_; + + // Feature groups; free builders after it's dumped. + data->groups_.resize(groups_.size()); + for (int group = 0; group != groups_.size(); ++group) { + FeatureGroup *new_group = groups_[group]->Dump(max_future_size_); + if (new_group == nullptr) { + error_ = true; + FSTERROR() << "Error in dumping group " << group; + return nullptr; + } + data->groups_[group].reset(new_group); + groups_[group].reset(); + VLOG(1) << "Group " << group << ": " << new_group->Stats(); + } + + // Per-group feature mapping + data->group_feat_map_.Init(data->NumGroups(), max_input_label_ + 1); + for (Label word = 1; word <= max_input_label_; ++word) { + typename std::map>::const_iterator it = + word_feat_map_.find(word); + if (it == word_feat_map_.end()) continue; + for (typename std::set::AddWord( + Label word, const std::vector::AddGroup() { + if (error_) { + FSTERROR() << "Calling LinearClassifierFstDataBuilder<>::AddGroup() at " + "error state"; + return -1; + } + for (int i = 0; i < num_classes_; ++i) builder_.AddGroup(0); + if (builder_.Error()) { + error_ = true; + return -1; + } + return num_groups_++; +} + +template +inline bool LinearClassifierFstDataBuilder::AddWeight( + size_t group, const std::vector *LinearClassifierFstDataBuilder::Dump() { + if (error_) { + FSTERROR() + << "Calling LinearClassifierFstDataBuilder<>::Dump() at error state"; + return nullptr; + } + LinearFstData *data = builder_.Dump(); + error_ = true; + return data; +} + +// +// Implementation of methods in `FeatureGroupBuilder` +// +template +bool FeatureGroupBuilder::AddWeight(const std::vector::kStartOfSentence) + ++num_input_start; + int num_output_start = 0; + while (num_output_start < output.size() && + output[num_output_start] == LinearFstData::kStartOfSentence) + ++num_output_start; + int num_input_end = 0; + for (int i = input.size() - 1; + i >= 0 && input[i] == LinearFstData::kEndOfSentence; --i) + ++num_input_end; + int num_output_end = 0; + for (int i = output.size() - 1; + i >= 0 && output[i] == LinearFstData::kEndOfSentence; --i) + ++num_output_end; + + DCHECK_LE(num_output_end, 1); + + if (input.size() - num_input_start < future_size_) { + LOG(WARNING) << "Ignored: start-of-sentence in the future!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, fsyms_); + return false; + } + if (num_input_start > 0 && + input.size() - future_size_ - num_input_start < + output.size() - num_output_start) { + LOG(WARNING) << "Ignored: matching start-of-sentence with actual output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (num_output_start > 0 && + input.size() - future_size_ - num_input_start > + output.size() - num_output_start) { + LOG(WARNING) << "Ignored: matching start-of-sentence with actual input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + // The following two require `num_output_end` <= 1. + if (num_input_end > future_size_ && num_input_end - future_size_ != 1) { + LOG(WARNING) << "Ignored: matching end-of-sentence with actual output!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + if (num_output_end > 0 && + ((input.size() == future_size_ && future_size_ != num_input_end) || + (input.size() > future_size_ && + num_input_end != future_size_ + num_output_end))) { + LOG(WARNING) << "Ignored: matching end-of-sentence with actual input!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + // Check if the context has no other labels than boundary marks + // (such features are useless). + if (num_input_start + num_input_end == input.size() && + num_output_start + num_output_end == output.size()) { + LOG(WARNING) + << "Ignored: feature context consisting of only boundary marks!"; + LOG(WARNING) << "\tInput: " << JoinLabels(input, fsyms_); + LOG(WARNING) << "\tOutput: " << JoinLabels(output, osyms_); + return false; + } + + // Start point for insertion in the trie. Insert at `start_` iff the + // beginning of the context is non-consumed start-of-sentence. + int cur = (num_input_start == 0 && num_output_start <= future_size_) + ? trie_.Root() + : start_; + // Skip all input start-of-sentence marks + size_t ipos = num_input_start; + // Skip to keep at most `future_size_` start-of-sentence marks + size_t opos = + num_output_start <= future_size_ ? 0 : num_output_start - future_size_; + // Skip `num_output_end` end-of-sentence marks on both input and output + size_t iend = !input.empty() ? input.size() - num_output_end : 0, + oend = output.size() - num_output_end; + // Further, when output is empty, keep at most `future_size_` + // end-of-sentence marks on input. + if (output.empty() && num_input_end > future_size_) + iend = input.size() - num_input_end + future_size_; + + // Actual feature context is (input[ipos:iend], output[opos:oend]). + + // Pad `kNoLabel` as don't cares on the shorter of actual `input` + // and `output`. + const size_t effective_input_size = iend - ipos, + effective_output_size = oend - opos; + if (effective_input_size > effective_output_size) { + for (size_t pad = effective_input_size - effective_output_size; pad != 0; + --pad, ++ipos) + cur = trie_.Insert(cur, InputOutputLabel(input[ipos], kNoLabel)); + } else if (effective_input_size < effective_output_size) { + for (size_t pad = effective_output_size - effective_input_size; pad != 0; + --pad, ++opos) + cur = trie_.Insert(cur, InputOutputLabel(kNoLabel, output[opos])); + } + CHECK_EQ(iend - ipos, oend - opos); + for (; ipos != iend; ++ipos, ++opos) + cur = trie_.Insert(cur, InputOutputLabel(input[ipos], output[opos])); + // We only need to attach final weight when there is an output + // end-of-sentence. When there is only end-of-sentence on the input, + // they are all consumed as the end-of-sentence paddings from + // `LinearFstImpl<>::ShiftBuffer()`. `LinearFstImpl<>::Expand()` + // and `LinearFstImpl<>::MatchInput()` ensures no other + // transition takes place after consuming the padding. + if (num_output_end > 0 || (output.empty() && num_input_end > future_size_)) + trie_[cur].final_weight = Times(trie_[cur].final_weight, weight); + else + trie_[cur].weight = Times(trie_[cur].weight, weight); + + return true; +} + +template +FeatureGroup *FeatureGroupBuilder::Dump(size_t max_future_size) { + if (error_) { + FSTERROR() << "Calling FeatureGroupBuilder<>::PreAccumulateWeights() " + << "at error state"; + return nullptr; + } + + if (max_future_size < future_size_) { + error_ = true; + FSTERROR() << "max_future_size (= " << max_future_size + << ") is smaller the builder's future_size (= " << future_size_ + << ")"; + return nullptr; + } + + BuildBackLinks(); + if (error_) return nullptr; + PreAccumulateWeights(); // does not fail + + FeatureGroup *ret = + new FeatureGroup(max_future_size - future_size_, start_); + + // Walk around the trie to compute next states + ret->next_state_.resize(trie_.NumNodes()); + const Topology &topology = trie_.TrieTopology(); + for (int i = 0; i < topology.NumNodes(); ++i) { + int next = i; + while (next != topology.Root() && topology.ChildrenOf(next).empty() && + trie_[next].final_weight == + trie_[trie_[next].back_link].final_weight) + next = trie_[next].back_link; + ret->next_state_[i] = next; + } + + // Copy the trie + typename FeatureGroup::Trie store_trie(trie_); + ret->trie_.swap(store_trie); + + // Put the builder at error state to prevent repeated call of `Dump()`. + error_ = true; + return ret; +} + +template +int FeatureGroupBuilder::FindFirstMatch(InputOutputLabel label, int parent, + int *hop) const { + int hop_count = 0; + int ret = kNoTrieNodeId; + for (; parent >= 0; parent = trie_[parent].back_link, ++hop_count) { + int next = trie_.Find(parent, label); + if (next != kNoTrieNodeId) { + ret = next; + break; + } + } + if (hop != nullptr) *hop = hop_count; + return ret; +} + +template +void FeatureGroupBuilder::BuildBackLinks() { + // Breadth first search from the root. In the case where we only + // have the input label, the immedate back-off is simply the longest + // suffix of the current node that is also in the trie. For a node + // reached from its parent with label L, we can simply walk through + // the parent's back-off chain to find the first state with an arc + // of the same label L. The uniqueness is always + // guanranteed. However, in the case with both input and output + // labels, it is possible to back off by removing first labels from + // either side, which in general causes non-uniqueness. + + const Topology &topology = trie_.TrieTopology(); + std::queue q; // all enqueued or visited nodes have known links + + // Note: nodes have back link initialized to -1 in their + // constructor. + q.push(trie_.Root()); + while (!error_ && !q.empty()) { + int parent = q.front(); + q.pop(); + // Find links for every child + const typename Topology::NextMap &children = topology.ChildrenOf(parent); + for (typename Topology::NextMap::const_iterator eit = children.begin(); + eit != children.end(); ++eit) { + const std::pair &edge = *eit; + InputOutputLabel label = edge.first; + int child = edge.second; + if (label.input == kNoLabel || label.output == kNoLabel) { + // Label pairs from root to here all have one and only one + // `kNoLabel` on the same side; equivalent to the + // "longest-suffix" case. + trie_[child].back_link = + FindFirstMatch(label, trie_[parent].back_link, nullptr); + } else { + // Neither side is `kNoLabel` at this point, there are + // three possible ways to back-off: if the parent backs + // off to some context with only one side non-empty, the + // empty side may remain empty; or else an exact match of + // both sides is needed. Try to find all three possible + // backs and look for the closest one (in terms of hops + // along the parent's back-off chain). + int only_input_hop, only_output_hop, full_hop; + int only_input_link = + FindFirstMatch(InputOutputLabel(label.input, kNoLabel), parent, + &only_input_hop), + only_output_link = + FindFirstMatch(InputOutputLabel(kNoLabel, label.output), parent, + &only_output_hop), + full_link = + FindFirstMatch(label, trie_[parent].back_link, &full_hop); + if (only_input_link != -1 && only_output_link != -1) { + error_ = true; + FSTERROR() << "Branching back-off chain:\n" + << "\tnode " << child << ": " << TriePath(child, topology) + << "\n" + << "\tcan back-off to node " << only_input_link << ": " + << TriePath(only_input_link, topology) << "\n" + << "\tcan back-off to node " << only_output_link << ": " + << TriePath(only_output_link, topology); + return; + } else if (full_link != -1) { + ++full_hop; + if (full_hop <= only_input_hop && full_hop <= only_output_hop) { + trie_[child].back_link = full_link; + } else { + error_ = true; + int problem_link = only_input_link != kNoTrieNodeId + ? only_input_link + : only_output_link; + CHECK_NE(problem_link, kNoTrieNodeId); + FSTERROR() << "Branching back-off chain:\n" + << "\tnode " << child << ": " + << TriePath(child, topology) << "\n" + << "\tcan back-off to node " << full_link << ": " + << TriePath(full_link, topology) << "\n" + << "tcan back-off to node " << problem_link << ": " + << TriePath(problem_link, topology); + return; + } + } else { + trie_[child].back_link = + only_input_link != -1 ? only_input_link : only_output_link; + } + } + if (error_) break; + // Point to empty context (root) when no back-off can be found + if (trie_[child].back_link == -1) trie_[child].back_link = 0; + q.push(child); + } + } +} + +template +void FeatureGroupBuilder::PreAccumulateWeights() { + std::vector visited(trie_.NumNodes(), false); + visited[trie_.Root()] = true; + + for (size_t i = 0; i != trie_.NumNodes(); ++i) { + std::stack back_offs; + for (int j = i; !visited[j]; j = trie_[j].back_link) back_offs.push(j); + while (!back_offs.empty()) { + int j = back_offs.top(); + back_offs.pop(); + WeightBackLink &node = trie_[j]; + node.weight = Times(node.weight, trie_[node.back_link].weight); + node.final_weight = + Times(node.final_weight, trie_[node.back_link].final_weight); + visited[j] = true; + } + } +} + +template +bool FeatureGroupBuilder::TrieDfs( + const Topology &topology, int cur, int target, + std::vector *path) const { + if (cur == target) return true; + const typename Topology::NextMap &children = topology.ChildrenOf(cur); + for (typename Topology::NextMap::const_iterator eit = children.begin(); + eit != children.end(); ++eit) { + const std::pair &edge = *eit; + path->push_back(edge.first); + if (TrieDfs(topology, edge.second, target, path)) return true; + path->pop_back(); + } + return false; +} + +template +string FeatureGroupBuilder::TriePath(int node, + const Topology &topology) const { + std::vector labels; + TrieDfs(topology, topology.Root(), node, &labels); + bool first = true; + std::ostringstream strm; + for (typename std::vector::const_iterator it = + labels.begin(); + it != labels.end(); ++it) { + InputOutputLabel i = *it; + if (first) + first = false; + else + strm << ", "; + strm << "(" << TranslateLabel(i.input, fsyms_) << ", " + << TranslateLabel(i.output, osyms_) << ")"; + } + return strm.str(); +} + +inline string TranslateLabel(int64 label, const SymbolTable *syms) { + string ret; + if (syms != nullptr) ret += syms->Find(label); + if (ret.empty()) { + std::ostringstream strm; + strm << '<' << label << '>'; + ret = strm.str(); + } + return ret; +} + +template +string JoinLabels(Iterator begin, Iterator end, const SymbolTable *syms) { + if (begin == end) return ""; + std::ostringstream strm; + bool first = true; + for (Iterator it = begin; it != end; ++it) { + if (first) + first = false; + else + strm << '|'; + strm << TranslateLabel(*it, syms); + } + return strm.str(); +} + +template +string JoinLabels(const std::vector::kStartOfSentence; + } else if (left && !right) { + // Can only be end + (*sequence)[i] = LinearFstData::kEndOfSentence; + } else if (!left && right) { + // Can only be start + (*sequence)[i] = LinearFstData::kStartOfSentence; + } else { + // !left && !right; can't really tell + ++unresolved; + } + } + return unresolved; +} + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_BUILDER_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h new file mode 100644 index 00000000..3b39c29d --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/linear-fst-data.h @@ -0,0 +1,526 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Data structures for storing and looking up the actual feature weights. + +#ifndef FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_ +#define FST_EXTENSIONS_LINEAR_LINEAR_FST_DATA_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace fst { + +// Forward declarations +template +class LinearFstDataBuilder; +template +class FeatureGroup; + +// Immutable data storage of the feature weights in a linear +// model. Produces state tuples that represent internal states of a +// LinearTaggerFst. Object of this class can only be constructed via +// either `LinearFstDataBuilder::Dump()` or `LinearFstData::Read()` +// and usually used as refcount'd object shared across mutiple +// `LinearTaggerFst` copies. +// +// TODO(wuke): more efficient trie implementation +template +class LinearFstData { + public: + friend class LinearFstDataBuilder; // For builder access + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + // Sentence boundary labels. Both of them are negative labels other + // than `kNoLabel`. + static const Label kStartOfSentence; + static const Label kEndOfSentence; + + // Constructs empty data; for non-trivial ways of construction see + // `Read()` and `LinearFstDataBuilder`. + LinearFstData() + : max_future_size_(0), max_input_label_(1), input_attribs_(1) {} + + // Appends the state tuple of the start state to `output`, where + // each tuple holds the node ids of a trie for each feature group. + void EncodeStartState(std::vector *Read(std::istream &strm); // NOLINT + std::ostream &Write(std::ostream &strm) const; // NOLINT + + private: + // Offsets in `output_pool_` + struct InputAttribute { + size_t output_begin, output_length; + + std::istream &Read(std::istream &strm); // NOLINT + std::ostream &Write(std::ostream &strm) const; // NOLINT + }; + + // Mapping from input label to per-group feature label + class GroupFeatureMap; + + // Translates the input label into input feature label of group + // `group`; returns `kNoLabel` when there is no feature for that + // group. + Label FindFeature(size_t group, Label word) const; + + size_t max_future_size_; + Label max_input_label_; + std::vector>> groups_; + std::vector input_attribs_; + std::vector::kStartOfSentence = -3; +template +const typename A::Label LinearFstData::kEndOfSentence = -2; + +template +template +void LinearFstData::TakeTransition(Iterator buffer_end, + Iterator trie_state_begin, + Iterator trie_state_end, Label ilabel, + Label olabel, std::vector::GroupTransition(int group_id, + int trie_state, + Label ilabel, Label olabel, + Weight *weight) const { + Label group_ilabel = FindFeature(group_id, ilabel); + return groups_[group_id]->Walk(trie_state, group_ilabel, olabel, weight); +} + +template +template +inline typename A::Weight LinearFstData::FinalWeight( + Iterator trie_state_begin, Iterator trie_state_end) const { + DCHECK_EQ(trie_state_end - trie_state_begin, groups_.size()); + size_t group_id = 0; + Weight accum = Weight::One(); + for (Iterator it = trie_state_begin; it != trie_state_end; ++it, ++group_id) + accum = Times(accum, GroupFinalWeight(group_id, *it)); + return accum; +} + +template +inline std::pair::const_iterator, + typename std::vector::const_iterator> +LinearFstData::PossibleOutputLabels(Label word) const { + const InputAttribute &attrib = input_attribs_[word]; + if (attrib.output_length == 0) + return std::make_pair(output_set_.begin(), output_set_.end()); + else + return std::make_pair( + output_pool_.begin() + attrib.output_begin, + output_pool_.begin() + attrib.output_begin + attrib.output_length); +} + +template +inline LinearFstData *LinearFstData::Read(std::istream &strm) { // NOLINT + std::unique_ptr> data(new LinearFstData()); + ReadType(strm, &(data->max_future_size_)); + ReadType(strm, &(data->max_input_label_)); + // Feature groups + size_t num_groups = 0; + ReadType(strm, &num_groups); + data->groups_.resize(num_groups); + for (size_t i = 0; i < num_groups; ++i) + data->groups_[i].reset(FeatureGroup::Read(strm)); + // Other data + ReadType(strm, &(data->input_attribs_)); + ReadType(strm, &(data->output_pool_)); + ReadType(strm, &(data->output_set_)); + ReadType(strm, &(data->group_feat_map_)); + if (strm) { + return data.release(); + } else { + return nullptr; + } +} + +template +inline std::ostream &LinearFstData::Write( + std::ostream &strm) const { // NOLINT + WriteType(strm, max_future_size_); + WriteType(strm, max_input_label_); + // Feature groups + WriteType(strm, groups_.size()); + for (size_t i = 0; i < groups_.size(); ++i) { + groups_[i]->Write(strm); + } + // Other data + WriteType(strm, input_attribs_); + WriteType(strm, output_pool_); + WriteType(strm, output_set_); + WriteType(strm, group_feat_map_); + return strm; +} + +template +typename A::Label LinearFstData::FindFeature(size_t group, + Label word) const { + DCHECK(word > 0 || word == kStartOfSentence || word == kEndOfSentence); + if (word == kStartOfSentence || word == kEndOfSentence) + return word; + else + return group_feat_map_.Find(group, word); +} + +template +inline std::istream &LinearFstData::InputAttribute::Read( + std::istream &strm) { // NOLINT + ReadType(strm, &output_begin); + ReadType(strm, &output_length); + return strm; +} + +template +inline std::ostream &LinearFstData::InputAttribute::Write( + std::ostream &strm) const { // NOLINT + WriteType(strm, output_begin); + WriteType(strm, output_length); + return strm; +} + +// Forward declaration +template +class FeatureGroupBuilder; + +// An immutable grouping of features with similar context shape. Like +// `LinearFstData`, this can only be constructed via `Read()` or +// via its builder. +// +// Internally it uses a trie to store all feature n-grams and their +// weights. The label of a trie edge is a pair (feat, olabel) of +// labels. They can be either positive (ordinary label), `kNoLabel`, +// `kStartOfSentence`, or `kEndOfSentence`. `kNoLabel` usually means +// matching anything, with one exception: from the root of the trie, +// there is a special (kNoLabel, kNoLabel) that leads to the implicit +// start-of-sentence state. This edge is never actually matched +// (`FindFirstMatch()` ensures this). +template +class FeatureGroup { + public: + friend class FeatureGroupBuilder; // for builder access + + typedef typename A::Label Label; + typedef typename A::Weight Weight; + + int Start() const { return start_; } + + // Finds destination node from `cur` by consuming `ilabel` and + // `olabel`. The transition weight is multiplied onto `weight`. + int Walk(int cur, Label ilabel, Label olabel, Weight *weight) const; + + // Returns the final weight of the current trie state. Only valid if + // the state is already known to be part of a final state (see + // `LinearFstData<>::CanBeFinal()`). + Weight FinalWeight(int trie_state) const { + return trie_[trie_state].final_weight; + } + + static FeatureGroup *Read(std::istream &strm) { // NOLINT + size_t delay; + ReadType(strm, &delay); + int start; + ReadType(strm, &start); + Trie trie; + ReadType(strm, &trie); + std::unique_ptr> ret(new FeatureGroup(delay, start)); + ret->trie_.swap(trie); + ReadType(strm, &ret->next_state_); + if (strm) { + return ret.release(); + } else { + return nullptr; + } + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, delay_); + WriteType(strm, start_); + WriteType(strm, trie_); + WriteType(strm, next_state_); + return strm; + } + + size_t Delay() const { return delay_; } + + string Stats() const; + + private: + // Label along the arcs on the trie. `kNoLabel` means anything + // (non-negative label) can match; both sides holding `kNoLabel` + // is not allow; otherwise the label is > 0 (enforced by + // `LinearFstDataBuilder::AddWeight()`). + struct InputOutputLabel; + struct InputOutputLabelHash; + + // Data to be stored on the trie + struct WeightBackLink { + int back_link; + Weight weight, final_weight; + + WeightBackLink() + : back_link(kNoTrieNodeId), + weight(Weight::One()), + final_weight(Weight::One()) {} + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &back_link); + ReadType(strm, &weight); + ReadType(strm, &final_weight); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, back_link); + WriteType(strm, weight); + WriteType(strm, final_weight); + return strm; + } + }; + + typedef FlatTrieTopology Topology; + typedef MutableTrie Trie; + + explicit FeatureGroup(size_t delay, int start) + : delay_(delay), start_(start) {} + + // Finds the first node with an arc with `label` following the + // back-off chain of `parent`. Returns the node index or + // `kNoTrieNodeId` when not found. + int FindFirstMatch(InputOutputLabel label, int parent) const; + + size_t delay_; + int start_; + Trie trie_; + // Where to go after hitting this state. When we reach a state with + // no child and with no additional final weight (i.e. its final + // weight is the same as its back-off), we can immediately go to its + // back-off state. + std::vector next_state_; + + FeatureGroup(const FeatureGroup &) = delete; + FeatureGroup &operator=(const FeatureGroup &) = delete; +}; + +template +struct FeatureGroup::InputOutputLabel { + Label input, output; + + InputOutputLabel(Label i = kNoLabel, Label o = kNoLabel) + : input(i), output(o) {} + + bool operator==(InputOutputLabel that) const { + return input == that.input && output == that.output; + } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &input); + ReadType(strm, &output); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, input); + WriteType(strm, output); + return strm; + } +}; + +template +struct FeatureGroup::InputOutputLabelHash { + size_t operator()(InputOutputLabel label) const { + return static_cast(label.input * 7853 + label.output); + } +}; + +template +int FeatureGroup::Walk(int cur, Label ilabel, Label olabel, + Weight *weight) const { + // Note: user of this method need to ensure `ilabel` and `olabel` + // are valid (e.g. see DCHECKs in + // `LinearFstData<>::TakeTransition()` and + // `LinearFstData<>::FindFeature()`). + int next; + if (ilabel == LinearFstData::kStartOfSentence) { + // An observed start-of-sentence only occurs in the beginning of + // the input, when this feature group is delayed (i.e. there is + // another feature group with a larger future size). The actual + // input hasn't arrived so stay at the start state. + DCHECK_EQ(cur, start_); + next = start_; + } else { + // First, try exact match + next = FindFirstMatch(InputOutputLabel(ilabel, olabel), cur); + // Then try with don't cares + if (next == kNoTrieNodeId) + next = FindFirstMatch(InputOutputLabel(ilabel, kNoLabel), cur); + if (next == kNoTrieNodeId) + next = FindFirstMatch(InputOutputLabel(kNoLabel, olabel), cur); + // All failed, go to empty context + if (next == kNoTrieNodeId) next = trie_.Root(); + *weight = Times(*weight, trie_[next].weight); + next = next_state_[next]; + } + return next; +} + +template +inline int FeatureGroup::FindFirstMatch(InputOutputLabel label, + int parent) const { + if (label.input == kNoLabel && label.output == kNoLabel) + return kNoTrieNodeId; // very important; see class doc. + for (; parent != kNoTrieNodeId; parent = trie_[parent].back_link) { + int next = trie_.Find(parent, label); + if (next != kNoTrieNodeId) return next; + } + return kNoTrieNodeId; +} + +template +inline string FeatureGroup::Stats() const { + std::ostringstream strm; + int num_states = 2; + for (int i = 2; i < next_state_.size(); ++i) + num_states += i == next_state_[i]; + strm << trie_.NumNodes() << " node(s); " << num_states << " state(s)"; + return strm.str(); +} + +template +class LinearFstData::GroupFeatureMap { + public: + GroupFeatureMap() {} + + void Init(size_t num_groups, size_t num_words) { + num_groups_ = num_groups; + pool_.clear(); + pool_.resize(num_groups * num_words, kNoLabel); + } + + Label Find(size_t group_id, Label ilabel) const { + return pool_[IndexOf(group_id, ilabel)]; + } + + bool Set(size_t group_id, Label ilabel, Label feat) { + size_t i = IndexOf(group_id, ilabel); + if (pool_[i] != kNoLabel && pool_[i] != feat) { + FSTERROR() << "Feature group " << group_id + << " already has feature for word " << ilabel; + return false; + } + pool_[i] = feat; + return true; + } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &num_groups_); + ReadType(strm, &pool_); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, num_groups_); + WriteType(strm, pool_); + return strm; + } + + private: + size_t IndexOf(size_t group_id, Label ilabel) const { + return ilabel * num_groups_ + group_id; + } + + size_t num_groups_; + // `pool_[ilabel * num_groups_ + group_id]` is the feature active + // for group `group_id` with input `ilabel` + std::vector { + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename Collection::SetIterator NGramIterator; + + // Constructs an empty FST by default. + LinearTaggerFstImpl() + : CacheImpl(CacheOptions()), + data_(std::make_shared>()), + delay_(0) { + SetType("linear-tagger"); + } + + // Constructs the FST with given data storage and symbol + // tables. + // + // TODO(wuke): when there is no constraint on output we can delay + // less than `data->MaxFutureSize` positions. + LinearTaggerFstImpl(const LinearFstData *data, const SymbolTable *isyms, + const SymbolTable *osyms, CacheOptions opts) + : CacheImpl(opts), data_(data), delay_(data->MaxFutureSize()) { + SetType("linear-tagger"); + SetProperties(kILabelSorted, kFstProperties); + SetInputSymbols(isyms); + SetOutputSymbols(osyms); + ReserveStubSpace(); + } + + // Copy by sharing the underlying data storage. + LinearTaggerFstImpl(const LinearTaggerFstImpl &impl) + : CacheImpl(impl), data_(impl.data_), delay_(impl.delay_) { + SetType("linear-tagger"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + ReserveStubSpace(); + } + + StateId Start() { + if (!HasStart()) { + StateId start = FindStartState(); + SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + state_stub_.clear(); + FillState(s, &state_stub_); + if (CanBeFinal(state_stub_)) + SetFinal(s, data_->FinalWeight(InternalBegin(state_stub_), + InternalEnd(state_stub_))); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new + // destination states as needed. + void Expand(StateId s); + + // Appends to `arcs` all out-going arcs from state `s` that matches `label` as + // the input label. + void MatchInput(StateId s, Label ilabel, std::vector *arcs); + + static LinearTaggerFstImpl *Read(std::istream &strm, + const FstReadOptions &opts); + + bool Write(std::ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader header; + header.SetStart(kNoStateId); + WriteHeader(strm, opts, kFileVersion, &header); + data_->Write(strm); + if (!strm) { + LOG(ERROR) << "LinearTaggerFst::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + private: + static const int kMinFileVersion; + static const int kFileVersion; + + // A collection of functions to access parts of the state tuple. A + // state tuple is a vector of `Label`s with two parts: + // [buffer] [internal]. + // + // - [buffer] is a buffer of observed input labels with length + // `delay_`. `LinearFstData::kStartOfSentence` + // (resp. `LinearFstData::kEndOfSentence`) are used as + // paddings when the buffer has fewer than `delay_` elements, which + // can only appear as the prefix (resp. suffix) of the buffer. + // + // - [internal] is the internal state tuple for `LinearFstData` + typename std::vector::kStartOfSentence); + // Append internal states + data_->EncodeStartState(&state_stub_); + return FindState(state_stub_); + } + + // Tests whether the buffer in `(begin, end)` is empty. + bool IsEmptyBuffer(typename std::vector::kEndOfSentence => + // buffer[i+x] == LinearFstData::kEndOfSentence + // - buffer[i] == LinearFstData::kStartOfSentence => + // buffer[i-x] == LinearFstData::kStartOfSentence + return delay_ == 0 || *(end - 1) == LinearFstData::kStartOfSentence || + *begin == LinearFstData::kEndOfSentence; + } + + // Tests whether the given state tuple can be a final state. A state + // is final iff there is no observed input in the buffer. + bool CanBeFinal(const std::vector::kMinFileVersion = 1; + +template +const int LinearTaggerFstImpl::kFileVersion = 1; + +template +inline typename A::Label LinearTaggerFstImpl::ShiftBuffer( + const std::vector::kEndOfSentence); + if (delay_ == 0) { + DCHECK_GT(ilabel, 0); + return ilabel; + } else { + (*next_stub_)[BufferEnd(*next_stub_) - next_stub_->begin() - 1] = ilabel; + return *BufferBegin(state); + } +} + +template +inline A LinearTaggerFstImpl::MakeArc(const std::vector::kEndOfSentence); + DCHECK(olabel > 0 || olabel == LinearFstData::kStartOfSentence); + Weight weight(Weight::One()); + data_->TakeTransition(BufferEnd(state), InternalBegin(state), + InternalEnd(state), ilabel, olabel, next_stub_, + &weight); + StateId nextstate = FindState(*next_stub_); + // Restore `next_stub_` to its size before the call + next_stub_->resize(delay_); + // In the actual arc, we use epsilons instead of boundaries. + return A(ilabel == LinearFstData::kEndOfSentence ? 0 : ilabel, + olabel == LinearFstData::kStartOfSentence ? 0 : olabel, weight, + nextstate); +} + +template +inline void LinearTaggerFstImpl::ExpandArcs(StateId s, + const std::vector::kStartOfSentence) { + // This happens when input is shorter than `delay_`. + PushArc(s, MakeArc(state, ilabel, LinearFstData::kStartOfSentence, + next_stub_)); + } else { + std::pair::const_iterator, + typename std::vector::const_iterator> range = + data_->PossibleOutputLabels(obs_ilabel); + for (typename std::vector::const_iterator it = + range.first; + it != range.second; ++it) + PushArc(s, MakeArc(state, ilabel, *it, next_stub_)); + } +} + +// TODO(wuke): this has much in duplicate with `ExpandArcs()` +template +inline void LinearTaggerFstImpl::AppendArcs(StateId /*s*/, + const std::vector::kStartOfSentence) { + // This happens when input is shorter than `delay_`. + arcs->push_back( + MakeArc(state, ilabel, LinearFstData::kStartOfSentence, next_stub_)); + } else { + std::pair::const_iterator, + typename std::vector::const_iterator> range = + data_->PossibleOutputLabels(obs_ilabel); + for (typename std::vector::const_iterator it = + range.first; + it != range.second; ++it) + arcs->push_back(MakeArc(state, ilabel, *it, next_stub_)); + } +} + +template +void LinearTaggerFstImpl::Expand(StateId s) { + VLOG(3) << "Expand " << s; + state_stub_.clear(); + FillState(s, &state_stub_); + + // Precompute the first `delay_ - 1` elements in the buffer of + // next states, which are identical for different input/output. + next_stub_.clear(); + next_stub_.resize(delay_); + if (delay_ > 0) + std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_), + next_stub_.begin()); + + // Epsilon transition for flushing out the next observed input + if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_))) + ExpandArcs(s, state_stub_, LinearFstData::kEndOfSentence, &next_stub_); + + // Non-epsilon input when we haven't flushed + if (delay_ == 0 || + *(BufferEnd(state_stub_) - 1) != LinearFstData::kEndOfSentence) + for (Label ilabel = data_->MinInputLabel(); + ilabel <= data_->MaxInputLabel(); ++ilabel) + ExpandArcs(s, state_stub_, ilabel, &next_stub_); + + SetArcs(s); +} + +template +void LinearTaggerFstImpl::MatchInput(StateId s, Label ilabel, + std::vector *arcs) { + state_stub_.clear(); + FillState(s, &state_stub_); + + // Precompute the first `delay_ - 1` elements in the buffer of + // next states, which are identical for different input/output. + next_stub_.clear(); + next_stub_.resize(delay_); + if (delay_ > 0) + std::copy(BufferBegin(state_stub_) + 1, BufferEnd(state_stub_), + next_stub_.begin()); + + if (ilabel == 0) { + // Epsilon transition for flushing out the next observed input + if (!IsEmptyBuffer(BufferBegin(state_stub_), BufferEnd(state_stub_))) + AppendArcs(s, state_stub_, LinearFstData::kEndOfSentence, &next_stub_, + arcs); + } else { + // Non-epsilon input when we haven't flushed + if (delay_ == 0 || + *(BufferEnd(state_stub_) - 1) != LinearFstData::kEndOfSentence) + AppendArcs(s, state_stub_, ilabel, &next_stub_, arcs); + } +} + +template +inline LinearTaggerFstImpl *LinearTaggerFstImpl::Read( + std::istream &strm, const FstReadOptions &opts) { // NOLINT + std::unique_ptr> impl(new LinearTaggerFstImpl()); + FstHeader header; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) { + return nullptr; + } + impl->data_ = std::shared_ptr>(LinearFstData::Read(strm)); + if (!impl->data_) { + return nullptr; + } + impl->delay_ = impl->data_->MaxFutureSize(); + impl->ReserveStubSpace(); + return impl.release(); +} + +} // namespace internal + +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class LinearTaggerFst : public ImplToFst> { + public: + friend class ArcIterator>; + friend class StateIterator>; + friend class LinearFstMatcherTpl>; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef DefaultCacheStore Store; + typedef typename Store::State State; + using Impl = internal::LinearTaggerFstImpl; + + LinearTaggerFst() : ImplToFst(std::make_shared()) {} + + explicit LinearTaggerFst(LinearFstData *data, + const SymbolTable *isyms = nullptr, + const SymbolTable *osyms = nullptr, + CacheOptions opts = CacheOptions()) + : ImplToFst(std::make_shared(data, isyms, osyms, opts)) {} + + explicit LinearTaggerFst(const Fst &fst) + : ImplToFst(std::make_shared()) { + LOG(FATAL) << "LinearTaggerFst: no constructor from arbitrary FST."; + } + + // See Fst<>::Copy() for doc. + LinearTaggerFst(const LinearTaggerFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this LinearTaggerFst. See Fst<>::Copy() for further doc. + LinearTaggerFst *Copy(bool safe = false) const override { + return new LinearTaggerFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new LinearFstMatcherTpl>(this, match_type); + } + + static LinearTaggerFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "LinearTaggerFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + static LinearTaggerFst *Read(std::istream &in, // NOLINT + const FstReadOptions &opts) { + auto *impl = Impl::Read(in, opts); + return impl ? new LinearTaggerFst(std::shared_ptr(impl)) : nullptr; + } + + bool Write(const string &filename) const override { + if (!filename.empty()) { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "LinearTaggerFst::Write: Can't open file: " << filename; + return false; + } + return Write(strm, FstWriteOptions(filename)); + } else { + return Write(std::cout, FstWriteOptions("standard output")); + } + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + explicit LinearTaggerFst(std::shared_ptr impl) + : ImplToFst(impl) {} + + void operator=(const LinearTaggerFst &fst) = delete; +}; + +// Specialization for LinearTaggerFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const LinearTaggerFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for LinearTaggerFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const LinearTaggerFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void LinearTaggerFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +namespace internal { + +// Implementation class for on-the-fly generated LinearClassifierFst with +// special optimization in matching. +template +class LinearClassifierFstImpl : public CacheImpl { + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::WriteHeader; + + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef typename Collection::SetIterator NGramIterator; + + // Constructs an empty FST by default. + LinearClassifierFstImpl() + : CacheImpl(CacheOptions()), + data_(std::make_shared>()) { + SetType("linear-classifier"); + num_classes_ = 0; + num_groups_ = 0; + } + + // Constructs the FST with given data storage, number of classes and + // symbol tables. + LinearClassifierFstImpl(const LinearFstData *data, size_t num_classes, + const SymbolTable *isyms, const SymbolTable *osyms, + CacheOptions opts) + : CacheImpl(opts), + data_(data), + num_classes_(num_classes), + num_groups_(data_->NumGroups() / num_classes_) { + SetType("linear-classifier"); + SetProperties(kILabelSorted, kFstProperties); + SetInputSymbols(isyms); + SetOutputSymbols(osyms); + ReserveStubSpace(); + } + + // Copy by sharing the underlying data storage. + LinearClassifierFstImpl(const LinearClassifierFstImpl &impl) + : CacheImpl(impl), + data_(impl.data_), + num_classes_(impl.num_classes_), + num_groups_(impl.num_groups_) { + SetType("linear-classifier"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + ReserveStubSpace(); + } + + StateId Start() { + if (!HasStart()) { + StateId start = FindStartState(); + SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + state_stub_.clear(); + FillState(s, &state_stub_); + SetFinal(s, FinalWeight(state_stub_)); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new + // destination states as needed. + void Expand(StateId s); + + // Appends to `arcs` all out-going arcs from state `s` that matches + // `label` as the input label. + void MatchInput(StateId s, Label ilabel, std::vector *arcs); + + static LinearClassifierFstImpl *Read(std::istream &strm, + const FstReadOptions &opts); + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + FstHeader header; + header.SetStart(kNoStateId); + WriteHeader(strm, opts, kFileVersion, &header); + data_->Write(strm); + WriteType(strm, num_classes_); + if (!strm) { + LOG(ERROR) << "LinearClassifierFst::Write: Write failed: " << opts.source; + return false; + } + return true; + } + + private: + static const int kMinFileVersion; + static const int kFileVersion; + + // A collection of functions to access parts of the state tuple. A + // state tuple is a vector of `Label`s with two parts: + // [prediction] [internal]. + // + // - [prediction] is a single label of the predicted class. A state + // must have a positive class label, unless it is the start state. + // + // - [internal] is the internal state tuple for `LinearFstData` of + // the given class; or kNoTrieNodeId's if in start state. + Label &Prediction(std::vector &) = delete; +}; + +template +const int LinearClassifierFstImpl::kMinFileVersion = 0; + +template +const int LinearClassifierFstImpl::kFileVersion = 0; + +template +void LinearClassifierFstImpl::Expand(StateId s) { + VLOG(3) << "Expand " << s; + state_stub_.clear(); + FillState(s, &state_stub_); + next_stub_.clear(); + next_stub_.resize(1 + num_groups_); + + if (IsStartState(state_stub_)) { + // Make prediction + for (Label pred = 1; pred <= num_classes_; ++pred) { + Prediction(next_stub_) = pred; + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i)); + PushArc(s, A(0, pred, Weight::One(), FindState(next_stub_))); + } + } else { + Label pred = Prediction(state_stub_); + DCHECK_GT(pred, 0); + DCHECK_LE(pred, num_classes_); + for (Label ilabel = data_->MinInputLabel(); + ilabel <= data_->MaxInputLabel(); ++ilabel) { + Prediction(next_stub_) = pred; + Weight weight = Weight::One(); + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = + data_->GroupTransition(GroupId(pred, i), InternalAt(state_stub_, i), + ilabel, pred, &weight); + PushArc(s, A(ilabel, 0, weight, FindState(next_stub_))); + } + } + + SetArcs(s); +} + +template +void LinearClassifierFstImpl::MatchInput(StateId s, Label ilabel, + std::vector *arcs) { + state_stub_.clear(); + FillState(s, &state_stub_); + next_stub_.clear(); + next_stub_.resize(1 + num_groups_); + + if (IsStartState(state_stub_)) { + // Make prediction if `ilabel` is epsilon. + if (ilabel == 0) { + for (Label pred = 1; pred <= num_classes_; ++pred) { + Prediction(next_stub_) = pred; + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = data_->GroupStartState(GroupId(pred, i)); + arcs->push_back(A(0, pred, Weight::One(), FindState(next_stub_))); + } + } + } else if (ilabel != 0) { + Label pred = Prediction(state_stub_); + Weight weight = Weight::One(); + Prediction(next_stub_) = pred; + for (int i = 0; i < num_groups_; ++i) + InternalAt(next_stub_, i) = data_->GroupTransition( + GroupId(pred, i), InternalAt(state_stub_, i), ilabel, pred, &weight); + arcs->push_back(A(ilabel, 0, weight, FindState(next_stub_))); + } +} + +template +inline LinearClassifierFstImpl *LinearClassifierFstImpl::Read( + std::istream &strm, const FstReadOptions &opts) { + std::unique_ptr> impl( + new LinearClassifierFstImpl()); + FstHeader header; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &header)) { + return nullptr; + } + impl->data_ = std::shared_ptr>(LinearFstData::Read(strm)); + if (!impl->data_) { + return nullptr; + } + ReadType(strm, &impl->num_classes_); + if (!strm) { + return nullptr; + } + impl->num_groups_ = impl->data_->NumGroups() / impl->num_classes_; + if (impl->num_groups_ * impl->num_classes_ != impl->data_->NumGroups()) { + FSTERROR() << "Total number of feature groups is not a multiple of the " + "number of classes: num groups = " + << impl->data_->NumGroups() + << ", num classes = " << impl->num_classes_; + return nullptr; + } + impl->ReserveStubSpace(); + return impl.release(); +} + +} // namespace internal + +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class LinearClassifierFst + : public ImplToFst> { + public: + friend class ArcIterator>; + friend class StateIterator>; + friend class LinearFstMatcherTpl>; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef typename A::StateId StateId; + typedef DefaultCacheStore Store; + typedef typename Store::State State; + using Impl = internal::LinearClassifierFstImpl; + + LinearClassifierFst() : ImplToFst(std::make_shared()) {} + + explicit LinearClassifierFst(LinearFstData *data, size_t num_classes, + const SymbolTable *isyms = nullptr, + const SymbolTable *osyms = nullptr, + CacheOptions opts = CacheOptions()) + : ImplToFst( + std::make_shared(data, num_classes, isyms, osyms, opts)) {} + + explicit LinearClassifierFst(const Fst &fst) + : ImplToFst(std::make_shared()) { + LOG(FATAL) << "LinearClassifierFst: no constructor from arbitrary FST."; + } + + // See Fst<>::Copy() for doc. + LinearClassifierFst(const LinearClassifierFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this LinearClassifierFst. See Fst<>::Copy() for further doc. + LinearClassifierFst *Copy(bool safe = false) const override { + return new LinearClassifierFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new LinearFstMatcherTpl>(this, match_type); + } + + static LinearClassifierFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "LinearClassifierFst::Read: Can't open file: " + << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + static LinearClassifierFst *Read(std::istream &in, + const FstReadOptions &opts) { + auto *impl = Impl::Read(in, opts); + return impl ? new LinearClassifierFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(const string &filename) const override { + if (!filename.empty()) { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "ProdLmFst::Write: Can't open file: " << filename; + return false; + } + return Write(strm, FstWriteOptions(filename)); + } else { + return Write(std::cout, FstWriteOptions("standard output")); + } + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + explicit LinearClassifierFst(std::shared_ptr impl) + : ImplToFst(impl) {} + + void operator=(const LinearClassifierFst &fst) = delete; +}; + +// Specialization for LinearClassifierFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const LinearClassifierFst &fst) + : CacheStateIterator>(fst, + fst.GetMutableImpl()) {} +}; + +// Specialization for LinearClassifierFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const LinearClassifierFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void LinearClassifierFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Specialized Matcher for LinearFsts. This matcher only supports +// matching from the input side. This is intentional because comparing +// the scores of different input sequences with the same output +// sequence is meaningless in a discriminative model. +template +class LinearFstMatcherTpl : public MatcherBase { + public: + typedef typename F::Arc Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + typedef F FST; + + // This makes a copy of the FST. + LinearFstMatcherTpl(const FST &fst, MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + match_type_(match_type), + s_(kNoStateId), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + cur_arc_(0), + error_(false) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_OUTPUT: + case MATCH_NONE: + break; + default: + FSTERROR() << "LinearFstMatcherTpl: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This doesn't copy the FST. + LinearFstMatcherTpl(const FST *fst, MatchType match_type) + : fst_(*fst), + match_type_(match_type), + s_(kNoStateId), + current_loop_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + cur_arc_(0), + error_(false) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_OUTPUT: + case MATCH_NONE: + break; + default: + FSTERROR() << "LinearFstMatcherTpl: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This makes a copy of the FST. + LinearFstMatcherTpl(const LinearFstMatcherTpl &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + match_type_(matcher.match_type_), + s_(kNoStateId), + current_loop_(false), + loop_(matcher.loop_), + cur_arc_(0), + error_(matcher.error_) {} + + LinearFstMatcherTpl *Copy(bool safe = false) const override { + return new LinearFstMatcherTpl(*this, safe); + } + + MatchType Type(bool /*test*/) const override { + // `MATCH_INPUT` is the only valid type + return match_type_ == MATCH_INPUT ? match_type_ : MATCH_NONE; + } + + void SetState(StateId s) final { + if (s_ == s) return; + s_ = s; + // `MATCH_INPUT` is the only valid type + if (match_type_ != MATCH_INPUT) { + FSTERROR() << "LinearFstMatcherTpl: Bad match type"; + error_ = true; + } + loop_.nextstate = s; + } + + bool Find(Label label) final { + if (error_) { + current_loop_ = false; + return false; + } + current_loop_ = label == 0; + if (label == kNoLabel) label = 0; + arcs_.clear(); + cur_arc_ = 0; + fst_.GetMutableImpl()->MatchInput(s_, label, &arcs_); + return current_loop_ || !arcs_.empty(); + } + + bool Done() const final { + return !(current_loop_ || cur_arc_ < arcs_.size()); + } + + const Arc &Value() const final { + return current_loop_ ? loop_ : arcs_[cur_arc_]; + } + + void Next() final { + if (current_loop_) + current_loop_ = false; + else + ++cur_arc_; + } + + ssize_t Priority(StateId s) final { return kRequirePriority; } + + const FST &GetFst() const override { return fst_; } + + uint64 Properties(uint64 props) const override { + if (error_) props |= kError; + return props; + } + + uint32 Flags() const override { return kRequireMatch; } + + private: + std::unique_ptr owned_fst_; + const FST &fst_; + MatchType match_type_; // Type of match to perform. + StateId s_; // Current state. + bool current_loop_; // Current arc is the implicit loop. + Arc loop_; // For non-consuming symbols. + // All out-going arcs matching the label in last Find() call. + std::vector arcs_; + size_t cur_arc_; // Index to the arc that `Value()` should return. + bool error_; // Error encountered. +}; + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_LINEAR_FST_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/linearscript.h b/projects/llm_framework/include/fst/extensions/linear/linearscript.h new file mode 100644 index 00000000..54106d20 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/linearscript.h @@ -0,0 +1,391 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ +#define FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +DECLARE_string(delimiter); +DECLARE_string(empty_symbol); +DECLARE_string(start_symbol); +DECLARE_string(end_symbol); +DECLARE_bool(classifier); + +namespace fst { +namespace script { +typedef std::tuple + LinearCompileArgs; + +bool ValidateDelimiter(); +bool ValidateEmptySymbol(); + +// Returns the proper label given the symbol. For symbols other than +// `FLAGS_start_symbol` or `FLAGS_end_symbol`, looks up the symbol +// table to decide the label. Depending on whether +// `FLAGS_start_symbol` and `FLAGS_end_symbol` are identical, it +// either returns `kNoLabel` for later processing or decides the label +// right away. +template +inline typename Arc::Label LookUp(const string &str, SymbolTable *syms) { + if (str == FLAGS_start_symbol) + return str == FLAGS_end_symbol ? kNoLabel + : LinearFstData::kStartOfSentence; + else if (str == FLAGS_end_symbol) + return LinearFstData::kEndOfSentence; + else + return syms->AddSymbol(str); +} + +// Splits `str` with `delim` as the delimiter and stores the labels in +// `output`. +template +void SplitAndPush(const string &str, const char delim, SymbolTable *syms, + std::vector *output) { + if (str == FLAGS_empty_symbol) return; + std::istringstream strm(str); + string buf; + while (std::getline(strm, buf, delim)) + output->push_back(LookUp(buf, syms)); +} + +// Like `std::replace_copy` but returns the number of modifications +template +size_t ReplaceCopy(InputIterator first, InputIterator last, + OutputIterator result, const T &old_value, + const T &new_value) { + size_t changes = 0; + while (first != last) { + if (*first == old_value) { + *result = new_value; + ++changes; + } else { + *result = *first; + } + ++first; + ++result; + } + return changes; +} + +template +bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT + SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, + typename Arc::Label *word, + std::vector *feature_labels, + std::vector *possible_labels, + size_t *num_line); + +template +bool GetModelRecord(const string &model, std::istream &strm, // NOLINT + SymbolTable *fsyms, SymbolTable *osyms, + std::vector *input_labels, + std::vector *output_labels, + typename Arc::Weight *weight, size_t *num_line); + +// Reads in vocabulary file. Each line is in the following format +// +// word features [ possible output ] +// +// where features and possible output are `FLAGS_delimiter`-delimited lists of +// tokens +template +void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms, + SymbolTable *osyms, LinearFstDataBuilder *builder) { + std::ifstream in(vocab); + if (!in) LOG(FATAL) << "Can't open file: " << vocab; + size_t num_line = 0, num_added = 0; + std::vector fields; + std::vector feature_labels, possible_labels; + typename Arc::Label word; + while (GetVocabRecord(vocab, in, isyms, fsyms, osyms, &word, + &feature_labels, &possible_labels, &num_line)) { + if (word == kNoLabel) { + LOG(WARNING) << "Ignored: boundary word: " << fields[0]; + continue; + } + if (possible_labels.empty()) + num_added += builder->AddWord(word, feature_labels); + else + num_added += builder->AddWord(word, feature_labels, possible_labels); + } + VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from " + << vocab; +} + +template +void AddVocab(const string &vocab, SymbolTable *isyms, SymbolTable *fsyms, + SymbolTable *osyms, + LinearClassifierFstDataBuilder *builder) { + std::ifstream in(vocab); + if (!in) LOG(FATAL) << "Can't open file: " << vocab; + size_t num_line = 0, num_added = 0; + std::vector fields; + std::vector feature_labels, possible_labels; + typename Arc::Label word; + while (GetVocabRecord(vocab, in, isyms, fsyms, osyms, &word, + &feature_labels, &possible_labels, &num_line)) { + if (!possible_labels.empty()) + LOG(FATAL) + << "Classifier vocabulary should not have possible output constraint"; + if (word == kNoLabel) { + LOG(WARNING) << "Ignored: boundary word: " << fields[0]; + continue; + } + num_added += builder->AddWord(word, feature_labels); + } + VLOG(1) << "Read " << num_added << " words in " << num_line << " lines from " + << vocab; +} + +// Reads in model file. The first line is an integer designating the +// size of future window in the input sequences. After this, each line +// is in the following format +// +// input sequence output sequence weight +// +// input sequence is a `FLAGS_delimiter`-delimited sequence of feature +// labels (see `AddVocab()`) . output sequence is a +// `FLAGS_delimiter`-delimited sequence of output labels where the +// last label is the output of the feature position before the history +// boundary. +template +void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms, + LinearFstDataBuilder *builder) { + std::ifstream in(model); + if (!in) LOG(FATAL) << "Can't open file: " << model; + string line; + std::getline(in, line); + if (!in) LOG(FATAL) << "Empty file: " << model; + size_t future_size; + { + std::istringstream strm(line); + strm >> future_size; + if (!strm) LOG(FATAL) << "Can't read future size: " << model; + } + size_t num_line = 1, num_added = 0; + const int group = builder->AddGroup(future_size); + VLOG(1) << "Group " << group << ": from " << model << "; future size is " + << future_size << "."; + // Add the rest of lines as a single feature group + std::vector fields; + std::vector input_labels, output_labels; + typename Arc::Weight weight; + while (GetModelRecord(model, in, fsyms, osyms, &input_labels, + &output_labels, &weight, &num_line)) { + if (output_labels.empty()) + LOG(FATAL) << "Empty output sequence in source " << model << ", line " + << num_line; + + const typename Arc::Label marks[] = {LinearFstData::kStartOfSentence, + LinearFstData::kEndOfSentence}; + + std::vector copy_input(input_labels.size()), + copy_output(output_labels.size()); + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) { + size_t num_input_changes = + ReplaceCopy(input_labels.begin(), input_labels.end(), + copy_input.begin(), kNoLabel, marks[i]); + size_t num_output_changes = + ReplaceCopy(output_labels.begin(), output_labels.end(), + copy_output.begin(), kNoLabel, marks[j]); + if ((num_input_changes > 0 || i == 0) && + (num_output_changes > 0 || j == 0)) + num_added += + builder->AddWeight(group, copy_input, copy_output, weight); + } + } + } + VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in " + << num_line << " lines."; +} + +template +void AddModel(const string &model, SymbolTable *fsyms, SymbolTable *osyms, + LinearClassifierFstDataBuilder *builder) { + std::ifstream in(model); + if (!in) LOG(FATAL) << "Can't open file: " << model; + string line; + std::getline(in, line); + if (!in) LOG(FATAL) << "Empty file: " << model; + size_t future_size; + { + std::istringstream strm(line); + strm >> future_size; + if (!strm) LOG(FATAL) << "Can't read future size: " << model; + } + if (future_size != 0) + LOG(FATAL) << "Classifier model must have future size = 0; got " + << future_size << " from " << model; + size_t num_line = 1, num_added = 0; + const int group = builder->AddGroup(); + VLOG(1) << "Group " << group << ": from " << model << "; future size is " + << future_size << "."; + // Add the rest of lines as a single feature group + std::vector fields; + std::vector input_labels, output_labels; + typename Arc::Weight weight; + while (GetModelRecord(model, in, fsyms, osyms, &input_labels, + &output_labels, &weight, &num_line)) { + if (output_labels.size() != 1) + LOG(FATAL) << "Output not a single label in source " << model << ", line " + << num_line; + + const typename Arc::Label marks[] = {LinearFstData::kStartOfSentence, + LinearFstData::kEndOfSentence}; + + typename Arc::Label pred = output_labels[0]; + + std::vector copy_input(input_labels.size()); + for (int i = 0; i < 2; ++i) { + size_t num_input_changes = + ReplaceCopy(input_labels.begin(), input_labels.end(), + copy_input.begin(), kNoLabel, marks[i]); + if (num_input_changes > 0 || i == 0) + num_added += builder->AddWeight(group, copy_input, pred, weight); + } + } + VLOG(1) << "Group " << group << ": read " << num_added << " weight(s) in " + << num_line << " lines."; +} + +void SplitByWhitespace(const string &str, std::vector *out); +int ScanNumClasses(char **models, int models_length); + +template +void LinearCompileTpl(LinearCompileArgs *args) { + const string &epsilon_symbol = std::get<0>(*args); + const string &unknown_symbol = std::get<1>(*args); + const string &vocab = std::get<2>(*args); + char **models = std::get<3>(*args); + const int models_length = std::get<4>(*args); + const string &out = std::get<5>(*args); + const string &save_isymbols = std::get<6>(*args); + const string &save_fsymbols = std::get<7>(*args); + const string &save_osymbols = std::get<8>(*args); + + SymbolTable isyms, // input (e.g. word tokens) + osyms, // output (e.g. tags) + fsyms; // feature (e.g. word identity, suffix, etc.) + isyms.AddSymbol(epsilon_symbol); + osyms.AddSymbol(epsilon_symbol); + fsyms.AddSymbol(epsilon_symbol); + isyms.AddSymbol(unknown_symbol); + + VLOG(1) << "start-of-sentence label is " + << LinearFstData::kStartOfSentence; + VLOG(1) << "end-of-sentence label is " << LinearFstData::kEndOfSentence; + + if (FLAGS_classifier) { + int num_classes = ScanNumClasses(models, models_length); + LinearClassifierFstDataBuilder builder(num_classes, &isyms, &fsyms, + &osyms); + + AddVocab(vocab, &isyms, &fsyms, &osyms, &builder); + for (int i = 0; i < models_length; ++i) + AddModel(models[i], &fsyms, &osyms, &builder); + + LinearClassifierFst fst(builder.Dump(), num_classes, &isyms, &osyms); + fst.Write(out); + } else { + LinearFstDataBuilder builder(&isyms, &fsyms, &osyms); + + AddVocab(vocab, &isyms, &fsyms, &osyms, &builder); + for (int i = 0; i < models_length; ++i) + AddModel(models[i], &fsyms, &osyms, &builder); + + LinearTaggerFst fst(builder.Dump(), &isyms, &osyms); + fst.Write(out); + } + + if (!save_isymbols.empty()) isyms.WriteText(save_isymbols); + if (!save_fsymbols.empty()) fsyms.WriteText(save_fsymbols); + if (!save_osymbols.empty()) osyms.WriteText(save_osymbols); +} + +void LinearCompile(const string &arc_type, const string &epsilon_symbol, + const string &unknown_symbol, const string &vocab, + char **models, int models_len, const string &out, + const string &save_isymbols, const string &save_fsymbols, + const string &save_osymbols); + +template +bool GetVocabRecord(const string &vocab, std::istream &strm, // NOLINT + SymbolTable *isyms, SymbolTable *fsyms, SymbolTable *osyms, + typename Arc::Label *word, + std::vector *feature_labels, + std::vector *possible_labels, + size_t *num_line) { + string line; + if (!std::getline(strm, line)) return false; + ++(*num_line); + + std::vector fields; + SplitByWhitespace(line, &fields); + if (fields.size() != 3) + LOG(FATAL) << "Wrong number of fields in source " << vocab << ", line " + << num_line; + + feature_labels->clear(); + possible_labels->clear(); + + *word = LookUp(fields[0], isyms); + + const char delim = FLAGS_delimiter[0]; + SplitAndPush(fields[1], delim, fsyms, feature_labels); + SplitAndPush(fields[2], delim, osyms, possible_labels); + + return true; +} + +template +bool GetModelRecord(const string &model, std::istream &strm, // NOLINT + SymbolTable *fsyms, SymbolTable *osyms, + std::vector *input_labels, + std::vector *output_labels, + typename Arc::Weight *weight, size_t *num_line) { + string line; + if (!std::getline(strm, line)) return false; + ++(*num_line); + + std::vector fields; + SplitByWhitespace(line, &fields); + if (fields.size() != 3) + LOG(FATAL) << "Wrong number of fields in source " << model << ", line " + << num_line; + + input_labels->clear(); + output_labels->clear(); + + const char delim = FLAGS_delimiter[0]; + SplitAndPush(fields[0], delim, fsyms, input_labels); + SplitAndPush(fields[1], delim, osyms, output_labels); + + *weight = StrToWeight(fields[2], model, *num_line); + + GuessStartOrEnd(input_labels, kNoLabel); + GuessStartOrEnd(output_labels, kNoLabel); + + return true; +} +} // namespace script +} // namespace fst + +#define REGISTER_FST_LINEAR_OPERATIONS(Arc) \ + REGISTER_FST_OPERATION(LinearCompileTpl, Arc, LinearCompileArgs); + +#endif // FST_EXTENSIONS_LINEAR_LINEARSCRIPT_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h b/projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h new file mode 100644 index 00000000..1b5d2eaf --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/loglinear-apply.h @@ -0,0 +1,77 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ +#define FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// Applies a FST model as a discriminative model to weighted input +// `ifst`. `A` is an arc type with tropical weight of all the +// input/output FSTs. +// +// In general, consider `ifst` an unnormalized probability +// distribution between its input X and output Y, P(X, Y); and `lfst` +// a group of unnormalized probability distributions of all its output +// Z for every input Y, Q(Z|Y). `normalize` controls whether Q is +// normalized for every Y before chaining with P(X, Y). I.e., for a +// path (X, Y, Z) in `ofst` (where Y is hidden), +// +// - When `normalize` is true, its weight is P(X, Y) Q(Z|Y) / sum_z Q(z|Y); +// - When `normalize` is false, its weight is P(X, Y) Q(Z|Y). +template +void LogLinearApply(const Fst &ifst, const Fst &lfst, MutableFst *ofst, + bool normalize = true) { + LogLinearApply(ifst, lfst, ofst, normalize); +} + +// This version gives finer control over the arc type (`B`) to be used +// in normalization. `B` is an arc type with log weight (e.g. `LogArc` +// or `Log64Arc`). +template +void LogLinearApply(const Fst &ifst, const Fst &lfst, MutableFst *ofst, + bool normalize = true) { + if (normalize) { + VectorFst unnormalized_ofst, rescored_ifsa; + Compose(ifst, lfst, &unnormalized_ofst); + { + VectorFst tropical_ifsa(unnormalized_ofst); + Project(&tropical_ifsa, PROJECT_INPUT); + { + VectorFst minimal_log_ifsa; + { + VectorFst log_ifsa; + ArcMap(tropical_ifsa, &log_ifsa, WeightConvertMapper()); + RmEpsilon(&log_ifsa); + Determinize(log_ifsa, &minimal_log_ifsa); + } + Minimize(&minimal_log_ifsa); + ArcMap(&minimal_log_ifsa, InvertWeightMapper()); + ArcMap(minimal_log_ifsa, &tropical_ifsa, WeightConvertMapper()); + } + ArcSort(&tropical_ifsa, OLabelCompare()); + Compose(tropical_ifsa, ifst, &rescored_ifsa); + } + ArcSort(&rescored_ifsa, OLabelCompare()); + Compose(rescored_ifsa, unnormalized_ofst, ofst); + } else { + Compose(ifst, lfst, ofst); + } +} + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_LOGLINEAR_APPLY_H_ diff --git a/projects/llm_framework/include/fst/extensions/linear/trie.h b/projects/llm_framework/include/fst/extensions/linear/trie.h new file mode 100644 index 00000000..b5ddb3ad --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/linear/trie.h @@ -0,0 +1,444 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_LINEAR_TRIE_H_ +#define FST_EXTENSIONS_LINEAR_TRIE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +const int kNoTrieNodeId = -1; + +// Forward declarations of all available trie topologies. +template +class NestedTrieTopology; +template +class FlatTrieTopology; + +// A pair of parent node id and label, part of a trie edge +template +struct ParentLabel { + int parent; + L label; + + ParentLabel() {} + ParentLabel(int p, L l) : parent(p), label(l) {} + + bool operator==(const ParentLabel &that) const { + return parent == that.parent && label == that.label; + } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &parent); + ReadType(strm, &label); + return strm; + } + + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, parent); + WriteType(strm, label); + return strm; + } +}; + +template +struct ParentLabelHash { + size_t operator()(const ParentLabel &pl) const { + return static_cast(pl.parent * 7853 + H()(pl.label)); + } +}; + +// The trie topology in a nested tree of hash maps; allows efficient +// iteration over children of a specific node. +template +class NestedTrieTopology { + public: + typedef L Label; + typedef H Hash; + typedef std::unordered_map NextMap; + + class const_iterator { + public: + typedef std::forward_iterator_tag iterator_category; + typedef std::pair, int> value_type; + typedef std::ptrdiff_t difference_type; + typedef const value_type *pointer; + typedef const value_type &reference; + + friend class NestedTrieTopology; + + const_iterator() : ptr_(nullptr), cur_node_(kNoTrieNodeId), cur_edge_() {} + + reference operator*() { + UpdateStub(); + return stub_; + } + pointer operator->() { + UpdateStub(); + return &stub_; + } + + const_iterator &operator++(); + const_iterator &operator++(int); // NOLINT + + bool operator==(const const_iterator &that) const { + return ptr_ == that.ptr_ && cur_node_ == that.cur_node_ && + cur_edge_ == that.cur_edge_; + } + bool operator!=(const const_iterator &that) const { + return !(*this == that); + } + + private: + const_iterator(const NestedTrieTopology *ptr, int cur_node) + : ptr_(ptr), cur_node_(cur_node) { + SetProperCurEdge(); + } + + void SetProperCurEdge() { + if (cur_node_ < ptr_->NumNodes()) + cur_edge_ = ptr_->nodes_[cur_node_]->begin(); + else + cur_edge_ = ptr_->nodes_[0]->begin(); + } + + void UpdateStub() { + stub_.first = ParentLabel(cur_node_, cur_edge_->first); + stub_.second = cur_edge_->second; + } + + const NestedTrieTopology *ptr_; + int cur_node_; + typename NextMap::const_iterator cur_edge_; + value_type stub_; + }; + + NestedTrieTopology(); + NestedTrieTopology(const NestedTrieTopology &that); + ~NestedTrieTopology(); + void swap(NestedTrieTopology &that); + NestedTrieTopology &operator=(const NestedTrieTopology &that); + bool operator==(const NestedTrieTopology &that) const; + bool operator!=(const NestedTrieTopology &that) const; + + int Root() const { return 0; } + size_t NumNodes() const { return nodes_.size(); } + int Insert(int parent, const L &label); + int Find(int parent, const L &label) const; + const NextMap &ChildrenOf(int parent) const { return *nodes_[parent]; } + + std::istream &Read(std::istream &strm); // NOLINT + std::ostream &Write(std::ostream &strm) const; // NOLINT + + const_iterator begin() const { return const_iterator(this, 0); } + const_iterator end() const { return const_iterator(this, NumNodes()); } + + private: + std::vector + nodes_; // Use pointers to avoid copying the maps when the + // vector grows +}; + +template +NestedTrieTopology::NestedTrieTopology() { + nodes_.push_back(new NextMap); +} + +template +NestedTrieTopology::NestedTrieTopology(const NestedTrieTopology &that) { + nodes_.reserve(that.nodes_.size()); + for (size_t i = 0; i < that.nodes_.size(); ++i) { + NextMap *node = that.nodes_[i]; + nodes_.push_back(new NextMap(*node)); + } +} + +template +NestedTrieTopology::~NestedTrieTopology() { + for (size_t i = 0; i < nodes_.size(); ++i) { + NextMap *node = nodes_[i]; + delete node; + } +} + +// TODO(wuke): std::swap compatibility +template +inline void NestedTrieTopology::swap(NestedTrieTopology &that) { + nodes_.swap(that.nodes_); +} + +template +inline NestedTrieTopology &NestedTrieTopology::operator=( + const NestedTrieTopology &that) { + NestedTrieTopology copy(that); + swap(copy); + return *this; +} + +template +inline bool NestedTrieTopology::operator==( + const NestedTrieTopology &that) const { + if (NumNodes() != that.NumNodes()) return false; + for (int i = 0; i < NumNodes(); ++i) + if (ChildrenOf(i) != that.ChildrenOf(i)) return false; + return true; +} + +template +inline bool NestedTrieTopology::operator!=( + const NestedTrieTopology &that) const { + return !(*this == that); +} + +template +inline int NestedTrieTopology::Insert(int parent, const L &label) { + int ret = Find(parent, label); + if (ret == kNoTrieNodeId) { + ret = NumNodes(); + (*nodes_[parent])[label] = ret; + nodes_.push_back(new NextMap); + } + return ret; +} + +template +inline int NestedTrieTopology::Find(int parent, const L &label) const { + typename NextMap::const_iterator it = nodes_[parent]->find(label); + return it == nodes_[parent]->end() ? kNoTrieNodeId : it->second; +} + +template +inline std::istream &NestedTrieTopology::Read( + std::istream &strm) { // NOLINT + NestedTrieTopology new_trie; + size_t num_nodes; + if (!ReadType(strm, &num_nodes)) return strm; + for (size_t i = 1; i < num_nodes; ++i) new_trie.nodes_.push_back(new NextMap); + for (size_t i = 0; i < num_nodes; ++i) ReadType(strm, new_trie.nodes_[i]); + if (strm) swap(new_trie); + return strm; +} + +template +inline std::ostream &NestedTrieTopology::Write( + std::ostream &strm) const { // NOLINT + WriteType(strm, NumNodes()); + for (size_t i = 0; i < NumNodes(); ++i) WriteType(strm, *nodes_[i]); + return strm; +} + +template +inline typename NestedTrieTopology::const_iterator + &NestedTrieTopology::const_iterator::operator++() { + ++cur_edge_; + if (cur_edge_ == ptr_->nodes_[cur_node_]->end()) { + ++cur_node_; + while (cur_node_ < ptr_->NumNodes() && ptr_->nodes_[cur_node_]->empty()) + ++cur_node_; + SetProperCurEdge(); + } + return *this; +} + +template +inline typename NestedTrieTopology::const_iterator + &NestedTrieTopology::const_iterator::operator++(int) { // NOLINT + const_iterator save(*this); + ++(*this); + return save; +} + +// The trie topology in a single hash map; only allows iteration over +// all the edges in arbitrary order. +template +class FlatTrieTopology { + private: + typedef std::unordered_map, int, ParentLabelHash> + NextMap; + + public: + // Iterator over edges as std::pair, int> + typedef typename NextMap::const_iterator const_iterator; + typedef L Label; + typedef H Hash; + + FlatTrieTopology() {} + FlatTrieTopology(const FlatTrieTopology &that) : next_(that.next_) {} + template + explicit FlatTrieTopology(const T &that); + + // TODO(wuke): std::swap compatibility + void swap(FlatTrieTopology &that) { next_.swap(that.next_); } + + bool operator==(const FlatTrieTopology &that) const { + return next_ == that.next_; + } + bool operator!=(const FlatTrieTopology &that) const { + return !(*this == that); + } + + int Root() const { return 0; } + size_t NumNodes() const { return next_.size() + 1; } + int Insert(int parent, const L &label); + int Find(int parent, const L &label) const; + + std::istream &Read(std::istream &strm) { // NOLINT + return ReadType(strm, &next_); + } + std::ostream &Write(std::ostream &strm) const { // NOLINT + return WriteType(strm, next_); + } + + const_iterator begin() const { return next_.begin(); } + const_iterator end() const { return next_.end(); } + + private: + NextMap next_; +}; + +template +template +FlatTrieTopology::FlatTrieTopology(const T &that) + : next_(that.begin(), that.end()) {} + +template +inline int FlatTrieTopology::Insert(int parent, const L &label) { + int ret = Find(parent, label); + if (ret == kNoTrieNodeId) { + ret = NumNodes(); + next_[ParentLabel(parent, label)] = ret; + } + return ret; +} + +template +inline int FlatTrieTopology::Find(int parent, const L &label) const { + typename NextMap::const_iterator it = + next_.find(ParentLabel(parent, label)); + return it == next_.end() ? kNoTrieNodeId : it->second; +} + +// A collection of implementations of the trie data structure. The key +// is a sequence of type `L` which must be hashable. The value is of +// `V` which must be default constructible and copyable. In addition, +// a value object is stored for each node in the trie therefore +// copying `V` should be cheap. +// +// One can access the store values with an integer node id, using the +// [] operator. A valid node id can be obtained by the following ways: +// +// 1. Using the `Root()` method to get the node id of the root. +// +// 2. Iterating through 0 to `NumNodes() - 1`. The node ids are dense +// so every integer in this range is a valid node id. +// +// 3. Using the node id returned from a successful `Insert()` or +// `Find()` call. +// +// 4. Iterating over the trie edges with an `EdgeIterator` and using +// the node ids returned from its `Parent()` and `Child()` methods. +// +// Below is an example of inserting keys into the trie: +// +// const string words[] = {"hello", "health", "jello"}; +// Trie dict; +// for (auto word : words) { +// int cur = dict.Root(); +// for (char c : word) { +// cur = dict.Insert(cur, c); +// } +// dict[cur] = true; +// } +// +// And the following is an example of looking up the longest prefix of +// a string using the trie constructed above: +// +// string query = "healed"; +// size_t prefix_length = 0; +// int cur = dict.Find(dict.Root(), query[prefix_length]); +// while (prefix_length < query.size() && +// cur != Trie::kNoNodeId) { +// ++prefix_length; +// cur = dict.Find(cur, query[prefix_length]); +// } +template +class MutableTrie { + public: + template + friend class MutableTrie; + + typedef L Label; + typedef V Value; + typedef T Topology; + + // Constructs a trie with only the root node. + MutableTrie() {} + + // Conversion from another trie of a possiblly different + // topology. The underlying topology must supported conversion. + template + explicit MutableTrie(const MutableTrie &that) + : topology_(that.topology_), values_(that.values_) {} + + // TODO(wuke): std::swap compatibility + void swap(MutableTrie &that) { + topology_.swap(that.topology_); + values_.swap(that.values_); + } + + int Root() const { return topology_.Root(); } + size_t NumNodes() const { return topology_.NumNodes(); } + + // Inserts an edge with given `label` at node `parent`. Returns the + // child node id. If the node already exists, returns the node id + // right away. + int Insert(int parent, const L &label) { + int ret = topology_.Insert(parent, label); + values_.resize(NumNodes()); + return ret; + } + + // Finds the node id of the node from `parent` via `label`. Returns + // `kNoTrieNodeId` when such a node does not exist. + int Find(int parent, const L &label) const { + return topology_.Find(parent, label); + } + + const T &TrieTopology() const { return topology_; } + + // Accesses the value stored for the given node. + V &operator[](int node_id) { return values_[node_id]; } + const V &operator[](int node_id) const { return values_[node_id]; } + + // Comparison by content + bool operator==(const MutableTrie &that) const { + return topology_ == that.topology_ && values_ == that.values_; + } + + bool operator!=(const MutableTrie &that) const { return !(*this == that); } + + std::istream &Read(std::istream &strm) { // NOLINT + ReadType(strm, &topology_); + ReadType(strm, &values_); + return strm; + } + std::ostream &Write(std::ostream &strm) const { // NOLINT + WriteType(strm, topology_); + WriteType(strm, values_); + return strm; + } + + private: + T topology_; + std::vector values_; +}; + +} // namespace fst + +#endif // FST_EXTENSIONS_LINEAR_TRIE_H_ diff --git a/projects/llm_framework/include/fst/extensions/mpdt/compose.h b/projects/llm_framework/include/fst/extensions/mpdt/compose.h new file mode 100644 index 00000000..47714e37 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/mpdt/compose.h @@ -0,0 +1,267 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Compose an MPDT and an FST. + +#ifndef FST_EXTENSIONS_MPDT_COMPOSE_H_ +#define FST_EXTENSIONS_MPDT_COMPOSE_H_ + +#include + +#include +#include +#include + +namespace fst { + +template +class MPdtParenFilter { + public: + using FST1 = typename Filter::FST1; + using FST2 = typename Filter::FST2; + using Arc = typename Filter::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + + using StackId = StateId; + using ParenStack = internal::MPdtStack; + using FilterState1 = typename Filter::FilterState; + using FilterState2 = IntegerFilterState; + using FilterState = PairFilterState; + + MPdtParenFilter(const FST1 &fst1, const FST2 &fst2, + Matcher1 *matcher1 = nullptr, Matcher2 *matcher2 = nullptr, + const std::vector> *parens = nullptr, + const std::vector(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + const ParenStack &GetStack() const { return GetImpl()->GetStack(); } + + const PdtStateTable &GetStateTable() const { + return GetImpl()->GetStateTable(); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + void operator=(const MPdtExpandFst &) = delete; +}; + +// Specialization for MPdtExpandFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const MPdtExpandFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for MPdtExpandFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const MPdtExpandFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s); + } +}; + +template +inline void MPdtExpandFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +struct MPdtExpandOptions { + bool connect; + bool keep_parentheses; + + explicit MPdtExpandOptions(bool connect = true, bool keep_parentheses = false) + : connect(connect), keep_parentheses(keep_parentheses) {} +}; + +// Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST. +// This version writes the expanded PDT to a mutable FST. In the MPDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// an MPDT, the parens for each stack must balance on a path. The open-close +// parenthesis label pair sets are passed using the parens argument, and the +// assignment of those pairs to stacks is passed using the assignments argument. +// The expansion enforces the parenthesis constraints. The MPDT must be +// expandable as an FST. +template +void Expand(const Fst &ifst, + const std::vector< + std::pair> &parens, + const std::vector &assignments, + MutableFst *ofst, const MPdtExpandOptions &opts) { + MPdtExpandFstOptions eopts; + eopts.gc_limit = 0; + eopts.keep_parentheses = opts.keep_parentheses; + *ofst = MPdtExpandFst(ifst, parens, assignments, eopts); + if (opts.connect) Connect(ofst); +} + +// Expands a multi-pushdown transducer (MPDT) encoded as an FST into an FST. +// This version writes the expanded PDT to a mutable FST. In the MPDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// an MPDT, the parens for each stack must balance on a path. The open-close +// parenthesis label pair sets are passed using the parens argument, and the +// assignment of those pairs to stacks is passed using the assignments argument. +// The expansion enforces the parenthesis constraints. The MPDT must be +// expandable as an FST. +template +void Expand(const Fst &ifst, + const std::vector> &parens, + const std::vector &assignments, + MutableFst *ofst, bool connect = true, + bool keep_parentheses = false) { + const MPdtExpandOptions opts(connect, keep_parentheses); + Expand(ifst, parens, assignments, ofst, opts); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_MPDT_EXPAND_H_ diff --git a/projects/llm_framework/include/fst/extensions/mpdt/info.h b/projects/llm_framework/include/fst/extensions/mpdt/info.h new file mode 100644 index 00000000..512fcfa8 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/mpdt/info.h @@ -0,0 +1,190 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Prints information about an MPDT. + +#ifndef FST_EXTENSIONS_MPDT_INFO_H_ +#define FST_EXTENSIONS_MPDT_INFO_H_ + +#include +#include + +#include +#include + +namespace fst { + +// Compute various information about MPDTs, helper class for mpdtinfo.cc. +template +class MPdtInfo { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MPdtInfo(const Fst &fst, + const std::vector> &parens, + const std::vector { + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::WriteHeader; + + friend class ArcIterator>; + friend class NGramFstMatcher; + + public: + using FstImpl::InputSymbols; + using FstImpl::SetProperties; + using FstImpl::Properties; + + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + NGramFstImpl() { + SetType("ngram"); + SetInputSymbols(nullptr); + SetOutputSymbols(nullptr); + SetProperties(kStaticProperties); + } + + NGramFstImpl(const Fst &fst, std::vector *order_out); + + explicit NGramFstImpl(const Fst &fst) : NGramFstImpl(fst, nullptr) {} + + NGramFstImpl(const NGramFstImpl &other) { + FSTERROR() << "Copying NGramFst Impls is not supported, use safe = false."; + SetProperties(kError, kError); + } + + ~NGramFstImpl() override { + if (owned_) { + delete[] data_; + } + } + + static NGramFstImpl *Read(std::istream &strm, // NOLINT + const FstReadOptions &opts) { + NGramFstImpl *impl = new NGramFstImpl(); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return 0; + uint64 num_states, num_futures, num_final; + const size_t offset = + sizeof(num_states) + sizeof(num_futures) + sizeof(num_final); + // Peek at num_states and num_futures to see how much more needs to be read. + strm.read(reinterpret_cast(&num_states), sizeof(num_states)); + strm.read(reinterpret_cast(&num_futures), sizeof(num_futures)); + strm.read(reinterpret_cast(&num_final), sizeof(num_final)); + size_t size = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(size); + char *data = reinterpret_cast(data_region->mutable_data()); + // Copy num_states, num_futures and num_final back into data. + memcpy(data, reinterpret_cast(&num_states), sizeof(num_states)); + memcpy(data + sizeof(num_states), reinterpret_cast(&num_futures), + sizeof(num_futures)); + memcpy(data + sizeof(num_states) + sizeof(num_futures), + reinterpret_cast(&num_final), sizeof(num_final)); + strm.read(data + offset, size - offset); + if (strm.fail()) { + delete impl; + return nullptr; + } + impl->Init(data, false, data_region); + return impl; + } + + bool Write(std::ostream &strm, // NOLINT + const FstWriteOptions &opts) const { + FstHeader hdr; + hdr.SetStart(Start()); + hdr.SetNumStates(num_states_); + WriteHeader(strm, opts, kFileVersion, &hdr); + strm.write(data_, StorageSize()); + return !strm.fail(); + } + + StateId Start() const { return start_; } + + Weight Final(StateId state) const { + if (final_index_.Get(state)) { + return final_probs_[final_index_.Rank1(state)]; + } else { + return Weight::Zero(); + } + } + + size_t NumArcs(StateId state, NGramFstInst *inst = nullptr) const { + if (inst == nullptr) { + const std::pair zeros = + (state == 0) ? select_root_ : future_index_.Select0s(state); + return zeros.second - zeros.first - 1; + } + SetInstFuture(state, inst); + return inst->num_futures_ + ((state == 0) ? 0 : 1); + } + + size_t NumInputEpsilons(StateId state) const { + // State 0 has no parent, thus no backoff. + if (state == 0) return 0; + return 1; + } + + size_t NumOutputEpsilons(StateId state) const { + return NumInputEpsilons(state); + } + + StateId NumStates() const { return num_states_; } + + void InitStateIterator(StateIteratorData *data) const { + data->base = 0; + data->nstates = num_states_; + } + + static size_t Storage(uint64 num_states, uint64 num_futures, + uint64 num_final) { + uint64 b64; + Weight weight; + Label label; + size_t offset = + sizeof(num_states) + sizeof(num_futures) + sizeof(num_final); + offset += + sizeof(b64) * (BitmapIndex::StorageSize(num_states * 2 + 1) + + BitmapIndex::StorageSize(num_futures + num_states + 1) + + BitmapIndex::StorageSize(num_states)); + offset += (num_states + 1) * sizeof(label) + num_futures * sizeof(label); + // Pad for alignemnt, see + // http://en.wikipedia.org/wiki/Data_structure_alignment#Computing_padding + offset = (offset + sizeof(weight) - 1) & ~(sizeof(weight) - 1); + offset += (num_states + 1) * sizeof(weight) + num_final * sizeof(weight) + + (num_futures + 1) * sizeof(weight); + return offset; + } + + void SetInstFuture(StateId state, NGramFstInst *inst) const { + if (inst->state_ != state) { + inst->state_ = state; + const std::pair zeros = future_index_.Select0s(state); + inst->num_futures_ = zeros.second - zeros.first - 1; + inst->offset_ = future_index_.Rank1(zeros.first + 1); + } + } + + void SetInstNode(NGramFstInst *inst) const { + if (inst->node_state_ != inst->state_) { + inst->node_state_ = inst->state_; + inst->node_ = context_index_.Select1(inst->state_); + } + } + + void SetInstContext(NGramFstInst *inst) const { + SetInstNode(inst); + if (inst->context_state_ != inst->state_) { + inst->context_state_ = inst->state_; + inst->context_.clear(); + size_t node = inst->node_; + while (node != 0) { + inst->context_.push_back(context_words_[context_index_.Rank1(node)]); + node = context_index_.Select1(context_index_.Rank0(node) - 1); + } + } + } + + // Access to the underlying representation + const char *GetData(size_t *data_size) const { + *data_size = StorageSize(); + return data_; + } + + void Init(const char *data, bool owned, MappedFile *file = nullptr); + + const std::vector *inst) const { + SetInstFuture(s, inst); + SetInstContext(inst); + return inst->context_; + } + + size_t StorageSize() const { + return Storage(num_states_, num_futures_, num_final_); + } + + void GetStates(const std::vector::GetStates( + const std::vector; + + public: + typedef A Arc; + typedef typename A::StateId StateId; + typedef typename A::Label Label; + typedef typename A::Weight Weight; + typedef internal::NGramFstImpl Impl; + + explicit NGramFst(const Fst &dst) + : ImplToExpandedFst(std::make_shared(dst, nullptr)) {} + + NGramFst(const Fst &fst, std::vector *order_out) + : ImplToExpandedFst(std::make_shared(fst, order_out)) {} + + // Because the NGramFstImpl is a const stateless data structure, there + // is never a need to do anything beside copy the reference. + NGramFst(const NGramFst &fst, bool safe = false) + : ImplToExpandedFst(fst, false) {} + + NGramFst() : ImplToExpandedFst(std::make_shared()) {} + + // Non-standard constructor to initialize NGramFst directly from data. + NGramFst(const char *data, bool owned) + : ImplToExpandedFst(std::make_shared()) { + GetMutableImpl()->Init(data, owned, nullptr); + } + + // Get method that gets the data associated with Init(). + const char *GetData(size_t *data_size) const { + return GetImpl()->GetData(data_size); + } + + const std::vector *Copy(bool safe = false) const override { + return new NGramFst(*this, safe); + } + + static NGramFst *Read(std::istream &strm, const FstReadOptions &opts) { + Impl *impl = Impl::Read(strm, opts); + return impl ? new NGramFst(std::shared_ptr(impl)) : nullptr; + } + + static NGramFst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm.good()) { + LOG(ERROR) << "NGramFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + inline void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + inline void InitArcIterator(StateId s, + ArcIteratorData *data) const override; + + MatcherBase *InitMatcher(MatchType match_type) const override { + return new NGramFstMatcher(this, match_type); + } + + size_t StorageSize() const { return GetImpl()->StorageSize(); } + + static bool HasRequiredProps(const Fst &fst) { + static const auto props = + kAcceptor | kIDeterministic | kILabelSorted | kIEpsilons | kAccessible; + return fst.Properties(props, true) == props; + } + + static bool HasRequiredStructure(const Fst &fst) { + if (!HasRequiredProps(fst)) { + return false; + } + typename A::StateId unigram = fst.Start(); + while (true) { // Follows epsilon arc chain to find unigram state. + if (unigram == fst::kNoStateId) return false; // No unigram state. + typename fst::ArcIterator> aiter(fst, unigram); + if (aiter.Done() || aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + aiter.Next(); + } + // Other requirement: all states other than unigram an epsilon arc. + for (fst::StateIterator> siter(fst); !siter.Done(); + siter.Next()) { + const typename A::StateId &state = siter.Value(); + fst::ArcIterator> aiter(fst, state); + if (state != unigram) { + if (aiter.Done()) return false; + if (aiter.Value().ilabel != 0) return false; + aiter.Next(); + if (!aiter.Done() && aiter.Value().ilabel == 0) return false; + } + } + return true; + } + + private: + using ImplToExpandedFst>::GetImpl; + using ImplToExpandedFst>::GetMutableImpl; + + explicit NGramFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + mutable NGramFstInst inst_; +}; + +template +inline void NGramFst::InitArcIterator(StateId s, + ArcIteratorData *data) const { + GetImpl()->SetInstFuture(s, &inst_); + GetImpl()->SetInstNode(&inst_); + data->base = new ArcIterator>(*this, s); +} + +namespace internal { + +template +NGramFstImpl::NGramFstImpl(const Fst &fst, + std::vector *order_out) { + typedef A Arc; + typedef typename Arc::Label Label; + typedef typename Arc::Weight Weight; + typedef typename Arc::StateId StateId; + SetType("ngram"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + SetProperties(kStaticProperties); + + // Check basic requirements for an OpenGrm language model Fst. + if (!NGramFst::HasRequiredProps(fst)) { + FSTERROR() << "NGramFst only accepts OpenGrm language models as input"; + SetProperties(kError, kError); + return; + } + + int64 num_states = CountStates(fst); + Label *context = new Label[num_states]; + + // Find the unigram state by starting from the start state, following + // epsilons. + StateId unigram = fst.Start(); + while (1) { + if (unigram == kNoStateId) { + FSTERROR() << "Could not identify unigram state"; + SetProperties(kError, kError); + return; + } + ArcIterator> aiter(fst, unigram); + if (aiter.Done()) { + LOG(WARNING) << "Unigram state " << unigram << " has no arcs."; + break; + } + if (aiter.Value().ilabel != 0) break; + unigram = aiter.Value().nextstate; + } + + // Each state's context is determined by the subtree it is under from the + // unigram state. + std::queue> label_queue; + std::vector visited(num_states); + // Force an epsilon link to the start state. + label_queue.push(std::make_pair(fst.Start(), 0)); + for (ArcIterator> aiter(fst, unigram); !aiter.Done(); aiter.Next()) { + label_queue.push( + std::make_pair(aiter.Value().nextstate, aiter.Value().ilabel)); + } + // investigate states in breadth first fashion to assign context words. + while (!label_queue.empty()) { + std::pair &now = label_queue.front(); + if (!visited[now.first]) { + context[now.first] = now.second; + visited[now.first] = true; + for (ArcIterator> aiter(fst, now.first); !aiter.Done(); + aiter.Next()) { + const Arc &arc = aiter.Value(); + if (arc.ilabel != 0) { + label_queue.push(std::make_pair(arc.nextstate, now.second)); + } + } + } + label_queue.pop(); + } + visited.clear(); + + // The arc from the start state should be assigned an epsilon to put it + // in front of the all other labels (which makes Start state 1 after + // unigram which is state 0). + context[fst.Start()] = 0; + + // Build the tree of contexts fst by reversing the epsilon arcs from fst. + VectorFst context_fst; + uint64 num_final = 0; + for (int i = 0; i < num_states; ++i) { + if (fst.Final(i) != Weight::Zero()) { + ++num_final; + } + context_fst.SetFinal(context_fst.AddState(), fst.Final(i)); + } + context_fst.SetStart(unigram); + context_fst.SetInputSymbols(fst.InputSymbols()); + context_fst.SetOutputSymbols(fst.OutputSymbols()); + int64 num_context_arcs = 0; + int64 num_futures = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + const StateId &state = siter.Value(); + num_futures += fst.NumArcs(state) - fst.NumInputEpsilons(state); + ArcIterator> aiter(fst, state); + if (!aiter.Done()) { + const Arc &arc = aiter.Value(); + // this arc goes from state to arc.nextstate, so create an arc from + // arc.nextstate to state to reverse it. + if (arc.ilabel == 0) { + context_fst.AddArc(arc.nextstate, Arc(context[state], context[state], + arc.weight, state)); + num_context_arcs++; + } + } + } + if (num_context_arcs != context_fst.NumStates() - 1) { + FSTERROR() << "Number of contexts arcs != number of states - 1"; + SetProperties(kError, kError); + return; + } + if (context_fst.NumStates() != num_states) { + FSTERROR() << "Number of contexts != number of states"; + SetProperties(kError, kError); + return; + } + int64 context_props = + context_fst.Properties(kIDeterministic | kILabelSorted, true); + if (!(context_props & kIDeterministic)) { + FSTERROR() << "Input Fst is not structured properly"; + SetProperties(kError, kError); + return; + } + if (!(context_props & kILabelSorted)) { + ArcSort(&context_fst, ILabelCompare()); + } + + delete[] context; + + uint64 b64; + Weight weight; + Label label = kNoLabel; + const size_t storage = Storage(num_states, num_futures, num_final); + MappedFile *data_region = MappedFile::Allocate(storage); + char *data = reinterpret_cast(data_region->mutable_data()); + memset(data, 0, storage); + size_t offset = 0; + memcpy(data + offset, reinterpret_cast(&num_states), + sizeof(num_states)); + offset += sizeof(num_states); + memcpy(data + offset, reinterpret_cast(&num_futures), + sizeof(num_futures)); + offset += sizeof(num_futures); + memcpy(data + offset, reinterpret_cast(&num_final), + sizeof(num_final)); + offset += sizeof(num_final); + uint64 *context_bits = reinterpret_cast(data + offset); + offset += BitmapIndex::StorageSize(num_states * 2 + 1) * sizeof(b64); + uint64 *future_bits = reinterpret_cast(data + offset); + offset += + BitmapIndex::StorageSize(num_futures + num_states + 1) * sizeof(b64); + uint64 *final_bits = reinterpret_cast(data + offset); + offset += BitmapIndex::StorageSize(num_states) * sizeof(b64); + Label *context_words = reinterpret_cast::Init(const char *data, bool owned, + MappedFile *data_region) { + if (owned_) { + delete[] data_; + } + data_region_.reset(data_region); + owned_ = owned; + data_ = data; + size_t offset = 0; + num_states_ = *(reinterpret_cast(data_ + offset)); + offset += sizeof(num_states_); + num_futures_ = *(reinterpret_cast(data_ + offset)); + offset += sizeof(num_futures_); + num_final_ = *(reinterpret_cast(data_ + offset)); + offset += sizeof(num_final_); + uint64 bits; + size_t context_bits = num_states_ * 2 + 1; + size_t future_bits = num_futures_ + num_states_ + 1; + context_ = reinterpret_cast(data_ + offset); + offset += BitmapIndex::StorageSize(context_bits) * sizeof(bits); + future_ = reinterpret_cast(data_ + offset); + offset += BitmapIndex::StorageSize(future_bits) * sizeof(bits); + final_ = reinterpret_cast(data_ + offset); + offset += BitmapIndex::StorageSize(num_states_) * sizeof(bits); + context_words_ = reinterpret_cast(data_ + offset); + offset += (num_states_ + 1) * sizeof(*context_words_); + future_words_ = reinterpret_cast(data_ + offset); + offset += num_futures_ * sizeof(*future_words_); + offset = (offset + sizeof(*backoff_) - 1) & ~(sizeof(*backoff_) - 1); + backoff_ = reinterpret_cast(data_ + offset); + offset += (num_states_ + 1) * sizeof(*backoff_); + final_probs_ = reinterpret_cast(data_ + offset); + offset += num_final_ * sizeof(*final_probs_); + future_probs_ = reinterpret_cast(data_ + offset); + + context_index_.BuildIndex(context_, context_bits); + future_index_.BuildIndex(future_, future_bits); + final_index_.BuildIndex(final_, num_states_); + + select_root_ = context_index_.Select0s(0); + if (context_index_.Rank1(0) != 0 || select_root_.first != 1 || + context_index_.Get(2) == false) { + FSTERROR() << "Malformed file"; + SetProperties(kError, kError); + return; + } + root_children_ = context_words_ + context_index_.Rank1(2); + start_ = 1; +} + +template +inline typename A::StateId NGramFstImpl::Transition( + const std::vector { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + // This makes a copy of the FST. + NGramFstMatcher(const NGramFst &fst, MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + inst_(fst_.inst_), + match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + } + + // This doesn't copy the FST. + NGramFstMatcher(const NGramFst *fst, MatchType match_type) + : fst_(*fst), + inst_(fst_.inst_), + match_type_(match_type), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + } + + // This makes a copy of the FST. + NGramFstMatcher(const NGramFstMatcher &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + inst_(matcher.inst_), + match_type_(matcher.match_type_), + current_loop_(false), + loop_(kNoLabel, 0, A::Weight::One(), kNoStateId) { + if (match_type_ == MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + } + + NGramFstMatcher *Copy(bool safe = false) const override { + return new NGramFstMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return match_type_; } + + const Fst &GetFst() const override { return fst_; } + + uint64 Properties(uint64 props) const override { return props; } + + void SetState(StateId s) final { + fst_.GetImpl()->SetInstFuture(s, &inst_); + current_loop_ = false; + } + + bool Find(Label label) final { + const Label nolabel = kNoLabel; + done_ = true; + if (label == 0 || label == nolabel) { + if (label == 0) { + current_loop_ = true; + loop_.nextstate = inst_.state_; + } + // The unigram state has no epsilon arc. + if (inst_.state_ != 0) { + arc_.ilabel = arc_.olabel = 0; + fst_.GetImpl()->SetInstNode(&inst_); + arc_.nextstate = fst_.GetImpl()->context_index_.Rank1( + fst_.GetImpl()->context_index_.Select1( + fst_.GetImpl()->context_index_.Rank0(inst_.node_) - 1)); + arc_.weight = fst_.GetImpl()->backoff_[inst_.state_]; + done_ = false; + } + } else { + current_loop_ = false; + const Label *start = fst_.GetImpl()->future_words_ + inst_.offset_; + const Label *end = start + inst_.num_futures_; + const Label *search = std::lower_bound(start, end, label); + if (search != end && *search == label) { + size_t state = search - start; + arc_.ilabel = arc_.olabel = label; + arc_.weight = fst_.GetImpl()->future_probs_[inst_.offset_ + state]; + fst_.GetImpl()->SetInstContext(&inst_); + arc_.nextstate = fst_.GetImpl()->Transition(inst_.context_, label); + done_ = false; + } + } + return !Done(); + } + + bool Done() const final { return !current_loop_ && done_; } + + const Arc &Value() const final { return (current_loop_) ? loop_ : arc_; } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else { + done_ = true; + } + } + + ssize_t Priority(StateId s) final { return fst_.NumArcs(s); } + + private: + std::unique_ptr> owned_fst_; + const NGramFst &fst_; + NGramFstInst inst_; + MatchType match_type_; // Supplied by caller + bool done_; + Arc arc_; + bool current_loop_; // Current arc is the implicit loop + Arc loop_; +}; + +/*****************************************************************************/ +// Specialization for NGramFst; see generic version in fst.h +// for sample usage (but use the ProdLmFst type!). This version +// should inline. +template +class StateIterator> : public StateIteratorBase { + public: + typedef typename A::StateId StateId; + + explicit StateIterator(const NGramFst &fst) + : s_(0), num_states_(fst.NumStates()) {} + + bool Done() const final { return s_ >= num_states_; } + + StateId Value() const final { return s_; } + + void Next() final { ++s_; } + + void Reset() final { s_ = 0; } + + private: + StateId s_; + StateId num_states_; +}; + +/*****************************************************************************/ +template +class ArcIterator> : public ArcIteratorBase { + public: + typedef A Arc; + typedef typename A::Label Label; + typedef typename A::StateId StateId; + typedef typename A::Weight Weight; + + ArcIterator(const NGramFst &fst, StateId state) + : lazy_(~0), impl_(fst.GetImpl()), i_(0), flags_(kArcValueFlags) { + inst_ = fst.inst_; + impl_->SetInstFuture(state, &inst_); + impl_->SetInstNode(&inst_); + } + + bool Done() const final { + return i_ >= + ((inst_.node_ == 0) ? inst_.num_futures_ : inst_.num_futures_ + 1); + } + + const Arc &Value() const final { + bool eps = (inst_.node_ != 0 && i_ == 0); + StateId state = (inst_.node_ == 0) ? i_ : i_ - 1; + if (flags_ & lazy_ & (kArcILabelValue | kArcOLabelValue)) { + arc_.ilabel = arc_.olabel = + eps ? 0 : impl_->future_words_[inst_.offset_ + state]; + lazy_ &= ~(kArcILabelValue | kArcOLabelValue); + } + if (flags_ & lazy_ & kArcNextStateValue) { + if (eps) { + arc_.nextstate = + impl_->context_index_.Rank1(impl_->context_index_.Select1( + impl_->context_index_.Rank0(inst_.node_) - 1)); + } else { + if (lazy_ & kArcNextStateValue) { + impl_->SetInstContext(&inst_); // first time only. + } + arc_.nextstate = impl_->Transition( + inst_.context_, impl_->future_words_[inst_.offset_ + state]); + } + lazy_ &= ~kArcNextStateValue; + } + if (flags_ & lazy_ & kArcWeightValue) { + arc_.weight = eps ? impl_->backoff_[inst_.state_] + : impl_->future_probs_[inst_.offset_ + state]; + lazy_ &= ~kArcWeightValue; + } + return arc_; + } + + void Next() final { + ++i_; + lazy_ = ~0; + } + + size_t Position() const final { return i_; } + + void Reset() final { + i_ = 0; + lazy_ = ~0; + } + + void Seek(size_t a) final { + if (i_ != a) { + i_ = a; + lazy_ = ~0; + } + } + + uint32 Flags() const final { return flags_; } + + void SetFlags(uint32 flags, uint32 mask) final { + flags_ &= ~mask; + flags_ |= (flags & kArcValueFlags); + } + + private: + mutable Arc arc_; + mutable uint32 lazy_; + const internal::NGramFstImpl *impl_; // Borrowed reference. + mutable NGramFstInst inst_; + + size_t i_; + uint32 flags_; +}; + +} // namespace fst +#endif // FST_EXTENSIONS_NGRAM_NGRAM_FST_H_ diff --git a/projects/llm_framework/include/fst/extensions/ngram/nthbit.h b/projects/llm_framework/include/fst/extensions/ngram/nthbit.h new file mode 100644 index 00000000..1e6ec635 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/ngram/nthbit.h @@ -0,0 +1,49 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_NGRAM_NTHBIT_H_ +#define FST_EXTENSIONS_NGRAM_NTHBIT_H_ + +#include +#include + +#ifdef __BMI2__ +// PDEP requires BMI2. + +// Returns the position (0-63) of the r-th 1 bit in v. +// 1 <= r <= CountOnes(v) <= 64. Therefore, v must not be 0. +inline uint32 nth_bit(uint64 v, uint32 r) { + // PDEP example from https://stackoverflow.com/a/27453505 + return __builtin_ctzll(_pdep_u64(uint64{1} << (r - 1), v)); +} + +#else // !defined(__BMI2__) + +extern const uint32 nth_bit_bit_offset[]; + +// Returns the position (0-63) of the r-th 1 bit in v. +// 1 <= r <= CountOnes(v) <= 64. Therefore, v must not be 0. +inline uint32 nth_bit(uint64 v, uint32 r) { + uint32 shift = 0; + uint32 c = __builtin_popcount(v & 0xffffffff); + uint32 mask = -(r > c); + r -= c & mask; + shift += (32 & mask); + + c = __builtin_popcount((v >> shift) & 0xffff); + mask = -(r > c); + r -= c & mask; + shift += (16 & mask); + + c = __builtin_popcount((v >> shift) & 0xff); + mask = -(r > c); + r -= c & mask; + shift += (8 & mask); + + return shift + + ((nth_bit_bit_offset[(v >> shift) & 0xff] >> ((r - 1) << 2)) & 0xf); +} + +#endif // !defined(__BMI2__) + +#endif // FST_EXTENSIONS_NGRAM_NTHBIT_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/collection.h b/projects/llm_framework/include/fst/extensions/pdt/collection.h new file mode 100644 index 00000000..ae34aba7 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/collection.h @@ -0,0 +1,107 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to store a collection of ordered (multi-)sets with elements of type T. + +#ifndef FST_EXTENSIONS_PDT_COLLECTION_H_ +#define FST_EXTENSIONS_PDT_COLLECTION_H_ + +#include +#include + +#include +#include + +namespace fst { + +// Stores a collection of non-empty, ordered (multi-)sets with elements of type +// T. A default constructor, operator==, and an STL-style hash functor must be +// defined on the elements. Provides signed integer ID (of type I) for each +// unique set. The IDs are allocated starting from 0 in order. +template +class Collection { + public: + struct Node { // Trie node. + I node_id; // Root is kNoNodeId; + T element; + + Node() : node_id(kNoNodeId), element(T()) {} + + Node(I i, const T &t) : node_id(i), element(t) {} + + bool operator==(const Node &n) const { + return n.node_id == node_id && n.element == element; + } + }; + + struct NodeHash { + size_t operator()(const Node &n) const { + static constexpr auto kPrime = 7853; + return n.node_id + hash_(n.element) * kPrime; + } + }; + + using NodeTable = CompactHashBiTable; + + class SetIterator { + public: + SetIterator(I id, Node node, NodeTable *node_table) + : id_(id), node_(node), node_table_(node_table) {} + + bool Done() const { return id_ == kNoNodeId; } + + const T &Element() const { return node_.element; } + + void Next() { + id_ = node_.node_id; + if (id_ != kNoNodeId) node_ = node_table_->FindEntry(id_); + } + + private: + I id_; // Iterator set node ID. + Node node_; // Iterator set node. + NodeTable *node_table_; + }; + + Collection() {} + + // Looks up integer ID from ordered multi-se, and if it doesn't exist and + // insert is true, then adds it. Otherwise returns -1. + I FindId(const std::vector &set, bool insert = true) { + I node_id = kNoNodeId; + for (ssize_t i = set.size() - 1; i >= 0; --i) { + Node node(node_id, set[i]); + node_id = node_table_.FindId(node, insert); + if (node_id == -1) break; + } + return node_id; + } + + // Finds ordered (multi-)set given integer ID. Returns set iterator to + // traverse result. + SetIterator FindSet(I id) { + if (id < 0 || id >= node_table_.Size()) { + return SetIterator(kNoNodeId, Node(kNoNodeId, T()), &node_table_); + } else { + return SetIterator(id, node_table_.FindEntry(id), &node_table_); + } + } + + I Size() const { return node_table_.Size(); } + + private: + static constexpr I kNoNodeId = -1; + static const std::hash hash_; + + NodeTable node_table_; +}; + +template +constexpr I Collection::kNoNodeId; + +template +const std::hash Collection::hash_ = {}; + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_COLLECTION_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/compose.h b/projects/llm_framework/include/fst/extensions/pdt/compose.h new file mode 100644 index 00000000..525d613a --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/compose.h @@ -0,0 +1,493 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Composes a PDT and an FST. + +#ifndef FST_EXTENSIONS_PDT_COMPOSE_H_ +#define FST_EXTENSIONS_PDT_COMPOSE_H_ + +#include + +#include +#include + +namespace fst { + +// Returns paren arcs for Find(kNoLabel). +constexpr uint32 kParenList = 0x00000001; + +// Returns a kNolabel loop for Find(paren). +constexpr uint32 kParenLoop = 0x00000002; + +// This class is a matcher that treats parens as multi-epsilon labels. +// It is most efficient if the parens are in a range non-overlapping with +// the non-paren labels. +template +class ParenMatcher { + public: + using FST = F; + using M = SortedMatcher; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST. + ParenMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), match_type_(match_type), flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + // This doesn't copy the FST. + ParenMatcher(const FST *fst, MatchType match_type, + uint32 flags = (kParenLoop | kParenList)) + : matcher_(fst, match_type), match_type_(match_type), flags_(flags) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + // This makes a copy of the FST. + ParenMatcher(const ParenMatcher &matcher, bool safe = false) + : matcher_(matcher.matcher_, safe), + match_type_(matcher.match_type_), + flags_(matcher.flags_), + open_parens_(matcher.open_parens_), + close_parens_(matcher.close_parens_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ParenMatcher *Copy(bool safe = false) const { + return new ParenMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_.Type(test); } + + void SetState(StateId s) { + matcher_.SetState(s); + loop_.nextstate = s; + } + + bool Find(Label match_label); + + bool Done() const { return done_; } + + const Arc &Value() const { return paren_loop_ ? loop_ : matcher_.Value(); } + + void Next(); + + Weight Final(StateId s) { return matcher_.Final(s); } + + ssize_t Priority(StateId s) { return matcher_.Priority(s); } + + const FST &GetFst() const { return matcher_.GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_.Properties(props); } + + uint32 Flags() const { return matcher_.Flags(); } + + void AddOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Insert(label); + } + } + + void AddCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Insert(label); + } + } + + void RemoveOpenParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad open paren label: 0"; + } else { + open_parens_.Erase(label); + } + } + + void RemoveCloseParen(Label label) { + if (label == 0) { + FSTERROR() << "ParenMatcher: Bad close paren label: 0"; + } else { + close_parens_.Erase(label); + } + } + + void ClearOpenParens() { open_parens_.Clear(); } + + void ClearCloseParens() { close_parens_.Clear(); } + + bool IsOpenParen(Label label) const { return open_parens_.Member(label); } + + bool IsCloseParen(Label label) const { return close_parens_.Member(label); } + + private: + // Advances matcher to next open paren, returning true if it exists. + bool NextOpenParen(); + + // Advances matcher to next close paren, returning true if it exists. + bool NextCloseParen(); + + M matcher_; + MatchType match_type_; // Type of match to perform. + uint32 flags_; + // Open paren label set. + CompactSet open_parens_; + // Close paren label set. + CompactSet close_parens_; + bool open_paren_list_; // Matching open paren list? + bool close_paren_list_; // Matching close paren list? + bool paren_loop_; // Current arc is the implicit paren loop? + mutable Arc loop_; // For non-consuming symbols. + bool done_; // Matching done? + + ParenMatcher &operator=(const ParenMatcher &) = delete; +}; + +template +inline bool ParenMatcher::Find(Label match_label) { + open_paren_list_ = false; + close_paren_list_ = false; + paren_loop_ = false; + done_ = false; + // Returns all parenthesis arcs. + if (match_label == kNoLabel && (flags_ & kParenList)) { + if (open_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(open_parens_.LowerBound()); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return true; + } + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return true; + } + } + // Returns the implicit paren loop. + if (match_label > 0 && (flags_ & kParenLoop) && + (IsOpenParen(match_label) || IsCloseParen(match_label))) { + paren_loop_ = true; + return true; + } + // Returns all other labels. + if (matcher_.Find(match_label)) return true; + done_ = true; + return false; +} + +template +inline void ParenMatcher::Next() { + if (paren_loop_) { + paren_loop_ = false; + done_ = true; + } else if (open_paren_list_) { + matcher_.Next(); + open_paren_list_ = NextOpenParen(); + if (open_paren_list_) return; + if (close_parens_.LowerBound() != kNoLabel) { + matcher_.LowerBound(close_parens_.LowerBound()); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + } + done_ = !matcher_.Find(kNoLabel); + } else if (close_paren_list_) { + matcher_.Next(); + close_paren_list_ = NextCloseParen(); + if (close_paren_list_) return; + done_ = !matcher_.Find(kNoLabel); + } else { + matcher_.Next(); + done_ = matcher_.Done(); + } +} + +// Advances matcher to next open paren, returning true if it exists. +template +inline bool ParenMatcher::NextOpenParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? matcher_.Value().ilabel + : matcher_.Value().olabel; + if (label > open_parens_.UpperBound()) return false; + if (IsOpenParen(label)) return true; + } + return false; +} + +// Advances matcher to next close paren, returning true if it exists. +template +inline bool ParenMatcher::NextCloseParen() { + for (; !matcher_.Done(); matcher_.Next()) { + Label label = match_type_ == MATCH_INPUT ? matcher_.Value().ilabel + : matcher_.Value().olabel; + if (label > close_parens_.UpperBound()) return false; + if (IsCloseParen(label)) return true; + } + return false; +} + +template +class ParenFilter { + public: + using FST1 = typename Filter::FST1; + using FST2 = typename Filter::FST2; + using Arc = typename Filter::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Matcher1 = typename Filter::Matcher1; + using Matcher2 = typename Filter::Matcher2; + + using StackId = StateId; + using ParenStack = PdtStack; + using FilterState1 = typename Filter::FilterState; + using FilterState2 = IntegerFilterState; + using FilterState = PairFilterState; + + ParenFilter(const FST1 &fst1, const FST2 &fst2, Matcher1 *matcher1 = nullptr, + Matcher2 *matcher2 = nullptr, + const std::vector> *parens = nullptr, + bool expand = false, bool keep_parens = true) + : filter_(fst1, fst2, matcher1, matcher2), + parens_(parens ? *parens : std::vector>()), + expand_(expand), + keep_parens_(keep_parens), + fs_(FilterState::NoState()), + stack_(parens_), + paren_id_(-1) { + if (parens) { + for (const auto &pair : *parens) { + parens_.push_back(pair); + GetMatcher1()->AddOpenParen(pair.first); + GetMatcher2()->AddOpenParen(pair.first); + if (!expand_) { + GetMatcher1()->AddCloseParen(pair.second); + GetMatcher2()->AddCloseParen(pair.second); + } + } + } + } + + ParenFilter(const ParenFilter &filter, bool safe = false) + : filter_(filter.filter_, safe), + parens_(filter.parens_), + expand_(filter.expand_), + keep_parens_(filter.keep_parens_), + fs_(FilterState::NoState()), + stack_(filter.parens_), + paren_id_(-1) {} + + FilterState Start() const { + return FilterState(filter_.Start(), FilterState2(0)); + } + + void SetState(StateId s1, StateId s2, const FilterState &fs) { + fs_ = fs; + filter_.SetState(s1, s2, fs_.GetState1()); + if (!expand_) return; + ssize_t paren_id = stack_.Top(fs.GetState2().GetState()); + if (paren_id != paren_id_) { + if (paren_id_ != -1) { + GetMatcher1()->RemoveCloseParen(parens_[paren_id_].second); + GetMatcher2()->RemoveCloseParen(parens_[paren_id_].second); + } + paren_id_ = paren_id; + if (paren_id_ != -1) { + GetMatcher1()->AddCloseParen(parens_[paren_id_].second); + GetMatcher2()->AddCloseParen(parens_[paren_id_].second); + } + } + } + + FilterState FilterArc(Arc *arc1, Arc *arc2) const { + const auto fs1 = filter_.FilterArc(arc1, arc2); + const auto &fs2 = fs_.GetState2(); + if (fs1 == FilterState1::NoState()) return FilterState::NoState(); + if (arc1->olabel == kNoLabel && arc2->ilabel) { // arc2 parentheses. + if (keep_parens_) { + arc1->ilabel = arc2->ilabel; + } else if (arc2->ilabel) { + arc2->olabel = arc1->ilabel; + } + return FilterParen(arc2->ilabel, fs1, fs2); + } else if (arc2->ilabel == kNoLabel && arc1->olabel) { // arc1 parentheses. + if (keep_parens_) { + arc2->olabel = arc1->olabel; + } else { + arc1->ilabel = arc2->olabel; + } + return FilterParen(arc1->olabel, fs1, fs2); + } else { + return FilterState(fs1, fs2); + } + } + + void FilterFinal(Weight *w1, Weight *w2) const { + if (fs_.GetState2().GetState() != 0) *w1 = Weight::Zero(); + filter_.FilterFinal(w1, w2); + } + + // Returns respective matchers; ownership stays with filter. + + Matcher1 *GetMatcher1() { return filter_.GetMatcher1(); } + + Matcher2 *GetMatcher2() { return filter_.GetMatcher2(); } + + uint64 Properties(uint64 iprops) const { + return filter_.Properties(iprops) & kILabelInvariantProperties & + kOLabelInvariantProperties; + } + + private: + const FilterState FilterParen(Label label, const FilterState1 &fs1, + const FilterState2 &fs2) const { + if (!expand_) return FilterState(fs1, fs2); + const auto stack_id = stack_.Find(fs2.GetState(), label); + if (stack_id < 0) { + return FilterState::NoState(); + } else { + return FilterState(fs1, FilterState2(stack_id)); + } + } + + Filter filter_; + std::vector> parens_; + bool expand_; // Expands to FST? + bool keep_parens_; // Retains parentheses in output? + FilterState fs_; // Current filter state. + mutable ParenStack stack_; + ssize_t paren_id_; +}; + +// Class to setup composition options for PDT composition. Default is to take +// the PDT as the first composition argument. +template +class PdtComposeFstOptions + : public ComposeFstOptions< + Arc, ParenMatcher>, + ParenFilter>>>> { + public: + using Label = typename Arc::Label; + using PdtMatcher = ParenMatcher>; + using PdtFilter = ParenFilter>; + + using ComposeFstOptions::matcher1; + using ComposeFstOptions::matcher2; + using ComposeFstOptions::filter; + + PdtComposeFstOptions(const Fst &ifst1, + const std::vector> &parens, + const Fst &ifst2, bool expand = false, + bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenList); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenLoop); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand, + keep_parens); + } +}; + +// Class to setup composition options for PDT with FST composition. +// Specialization is for the FST as the first composition argument. +template +class PdtComposeFstOptions + : public ComposeFstOptions< + Arc, ParenMatcher>, + ParenFilter>>>> { + public: + using Label = typename Arc::Label; + using PdtMatcher = ParenMatcher>; + using PdtFilter = ParenFilter>; + + using ComposeFstOptions::matcher1; + using ComposeFstOptions::matcher2; + using ComposeFstOptions::filter; + + PdtComposeFstOptions(const Fst &ifst1, const Fst &ifst2, + const std::vector> &parens, + bool expand = false, bool keep_parens = true) { + matcher1 = new PdtMatcher(ifst1, MATCH_OUTPUT, kParenLoop); + matcher2 = new PdtMatcher(ifst2, MATCH_INPUT, kParenList); + filter = new PdtFilter(ifst1, ifst2, matcher1, matcher2, &parens, expand, + keep_parens); + } +}; + +enum PdtComposeFilter { + PAREN_FILTER, // Bar-Hillel construction; keeps parentheses. + EXPAND_FILTER, // Bar-Hillel + expansion; removes parentheses. + EXPAND_PAREN_FILTER, // Bar-Hillel + expansion; keeps parentheses. +}; + +struct PdtComposeOptions { + bool connect; // Connect output? + PdtComposeFilter filter_type; // Pre-defined filter to use. + + explicit PdtComposeOptions(bool connect = true, + PdtComposeFilter filter_type = PAREN_FILTER) + : connect(connect), filter_type(filter_type) {} +}; + +// Composes pushdown transducer (PDT) encoded as an FST (1st arg) and an FST +// (2nd arg) with the result also a PDT encoded as an FST (3rd arg). In the +// PDTs, some transitions are labeled with open or close parentheses. To be +// interpreted as a PDT, the parens must balance on a path (see PdtExpand()). +// The open-close parenthesis label pairs are passed using the parens argument. +template +void Compose(const Fst &ifst1, + const std::vector< + std::pair> &parens, + const Fst &ifst2, MutableFst *ofst, + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions copts(ifst1, parens, ifst2, expand, + keep_parens); + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + if (opts.connect) Connect(ofst); +} + +// Composes an FST (1st arg) and pushdown transducer (PDT) encoded as an FST +// (2nd arg) with the result also a PDT encoded as an FST (3rd arg). In the +// PDTs, some transitions are labeled with open or close parentheses. To be +// interpreted as a PDT, the parens must balance on a path (see ExpandFst()). +// The open-close parenthesis label pairs are passed using the parens argument. +template +void Compose(const Fst &ifst1, const Fst &ifst2, + const std::vector< + std::pair> &parens, + MutableFst *ofst, + const PdtComposeOptions &opts = PdtComposeOptions()) { + bool expand = opts.filter_type != PAREN_FILTER; + bool keep_parens = opts.filter_type != EXPAND_FILTER; + PdtComposeFstOptions copts(ifst1, ifst2, parens, expand, + keep_parens); + copts.gc_limit = 0; + *ofst = ComposeFst(ifst1, ifst2, copts); + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_COMPOSE_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/expand.h b/projects/llm_framework/include/fst/extensions/pdt/expand.h new file mode 100644 index 00000000..eeee781e --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/expand.h @@ -0,0 +1,933 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Expands a PDT to an FST. + +#ifndef FST_EXTENSIONS_PDT_EXPAND_H_ +#define FST_EXTENSIONS_PDT_EXPAND_H_ + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +template +struct PdtExpandFstOptions : public CacheOptions { + bool keep_parentheses; + PdtStack *stack; + PdtStateTable *state_table; + + explicit PdtExpandFstOptions( + const CacheOptions &opts = CacheOptions(), bool keep_parentheses = false, + PdtStack *stack = nullptr, + PdtStateTable *state_table = + nullptr) + : CacheOptions(opts), + keep_parentheses(keep_parentheses), + stack(stack), + state_table(state_table) {} +}; + +namespace internal { + +// Implementation class for PdtExpandFst. +template +class PdtExpandFstImpl : public CacheImpl { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StackId = StateId; + using StateTuple = PdtStateTuple; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + PdtExpandFstImpl(const Fst &fst, + const std::vector> &parens, + const PdtExpandFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + stack_(opts.stack ? opts.stack : new PdtStack(parens)), + state_table_(opts.state_table ? opts.state_table + : new PdtStateTable()), + own_stack_(opts.stack == 0), + own_state_table_(opts.state_table == 0), + keep_parentheses_(opts.keep_parentheses) { + SetType("expand"); + const auto props = fst.Properties(kFstProperties, false); + SetProperties(PdtExpandProperties(props), kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + PdtExpandFstImpl(const PdtExpandFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + stack_(new PdtStack(*impl.stack_)), + state_table_(new PdtStateTable()), + own_stack_(true), + own_state_table_(true), + keep_parentheses_(impl.keep_parentheses_) { + SetType("expand"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + ~PdtExpandFstImpl() override { + if (own_stack_) delete stack_; + if (own_state_table_) delete state_table_; + } + + StateId Start() { + if (!HasStart()) { + const auto s = fst_->Start(); + if (s == kNoStateId) return kNoStateId; + StateTuple tuple(s, 0); + const auto start = state_table_->FindState(tuple); + SetStart(start); + } + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) { + const auto &tuple = state_table_->Tuple(s); + const auto weight = fst_->Final(tuple.state_id); + if (weight != Weight::Zero() && tuple.stack_id == 0) + SetFinal(s, weight); + else + SetFinal(s, Weight::Zero()); + } + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) ExpandState(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) ExpandState(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) ExpandState(s); + return CacheImpl::NumOutputEpsilons(s); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) ExpandState(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void ExpandState(StateId s) { + StateTuple tuple = state_table_->Tuple(s); + for (ArcIterator> aiter(*fst_, tuple.state_id); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + const auto stack_id = stack_->Find(tuple.stack_id, arc.ilabel); + if (stack_id == -1) { // Non-matching close parenthesis. + continue; + } else if ((stack_id != tuple.stack_id) && !keep_parentheses_) { + // Stack push/pop. + arc.ilabel = 0; + arc.olabel = 0; + } + StateTuple ntuple(arc.nextstate, stack_id); + arc.nextstate = state_table_->FindState(ntuple); + PushArc(s, arc); + } + SetArcs(s); + } + + const PdtStack &GetStack() const { return *stack_; } + + const PdtStateTable &GetStateTable() const { + return *state_table_; + } + + private: + // Properties for an expanded PDT. + inline uint64 PdtExpandProperties(uint64 inprops) { + return inprops & (kAcceptor | kAcyclic | kInitialAcyclic | kUnweighted); + } + + std::unique_ptr> fst_; + PdtStack *stack_; + PdtStateTable *state_table_; + bool own_stack_; + bool own_state_table_; + bool keep_parentheses_; +}; + +} // namespace internal + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. This +// version is a delayed FST. In the PDT, some transitions are labeled with open +// or close parentheses. To be interpreted as a PDT, the parens must balance on +// a path. The open-close parenthesis label pairs are passed using the parens +// argument. The expansion enforces the parenthesis constraints. The PDT must be +// expandable as an FST. +// +// This class attaches interface to implementation and handles reference +// counting, delegating most methods to ImplToFst. +template +class PdtExpandFst : public ImplToFst> { + public: + using Arc = A; + + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StackId = StateId; + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::PdtExpandFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + PdtExpandFst(const Fst &fst, + const std::vector> &parens) + : ImplToFst( + std::make_shared(fst, parens, PdtExpandFstOptions())) {} + + PdtExpandFst(const Fst &fst, + const std::vector> &parens, + const PdtExpandFstOptions &opts) + : ImplToFst(std::make_shared(fst, parens, opts)) {} + + // See Fst<>::Copy() for doc. + PdtExpandFst(const PdtExpandFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this ExpandFst. See Fst<>::Copy() for further doc. + PdtExpandFst *Copy(bool safe = false) const override { + return new PdtExpandFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + const PdtStack &GetStack() const { + return GetImpl()->GetStack(); + } + + const PdtStateTable &GetStateTable() const { + return GetImpl()->GetStateTable(); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + void operator=(const PdtExpandFst &) = delete; +}; + +// Specialization for PdtExpandFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const PdtExpandFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for PdtExpandFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const PdtExpandFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->ExpandState(s); + } +}; + +template +inline void PdtExpandFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// PrunedExpand prunes the delayed expansion of a pushdown transducer (PDT) +// encoded as an FST into an FST. In the PDT, some transitions are labeled with +// open or close parentheses. To be interpreted as a PDT, the parens must +// balance on a path. The open-close parenthesis label pairs are passed +// using the parens argument. The expansion enforces the parenthesis +// constraints. +// +// The algorithm works by visiting the delayed ExpandFst using a shortest-stack +// first queue discipline and relies on the shortest-distance information +// computed using a reverse shortest-path call to perform the pruning. +// +// The algorithm maintains the same state ordering between the ExpandFst being +// visited (efst_) and the result of pruning written into the MutableFst (ofst_) +// to improve readability. +template +class PdtPrunedExpand { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StackId = StateId; + using Stack = PdtStack; + using StateTable = PdtStateTable; + using SetIterator = typename internal::PdtBalanceData::SetIterator; + + // Constructor taking as input a PDT specified by by an input FST and a vector + // of parentheses. The keep_parentheses argument specifies whether parentheses + // are replaced by epsilons or not during the expansion. The cache options are + // passed to the underlying ExpandFst. + PdtPrunedExpand(const Fst &ifst, + const std::vector> &parens, + bool keep_parentheses = false, + const CacheOptions &opts = CacheOptions()) + : ifst_(ifst.Copy()), + keep_parentheses_(keep_parentheses), + stack_(parens), + efst_(ifst, parens, + PdtExpandFstOptions(opts, true, &stack_, &state_table_)), + queue_(state_table_, stack_, stack_length_, distance_, fdistance_), + error_(false) { + Reverse(*ifst_, parens, &rfst_); + VectorFst path; + reverse_shortest_path_.reset(new PdtShortestPath>( + rfst_, parens, + PdtShortestPathOptions>(true, false))); + reverse_shortest_path_->ShortestPath(&path); + error_ = (path.Properties(kError, true) == kError); + balance_data_.reset(reverse_shortest_path_->GetBalanceData()->Reverse( + rfst_.NumStates(), 10, -1)); + InitCloseParenMultimap(parens); + } + + bool Error() const { return error_; } + + // Expands and prunes the input PDT according to the provided weight + // threshold, wirting the result into an output mutable FST. + void Expand(MutableFst *ofst, const Weight &threshold); + + private: + static constexpr uint8 kEnqueued = 0x01; + static constexpr uint8 kExpanded = 0x02; + static constexpr uint8 kSourceState = 0x04; + + // Comparison functor used by the queue: + // + // 1. States corresponding to shortest stack first, and + // 2. for stacks of matching length, reverse lexicographic order is used, and + // 3. for states with the same stack, shortest-first order is used. + class StackCompare { + public: + StackCompare(const StateTable &state_table, const Stack &stack, + const std::vector &stack_length, + const std::vector &distance, + const std::vector &fdistance) + : state_table_(state_table), + stack_(stack), + stack_length_(stack_length), + distance_(distance), + fdistance_(fdistance) {} + + bool operator()(StateId s1, StateId s2) const { + auto si1 = state_table_.Tuple(s1).stack_id; + auto si2 = state_table_.Tuple(s2).stack_id; + if (stack_length_[si1] < stack_length_[si2]) return true; + if (stack_length_[si1] > stack_length_[si2]) return false; + // If stack IDs are equal, use A*. + if (si1 == si2) { + return less_(Distance(s1), Distance(s2)); + } + // If lengths are equal, uses reverse lexicographic order. + for (; si1 != si2; si1 = stack_.Pop(si1), si2 = stack_.Pop(si2)) { + if (stack_.Top(si1) < stack_.Top(si2)) return true; + if (stack_.Top(si1) > stack_.Top(si2)) return false; + } + return false; + } + + private: + Weight Distance(StateId s) const { + return (s < distance_.size()) && (s < fdistance_.size()) + ? Times(distance_[s], fdistance_[s]) + : Weight::Zero(); + } + + const StateTable &state_table_; + const Stack &stack_; + const std::vector &stack_length_; + const std::vector &distance_; + const std::vector &fdistance_; + const NaturalLess less_; + }; + + class ShortestStackFirstQueue + : public ShortestFirstQueue { + public: + ShortestStackFirstQueue(const PdtStateTable &state_table, + const Stack &stack, + const std::vector &stack_length, + const std::vector &distance, + const std::vector &fdistance) + : ShortestFirstQueue(StackCompare( + state_table, stack, stack_length, distance, fdistance)) {} + }; + + void InitCloseParenMultimap( + const std::vector> &parens); + + Weight DistanceToDest(StateId source, StateId dest) const; + + uint8 Flags(StateId s) const; + + void SetFlags(StateId s, uint8 flags, uint8 mask); + + Weight Distance(StateId s) const; + + void SetDistance(StateId s, Weight weight); + + Weight FinalDistance(StateId s) const; + + void SetFinalDistance(StateId s, Weight weight); + + StateId SourceState(StateId s) const; + + void SetSourceState(StateId s, StateId p); + + void AddStateAndEnqueue(StateId s); + + void Relax(StateId s, const Arc &arc, Weight weight); + + bool PruneArc(StateId s, const Arc &arc); + + void ProcStart(); + + void ProcFinal(StateId s); + + bool ProcNonParen(StateId s, const Arc &arc, bool add_arc); + + bool ProcOpenParen(StateId s, const Arc &arc, StackId si, StackId nsi); + + bool ProcCloseParen(StateId s, const Arc &arc); + + void ProcDestStates(StateId s, StackId si); + + // Input PDT. + std::unique_ptr> ifst_; + // Reversed PDT. + VectorFst rfst_; + // Keep parentheses in ofst? + const bool keep_parentheses_; + // State table for efst_. + StateTable state_table_; + // Stack trie. + Stack stack_; + // Expanded PDT. + PdtExpandFst efst_; + // Length of stack for given stack ID. + std::vector stack_length_; + // Distance from initial state in efst_/ofst. + std::vector distance_; + // Distance to final states in efst_/ofst. + std::vector fdistance_; + // Queue used to visit efst_. + ShortestStackFirstQueue queue_; + // Construction time failure? + bool error_; + // Status flags for states in efst_/ofst. + std::vector flags_; + // PDT source state for each expanded state. + std::vector sources_; + // Shortest path for rfst_. + std::unique_ptr>> + reverse_shortest_path_; + std::unique_ptr> balance_data_; + // Maps open paren arcs to balancing close paren arcs. + typename PdtShortestPath>::CloseParenMultimap + close_paren_multimap_; + MutableFst *ofst_; // Output FST. + Weight limit_; // Weight limit. + + // Maps a state s in ifst (i.e., the source of a closed paranthesis matching + // the top of current_stack_id_ to final states in efst_. + std::unordered_map dest_map_; + // Stack ID of the states currently at the top of the queue, i.e., the states + // currently being popped and processed. + StackId current_stack_id_; + ssize_t current_paren_id_; // Paren ID at top of current stack. + ssize_t cached_stack_id_; + StateId cached_source_; + // The set of pairs of destination states and weights to final states for the + // source state cached_source_ and the paren ID cached_paren_id_; i.e., the + // set of source states of a closed parenthesis with paren ID cached_paren_id + // balancing an incoming open parenthesis with paren ID cached_paren_id_ in + // state cached_source_. + std::forward_list> cached_dest_list_; + NaturalLess less_; +}; + +// Initializes close paren multimap, mapping pairs (s, paren_id) to all the arcs +// out of s labeled with close parenthese for paren_id. +template +void PdtPrunedExpand::InitCloseParenMultimap( + const std::vector> &parens) { + std::unordered_map paren_map; + for (size_t i = 0; i < parens.size(); ++i) { + const auto &pair = parens[i]; + paren_map[pair.first] = i; + paren_map[pair.second] = i; + } + for (StateIterator> siter(*ifst_); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + for (ArcIterator> aiter(*ifst_, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + const auto it = paren_map.find(arc.ilabel); + if (it == paren_map.end()) continue; + if (arc.ilabel == parens[it->second].second) { // Close paren. + const internal::ParenState key(it->second, s); + close_paren_multimap_.emplace(key, arc); + } + } + } +} + +// Returns the weight of the shortest balanced path from source to dest +// in ifst_; dest must be the source state of a close paren arc. +template +typename Arc::Weight PdtPrunedExpand::DistanceToDest(StateId source, + StateId dest) const { + using SearchState = + typename PdtShortestPath>::SearchState; + const SearchState ss(source + 1, dest + 1); + const auto distance = + reverse_shortest_path_->GetShortestPathData().Distance(ss); + VLOG(2) << "D(" << source << ", " << dest << ") =" << distance; + return distance; +} + +// Returns the flags for state s in ofst_. +template +uint8 PdtPrunedExpand::Flags(StateId s) const { + return s < flags_.size() ? flags_[s] : 0; +} + +// Modifies the flags for state s in ofst_. +template +void PdtPrunedExpand::SetFlags(StateId s, uint8 flags, uint8 mask) { + while (flags_.size() <= s) flags_.push_back(0); + flags_[s] &= ~mask; + flags_[s] |= flags & mask; +} + +// Returns the shortest distance from the initial state to s in ofst_. +template +typename Arc::Weight PdtPrunedExpand::Distance(StateId s) const { + return s < distance_.size() ? distance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from the initial state to s in ofst_. +template +void PdtPrunedExpand::SetDistance(StateId s, Weight weight) { + while (distance_.size() <= s) distance_.push_back(Weight::Zero()); + distance_[s] = std::move(weight); +} + +// Returns the shortest distance from s to the final states in ofst_. +template +typename Arc::Weight PdtPrunedExpand::FinalDistance(StateId s) const { + return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); +} + +// Sets the shortest distance from s to the final states in ofst_. +template +void PdtPrunedExpand::SetFinalDistance(StateId s, Weight weight) { + while (fdistance_.size() <= s) fdistance_.push_back(Weight::Zero()); + fdistance_[s] = std::move(weight); +} + +// Returns the PDT source state of state s in ofst_. +template +typename Arc::StateId PdtPrunedExpand::SourceState(StateId s) const { + return s < sources_.size() ? sources_[s] : kNoStateId; +} + +// Sets the PDT source state of state s in ofst_ to state p'in ifst_. +template +void PdtPrunedExpand::SetSourceState(StateId s, StateId p) { + while (sources_.size() <= s) sources_.push_back(kNoStateId); + sources_[s] = p; +} + +// Adds state s of efst_ to ofst_ and inserts it in the queue, modifying the +// flags for s accordingly. +template +void PdtPrunedExpand::AddStateAndEnqueue(StateId s) { + if (!(Flags(s) & (kEnqueued | kExpanded))) { + while (ofst_->NumStates() <= s) ofst_->AddState(); + queue_.Enqueue(s); + SetFlags(s, kEnqueued, kEnqueued); + } else if (Flags(s) & kEnqueued) { + queue_.Update(s); + } + // TODO(allauzen): Check everything is fine when kExpanded? +} + +// Relaxes arc out of state s in ofst_ as follows: +// +// 1. If the distance to s times the weight of arc is smaller than +// the currently stored distance for arc.nextstate, updates +// Distance(arc.nextstate) with a new estimate +// 2. If fd is less than the currently stored distance from arc.nextstate to the +// final state, updates with new estimate. +template +void PdtPrunedExpand::Relax(StateId s, const Arc &arc, Weight fd) { + const auto nd = Times(Distance(s), arc.weight); + if (less_(nd, Distance(arc.nextstate))) { + SetDistance(arc.nextstate, nd); + SetSourceState(arc.nextstate, SourceState(s)); + } + if (less_(fd, FinalDistance(arc.nextstate))) { + SetFinalDistance(arc.nextstate, fd); + } + VLOG(2) << "Relax: " << s << ", d[s] = " << Distance(s) << ", to " + << arc.nextstate << ", d[ns] = " << Distance(arc.nextstate) + << ", nd = " << nd; +} + +// Returns whether the arc out of state s in efst needs pruned. +template +bool PdtPrunedExpand::PruneArc(StateId s, const Arc &arc) { + VLOG(2) << "Prune ?"; + auto fd = Weight::Zero(); + if ((cached_source_ != SourceState(s)) || + (cached_stack_id_ != current_stack_id_)) { + cached_source_ = SourceState(s); + cached_stack_id_ = current_stack_id_; + cached_dest_list_.clear(); + if (cached_source_ != ifst_->Start()) { + for (auto set_iter = + balance_data_->Find(current_paren_id_, cached_source_); + !set_iter.Done(); set_iter.Next()) { + auto dest = set_iter.Element(); + const auto it = dest_map_.find(dest); + cached_dest_list_.push_front(*it); + } + } else { + // TODO(allauzen): queue discipline should prevent this from ever + // happening. + // Replace by a check. + cached_dest_list_.push_front( + std::make_pair(rfst_.Start() - 1, Weight::One())); + } + } + for (auto it = cached_dest_list_.begin(); it != cached_dest_list_.end(); + ++it) { + const auto d = + DistanceToDest(state_table_.Tuple(arc.nextstate).state_id, it->first); + fd = Plus(fd, Times(d, it->second)); + } + Relax(s, arc, fd); + return less_(limit_, Times(Distance(s), Times(arc.weight, fd))); +} + +// Adds start state of efst_ to ofst_, enqueues it, and initializes the distance +// data structures. +template +void PdtPrunedExpand::ProcStart() { + const auto s = efst_.Start(); + AddStateAndEnqueue(s); + ofst_->SetStart(s); + SetSourceState(s, ifst_->Start()); + current_stack_id_ = 0; + current_paren_id_ = -1; + stack_length_.push_back(0); + const auto r = rfst_.Start() - 1; + cached_source_ = ifst_->Start(); + cached_stack_id_ = 0; + cached_dest_list_.push_front(std::make_pair(r, Weight::One())); + const PdtStateTuple tuple(r, 0); + SetFinalDistance(state_table_.FindState(tuple), Weight::One()); + SetDistance(s, Weight::One()); + const auto d = DistanceToDest(ifst_->Start(), r); + SetFinalDistance(s, d); + VLOG(2) << d; +} + +// Makes s final in ofst_ if shortest accepting path ending in s is below +// threshold. +template +void PdtPrunedExpand::ProcFinal(StateId s) { + const auto weight = efst_.Final(s); + if (weight == Weight::Zero()) return; + if (less_(limit_, Times(Distance(s), weight))) return; + ofst_->SetFinal(s, weight); +} + +// Returns true when an arc (or meta-arc) leaving state s in efst_ is below the +// threshold. When add_arc is true, arc is added to ofst_. +template +bool PdtPrunedExpand::ProcNonParen(StateId s, const Arc &arc, + bool add_arc) { + VLOG(2) << "ProcNonParen: " << s << " to " << arc.nextstate << ", " + << arc.ilabel << ":" << arc.olabel << " / " << arc.weight + << ", add_arc = " << (add_arc ? "true" : "false"); + if (PruneArc(s, arc)) return false; + if (add_arc) ofst_->AddArc(s, arc); + AddStateAndEnqueue(arc.nextstate); + return true; +} + +// Processes an open paren arc leaving state s in ofst_. When the arc is labeled +// with an open paren, +// +// 1. Considers each (shortest) balanced path starting in s by taking the arc +// and ending by a close paren balancing the open paren of as a meta-arc, +// processing and pruning each meta-arc as a non-paren arc, inserting its +// destination to the queue; +// 2. if at least one of these meta-arcs has not been pruned, adds the +// destination of arc to ofst_ as a new source state for the stack ID nsi, and +// inserts it in the queue. +template +bool PdtPrunedExpand::ProcOpenParen(StateId s, const Arc &arc, StackId si, + StackId nsi) { + // Updates the stack length when needed. + while (stack_length_.size() <= nsi) stack_length_.push_back(-1); + if (stack_length_[nsi] == -1) stack_length_[nsi] = stack_length_[si] + 1; + const auto ns = arc.nextstate; + VLOG(2) << "Open paren: " << s << "(" << state_table_.Tuple(s).state_id + << ") to " << ns << "(" << state_table_.Tuple(ns).state_id << ")"; + bool proc_arc = false; + auto fd = Weight::Zero(); + const auto paren_id = stack_.ParenId(arc.ilabel); + std::forward_list sources; + for (auto set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(ns).state_id); + !set_iter.Done(); set_iter.Next()) { + sources.push_front(set_iter.Element()); + } + for (const auto source : sources) { + VLOG(2) << "Close paren source: " << source; + const internal::ParenState paren_state(paren_id, source); + for (auto it = close_paren_multimap_.find(paren_state); + it != close_paren_multimap_.end() && paren_state == it->first; ++it) { + auto meta_arc = it->second; + const PdtStateTuple tuple(meta_arc.nextstate, si); + meta_arc.nextstate = state_table_.FindState(tuple); + const auto state_id = state_table_.Tuple(ns).state_id; + const auto d = DistanceToDest(state_id, source); + VLOG(2) << state_id << ", " << source; + VLOG(2) << "Meta arc weight = " << arc.weight << " Times " << d + << " Times " << meta_arc.weight; + meta_arc.weight = Times(arc.weight, Times(d, meta_arc.weight)); + proc_arc |= ProcNonParen(s, meta_arc, false); + fd = Plus( + fd, + Times(Times(DistanceToDest(state_table_.Tuple(ns).state_id, source), + it->second.weight), + FinalDistance(meta_arc.nextstate))); + } + } + if (proc_arc) { + VLOG(2) << "Proc open paren " << s << " to " << arc.nextstate; + ofst_->AddArc( + s, keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + AddStateAndEnqueue(arc.nextstate); + const auto nd = Times(Distance(s), arc.weight); + if (less_(nd, Distance(arc.nextstate))) SetDistance(arc.nextstate, nd); + // FinalDistance not necessary for source state since pruning decided using + // meta-arcs above. But this is a problem with A*, hence the following. + if (less_(fd, FinalDistance(arc.nextstate))) + SetFinalDistance(arc.nextstate, fd); + SetFlags(arc.nextstate, kSourceState, kSourceState); + } + return proc_arc; +} + +// Checks that shortest path through close paren arc in efst_ is below +// threshold, and if so, adds it to ofst_. +template +bool PdtPrunedExpand::ProcCloseParen(StateId s, const Arc &arc) { + const auto weight = + Times(Distance(s), Times(arc.weight, FinalDistance(arc.nextstate))); + if (less_(limit_, weight)) return false; + ofst_->AddArc(s, + keep_parentheses_ ? arc : Arc(0, 0, arc.weight, arc.nextstate)); + return true; +} + +// When state s in ofst_ is a source state for stack ID si, identifies all the +// corresponding possible destination states, that is, all the states in ifst_ +// that have an outgoing close paren arc balancing the incoming open paren taken +// to get to s. For each such state t, computes the shortest distance from (t, +// si) to the final states in ofst_. Stores this information in dest_map_. +template +void PdtPrunedExpand::ProcDestStates(StateId s, StackId si) { + if (!(Flags(s) & kSourceState)) return; + if (si != current_stack_id_) { + dest_map_.clear(); + current_stack_id_ = si; + current_paren_id_ = stack_.Top(current_stack_id_); + VLOG(2) << "StackID " << si << " dequeued for first time"; + } + // TODO(allauzen): clean up source state business; rename current function to + // ProcSourceState. + SetSourceState(s, state_table_.Tuple(s).state_id); + const auto paren_id = stack_.Top(si); + for (auto set_iter = + balance_data_->Find(paren_id, state_table_.Tuple(s).state_id); + !set_iter.Done(); set_iter.Next()) { + const auto dest_state = set_iter.Element(); + if (dest_map_.find(dest_state) != dest_map_.end()) continue; + auto dest_weight = Weight::Zero(); + internal::ParenState paren_state(paren_id, dest_state); + for (auto it = close_paren_multimap_.find(paren_state); + it != close_paren_multimap_.end() && paren_state == it->first; ++it) { + const auto &arc = it->second; + const PdtStateTuple tuple(arc.nextstate, + stack_.Pop(si)); + dest_weight = + Plus(dest_weight, + Times(arc.weight, FinalDistance(state_table_.FindState(tuple)))); + } + dest_map_[dest_state] = dest_weight; + VLOG(2) << "State " << dest_state << " is a dest state for stack ID " << si + << " with weight " << dest_weight; + } +} + +// Expands and prunes the input PDT, writing the result in ofst. +template +void PdtPrunedExpand::Expand(MutableFst *ofst, + const typename Arc::Weight &threshold) { + ofst_ = ofst; + if (error_) { + ofst_->SetProperties(kError, kError); + return; + } + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst_->InputSymbols()); + ofst_->SetOutputSymbols(ifst_->OutputSymbols()); + limit_ = Times(DistanceToDest(ifst_->Start(), rfst_.Start() - 1), threshold); + flags_.clear(); + ProcStart(); + while (!queue_.Empty()) { + const auto s = queue_.Head(); + queue_.Dequeue(); + SetFlags(s, kExpanded, kExpanded | kEnqueued); + VLOG(2) << s << " dequeued!"; + ProcFinal(s); + StackId stack_id = state_table_.Tuple(s).stack_id; + ProcDestStates(s, stack_id); + for (ArcIterator> aiter(efst_, s); !aiter.Done(); + aiter.Next()) { + const auto &arc = aiter.Value(); + const auto nextstack_id = state_table_.Tuple(arc.nextstate).stack_id; + if (stack_id == nextstack_id) { + ProcNonParen(s, arc, true); + } else if (stack_id == stack_.Pop(nextstack_id)) { + ProcOpenParen(s, arc, stack_id, nextstack_id); + } else { + ProcCloseParen(s, arc); + } + } + VLOG(2) << "d[" << s << "] = " << Distance(s) << ", fd[" << s + << "] = " << FinalDistance(s); + } +} + +// Expand functions. + +template +struct PdtExpandOptions { + using Weight = typename Arc::Weight; + + bool connect; + bool keep_parentheses; + Weight weight_threshold; + + PdtExpandOptions(bool connect = true, bool keep_parentheses = false, + Weight weight_threshold = Weight::Zero()) + : connect(connect), + keep_parentheses(keep_parentheses), + weight_threshold(std::move(weight_threshold)) {} +}; + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. This +// version writes the expanded PDT to a mutable FST. In the PDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// a PDT, the parens must balance on a path. The open-close parenthesis label +// pairs are passed using the parens argument. Expansion enforces the +// parenthesis constraints. The PDT must be expandable as an FST. +template +void Expand( + const Fst &ifst, + const std::vector> + &parens, + MutableFst *ofst, const PdtExpandOptions &opts) { + PdtExpandFstOptions eopts; + eopts.gc_limit = 0; + if (opts.weight_threshold == Arc::Weight::Zero()) { + eopts.keep_parentheses = opts.keep_parentheses; + *ofst = PdtExpandFst(ifst, parens, eopts); + } else { + PdtPrunedExpand pruned_expand(ifst, parens, opts.keep_parentheses); + pruned_expand.Expand(ofst, opts.weight_threshold); + } + if (opts.connect) Connect(ofst); +} + +// Expands a pushdown transducer (PDT) encoded as an FST into an FST. This +// version writes the expanded PDT result to a mutable FST. In the PDT, some +// transitions are labeled with open or close parentheses. To be interpreted as +// a PDT, the parens must balance on a path. The open-close parenthesis label +// pairs are passed using the parents argument. Expansion enforces the +// parenthesis constraints. The PDT must be expandable as an FST. +template +void Expand(const Fst &ifst, + const std::vector> + &parens, MutableFst *ofst, bool connect = true, + bool keep_parentheses = false) { + const PdtExpandOptions opts(connect, keep_parentheses); + Expand(ifst, parens, ofst, opts); +} + +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_EXPAND_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/getters.h b/projects/llm_framework/include/fst/extensions/pdt/getters.h new file mode 100644 index 00000000..69dd150d --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/getters.h @@ -0,0 +1,22 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_EXTENSIONS_PDT_GETTERS_H_ +#define FST_EXTENSIONS_PDT_GETTERS_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +bool GetPdtComposeFilter(const string &str, PdtComposeFilter *cf); + +bool GetPdtParserType(const string &str, PdtParserType *pt); + +} // namespace script +} // namespace fst + +#endif // FST_EXTENSIONS_PDT_GETTERS_H_ diff --git a/projects/llm_framework/include/fst/extensions/pdt/info.h b/projects/llm_framework/include/fst/extensions/pdt/info.h new file mode 100644 index 00000000..3de54772 --- /dev/null +++ b/projects/llm_framework/include/fst/extensions/pdt/info.h @@ -0,0 +1,152 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Prints information about a PDT. + +#ifndef FST_EXTENSIONS_PDT_INFO_H_ +#define FST_EXTENSIONS_PDT_INFO_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +// Compute various information about PDTs. +template +class PdtInfo { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + PdtInfo(const Fst &fst, + const std::vector> &parents); + + const string &FstType() const { return fst_type_; } + + const string &ArcType() const { return Arc::Type(); } + + int64 NumStates() const { return nstates_; } + + int64 NumArcs() const { return narcs_; } + + int64 NumOpenParens() const { return nopen_parens_; } + + int64 NumCloseParens() const { return nclose_parens_; } + + int64 NumUniqueOpenParens() const { return nuniq_open_parens_; } + + int64 NumUniqueCloseParens() const { return nuniq_close_parens_; } + + int64 NumOpenParenStates() const { return nopen_paren_states_; } + + int64 NumCloseParenStates() const { return nclose_paren_states_; } + + private: + string fst_type_; + int64 nstates_; + int64 narcs_; + int64 nopen_parens_; + int64 nclose_parens_; + int64 nuniq_open_parens_; + int64 nuniq_close_parens_; + int64 nopen_paren_states_; + int64 nclose_paren_states_; +}; + +template +PdtInfo::PdtInfo( + const Fst &fst, + const std::vector> + &parens) + : fst_type_(fst.Type()), + nstates_(0), + narcs_(0), + nopen_parens_(0), + nclose_parens_(0), + nuniq_open_parens_(0), + nuniq_close_parens_(0), + nopen_paren_states_(0), + nclose_paren_states_(0) { + std::unordered_map paren_map; + std::unordered_set> +class CompactFst; + +template +class ConstFst; + +template +class EditFst; + +template +class ExpandedFst; + +template +class Fst; + +template +class MutableFst; + +template > +class VectorState; + +template > +class VectorFst; + +template +class DefaultReplaceStateTable; + +// On-the-fly operations. + +template +class ArcSortFst; + +template +class ClosureFst; + +template > +class ComposeFst; + +template +class ConcatFst; + +template +class DeterminizeFst; + +template +class DifferenceFst; + +template +class IntersectFst; + +template +class InvertFst; + +template +class ArcMapFst; + +template +class ProjectFst; + +template +class RandGenFst; + +template +class RelabelFst; + +template , + class Store = DefaultCacheStore> +class ReplaceFst; + +template +class RmEpsilonFst; + +template +class UnionFst; + +// Heap. + +template +class Heap; + +// Compactors. + +template +class AcceptorCompactor; + +template +class StringCompactor; + +template +class UnweightedAcceptorCompactor; + +template +class UnweightedCompactor; + +template +class WeightedStringCompactor; + +// Compact FSTs. + +template +using CompactStringFst = CompactFst, U>; + +template +using CompactWeightedStringFst = + CompactFst, U>; + +template +using CompactAcceptorFst = CompactFst, U>; + +template +using CompactUnweightedFst = CompactFst, U>; + +template +using CompactUnweightedAcceptorFst = + CompactFst, U>; + +// StdArc aliases for FSTs. + +using StdConstFst = ConstFst; +using StdExpandedFst = ExpandedFst; +using StdFst = Fst; +using StdMutableFst = MutableFst; +using StdVectorFst = VectorFst; + +// StdArc aliases for on-the-fly operations. + +template +using StdArcSortFst = ArcSortFst; + +using StdClosureFst = ClosureFst; + +using StdComposeFst = ComposeFst; + +using StdConcatFst = ConcatFst; + +using StdDeterminizeFst = DeterminizeFst; + +using StdDifferenceFst = DifferenceFst; + +using StdIntersectFst = IntersectFst; + +using StdInvertFst = InvertFst; + +using StdProjectFst = ProjectFst; + +using StdRelabelFst = RelabelFst; + +using StdReplaceFst = ReplaceFst; + +using StdRmEpsilonFst = RmEpsilonFst; + +using StdUnionFst = UnionFst; + +// Filter states. + +template +class IntegerFilterState; + +using CharFilterState = IntegerFilterState; + +using ShortFilterState = IntegerFilterState; // NOLINT + +using IntFilterState = IntegerFilterState; + +// Matchers and filters. + +template +class Matcher; + +template +class NullComposeFilter; + +template +class TrivialComposeFilter; + +template +class SequenceComposeFilter; + +template +class AltSequenceComposeFilter; + +template +class MatchComposeFilter; + +template +class NoMatchComposeFilter; + +} // namespace fst + +#endif // FST_FST_DECL_H_ diff --git a/projects/llm_framework/include/fst/fst.h b/projects/llm_framework/include/fst/fst.h new file mode 100644 index 00000000..20e6bb3c --- /dev/null +++ b/projects/llm_framework/include/fst/fst.h @@ -0,0 +1,1007 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST abstract base class definition, state and arc iterator interface, and +// suggested base implementation. + +#ifndef FST_FST_H_ +#define FST_FST_H_ + +#include + +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + + +DECLARE_bool(fst_align); + +namespace fst { + +bool IsFstHeader(std::istream &, const string &); + +class FstHeader; + +template +struct StateIteratorData; + +template +struct ArcIteratorData; + +template +class MatcherBase; + +struct FstReadOptions { + // FileReadMode(s) are advisory, there are many conditions than prevent a + // file from being mapped, READ mode will be selected in these cases with + // a warning indicating why it was chosen. + enum FileReadMode { READ, MAP }; + + string source; // Where you're reading from. + const FstHeader *header; // Pointer to FST header; if non-zero, use + // this info (don't read a stream header). + const SymbolTable *isymbols; // Pointer to input symbols; if non-zero, use + // this info (read and skip stream isymbols) + const SymbolTable *osymbols; // Pointer to output symbols; if non-zero, use + // this info (read and skip stream osymbols) + FileReadMode mode; // Read or map files (advisory, if possible) + bool read_isymbols; // Read isymbols, if any (default: true). + bool read_osymbols; // Read osymbols, if any (default: true). + + explicit FstReadOptions(const string &source = "", + const FstHeader *header = nullptr, + const SymbolTable *isymbols = nullptr, + const SymbolTable *osymbols = nullptr); + + explicit FstReadOptions(const string &source, const SymbolTable *isymbols, + const SymbolTable *osymbols = nullptr); + + // Helper function to convert strings FileReadModes into their enum value. + static FileReadMode ReadMode(const string &mode); + + // Outputs a debug string for the FstReadOptions object. + string DebugString() const; +}; + +struct FstWriteOptions { + string source; // Where you're writing to. + bool write_header; // Write the header? + bool write_isymbols; // Write input symbols? + bool write_osymbols; // Write output symbols? + bool align; // Write data aligned (may fail on pipes)? + bool stream_write; // Avoid seek operations in writing. + + explicit FstWriteOptions(const string &source = "", + bool write_header = true, bool write_isymbols = true, + bool write_osymbols = true, + bool align = FLAGS_fst_align, + bool stream_write = false) + : source(source), + write_header(write_header), + write_isymbols(write_isymbols), + write_osymbols(write_osymbols), + align(align), + stream_write(stream_write) {} +}; + +// Header class. +// +// This is the recommended file header representation. + +class FstHeader { + public: + enum { + HAS_ISYMBOLS = 0x1, // Has input symbol table. + HAS_OSYMBOLS = 0x2, // Has output symbol table. + IS_ALIGNED = 0x4, // Memory-aligned (where appropriate). + } Flags; + + FstHeader() : version_(0), flags_(0), properties_(0), start_(-1), + numstates_(0), numarcs_(0) {} + + const string &FstType() const { return fsttype_; } + + const string &ArcType() const { return arctype_; } + + int32 Version() const { return version_; } + + int32 GetFlags() const { return flags_; } + + uint64 Properties() const { return properties_; } + + int64 Start() const { return start_; } + + int64 NumStates() const { return numstates_; } + + int64 NumArcs() const { return numarcs_; } + + void SetFstType(const string &type) { fsttype_ = type; } + + void SetArcType(const string &type) { arctype_ = type; } + + void SetVersion(int32 version) { version_ = version; } + + void SetFlags(int32 flags) { flags_ = flags; } + + void SetProperties(uint64 properties) { properties_ = properties; } + + void SetStart(int64 start) { start_ = start; } + + void SetNumStates(int64 numstates) { numstates_ = numstates; } + + void SetNumArcs(int64 numarcs) { numarcs_ = numarcs; } + + bool Read(std::istream &strm, const string &source, + bool rewind = false); + + bool Write(std::ostream &strm, const string &source) const; + + // Outputs a debug string for the FstHeader object. + string DebugString() const; + + private: + string fsttype_; // E.g. "vector". + string arctype_; // E.g. "standard". + int32 version_; // Type version number. + int32 flags_; // File format bits. + uint64 properties_; // FST property bits. + int64 start_; // Start state. + int64 numstates_; // # of states. + int64 numarcs_; // # of arcs. +}; + +// Specifies matcher action. +enum MatchType { + MATCH_INPUT = 1, // Match input label. + MATCH_OUTPUT = 2, // Match output label. + MATCH_BOTH = 3, // Match input or output label. + MATCH_NONE = 4, // Match nothing. + MATCH_UNKNOWN = 5 +}; // Otherwise, match type unknown. + +constexpr int kNoLabel = -1; // Not a valid label. +constexpr int kNoStateId = -1; // Not a valid state ID. + +// A generic FST, templated on the arc definition, with common-demoninator +// methods (use StateIterator and ArcIterator to iterate over its states and +// arcs). +template +class Fst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + virtual ~Fst() {} + + // Initial state. + virtual StateId Start() const = 0; + + // State's final weight. + virtual Weight Final(StateId) const = 0; + + // State's arc count. + virtual size_t NumArcs(StateId) const = 0; + + // State's input epsilon count. + virtual size_t NumInputEpsilons(StateId) const = 0; + + // State's output epsilon count. + virtual size_t NumOutputEpsilons(StateId) const = 0; + + // Property bits. If test = false, return stored properties bits for mask + // (some possibly unknown); if test = true, return property bits for mask + // (computing o.w. unknown). + virtual uint64 Properties(uint64 mask, bool test) const = 0; + + // FST type name. + virtual const string &Type() const = 0; + + // Gets a copy of this Fst. The copying behaves as follows: + // + // (1) The copying is constant time if safe = false or if safe = true + // and is on an otherwise unaccessed FST. + // + // (2) If safe = true, the copy is thread-safe in that the original + // and copy can be safely accessed (but not necessarily mutated) by + // separate threads. For some FST types, 'Copy(true)' should only be + // called on an FST that has not otherwise been accessed. Behavior is + // otherwise undefined. + // + // (3) If a MutableFst is copied and then mutated, then the original is + // unmodified and vice versa (often by a copy-on-write on the initial + // mutation, which may not be constant time). + virtual Fst *Copy(bool safe = false) const = 0; + + // Reads an FST from an input stream; returns nullptr on error. + static Fst *Read(std::istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) { + hdr = *opts.header; + } else { + if (!hdr.Read(strm, opts.source)) return nullptr; + ropts.header = &hdr; + } + const auto &fst_type = hdr.FstType(); + const auto reader = FstRegister::GetRegister()->GetReader(fst_type); + if (!reader) { + LOG(ERROR) << "Fst::Read: Unknown FST type " << fst_type + << " (arc type = " << Arc::Type() << "): " << ropts.source; + return nullptr; + } + return reader(strm, ropts); + } + + // Reads an FST from a file; returns nullptr on error. An empty filename + // results in reading from standard input. + static Fst *Read(const string &filename) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "Fst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } + + // Writes an FST to an output stream; returns false on error. + virtual bool Write(std::ostream &strm, const FstWriteOptions &opts) const { + LOG(ERROR) << "Fst::Write: No write stream method for " << Type() + << " FST type"; + return false; + } + + // Writes an FST to a file; returns false on error; an empty filename + // results in writing to standard output. + virtual bool Write(const string &filename) const { + LOG(ERROR) << "Fst::Write: No write filename method for " << Type() + << " FST type"; + return false; + } + + // Returns input label symbol table; return nullptr if not specified. + virtual const SymbolTable *InputSymbols() const = 0; + + // Return output label symbol table; return nullptr if not specified. + virtual const SymbolTable *OutputSymbols() const = 0; + + // For generic state iterator construction (not normally called directly by + // users). Does not copy the FST. + virtual void InitStateIterator(StateIteratorData *data) const = 0; + + // For generic arc iterator construction (not normally called directly by + // users). Does not copy the FST. + virtual void InitArcIterator(StateId s, ArcIteratorData *data) const = 0; + + // For generic matcher construction (not normally called directly by users). + // Does not copy the FST. + virtual MatcherBase *InitMatcher(MatchType match_type) const; + + protected: + bool WriteFile(const string &filename) const { + if (!filename.empty()) { + std::ofstream strm(filename, + std::ios_base::out | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "Fst::Write: Can't open file: " << filename; + return false; + } + bool val = Write(strm, FstWriteOptions(filename)); + if (!val) LOG(ERROR) << "Fst::Write failed: " << filename; + return val; + } else { + return Write(std::cout, FstWriteOptions("standard output")); + } + } +}; + +// A useful alias when using StdArc. +using StdFst = Fst; + +// State and arc iterator definitions. +// +// State iterator interface templated on the Arc definition; used for +// StateIterator specializations returned by the InitStateIterator FST method. +template +class StateIteratorBase { + public: + using StateId = typename Arc::StateId; + + virtual ~StateIteratorBase() {} + + // End of iterator? + virtual bool Done() const = 0; + // Returns current state (when !Done()). + virtual StateId Value() const = 0; + // Advances to next state (when !Done()). + virtual void Next() = 0; + // Resets to initial condition. + virtual void Reset() = 0; +}; + +// StateIterator initialization data. + +template +struct StateIteratorData { + using StateId = typename Arc::StateId; + + // Specialized iterator if non-zero. + StateIteratorBase *base; + // Otherwise, the total number of states. + StateId nstates; + + StateIteratorData() : base(nullptr), nstates(0) {} + + StateIteratorData(const StateIteratorData &) = delete; + StateIteratorData &operator=(const StateIteratorData &) = delete; +}; + +// Generic state iterator, templated on the FST definition (a wrapper +// around a pointer to a specific one). Here is a typical use: +// +// for (StateIterator siter(fst); +// !siter.Done(); +// siter.Next()) { +// StateId s = siter.Value(); +// ... +// } +// There is no copying of the FST. +template +class StateIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + explicit StateIterator(const FST &fst) : s_(0) { + fst.InitStateIterator(&data_); + } + + ~StateIterator() { delete data_.base; } + + bool Done() const { + return data_.base ? data_.base->Done() : s_ >= data_.nstates; + } + + StateId Value() const { return data_.base ? data_.base->Value() : s_; } + + void Next() { + if (data_.base) { + data_.base->Next(); + } else { + ++s_; + } + } + + void Reset() { + if (data_.base) { + data_.base->Reset(); + } else { + s_ = 0; + } + } + + private: + StateIteratorData data_; + StateId s_; +}; + +// Flags to control the behavior on an arc iterator. +static constexpr uint32 kArcILabelValue = + 0x0001; // Value() gives valid ilabel. +static constexpr uint32 kArcOLabelValue = 0x0002; // " " " olabel. +static constexpr uint32 kArcWeightValue = 0x0004; // " " " weight. +static constexpr uint32 kArcNextStateValue = + 0x0008; // " " " nextstate. +static constexpr uint32 kArcNoCache = 0x0010; // No need to cache arcs. + +static constexpr uint32 kArcValueFlags = + kArcILabelValue | kArcOLabelValue | kArcWeightValue | kArcNextStateValue; + +static constexpr uint32 kArcFlags = kArcValueFlags | kArcNoCache; + +// Arc iterator interface, templated on the arc definition; used for arc +// iterator specializations that are returned by the InitArcIterator FST method. +template +class ArcIteratorBase { + public: + using StateId = typename Arc::StateId; + + virtual ~ArcIteratorBase() {} + + // End of iterator? + virtual bool Done() const = 0; + // Returns current arc (when !Done()). + virtual const Arc &Value() const = 0; + // Advances to next arc (when !Done()). + virtual void Next() = 0; + // Returns current position. + virtual size_t Position() const = 0; + // Returns to initial condition. + virtual void Reset() = 0; + // Advances to arbitrary arc by position. + virtual void Seek(size_t) = 0; + // Returns current behavorial flags + virtual uint32 Flags() const = 0; + // Sets behavorial flags. + virtual void SetFlags(uint32, uint32) = 0; +}; + +// ArcIterator initialization data. +template +struct ArcIteratorData { + ArcIteratorData() + : base(nullptr), arcs(nullptr), narcs(0), ref_count(nullptr) {} + + ArcIteratorData(const ArcIteratorData &) = delete; + + ArcIteratorData &operator=(const ArcIteratorData &) = delete; + + ArcIteratorBase *base; // Specialized iterator if non-zero. + const Arc *arcs; // O.w. arcs pointer + size_t narcs; // ... and arc count. + int *ref_count; // ... and reference count if non-zero. +}; + +// Generic arc iterator, templated on the FST definition (a wrapper around a +// pointer to a specific one). Here is a typical use: +// +// for (ArcIterator aiter(fst, s); +// !aiter.Done(); +// aiter.Next()) { +// StdArc &arc = aiter.Value(); +// ... +// } +// There is no copying of the FST. +template +class ArcIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + ArcIterator(const FST &fst, StateId s) : i_(0) { + fst.InitArcIterator(s, &data_); + } + + explicit ArcIterator(const ArcIteratorData &data) : data_(data), i_(0) { + if (data_.ref_count) ++(*data_.ref_count); + } + + ~ArcIterator() { + if (data_.base) { + delete data_.base; + } else if (data_.ref_count) { + --(*data_.ref_count); + } + } + + bool Done() const { + return data_.base ? data_.base->Done() : i_ >= data_.narcs; + } + + const Arc &Value() const { + return data_.base ? data_.base->Value() : data_.arcs[i_]; + } + + void Next() { + if (data_.base) { + data_.base->Next(); + } else { + ++i_; + } + } + + void Reset() { + if (data_.base) { + data_.base->Reset(); + } else { + i_ = 0; + } + } + + void Seek(size_t a) { + if (data_.base) { + data_.base->Seek(a); + } else { + i_ = a; + } + } + + size_t Position() const { return data_.base ? data_.base->Position() : i_; } + + uint32 Flags() const { + if (data_.base) { + return data_.base->Flags(); + } else { + return kArcValueFlags; + } + } + + void SetFlags(uint32 flags, uint32 mask) { + if (data_.base) data_.base->SetFlags(flags, mask); + } + + private: + ArcIteratorData data_; + size_t i_; +}; + +} // namespace fst + +// ArcIterator placement operator new and destroy function; new needs to be in +// the global namespace. + +template +void *operator new(size_t size, + fst::MemoryPool> *pool) { + return pool->Allocate(); +} + +namespace fst { + +template +void Destroy(ArcIterator *aiter, MemoryPool> *pool) { + if (aiter) { + aiter->~ArcIterator(); + pool->Free(aiter); + } +} + +// Matcher definitions. + +template +MatcherBase *Fst::InitMatcher(MatchType match_type) const { + return nullptr; // One should just use the default matcher. +} + +// FST accessors, useful in high-performance applications. + +namespace internal { + +// General case, requires non-abstract, 'final' methods. Use for inlining. + +template +inline typename F::Arc::Weight Final(const F &fst, typename F::Arc::StateId s) { + return fst.F::Final(s); +} + +template +inline ssize_t NumArcs(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumArcs(s); +} + +template +inline ssize_t NumInputEpsilons(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumInputEpsilons(s); +} + +template +inline ssize_t NumOutputEpsilons(const F &fst, typename F::Arc::StateId s) { + return fst.F::NumOutputEpsilons(s); +} + +// Fst case, abstract methods. + +template +inline typename Arc::Weight Final(const Fst &fst, + typename Arc::StateId s) { + return fst.Final(s); +} + +template +inline size_t NumArcs(const Fst &fst, typename Arc::StateId s) { + return fst.NumArcs(s); +} + +template +inline size_t NumInputEpsilons(const Fst &fst, typename Arc::StateId s) { + return fst.NumInputEpsilons(s); +} + +template +inline size_t NumOutputEpsilons(const Fst &fst, typename Arc::StateId s) { + return fst.NumOutputEpsilons(s); +} + +// FST implementation base. +// +// This is the recommended FST implementation base class. It will handle +// reference counts, property bits, type information and symbols. +// +// Users are discouraged, but not prohibited, from subclassing this outside the +// FST library. +template +class FstImpl { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + FstImpl() : properties_(0), type_("null") {} + + FstImpl(const FstImpl &impl) + : properties_(impl.properties_), + type_(impl.type_), + isymbols_(impl.isymbols_ ? impl.isymbols_->Copy() : nullptr), + osymbols_(impl.osymbols_ ? impl.osymbols_->Copy() : nullptr) {} + + FstImpl(FstImpl &&impl) noexcept; + + virtual ~FstImpl() {} + + FstImpl &operator=(const FstImpl &impl) { + properties_ = impl.properties_; + type_ = impl.type_; + isymbols_ = impl.isymbols_ ? impl.isymbols_->Copy() : nullptr; + osymbols_ = impl.osymbols_ ? impl.osymbols_->Copy() : nullptr; + return *this; + } + + FstImpl &operator=(FstImpl &&impl) noexcept; + + const string &Type() const { return type_; } + + void SetType(const string &type) { type_ = type; } + + virtual uint64 Properties() const { return properties_; } + + virtual uint64 Properties(uint64 mask) const { return properties_ & mask; } + + void SetProperties(uint64 props) { + properties_ &= kError; // kError can't be cleared. + properties_ |= props; + } + + void SetProperties(uint64 props, uint64 mask) { + properties_ &= ~mask | kError; // kError can't be cleared. + properties_ |= props & mask; + } + + // Allows (only) setting error bit on const FST implementations. + void SetProperties(uint64 props, uint64 mask) const { + if (mask != kError) { + FSTERROR() << "FstImpl::SetProperties() const: Can only set kError"; + } + properties_ |= kError; + } + + const SymbolTable *InputSymbols() const { return isymbols_.get(); } + + const SymbolTable *OutputSymbols() const { return osymbols_.get(); } + + SymbolTable *InputSymbols() { return isymbols_.get(); } + + SymbolTable *OutputSymbols() { return osymbols_.get(); } + + void SetInputSymbols(const SymbolTable *isyms) { + isymbols_.reset(isyms ? isyms->Copy() : nullptr); + } + + void SetOutputSymbols(const SymbolTable *osyms) { + osymbols_.reset(osyms ? osyms->Copy() : nullptr); + } + + // Reads header and symbols from input stream, initializes FST, and returns + // the header. If opts.header is non-null, skips reading and uses the option + // value instead. If opts.[io]symbols is non-null, reads in (if present), but + // uses the option value. + bool ReadHeader(std::istream &strm, const FstReadOptions &opts, + int min_version, FstHeader *hdr); + + // Writes header and symbols to output stream. If opts.header is false, skips + // writing header. If opts.[io]symbols is false, skips writing those symbols. + // This method is needed for implementations that implement Write methods. + void WriteHeader(std::ostream &strm, const FstWriteOptions &opts, + int version, FstHeader *hdr) const { + if (opts.write_header) { + hdr->SetFstType(type_); + hdr->SetArcType(Arc::Type()); + hdr->SetVersion(version); + hdr->SetProperties(properties_); + int32 file_flags = 0; + if (isymbols_ && opts.write_isymbols) { + file_flags |= FstHeader::HAS_ISYMBOLS; + } + if (osymbols_ && opts.write_osymbols) { + file_flags |= FstHeader::HAS_OSYMBOLS; + } + if (opts.align) file_flags |= FstHeader::IS_ALIGNED; + hdr->SetFlags(file_flags); + hdr->Write(strm, opts.source); + } + if (isymbols_ && opts.write_isymbols) isymbols_->Write(strm); + if (osymbols_ && opts.write_osymbols) osymbols_->Write(strm); + } + + // Writes out header and symbols to output stream. If opts.header is false, + // skips writing header. If opts.[io]symbols is false, skips writing those + // symbols. `type` is the FST type being written. This method is used in the + // cross-type serialization methods Fst::WriteFst. + static void WriteFstHeader(const Fst &fst, std::ostream &strm, + const FstWriteOptions &opts, int version, + const string &type, uint64 properties, + FstHeader *hdr) { + if (opts.write_header) { + hdr->SetFstType(type); + hdr->SetArcType(Arc::Type()); + hdr->SetVersion(version); + hdr->SetProperties(properties); + int32 file_flags = 0; + if (fst.InputSymbols() && opts.write_isymbols) { + file_flags |= FstHeader::HAS_ISYMBOLS; + } + if (fst.OutputSymbols() && opts.write_osymbols) { + file_flags |= FstHeader::HAS_OSYMBOLS; + } + if (opts.align) file_flags |= FstHeader::IS_ALIGNED; + hdr->SetFlags(file_flags); + hdr->Write(strm, opts.source); + } + if (fst.InputSymbols() && opts.write_isymbols) { + fst.InputSymbols()->Write(strm); + } + if (fst.OutputSymbols() && opts.write_osymbols) { + fst.OutputSymbols()->Write(strm); + } + } + + // In serialization routines where the header cannot be written until after + // the machine has been serialized, this routine can be called to seek to the + // beginning of the file an rewrite the header with updated fields. It + // repositions the file pointer back at the end of the file. Returns true on + // success, false on failure. + static bool UpdateFstHeader(const Fst &fst, std::ostream &strm, + const FstWriteOptions &opts, int version, + const string &type, uint64 properties, + FstHeader *hdr, size_t header_offset) { + strm.seekp(header_offset); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source; + return false; + } + WriteFstHeader(fst, strm, opts, version, type, properties, hdr); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source; + return false; + } + strm.seekp(0, std::ios_base::end); + if (!strm) { + LOG(ERROR) << "Fst::UpdateFstHeader: Write failed: " << opts.source; + return false; + } + return true; + } + + protected: + mutable uint64 properties_; // Property bits. + + private: + string type_; // Unique name of FST class. + std::unique_ptr isymbols_; + std::unique_ptr osymbols_; +}; + +template +inline FstImpl::FstImpl(FstImpl &&) noexcept = default; + +template +inline FstImpl &FstImpl::operator=( + FstImpl &&) noexcept = default; + +template +bool FstImpl::ReadHeader(std::istream &strm, const FstReadOptions &opts, + int min_version, FstHeader *hdr) { + if (opts.header) { + *hdr = *opts.header; + } else if (!hdr->Read(strm, opts.source)) { + return false; + } + if (FLAGS_v >= 2) { + LOG(INFO) << "FstImpl::ReadHeader: source: " << opts.source + << ", fst_type: " << hdr->FstType() + << ", arc_type: " << Arc::Type() + << ", version: " << hdr->Version() + << ", flags: " << hdr->GetFlags(); + } + if (hdr->FstType() != type_) { + LOG(ERROR) << "FstImpl::ReadHeader: FST not of type " << type_ + << ": " << opts.source; + return false; + } + if (hdr->ArcType() != Arc::Type()) { + LOG(ERROR) << "FstImpl::ReadHeader: Arc not of type " << Arc::Type() + << ": " << opts.source; + return false; + } + if (hdr->Version() < min_version) { + LOG(ERROR) << "FstImpl::ReadHeader: Obsolete " << type_ + << " FST version: " << opts.source; + return false; + } + properties_ = hdr->Properties(); + if (hdr->GetFlags() & FstHeader::HAS_ISYMBOLS) { + isymbols_.reset(SymbolTable::Read(strm, opts.source)); + } + // Deletes input symbol table. + if (!opts.read_isymbols) SetInputSymbols(nullptr); + if (hdr->GetFlags() & FstHeader::HAS_OSYMBOLS) { + osymbols_.reset(SymbolTable::Read(strm, opts.source)); + } + // Deletes output symbol table. + if (!opts.read_osymbols) SetOutputSymbols(nullptr); + if (opts.isymbols) { + isymbols_.reset(opts.isymbols->Copy()); + } + if (opts.osymbols) { + osymbols_.reset(opts.osymbols->Copy()); + } + return true; +} + +} // namespace internal + +template +uint64 TestProperties(const Fst &fst, uint64 mask, uint64 *known); + +// This is a helper class template useful for attaching an FST interface to +// its implementation, handling reference counting. +template > +class ImplToFst : public FST { + public: + using Arc = typename Impl::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + StateId Start() const override { return impl_->Start(); } + + Weight Final(StateId s) const override { return impl_->Final(s); } + + size_t NumArcs(StateId s) const override { return impl_->NumArcs(s); } + + size_t NumInputEpsilons(StateId s) const override { + return impl_->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) const override { + return impl_->NumOutputEpsilons(s); + } + + uint64 Properties(uint64 mask, bool test) const override { + if (test) { + uint64 knownprops, testprops = TestProperties(*this, mask, &knownprops); + impl_->SetProperties(testprops, knownprops); + return testprops & mask; + } else { + return impl_->Properties(mask); + } + } + + const string &Type() const override { return impl_->Type(); } + + const SymbolTable *InputSymbols() const override { + return impl_->InputSymbols(); + } + + const SymbolTable *OutputSymbols() const override { + return impl_->OutputSymbols(); + } + + protected: + explicit ImplToFst(std::shared_ptr impl) : impl_(std::move(impl)) {} + + // This constructor presumes there is a copy constructor for the + // implementation. + ImplToFst(const ImplToFst &fst, bool safe) { + if (safe) { + impl_ = std::make_shared(*(fst.impl_)); + } else { + impl_ = fst.impl_; + } + } + + ImplToFst() = delete; + + ImplToFst(const ImplToFst &fst) : impl_(fst.impl_) {} + + ImplToFst(ImplToFst &&fst) noexcept + : impl_(std::move(fst.impl_)) { + fst.impl_ = std::make_shared(); + } + + ImplToFst &operator=(const ImplToFst &fst) { + impl_ = fst.impl_; + return *this; + } + + ImplToFst &operator=(ImplToFst &&fst) noexcept { + if (this != &fst) { + impl_ = std::move(fst.impl_); + fst.impl_ = std::make_shared(); + } + return *this; + } + + // Returns raw pointers to the shared object. + const Impl *GetImpl() const { return impl_.get(); } + + Impl *GetMutableImpl() const { return impl_.get(); } + + // Returns a ref-counted smart poiner to the implementation. + std::shared_ptr GetSharedImpl() const { return impl_; } + + bool Unique() const { return impl_.unique(); } + + void SetImpl(std::shared_ptr impl) { impl_ = std::move(impl); } + + private: + template + friend void Cast(const IFST &ifst, OFST *ofst); + + std::shared_ptr impl_; +}; + +// Converts FSTs by casting their implementations, where this makes sense +// (which excludes implementations with weight-dependent virtual methods). +// Must be a friend of the FST classes involved (currently the concrete FSTs: +// ConstFst, CompactFst, and VectorFst). This can only be safely used for arc +// types that have identical storage characteristics. As with an FST +// copy constructor and Copy() method, this is a constant time operation +// (but subject to copy-on-write if it is a MutableFst and modified). +template +void Cast(const IFST &ifst, OFST *ofst) { + using OImpl = typename OFST::Impl; + ofst->impl_ = std::shared_ptr(ifst.impl_, + reinterpret_cast(ifst.impl_.get())); +} + +// FST serialization. + +template +string FstToString(const Fst &fst, + const FstWriteOptions &options = + FstWriteOptions("FstToString")) { + std::ostringstream ostrm; + fst.Write(ostrm, options); + return ostrm.str(); +} + +template +void FstToString(const Fst &fst, string *result) { + *result = FstToString(fst); +} + +template +void FstToString(const Fst &fst, string *result, + const FstWriteOptions &options) { + *result = FstToString(fst, options); +} + +template +Fst *StringToFst(const string &s) { + std::istringstream istrm(s); + return Fst::Read(istrm, FstReadOptions("StringToFst")); +} + +} // namespace fst + +#endif // FST_FST_H_ diff --git a/projects/llm_framework/include/fst/fstlib.h b/projects/llm_framework/include/fst/fstlib.h new file mode 100644 index 00000000..e8b1c3a1 --- /dev/null +++ b/projects/llm_framework/include/fst/fstlib.h @@ -0,0 +1,130 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// This is a library for constructing, combining, optimizing, and searching +// "weighted finite-state transducers" (FSTs). Weighted finite-state transducers +// are automata where each transition has an input label, an output label, and a +// weight. The more familiar finite-state acceptor is represented as a +// transducer with each transition's input and output the same. Finite-state +// acceptors are used to represent sets of strings (specifically, "regular" or +// "rational sets"); finite-state transducers are used to represent binary +// relations between pairs of strings (specifically, "rational transductions"). +// The weights can be used to represent the cost of taking a particular +// transition. +// +// In this library, transducers are templated on the Arc (transition) +// definition, which allows changing the label, weight, and state ID sets. +// Labels and state IDs are restricted to signed integral types but the weight +// can be an arbitrary type whose members satisfy certain algebraic ("semiring") +// properties. +// +// This convenience file includes all other FST header files. + +#ifndef FST_FSTLIB_H_ +#define FST_FSTLIB_H_ + + +// Abstract FST classes. +#include +#include +#include + +// Concrete FST classes. +#include +#include +#include +#include + +// FST algorithms and delayed FST classes. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Weights. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Auxiliary classes for composition. +#include +#include +#include +#include +#include +#include + +// Data structures. +#include +#include +#include +#include + +// Miscellaneous. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +#endif // FST_FSTLIB_H_ diff --git a/projects/llm_framework/include/fst/generic-register.h b/projects/llm_framework/include/fst/generic-register.h new file mode 100644 index 00000000..ea6b8fe1 --- /dev/null +++ b/projects/llm_framework/include/fst/generic-register.h @@ -0,0 +1,126 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_GENERIC_REGISTER_H_ +#define FST_GENERIC_REGISTER_H_ + +#include +#ifndef FST_NO_DYNAMIC_LINKING +#include +#endif +#include +#include + +#include +#include + +// Generic class representing a globally-stored correspondence between +// objects of KeyType and EntryType. +// +// KeyType must: +// +// * be such as can be stored as a key in a std::map<>. +// * be concatenable with a const char* with the + operator +// (or you must subclass and redefine LoadEntryFromSharedObject) +// +// EntryType must be default constructible. +// +// The third template parameter should be the type of a subclass of this class +// (think CRTP). This is to allow GetRegister() to instantiate and return an +// object of the appropriate type. + +namespace fst { + +template +class GenericRegister { + public: + using Key = KeyType; + using Entry = EntryType; + + static RegisterType *GetRegister() { + static auto reg = new RegisterType; + return reg; + } + + void SetEntry(const KeyType &key, const EntryType &entry) { + MutexLock l(®ister_lock_); + register_table_.insert(std::make_pair(key, entry)); + } + + EntryType GetEntry(const KeyType &key) const { + const auto *entry = LookupEntry(key); + if (entry) { + return *entry; + } else { + return LoadEntryFromSharedObject(key); + } + } + + virtual ~GenericRegister() {} + + protected: + // Override this if you want to be able to load missing definitions from + // shared object files. + virtual EntryType LoadEntryFromSharedObject(const KeyType &key) const { +#ifdef FST_NO_DYNAMIC_LINKING + return EntryType(); +#else + const auto so_filename = ConvertKeyToSoFilename(key); + void *handle = dlopen(so_filename.c_str(), RTLD_LAZY); + if (handle == nullptr) { + LOG(ERROR) << "GenericRegister::GetEntry: " << dlerror(); + return EntryType(); + } +#ifdef RUN_MODULE_INITIALIZERS + RUN_MODULE_INITIALIZERS(); +#endif + // We assume that the DSO constructs a static object in its global scope + // that does the registration. Thus we need only load it, not call any + // methods. + const auto *entry = this->LookupEntry(key); + if (entry == nullptr) { + LOG(ERROR) << "GenericRegister::GetEntry: " + << "lookup failed in shared object: " << so_filename; + return EntryType(); + } + return *entry; +#endif // FST_NO_DYNAMIC_LINKING + } + + // Override this to define how to turn a key into an SO filename. + virtual string ConvertKeyToSoFilename(const KeyType &key) const = 0; + + virtual const EntryType *LookupEntry(const KeyType &key) const { + MutexLock l(®ister_lock_); + const auto it = register_table_.find(key); + if (it != register_table_.end()) { + return &it->second; + } else { + return nullptr; + } + } + + private: + mutable Mutex register_lock_; + std::map register_table_; +}; + +// Generic register-er class capable of creating new register entries in the +// given RegisterType template parameter. This type must define types Key and +// Entry, and have appropriate static GetRegister() and instance SetEntry() +// functions. An easy way to accomplish this is to have RegisterType be the +// type of a subclass of GenericRegister. +template +class GenericRegisterer { + public: + using Key = typename RegisterType::Key; + using Entry = typename RegisterType::Entry; + + GenericRegisterer(Key key, Entry entry) { + RegisterType::GetRegister()->SetEntry(key, entry); + } +}; + +} // namespace fst + +#endif // FST_GENERIC_REGISTER_H_ diff --git a/projects/llm_framework/include/fst/heap.h b/projects/llm_framework/include/fst/heap.h new file mode 100644 index 00000000..041a4bb9 --- /dev/null +++ b/projects/llm_framework/include/fst/heap.h @@ -0,0 +1,168 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Implementation of a heap as in STL, but allows tracking positions in heap +// using a key. The key can be used to do an in-place update of values in the +// heap. + +#ifndef FST_HEAP_H_ +#define FST_HEAP_H_ + +#include +#include + +#include +namespace fst { + +// A templated heap implementation that supports in-place update of values. +// +// The templated heap implementation is a little different from the STL +// priority_queue and the *_heap operations in STL. This heap supports +// indexing of values in the heap via an associated key. +// +// Each value is internally associated with a key which is returned to the +// calling functions on heap insert. This key can be used to later update +// the specific value in the heap. +// +// T: the element type of the hash. It can be POD, Data or a pointer to Data. +// Compare: comparison functor for determining min-heapness. +template +class Heap { + public: + using Value = T; + + static constexpr int kNoKey = -1; + + // Initializes with a specific comparator. + explicit Heap(Compare comp = Compare()) : comp_(comp), size_(0) {} + + // Inserts a value into the heap. + int Insert(const Value &value) { + if (size_ < values_.size()) { + values_[size_] = value; + pos_[key_[size_]] = size_; + } else { + values_.push_back(value); + pos_.push_back(size_); + key_.push_back(size_); + } + ++size_; + return Insert(value, size_ - 1); + } + + // Updates a value at position given by the key. The pos_ array is first + // indexed by the key. The position gives the position in the heap array. + // Once we have the position we can then use the standard heap operations + // to calculate the parent and child positions. + void Update(int key, const Value &value) { + const auto i = pos_[key]; + const bool is_better = comp_(value, values_[Parent(i)]); + values_[i] = value; + if (is_better) { + Insert(value, i); + } else { + Heapify(i); + } + } + + // Returns the least value. + Value Pop() { + Value top = values_.front(); + Swap(0, size_-1); + size_--; + Heapify(0); + return top; + } + + // Returns the least value w.r.t. the comparison function from the + // heap. + const Value &Top() const { return values_.front(); } + + // Returns the element for the given key. + const Value &Get(int key) const { return values_[pos_[key]]; } + + // Checks if the heap is empty. + bool Empty() const { return size_ == 0; } + + void Clear() { size_ = 0; } + + int Size() const { return size_; } + + void Reserve(int size) { + values_.reserve(size); + pos_.reserve(size); + key_.reserve(size); + } + + const Compare &GetCompare() const { return comp_; } + + private: + // The following private routines are used in a supportive role + // for managing the heap and keeping the heap properties. + + // Computes left child of parent. + static int Left(int i) { + return 2 * (i + 1) - 1; // 0 -> 1, 1 -> 3 + } + + // Computes right child of parent. + static int Right(int i) { + return 2 * (i + 1); // 0 -> 2, 1 -> 4 + } + + // Given a child computes parent. + static int Parent(int i) { + return (i - 1) / 2; // 0 -> 0, 1 -> 0, 2 -> 0, 3 -> 1, 4 -> 1, ... + } + + // Swaps a child and parent. Use to move element up/down tree. Note the use of + // a little trick here. When we swap we need to swap: + // + // - the value + // - the associated keys + // - the position of the value in the heap + void Swap(int j, int k) { + const auto tkey = key_[j]; + pos_[key_[j] = key_[k]] = j; + pos_[key_[k] = tkey] = k; + using std::swap; + swap(values_[j], values_[k]); + } + + // Heapifies the subtree rooted at index i. + void Heapify(int i) { + const auto l = Left(i); + const auto r = Right(i); + auto largest = (l < size_ && comp_(values_[l], values_[i])) ? l : i; + if (r < size_ && comp_(values_[r], values_[largest])) largest = r; + if (largest != i) { + Swap(i, largest); + Heapify(largest); + } + } + + // Inserts (updates) element at subtree rooted at index i. + int Insert(const Value &value, int i) { + int p; + while (i > 0 && !comp_(values_[p = Parent(i)], value)) { + Swap(i, p); + i = p; + } + return key_[i]; + } + + private: + const Compare comp_; + + std::vector pos_; + std::vector key_; + std::vector values_; + int size_; +}; + +template +constexpr int Heap::kNoKey; + +} // namespace fst + +#endif // FST_HEAP_H_ diff --git a/projects/llm_framework/include/fst/icu.h b/projects/llm_framework/include/fst/icu.h new file mode 100644 index 00000000..5459da24 --- /dev/null +++ b/projects/llm_framework/include/fst/icu.h @@ -0,0 +1,129 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// This library implements an unrestricted Thompson/Pike UTF-8 parser and +// serializer. UTF-8 is a restricted subset of this byte stream encoding. For +// a description of the encoding details, see: +// +// http://en.wikipedia.org/wiki/UTF-8 + +#ifndef FST_ICU_H_ +#define FST_ICU_H_ + +#include +#include + +#include + +namespace fst { + +// Trivial function to copy bytestrings into vectors of labels, truncating +// if necessary. It is possible to use this sensibly with as little as 8 bits +// of Label precision. This returns `true` deterministically for compatibility. +template +bool ByteStringToLabels(const string &str, std::vector { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ComposeFst::CreateBase; + using ComposeFst::CreateBase1; + using ComposeFst::Properties; + + IntersectFst(const Fst &fst1, const Fst &fst2, + const CacheOptions &opts = CacheOptions()) + : ComposeFst(CreateBase(fst1, fst2, opts)) { + const bool acceptors = + fst1.Properties(kAcceptor, true) && fst2.Properties(kAcceptor, true); + if (!acceptors) { + FSTERROR() << "IntersectFst: Input FSTs are not acceptors"; + GetMutableImpl()->SetProperties(kError); + } + } + + template + IntersectFst(const Fst &fst1, const Fst &fst2, + const IntersectFstOptions &opts) + : ComposeFst(CreateBase1(fst1, fst2, opts)) { + const bool acceptors = + fst1.Properties(kAcceptor, true) && fst2.Properties(kAcceptor, true); + if (!acceptors) { + FSTERROR() << "IntersectFst: input FSTs are not acceptors"; + GetMutableImpl()->SetProperties(kError); + } + } + + // See Fst<>::Copy() for doc. + IntersectFst(const IntersectFst &fst, bool safe = false) + : ComposeFst(fst, safe) {} + + // Get a copy of this IntersectFst. See Fst<>::Copy() for further doc. + IntersectFst *Copy(bool safe = false) const override { + return new IntersectFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for IntersectFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const IntersectFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for IntersectFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const IntersectFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdIntersectFst = IntersectFst; + +// Computes the intersection (Hadamard product) of two FSAs. This version +// writes the intersection to an output MurableFst. Only strings that are in +// both automata are retained in the result. +// +// The two arguments must be acceptors. One of the arguments must be +// label-sorted. +// +// Complexity: same as Compose. +// +// Caveats: same as Compose. +template +void Intersect(const Fst &ifst1, const Fst &ifst2, + MutableFst *ofst, + const IntersectOptions &opts = IntersectOptions()) { + using M = Matcher>; + // In each case, we cache only the last state for fastest copy. + switch (opts.filter_type) { + case AUTO_FILTER: { + CacheOptions nopts; + nopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, nopts); + break; + } + case SEQUENCE_FILTER: { + IntersectFstOptions iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case ALT_SEQUENCE_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case MATCH_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case NO_MATCH_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case NULL_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + case TRIVIAL_FILTER: { + IntersectFstOptions> iopts; + iopts.gc_limit = 0; + *ofst = IntersectFst(ifst1, ifst2, iopts); + break; + } + } + if (opts.connect) Connect(ofst); +} + +} // namespace fst + +#endif // FST_INTERSECT_H_ diff --git a/projects/llm_framework/include/fst/interval-set.h b/projects/llm_framework/include/fst/interval-set.h new file mode 100644 index 00000000..0942ea00 --- /dev/null +++ b/projects/llm_framework/include/fst/interval-set.h @@ -0,0 +1,398 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to represent and operate on sets of intervals. + +#ifndef FST_INTERVAL_SET_H_ +#define FST_INTERVAL_SET_H_ + +#include +#include +#include + + +#include + + +namespace fst { + +// Half-open integral interval [a, b) of signed integers of type T. +template +struct IntInterval { + T begin; + T end; + + IntInterval() : begin(-1), end(-1) {} + + IntInterval(T begin, T end) : begin(begin), end(end) {} + + bool operator<(const IntInterval &i) const { + return begin < i.begin || (begin == i.begin && end > i.end); + } + + bool operator==(const IntInterval &i) const { + return begin == i.begin && end == i.end; + } + + bool operator!=(const IntInterval &i) const { + return begin != i.begin || end != i.end; + } + + std::istream &Read(std::istream &strm) { + T n; + ReadType(strm, &n); + begin = n; + ReadType(strm, &n); + end = n; + return strm; + } + + std::ostream &Write(std::ostream &strm) const { + T n = begin; + WriteType(strm, n); + n = end; + WriteType(strm, n); + return strm; + } +}; + +// Stores IntIntervals in a vector. In addition, keeps the count of points in +// all intervals. +template +class VectorIntervalStore { + public: + using Interval = IntInterval; + using Iterator = typename std::vector::const_iterator; + + VectorIntervalStore() : count_(-1) {} + + std::vector *MutableIntervals() { return &intervals_; } + + const Interval *Intervals() const { return intervals_.data(); } + + T Size() const { return intervals_.size(); } + + T Count() const { return count_; } + + void SetCount(T count) { count_ = count; } + + void Clear() { + intervals_.clear(); + count_ = 0; + } + + Iterator begin() const { return intervals_.begin(); } + + Iterator end() const { return intervals_.end(); } + + std::istream &Read(std::istream &strm) { + ReadType(strm, &intervals_); + return ReadType(strm, &count_); + } + + std::ostream &Write(std::ostream &strm) const { + WriteType(strm, intervals_); + return WriteType(strm, count_); + } + + private: + std::vector intervals_; + T count_; +}; + +// Stores and operates on a set of half-open integral intervals [a, b) +// of signed integers of type T. +template > +class IntervalSet { + public: + using Interval = IntInterval; + + template + explicit IntervalSet(A... args) : intervals_(args...) {} + + // Returns the interval set as a vector. + std::vector *MutableIntervals() { + return intervals_.MutableIntervals(); + } + + // Returns a pointer to an array of Size() elements. + const Interval *Intervals() const { return intervals_.Intervals(); } + + bool Empty() const { return Size() == 0; } + + T Size() const { return intervals_.Size(); } + + // Number of points in the intervals (undefined if not normalized). + T Count() const { return intervals_.Count(); } + + void Clear() { intervals_.Clear(); } + + // Adds an interval set to the set. The result may not be normalized. + void Union(const IntervalSet &iset) { + intervals_.MutableIntervals()->insert(intervals_.MutableIntervals()->end(), + iset.intervals_.begin(), + iset.intervals_.end()); + } + + // Requires intervals be normalized. + bool Member(T value) const { + const Interval interval(value, value); + auto lb = std::lower_bound(intervals_.begin(), intervals_.end(), interval); + if (lb == intervals_.begin()) return false; + return (--lb)->end > value; + } + + // Requires intervals be normalized. + bool operator==(const IntervalSet &iset) const { + return Size() == iset.Size() && + std::equal(intervals_.begin(), intervals_.end(), + iset.intervals_.begin()); + } + + // Requires intervals be normalized. + bool operator!=(const IntervalSet &iset) const { + return Size() != iset.Size() || + !std::equal(intervals_.begin(), intervals_.end(), + iset.intervals_.begin()); + } + + bool Singleton() const { + return Size() == 1 && + intervals_.begin()->begin + 1 == intervals_.begin()->end; + } + + // Sorts, collapses overlapping and adjacent interals, and sets count. + void Normalize(); + + // Intersects an interval set with the set. Requires intervals be normalized. + // The result is normalized. + void Intersect(const IntervalSet &iset, + IntervalSet *oset) const; + + // Complements the set w.r.t [0, maxval). Requires intervals be normalized. + // The result is normalized. + void Complement(T maxval, IntervalSet *oset) const; + + // Subtract an interval set from the set. Requires intervals be normalized. + // The result is normalized. + void Difference(const IntervalSet &iset, + IntervalSet *oset) const; + + // Determines if an interval set overlaps with the set. Requires intervals be + // normalized. + bool Overlaps(const IntervalSet &iset) const; + + // Determines if an interval set overlaps with the set but neither is + // contained in the other. Requires intervals be normalized. + bool StrictlyOverlaps(const IntervalSet &iset) const; + + // Determines if an interval set is contained within the set. Requires + // intervals be normalized. + bool Contains(const IntervalSet &iset) const; + + std::istream &Read(std::istream &strm) { return intervals_.Read(strm); } + + std::ostream &Write(std::ostream &strm) const { + return intervals_.Write(strm); + } + + typename Store::Iterator begin() const { return intervals_.begin(); } + + typename Store::Iterator end() const { return intervals_.end(); } + + private: + Store intervals_; +}; + +// Sorts, collapses overlapping and adjacent intervals, and sets count. +template +void IntervalSet::Normalize() { + auto &intervals = *intervals_.MutableIntervals(); + std::sort(intervals.begin(), intervals.end()); + T count = 0; + T size = 0; + for (T i = 0; i < intervals.size(); ++i) { + auto &inti = intervals[i]; + if (inti.begin == inti.end) continue; + for (T j = i + 1; j < intervals.size(); ++j) { + auto &intj = intervals[j]; + if (intj.begin > inti.end) break; + if (intj.end > inti.end) inti.end = intj.end; + ++i; + } + count += inti.end - inti.begin; + intervals[size++] = inti; + } + intervals.resize(size); + intervals_.SetCount(count); +} + +// Intersects an interval set with the set. Requires intervals be normalized. +// The result is normalized. +template +void IntervalSet::Intersect(const IntervalSet &iset, + IntervalSet *oset) const { + auto *ointervals = oset->MutableIntervals(); + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + ointervals->clear(); + T count = 0; + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if (it1->end <= it2->begin) { + ++it1; + } else if (it2->end <= it1->begin) { + ++it2; + } else { + ointervals->emplace_back(std::max(it1->begin, it2->begin), + std::min(it1->end, it2->end)); + count += ointervals->back().end - ointervals->back().begin; + if (it1->end < it2->end) { + ++it1; + } else { + ++it2; + } + } + } + oset->intervals_.SetCount(count); +} + +// Complements the set w.r.t [0, maxval). Requires intervals be normalized. +// The result is normalized. +template +void IntervalSet::Complement(T maxval, + IntervalSet *oset) const { + auto *ointervals = oset->MutableIntervals(); + ointervals->clear(); + T count = 0; + Interval interval; + interval.begin = 0; + for (auto it = intervals_.begin(); it != intervals_.end(); ++it) { + interval.end = std::min(it->begin, maxval); + if ((interval.begin) < (interval.end)) { + ointervals->push_back(interval); + count += interval.end - interval.begin; + } + interval.begin = it->end; + } + interval.end = maxval; + if ((interval.begin) < (interval.end)) { + ointervals->push_back(interval); + count += interval.end - interval.begin; + } + oset->intervals_.SetCount(count); +} + +// Subtract an interval set from the set. Requires intervals be normalized. +// The result is normalized. +template +void IntervalSet::Difference(const IntervalSet &iset, + IntervalSet *oset) const { + if (Empty()) { + oset->MutableIntervals()->clear(); + oset->intervals_.SetCount(0); + } else { + IntervalSet cset; + iset.Complement(intervals_.Intervals()[intervals_.Size() - 1].end, &cset); + Intersect(cset, oset); + } +} + +// Determines if an interval set overlaps with the set. Requires intervals be +// normalized. +template +bool IntervalSet::Overlaps(const IntervalSet &iset) const { + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if (it1->end <= it2->begin) { + ++it1; + } else if (it2->end <= it1->begin) { + ++it2; + } else { + return true; + } + } + return false; +} + +// Determines if an interval set overlaps with the set but neither is contained +// in the other. Requires intervals be normalized. +template +bool IntervalSet::StrictlyOverlaps( + const IntervalSet &iset) const { + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + bool only1 = false; // Point in intervals_ but not intervals. + bool only2 = false; // Point in intervals but not intervals_. + bool overlap = false; // Point in both intervals_ and intervals. + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if (it1->end <= it2->begin) { // no overlap - it1 first + only1 = true; + ++it1; + } else if (it2->end <= it1->begin) { // no overlap - it2 first + only2 = true; + ++it2; + } else if (it2->begin == it1->begin && it2->end == it1->end) { // equals + overlap = true; + ++it1; + ++it2; + } else if (it2->begin <= it1->begin && it2->end >= it1->end) { // 1 c 2 + only2 = true; + overlap = true; + ++it1; + } else if (it1->begin <= it2->begin && it1->end >= it2->end) { // 2 c 1 + only1 = true; + overlap = true; + ++it2; + } else { // Strict overlap. + only1 = true; + only2 = true; + overlap = true; + } + if (only1 == true && only2 == true && overlap == true) return true; + } + if (it1 != intervals_.end()) only1 = true; + if (it2 != iset.intervals_.end()) only2 = true; + return only1 == true && only2 == true && overlap == true; +} + +// Determines if an interval set is contained within the set. Requires intervals +// be normalized. +template +bool IntervalSet::Contains(const IntervalSet &iset) const { + if (iset.Count() > Count()) return false; + auto it1 = intervals_.begin(); + auto it2 = iset.intervals_.begin(); + while (it1 != intervals_.end() && it2 != iset.intervals_.end()) { + if ((it1->end) <= (it2->begin)) { // No overlap; it1 first. + ++it1; + } else if ((it2->begin) < (it1->begin) || + (it2->end) > (it1->end)) { // No C. + return false; + } else if (it2->end == it1->end) { + ++it1; + ++it2; + } else { + ++it2; + } + } + return it2 == iset.intervals_.end(); +} + +template +std::ostream &operator<<(std::ostream &strm, const IntervalSet &s) { + strm << "{"; + for (T i = 0; i < s.Size(); ++i) { + if (i > 0) { + strm << ","; + } + const auto &interval = s.Intervals()[i]; + strm << "[" << interval.begin << "," << interval.end << ")"; + } + strm << "}"; + return strm; +} + +} // namespace fst + +#endif // FST_INTERVAL_SET_H_ diff --git a/projects/llm_framework/include/fst/invert.h b/projects/llm_framework/include/fst/invert.h new file mode 100644 index 00000000..bd243c62 --- /dev/null +++ b/projects/llm_framework/include/fst/invert.h @@ -0,0 +1,139 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to invert an FST. + +#ifndef FST_INVERT_H_ +#define FST_INVERT_H_ + +#include +#include + + +namespace fst { + +// Mapper to implement inversion of an arc. +template +struct InvertMapper { + using FromArc = A; + using ToArc = A; + + InvertMapper() {} + + ToArc operator()(const FromArc &arc) const { + return ToArc(arc.olabel, arc.ilabel, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { + return MAP_NO_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_CLEAR_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return InvertProperties(props); + } +}; + +// Inverts the transduction corresponding to an FST by exchanging the +// FST's input and output labels. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(1) +// +// where V is the number of states and E is the number of arcs. +template +inline void Invert(const Fst &ifst, MutableFst *ofst) { + std::unique_ptr input( + ifst.InputSymbols() ? ifst.InputSymbols()->Copy() : nullptr); + std::unique_ptr output( + ifst.OutputSymbols() ? ifst.OutputSymbols()->Copy() : nullptr); + ArcMap(ifst, ofst, InvertMapper()); + ofst->SetInputSymbols(output.get()); + ofst->SetOutputSymbols(input.get()); +} + +// Destructive variant of the above. +template +inline void Invert(MutableFst *fst) { + std::unique_ptr input( + fst->InputSymbols() ? fst->InputSymbols()->Copy() : nullptr); + std::unique_ptr output( + fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : nullptr); + ArcMap(fst, InvertMapper()); + fst->SetInputSymbols(output.get()); + fst->SetOutputSymbols(input.get()); +} + +// Inverts the transduction corresponding to an FST by exchanging the +// FST's input and output labels. This version is a delayed FST. +// +// Complexity: +// +// Time: O(v + e) +// Space: O(1) +// +// where v is the number of states visited and e is the number of arcs visited. +// Constant time and to visit an input state or arc is assumed and exclusive of +// caching. +template +class InvertFst : public ArcMapFst> { + public: + using Arc = A; + + using Mapper = InvertMapper; + using Impl = internal::ArcMapFstImpl>; + + explicit InvertFst(const Fst &fst) + : ArcMapFst(fst, Mapper()) { + GetMutableImpl()->SetOutputSymbols(fst.InputSymbols()); + GetMutableImpl()->SetInputSymbols(fst.OutputSymbols()); + } + + // See Fst<>::Copy() for doc. + InvertFst(const InvertFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Get a copy of this InvertFst. See Fst<>::Copy() for further doc. + InvertFst *Copy(bool safe = false) const override { + return new InvertFst(*this, safe); + } + + private: + using ImplToFst::GetMutableImpl; +}; + +// Specialization for InvertFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const InvertFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for InvertFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const InvertFst &fst, StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdInvertFst = InvertFst; + +} // namespace fst + +#endif // FST_INVERT_H_ diff --git a/projects/llm_framework/include/fst/isomorphic.h b/projects/llm_framework/include/fst/isomorphic.h new file mode 100644 index 00000000..b100b0a6 --- /dev/null +++ b/projects/llm_framework/include/fst/isomorphic.h @@ -0,0 +1,183 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to test two FSTs are isomorphic, i.e., they are equal up to a state +// and arc re-ordering. FSTs should be deterministic when viewed as +// unweighted automata. + +#ifndef FST_ISOMORPHIC_H_ +#define FST_ISOMORPHIC_H_ + +#include +#include +#include +#include + +#include + +#include + + +namespace fst { +namespace internal { + +// Orders weights for equality checking. +template ::value>::type * = nullptr> +bool WeightCompare(const Weight &w1, const Weight &w2, float delta, + bool *error) { + return NaturalLess()(w1, w2); +} + +template ::value>::type * = nullptr> +bool WeightCompare(const Weight &w1, const Weight &w2, float delta, + bool *error) { + // No natural order; use hash. + const auto q1 = w1.Quantize(delta); + const auto q2 = w2.Quantize(delta); + auto n1 = q1.Hash(); + auto n2 = q2.Hash(); + // Hash not unique; very unlikely to happen. + if (n1 == n2 && q1 != q2) { + VLOG(1) << "Isomorphic: Weight hash collision"; + *error = true; + } + return n1 < n2; +} + +template +class Isomorphism { + using StateId = typename Arc::StateId; + + public: + Isomorphism(const Fst &fst1, const Fst &fst2, float delta) + : fst1_(fst1.Copy()), + fst2_(fst2.Copy()), + delta_(delta), + error_(false), + comp_(delta, &error_) {} + + // Checks if input FSTs are isomorphic. + bool IsIsomorphic() { + if (fst1_->Start() == kNoStateId && fst2_->Start() == kNoStateId) { + return true; + } + if (fst1_->Start() == kNoStateId || fst2_->Start() == kNoStateId) { + return false; + } + PairState(fst1_->Start(), fst2_->Start()); + while (!queue_.empty()) { + const auto &pr = queue_.front(); + if (!IsIsomorphicState(pr.first, pr.second)) return false; + queue_.pop_front(); + } + return true; + } + + bool Error() const { return error_; } + + private: + // Orders arcs for equality checking. + class ArcCompare { + public: + ArcCompare(float delta, bool *error) : delta_(delta), error_(error) {} + + bool operator()(const Arc &arc1, const Arc &arc2) const { + if (arc1.ilabel < arc2.ilabel) return true; + if (arc1.ilabel > arc2.ilabel) return false; + if (arc1.olabel < arc2.olabel) return true; + if (arc1.olabel > arc2.olabel) return false; + return WeightCompare(arc1.weight, arc2.weight, delta_, error_); + } + + private: + float delta_; + bool *error_; + }; + + // Maintains state correspondences and queue. + bool PairState(StateId s1, StateId s2) { + if (state_pairs_.size() <= s1) state_pairs_.resize(s1 + 1, kNoStateId); + if (state_pairs_[s1] == s2) { + return true; // already seen this pair + } else if (state_pairs_[s1] != kNoStateId) { + return false; // s1 already paired with another s2 + } + state_pairs_[s1] = s2; + queue_.push_back(std::make_pair(s1, s2)); + return true; + } + + // Checks if state pair is isomorphic + bool IsIsomorphicState(StateId s1, StateId s2); + + std::unique_ptr> fst1_; + std::unique_ptr> fst2_; + float delta_; // Weight equality delta. + std::vector arcs1_; // For sorting arcs on FST1. + std::vector arcs2_; // For sorting arcs on FST2. + std::vector state_pairs_; // Maintains state correspondences. + std::list> queue_; // Queue of state pairs. + bool error_; // Error flag. + ArcCompare comp_; +}; + +template +bool Isomorphism::IsIsomorphicState(StateId s1, StateId s2) { + if (!ApproxEqual(fst1_->Final(s1), fst2_->Final(s2), delta_)) return false; + auto narcs1 = fst1_->NumArcs(s1); + auto narcs2 = fst2_->NumArcs(s2); + if (narcs1 != narcs2) return false; + ArcIterator> aiter1(*fst1_, s1); + ArcIterator> aiter2(*fst2_, s2); + arcs1_.clear(); + arcs1_.reserve(narcs1); + arcs2_.clear(); + arcs2_.reserve(narcs2); + for (; !aiter1.Done(); aiter1.Next(), aiter2.Next()) { + arcs1_.push_back(aiter1.Value()); + arcs2_.push_back(aiter2.Value()); + } + std::sort(arcs1_.begin(), arcs1_.end(), comp_); + std::sort(arcs2_.begin(), arcs2_.end(), comp_); + for (size_t i = 0; i < arcs1_.size(); ++i) { + const auto &arc1 = arcs1_[i]; + const auto &arc2 = arcs2_[i]; + if (arc1.ilabel != arc2.ilabel) return false; + if (arc1.olabel != arc2.olabel) return false; + if (!ApproxEqual(arc1.weight, arc2.weight, delta_)) return false; + if (!PairState(arc1.nextstate, arc2.nextstate)) return false; + if (i > 0) { // Checks for non-determinism. + const auto &arc0 = arcs1_[i - 1]; + if (arc1.ilabel == arc0.ilabel && arc1.olabel == arc0.olabel && + ApproxEqual(arc1.weight, arc0.weight, delta_)) { + VLOG(1) << "Isomorphic: Non-determinism as an unweighted automaton"; + error_ = true; + return false; + } + } + } + return true; +} + +} // namespace internal + +// Tests if two FSTs have the same states and arcs up to a reordering. +// Inputs should be non-deterministic when viewed as unweighted automata. +template +bool Isomorphic(const Fst &fst1, const Fst &fst2, + float delta = kDelta) { + internal::Isomorphism iso(fst1, fst2, delta); + bool result = iso.IsIsomorphic(); + if (iso.Error()) { + FSTERROR() << "Isomorphic: Cannot determine if inputs are isomorphic"; + return false; + } else { + return result; + } +} + +} // namespace fst + +#endif // FST_ISOMORPHIC_H_ diff --git a/projects/llm_framework/include/fst/label-reachable.h b/projects/llm_framework/include/fst/label-reachable.h new file mode 100644 index 00000000..f3d7f2bc --- /dev/null +++ b/projects/llm_framework/include/fst/label-reachable.h @@ -0,0 +1,511 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to determine if a non-epsilon label can be read as the first +// non-epsilon symbol along some path from a given state. + +#ifndef FST_LABEL_REACHABLE_H_ +#define FST_LABEL_REACHABLE_H_ + +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + + +namespace fst { + +// Stores shareable data for label reachable class copies. +template +class LabelReachableData { + public: + using LabelIntervalSet = IntervalSet *fst, C *mapper) { + ArcMap(fst, mapper); +} + +template +void Map(MutableFst *fst, C mapper) { + ArcMap(fst, mapper); +} + +template +void Map(const Fst &ifst, MutableFst *ofst, C *mapper) { + ArcMap(ifst, ofst, mapper); +} + +template +void Map(const Fst &ifst, MutableFst *ofst, C mapper) { + ArcMap(ifst, ofst, mapper); +} + +using MapFstOptions = ArcMapFstOptions; + +template +class MapFst : public ArcMapFst { + public: + using FromArc = A; + using ToArc = B; + + using StateId = typename ToArc::StateId; + using Weight = typename ToArc::Weight; + + using State = CacheState; + + MapFst(const Fst &fst, const C &mapper, const MapFstOptions &opts) + : ArcMapFst(fst, mapper, opts) {} + + MapFst(const Fst &fst, C *mapper, const MapFstOptions &opts) + : ArcMapFst(fst, mapper, opts) {} + + MapFst(const Fst &fst, const C &mapper) + : ArcMapFst(fst, mapper) {} + + MapFst(const Fst &fst, C *mapper) : ArcMapFst(fst, mapper) {} + + // See Fst<>::Copy() for doc. + MapFst(const MapFst &fst, bool safe = false) + : ArcMapFst(fst, safe) {} + + // Get a copy of this MapFst. See Fst<>::Copy() for further doc. + MapFst *Copy(bool safe = false) const override { + return new MapFst(*this, safe); + } +}; + +// Specialization for MapFst. +template +class StateIterator> + : public StateIterator> { + public: + explicit StateIterator(const ArcMapFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for MapFst. +template +class ArcIterator> : public ArcIterator> { + public: + ArcIterator(const ArcMapFst &fst, typename A::StateId s) + : ArcIterator>(fst, s) {} +}; + +// For backwards compatibility only; use IdentityArcMapper otherwise. +template +struct IdentityMapper { + using FromArc = A; + using ToArc = A; + + ToArc operator()(const FromArc &arc) const { return arc; } + + constexpr MapFinalAction FinalAction() const { return MAP_NO_SUPERFINAL; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { return props; } +}; + +} // namespace fst + +#endif // FST_MAP_H_ diff --git a/projects/llm_framework/include/fst/mapped-file.h b/projects/llm_framework/include/fst/mapped-file.h new file mode 100644 index 00000000..adb33c28 --- /dev/null +++ b/projects/llm_framework/include/fst/mapped-file.h @@ -0,0 +1,81 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_MAPPED_FILE_H_ +#define FST_MAPPED_FILE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { + +// A memory region is a simple abstraction for allocated memory or data from +// memory-mapped files. If mmap is null, then data represents an owned region +// of size bytes. Otherwise, mmap and size refer to the mapping and data is a +// casted pointer to a region contained within [mmap, mmap + size). If size is +// 0, then mmap and data refer to a block of memory managed externally by some +// other allocator. The offset is used when allocating memory to providing +// padding for alignment. +struct MemoryRegion { + void *data; + void *mmap; + size_t size; + int offset; +}; + +class MappedFile { + public: + ~MappedFile(); + + void *mutable_data() const { return region_.data; } + + const void *data() const { return region_.data; } + + // Returns a MappedFile object that contains the contents of the input stream + // strm starting from the current file position with size bytes. The memorymap + // bool is advisory, and Map will default to allocating and reading. The + // source argument needs to contain the filename that was used to open the + // input stream. + static MappedFile *Map(std::istream *istrm, bool memorymap, + const string &source, size_t size); + + // Returns a MappedFile object that contains the contents of the file referred + // to by the file descriptor starting from pos with size bytes. If the + // memory mapping fails, nullptr is returned. In contrast to Map(), this + // factory function does not backoff to allocating and reading. + static MappedFile *MapFromFileDescriptor(int fd, int pos, size_t size); + + // Creates a MappedFile object with a new[]'ed block of memory of size. The + // align argument can be used to specify a desired block alignment. + // This is RECOMMENDED FOR INTERNAL USE ONLY as it may change in future + // releases. + static MappedFile *Allocate(size_t size, int align = kArchAlignment); + + // Creates a MappedFile object pointing to a borrowed reference to data. This + // block of memory is not owned by the MappedFile object and will not be + // freed. This is RECOMMENDED FOR INTERNAL USE ONLY, may change in future + // releases. + static MappedFile *Borrow(void *data); + + // Alignment required for mapping structures in bytes. Regions of memory that + // are not aligned upon a 128-bit boundary are read from the file instead. + // This is consistent with the alignment boundary set in ConstFst and + // CompactFst. + static constexpr int kArchAlignment = 16; + + static constexpr size_t kMaxReadChunk = 256 * 1024 * 1024; // 256 MB. + + private: + explicit MappedFile(const MemoryRegion ®ion); + + MemoryRegion region_; + MappedFile(const MappedFile &) = delete; + MappedFile &operator=(const MappedFile &) = delete; +}; +} // namespace fst + +#endif // FST_MAPPED_FILE_H_ diff --git a/projects/llm_framework/include/fst/matcher-fst.h b/projects/llm_framework/include/fst/matcher-fst.h new file mode 100644 index 00000000..61e95820 --- /dev/null +++ b/projects/llm_framework/include/fst/matcher-fst.h @@ -0,0 +1,347 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to add a matcher to an FST. + +#ifndef FST_MATCHER_FST_H_ +#define FST_MATCHER_FST_H_ + +#include +#include + +#include +#include +#include + + +namespace fst { + +// Writeable matchers have the same interface as Matchers (as defined in +// matcher.h) along with the following additional methods: +// +// template +// class Matcher { +// public: +// using FST = F; +// ... +// using MatcherData = ...; // Initialization data. +// +// // Constructor with additional argument for external initialization data; +// // matcher increments its reference count on construction and decrements +// // the reference count, and deletes once the reference count has reached +// // zero. +// Matcher(const FST &fst, MatchType type, MatcherData *data); +// +// // Returns pointer to initialization data that can be passed to a Matcher +// // constructor. +// MatcherData *GetData() const; +// }; + +// The matcher initialization data class must also provide the following +// interface: +// +// class MatcherData { +// public: +// // Required copy constructor. +// MatcherData(const MatcherData &); +// +// // Required I/O methods. +// static MatcherData *Read(std::istream &istrm, const FstReadOptions &opts); +// bool Write(std::ostream &ostrm, const FstWriteOptions &opts) const; +// }; + +// Trivial (no-op) MatcherFst initializer functor. +template +class NullMatcherFstInit { + public: + using MatcherData = typename M::MatcherData; + using Data = AddOnPair; + using Impl = internal::AddOnImpl; + + explicit NullMatcherFstInit(std::shared_ptr *) {} +}; + +// Class adding a matcher to an FST type. Creates a new FST whose name is given +// by N. An optional functor Init can be used to initialize the FST. The Data +// template parameter allows the user to select the type of the add-on. +template < + class F, class M, const char *Name, class Init = NullMatcherFstInit, + class Data = AddOnPair> +class MatcherFst : public ImplToExpandedFst> { + public: + using FST = F; + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + using FstMatcher = M; + using MatcherData = typename FstMatcher::MatcherData; + + using Impl = internal::AddOnImpl; + using D = Data; + + friend class StateIterator>; + friend class ArcIterator>; + + MatcherFst() : ImplToExpandedFst(std::make_shared(FST(), Name)) {} + + explicit MatcherFst(const FST &fst, std::shared_ptr data = nullptr) + : ImplToExpandedFst(data ? CreateImpl(fst, Name, data) + : CreateDataAndImpl(fst, Name)) {} + + explicit MatcherFst(const Fst &fst) + : ImplToExpandedFst(CreateDataAndImpl(fst, Name)) {} + + // See Fst<>::Copy() for doc. + MatcherFst(const MatcherFst &fst, + bool safe = false) + : ImplToExpandedFst(fst, safe) {} + + // Get a copy of this MatcherFst. See Fst<>::Copy() for further doc. + MatcherFst *Copy( + bool safe = false) const override { + return new MatcherFst(*this, safe); + } + + // Read a MatcherFst from an input stream; return nullptr on error + static MatcherFst *Read( + std::istream &strm, const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new MatcherFst( + std::shared_ptr(impl)) + : nullptr; + } + + // Read a MatcherFst from a file; return nullptr on error + // Empty filename reads from standard input + static MatcherFst *Read( + const string &filename) { + auto *impl = ImplToExpandedFst::Read(filename); + return impl ? new MatcherFst( + std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return GetImpl()->Write(strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + void InitStateIterator(StateIteratorData *data) const override { + return GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + return GetImpl()->InitArcIterator(s, data); + } + + FstMatcher *InitMatcher(MatchType match_type) const override { + return new FstMatcher(&GetFst(), match_type, GetSharedData(match_type)); + } + + const FST &GetFst() const { return GetImpl()->GetFst(); } + + const Data *GetAddOn() const { return GetImpl()->GetAddOn(); } + + std::shared_ptr GetSharedAddOn() const { + return GetImpl()->GetSharedAddOn(); + } + + const MatcherData *GetData(MatchType match_type) const { + const auto *data = GetAddOn(); + return match_type == MATCH_INPUT ? data->First() : data->Second(); + } + + std::shared_ptr GetSharedData(MatchType match_type) const { + const auto *data = GetAddOn(); + return match_type == MATCH_INPUT ? data->SharedFirst() + : data->SharedSecond(); + } + + protected: + using ImplToFst>::GetImpl; + + static std::shared_ptr CreateDataAndImpl(const FST &fst, + const string &name) { + FstMatcher imatcher(fst, MATCH_INPUT); + FstMatcher omatcher(fst, MATCH_OUTPUT); + return CreateImpl(fst, name, + std::make_shared(imatcher.GetSharedData(), + omatcher.GetSharedData())); + } + + static std::shared_ptr CreateDataAndImpl(const Fst &fst, + const string &name) { + FST result(fst); + return CreateDataAndImpl(result, name); + } + + static std::shared_ptr CreateImpl(const FST &fst, const string &name, + std::shared_ptr data) { + auto impl = std::make_shared(fst, name); + impl->SetAddOn(data); + Init init(&impl); + return impl; + } + + explicit MatcherFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + private: + MatcherFst &operator=(const MatcherFst &) = delete; +}; + +// Specialization for MatcherFst. +template +class StateIterator> + : public StateIterator { + public: + explicit StateIterator(const MatcherFst &fst) + : StateIterator(fst.GetImpl()->GetFst()) {} +}; + +// Specialization for MatcherFst. +template +class ArcIterator> : public ArcIterator { + public: + using StateId = typename FST::Arc::StateId; + + ArcIterator(const MatcherFst &fst, + typename FST::Arc::StateId s) + : ArcIterator(fst.GetImpl()->GetFst(), s) {} +}; + +// Specialization for MatcherFst. +template +class Matcher> { + public: + using FST = MatcherFst; + using Arc = typename F::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + + Matcher(const FST &fst, MatchType match_type) + : matcher_(fst.InitMatcher(match_type)) {} + + Matcher(const Matcher &matcher) : matcher_(matcher.matcher_->Copy()) {} + + Matcher *Copy() const { return new Matcher(*this); } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { matcher_->SetState(s); } + + bool Find(Label label) { return matcher_->Find(label); } + + bool Done() const { return matcher_->Done(); } + + const Arc &Value() const { return matcher_->Value(); } + + void Next() { matcher_->Next(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + uint32 Flags() const { return matcher_->Flags(); } + + private: + std::unique_ptr matcher_; +}; + +// Specialization for MatcherFst. +template +class LookAheadMatcher> { + public: + using FST = MatcherFst; + using Arc = typename F::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + LookAheadMatcher(const FST &fst, MatchType match_type) + : matcher_(fst.InitMatcher(match_type)) {} + + LookAheadMatcher(const LookAheadMatcher &matcher, bool safe = false) + : matcher_(matcher.matcher_->Copy(safe)) {} + + // General matcher methods. + LookAheadMatcher *Copy(bool safe = false) const { + return new LookAheadMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId s) { matcher_->SetState(s); } + + bool Find(Label label) { return matcher_->Find(label); } + + bool Done() const { return matcher_->Done(); } + + const Arc &Value() const { return matcher_->Value(); } + + void Next() { matcher_->Next(); } + + const FST &GetFst() const { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + uint32 Flags() const { return matcher_->Flags(); } + + bool LookAheadLabel(Label label) const { + return matcher_->LookAheadLabel(label); + } + + bool LookAheadFst(const Fst &fst, StateId s) { + return matcher_->LookAheadFst(fst, s); + } + + Weight LookAheadWeight() const { return matcher_->LookAheadWeight(); } + + bool LookAheadPrefix(Arc *arc) const { + return matcher_->LookAheadPrefix(arc); + } + + void InitLookAheadFst(const Fst &fst, bool copy = false) { + matcher_->InitLookAheadFst(fst, copy); + } + + private: + std::unique_ptr matcher_; +}; + +// Useful aliases when using StdArc. + +extern const char arc_lookahead_fst_type[]; + +using StdArcLookAheadFst = + MatcherFst, + ArcLookAheadMatcher>>, + arc_lookahead_fst_type>; + +extern const char ilabel_lookahead_fst_type[]; +extern const char olabel_lookahead_fst_type[]; + +constexpr auto ilabel_lookahead_flags = + kInputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix | + kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; + +constexpr auto olabel_lookahead_flags = + kOutputLookAheadMatcher | kLookAheadWeight | kLookAheadPrefix | + kLookAheadEpsilons | kLookAheadNonEpsilonPrefix; + +using StdILabelLookAheadFst = MatcherFst< + ConstFst, + LabelLookAheadMatcher>, + ilabel_lookahead_flags, FastLogAccumulator>, + ilabel_lookahead_fst_type, LabelLookAheadRelabeler>; + +using StdOLabelLookAheadFst = MatcherFst< + ConstFst, + LabelLookAheadMatcher>, + olabel_lookahead_flags, FastLogAccumulator>, + olabel_lookahead_fst_type, LabelLookAheadRelabeler>; + +} // namespace fst + +#endif // FST_MATCHER_FST_H_ diff --git a/projects/llm_framework/include/fst/matcher.h b/projects/llm_framework/include/fst/matcher.h new file mode 100644 index 00000000..d9528d68 --- /dev/null +++ b/projects/llm_framework/include/fst/matcher.h @@ -0,0 +1,1575 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes to allow matching labels leaving FST states. + +#ifndef FST_MATCHER_H_ +#define FST_MATCHER_H_ + +#include +#include +#include +#include + +#include + +#include // for all internal FST accessors. + + +namespace fst { + +// Matchers find and iterate through requested labels at FST states. In the +// simplest form, these are just some associative map or search keyed on labels. +// More generally, they may implement matching special labels that represent +// sets of labels such as sigma (all), rho (rest), or phi (fail). The Matcher +// interface is: +// +// template +// class Matcher { +// public: +// using FST = F; +// using Arc = typename FST::Arc; +// using Label = typename Arc::Label; +// using StateId = typename Arc::StateId; +// using Weight = typename Arc::Weight; +// +// // Required constructors. Note: +// // -- the constructors that copy the FST arg are useful for +// // letting the matcher manage the FST through copies +// // (esp with 'safe' copies); e.g. ComposeFst depends on this. +// // -- the constructor that does not copy is useful when the +// // the FST is mutated during the lifetime of the matcher +// // (o.w. the matcher would have its own unmutated deep copy). +// +// // This makes a copy of the FST. +// Matcher(const FST &fst, MatchType type); +// // This doesn't copy the FST. +// Matcher(const FST *fst, MatchType type); +// // This makes a copy of the FST. +// // See Copy() below. +// Matcher(const Matcher &matcher, bool safe = false); +// +// // If safe = true, the copy is thread-safe. See Fst<>::Copy() for +// // further doc. +// Matcher *Copy(bool safe = false) const override; +// +// // Returns the match type that can be provided (depending on compatibility +// of the input FST). It is either the requested match type, MATCH_NONE, or +// MATCH_UNKNOWN. If test is false, a costly testing is avoided, but +// MATCH_UNKNOWN may be returned. If test is true, a definite answer is +// returned, but may involve more costly computation (e.g., visiting the FST). +// MatchType Type(bool test) const override; +// +// // Specifies the current state. +// void SetState(StateId s) final; +// +// // Finds matches to a label at the current state, returning true if a match +// // found. kNoLabel matches any non-consuming transitions, e.g., epsilon +// // transitions, which do not require a matching symbol. +// bool Find(Label label) final; +// +// // Iterator methods. Note that initially and after SetState() these have +// undefined behavior until Find() is called. +// +// bool Done() const final; +// +// const Arc &Value() const final; +// +// void Next() final; +// +// // Returns final weight of a state. +// Weight Final(StateId) const final; +// +// // Indicates preference for being the side used for matching in +// // composition. If the value is kRequirePriority, then it is +// // mandatory that it be used. Calling this method without passing the +// // current state of the matcher invalidates the state of the matcher. +// ssize_t Priority(StateId s) final; +// +// // This specifies the known FST properties as viewed from this matcher. It +// // takes as argument the input FST's known properties. +// uint64 Properties(uint64 props) const override; +// +// // Returns matcher flags. +// uint32 Flags() const override; +// +// // Returns matcher FST. +// const FST &GetFst() const override; +// }; + +// Basic matcher flags. + +// Matcher needs to be used as the matching side in composition for +// at least one state (has kRequirePriority). +constexpr uint32 kRequireMatch = 0x00000001; + +// Flags used for basic matchers (see also lookahead.h). +constexpr uint32 kMatcherFlags = kRequireMatch; + +// Matcher priority that is mandatory. +constexpr ssize_t kRequirePriority = -1; + +// Matcher interface, templated on the Arc definition; used for matcher +// specializations that are returned by the InitMatcher FST method. +template +class MatcherBase { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + virtual ~MatcherBase() {} + + // Virtual interface. + + virtual MatcherBase *Copy(bool safe = false) const = 0; + virtual MatchType Type(bool) const = 0; + virtual void SetState(StateId) = 0; + virtual bool Find(Label) = 0; + virtual bool Done() const = 0; + virtual const Arc &Value() const = 0; + virtual void Next() = 0; + virtual const Fst &GetFst() const = 0; + virtual uint64 Properties(uint64) const = 0; + + // Trivial implementations that can be used by derived classes. Full + // devirtualization is expected for any derived class marked final. + virtual uint32 Flags() const { return 0; } + + virtual Weight Final(StateId s) const { return internal::Final(GetFst(), s); } + + virtual ssize_t Priority(StateId s) { return internal::NumArcs(GetFst(), s); } +}; + +// A matcher that expects sorted labels on the side to be matched. +// If match_type == MATCH_INPUT, epsilons match the implicit self-loop +// Arc(kNoLabel, 0, Weight::One(), current_state) as well as any +// actual epsilon transitions. If match_type == MATCH_OUTPUT, then +// Arc(0, kNoLabel, Weight::One(), current_state) is instead matched. +template +class SortedMatcher : public MatcherBase { + public: + using FST = F; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using MatcherBase::Flags; + using MatcherBase::Properties; + + // Labels >= binary_label will be searched for by binary search; + // o.w. linear search is used. + // This makes a copy of the FST. + SortedMatcher(const FST &fst, MatchType match_type, Label binary_label = 1) + : SortedMatcher(fst.Copy(), match_type, binary_label) { + owned_fst_.reset(&fst_); + } + + // Labels >= binary_label will be searched for by binary search; + // o.w. linear search is used. + // This doesn't copy the FST. + SortedMatcher(const FST *fst, MatchType match_type, Label binary_label = 1) + : fst_(*fst), + state_(kNoStateId), + aiter_(nullptr), + match_type_(match_type), + binary_label_(binary_label), + match_label_(kNoLabel), + narcs_(0), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + error_(false), + aiter_pool_(1) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_NONE: + break; + case MATCH_OUTPUT: + std::swap(loop_.ilabel, loop_.olabel); + break; + default: + FSTERROR() << "SortedMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This makes a copy of the FST. + SortedMatcher(const SortedMatcher &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + state_(kNoStateId), + aiter_(nullptr), + match_type_(matcher.match_type_), + binary_label_(matcher.binary_label_), + match_label_(kNoLabel), + narcs_(0), + loop_(matcher.loop_), + error_(matcher.error_), + aiter_pool_(1) {} + + ~SortedMatcher() override { Destroy(aiter_, &aiter_pool_); } + + SortedMatcher *Copy(bool safe = false) const override { + return new SortedMatcher(*this, safe); + } + + MatchType Type(bool test) const override { + if (match_type_ == MATCH_NONE) return match_type_; + const auto true_prop = + match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted; + const auto false_prop = + match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted; + const auto props = fst_.Properties(true_prop | false_prop, test); + if (props & true_prop) { + return match_type_; + } else if (props & false_prop) { + return MATCH_NONE; + } else { + return MATCH_UNKNOWN; + } + } + + void SetState(StateId s) final { + if (state_ == s) return; + state_ = s; + if (match_type_ == MATCH_NONE) { + FSTERROR() << "SortedMatcher: Bad match type"; + error_ = true; + } + Destroy(aiter_, &aiter_pool_); + aiter_ = new (&aiter_pool_) ArcIterator(fst_, s); + aiter_->SetFlags(kArcNoCache, kArcNoCache); + narcs_ = internal::NumArcs(fst_, s); + loop_.nextstate = s; + } + + bool Find(Label match_label) final { + exact_match_ = true; + if (error_) { + current_loop_ = false; + match_label_ = kNoLabel; + return false; + } + current_loop_ = match_label == 0; + match_label_ = match_label == kNoLabel ? 0 : match_label; + if (Search()) { + return true; + } else { + return current_loop_; + } + } + + // Positions matcher to the first position where inserting match_label would + // maintain the sort order. + void LowerBound(Label label) { + exact_match_ = false; + current_loop_ = false; + if (error_) { + match_label_ = kNoLabel; + return; + } + match_label_ = label; + Search(); + } + + // After Find(), returns false if no more exact matches. + // After LowerBound(), returns false if no more arcs. + bool Done() const final { + if (current_loop_) return false; + if (aiter_->Done()) return true; + if (!exact_match_) return false; + aiter_->SetFlags(match_type_ == MATCH_INPUT ? + kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + return GetLabel() != match_label_; + } + + const Arc &Value() const final { + if (current_loop_) return loop_; + aiter_->SetFlags(kArcValueFlags, kArcValueFlags); + return aiter_->Value(); + } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else { + aiter_->Next(); + } + } + + Weight Final(StateId s) const final { + return MatcherBase::Final(s); + } + + ssize_t Priority(StateId s) final { + return MatcherBase::Priority(s); + } + + const FST &GetFst() const override { return fst_; } + + uint64 Properties(uint64 inprops) const override { + return inprops | (error_ ? kError : 0); + } + + size_t Position() const { return aiter_ ? aiter_->Position() : 0; } + + private: + Label GetLabel() const { + const auto &arc = aiter_->Value(); + return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel; + } + + bool BinarySearch(); + bool LinearSearch(); + bool Search(); + + std::unique_ptr owned_fst_; // FST ptr if owned. + const FST &fst_; // FST for matching. + StateId state_; // Matcher state. + ArcIterator *aiter_; // Iterator for current state. + MatchType match_type_; // Type of match to perform. + Label binary_label_; // Least label for binary search. + Label match_label_; // Current label to be matched. + size_t narcs_; // Current state arc count. + Arc loop_; // For non-consuming symbols. + bool current_loop_; // Current arc is the implicit loop. + bool exact_match_; // Exact match or lower bound? + bool error_; // Error encountered? + MemoryPool> aiter_pool_; // Pool of arc iterators. +}; + +// Returns true iff match to match_label_. The arc iterator is positioned at the +// lower bound, that is, the first element greater than or equal to +// match_label_, or the end if all elements are less than match_label_. +// If multiple elements are equal to the `match_label_`, returns the rightmost +// one. +template +inline bool SortedMatcher::BinarySearch() { + size_t size = narcs_; + if (size == 0) { + return false; + } + size_t high = size - 1; + while (size > 1) { + const size_t half = size / 2; + const size_t mid = high - half; + aiter_->Seek(mid); + if (GetLabel() >= match_label_) { + high = mid; + } + size -= half; + } + aiter_->Seek(high); + const auto label = GetLabel(); + if (label == match_label_) { + return true; + } + if (label < match_label_) { + aiter_->Next(); + } + return false; +} + +// Returns true iff match to match_label_, positioning arc iterator at lower +// bound. +template +inline bool SortedMatcher::LinearSearch() { + for (aiter_->Reset(); !aiter_->Done(); aiter_->Next()) { + const auto label = GetLabel(); + if (label == match_label_) return true; + if (label > match_label_) break; + } + return false; +} + +// Returns true iff match to match_label_, positioning arc iterator at lower +// bound. +template +inline bool SortedMatcher::Search() { + aiter_->SetFlags(match_type_ == MATCH_INPUT ? + kArcILabelValue : kArcOLabelValue, + kArcValueFlags); + if (match_label_ >= binary_label_) { + return BinarySearch(); + } else { + return LinearSearch(); + } +} + +// A matcher that stores labels in a per-state hash table populated upon the +// first visit to that state. Sorting is not required. Treatment of +// epsilons are the same as with SortedMatcher. +template +class HashMatcher : public MatcherBase { + public: + using FST = F; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using MatcherBase::Flags; + using MatcherBase::Final; + using MatcherBase::Priority; + + // This makes a copy of the FST. + HashMatcher(const FST &fst, MatchType match_type) + : HashMatcher(fst.Copy(), match_type) { + owned_fst_.reset(&fst_); + } + + // This doesn't copy the FST. + HashMatcher(const FST *fst, MatchType match_type) + : fst_(*fst), + state_(kNoStateId), + match_type_(match_type), + loop_(kNoLabel, 0, Weight::One(), kNoStateId), + error_(false), + state_table_(std::make_shared()) { + switch (match_type_) { + case MATCH_INPUT: + case MATCH_NONE: + break; + case MATCH_OUTPUT: + std::swap(loop_.ilabel, loop_.olabel); + break; + default: + FSTERROR() << "HashMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + } + + // This makes a copy of the FST. + HashMatcher(const HashMatcher &matcher, bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + state_(kNoStateId), + match_type_(matcher.match_type_), + loop_(matcher.loop_), + error_(matcher.error_), + state_table_( + safe ? std::make_shared() : matcher.state_table_) {} + + HashMatcher *Copy(bool safe = false) const override { + return new HashMatcher(*this, safe); + } + + // The argument is ignored as there are no relevant properties to test. + MatchType Type(bool test) const override { return match_type_; } + + void SetState(StateId s) final; + + bool Find(Label label) final { + current_loop_ = label == 0; + if (label == 0) { + Search(label); + return true; + } + if (label == kNoLabel) label = 0; + return Search(label); + } + + bool Done() const final { + if (current_loop_) return false; + return label_it_ == label_end_; + } + + const Arc &Value() const final { + if (current_loop_) return loop_; + aiter_->Seek(label_it_->second); + return aiter_->Value(); + } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + } else { + ++label_it_; + } + } + + const FST &GetFst() const override { return fst_; } + + uint64 Properties(uint64 inprops) const override { + return inprops | (error_ ? kError : 0); + } + + private: + Label GetLabel() const { + const auto &arc = aiter_->Value(); + return match_type_ == MATCH_INPUT ? arc.ilabel : arc.olabel; + } + + bool Search(Label match_label); + + using LabelTable = std::unordered_multimap; + using StateTable = std::unordered_map>; + + std::unique_ptr owned_fst_; // ptr to FST if owned. + const FST &fst_; // FST for matching. + StateId state_; // Matcher state. + MatchType match_type_; + Arc loop_; // The implicit loop itself. + bool current_loop_; // Is the current arc the implicit loop? + bool error_; // Error encountered? + std::unique_ptr> aiter_; + std::shared_ptr state_table_; // Table from state to label table. + LabelTable *label_table_; // Pointer to current state's label table. + typename LabelTable::iterator label_it_; // Position for label. + typename LabelTable::iterator label_end_; // Position for last label + 1. +}; + +template +void HashMatcher::SetState(typename FST::Arc::StateId s) { + if (state_ == s) return; + // Resets everything for the state. + state_ = s; + loop_.nextstate = state_; + aiter_.reset(new ArcIterator(fst_, state_)); + if (match_type_ == MATCH_NONE) { + FSTERROR() << "HashMatcher: Bad match type"; + error_ = true; + } + // Attempts to insert a new label table. + auto it_and_success = state_table_->emplace( + state_, std::unique_ptr(new LabelTable())); + // Sets instance's pointer to the label table for this state. + label_table_ = it_and_success.first->second.get(); + // If it already exists, no additional work is done and we simply return. + if (!it_and_success.second) return; + // Otherwise, populate this new table. + // Populates the label table. + label_table_->reserve(internal::NumArcs(fst_, state_)); + const auto aiter_flags = + (match_type_ == MATCH_INPUT ? kArcILabelValue : kArcOLabelValue) | + kArcNoCache; + aiter_->SetFlags(aiter_flags, kArcFlags); + for (; !aiter_->Done(); aiter_->Next()) { + label_table_->emplace(GetLabel(), aiter_->Position()); + } + aiter_->SetFlags(kArcValueFlags, kArcValueFlags); +} + +template +inline bool HashMatcher::Search(typename FST::Arc::Label match_label) { + auto range = label_table_->equal_range(match_label); + label_it_ = range.first; + label_end_ = range.second; + if (label_it_ == label_end_) return false; + aiter_->Seek(label_it_->second); + return true; +} + +// Specifies whether we rewrite both the input and output sides during matching. +enum MatcherRewriteMode { + MATCHER_REWRITE_AUTO = 0, // Rewrites both sides iff acceptor. + MATCHER_REWRITE_ALWAYS, + MATCHER_REWRITE_NEVER +}; + +// For any requested label that doesn't match at a state, this matcher +// considers the *unique* transition that matches the label 'phi_label' +// (phi = 'fail'), and recursively looks for a match at its +// destination. When 'phi_loop' is true, if no match is found but a +// phi self-loop is found, then the phi transition found is returned +// with the phi_label rewritten as the requested label (both sides if +// an acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'phi_label'). If 'phi_label' is +// kNoLabel, this special matching is not done. PhiMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by PhiMatcher. The user can instead pass in this +// object; in that case, PhiMatcher takes its ownership. +// Phi non-determinism not supported. No non-consuming symbols other +// than epsilon supported with the underlying template argument matcher. +template +class PhiMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + PhiMatcher(const FST &fst, MatchType match_type, Label phi_label = kNoLabel, + bool phi_loop = true, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + phi_label_(phi_label), + state_(kNoStateId), + phi_loop_(phi_loop), + error_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "PhiMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (rewrite_mode == MATCHER_REWRITE_AUTO) { + rewrite_both_ = fst.Properties(kAcceptor, true); + } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) { + rewrite_both_ = true; + } else { + rewrite_both_ = false; + } + } + + // This doesn't copy the FST. + PhiMatcher(const FST *fst, MatchType match_type, Label phi_label = kNoLabel, + bool phi_loop = true, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : PhiMatcher(*fst, match_type, phi_label, phi_loop, rewrite_mode, + matcher ? matcher : new M(fst, match_type)) { } + + + // This makes a copy of the FST. + PhiMatcher(const PhiMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + phi_label_(matcher.phi_label_), + rewrite_both_(matcher.rewrite_both_), + state_(kNoStateId), + phi_loop_(matcher.phi_loop_), + error_(matcher.error_) {} + + PhiMatcher *Copy(bool safe = false) const override { + return new PhiMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { + if (state_ == s) return; + matcher_->SetState(s); + state_ = s; + has_phi_ = phi_label_ != kNoLabel; + } + + bool Find(Label match_label) final; + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { + if ((phi_match_ == kNoLabel) && (phi_weight_ == Weight::One())) { + return matcher_->Value(); + } else if (phi_match_ == 0) { // Virtual epsilon loop. + phi_arc_ = Arc(kNoLabel, 0, Weight::One(), state_); + if (match_type_ == MATCH_OUTPUT) { + std::swap(phi_arc_.ilabel, phi_arc_.olabel); + } + return phi_arc_; + } else { + phi_arc_ = matcher_->Value(); + phi_arc_.weight = Times(phi_weight_, phi_arc_.weight); + if (phi_match_ != kNoLabel) { // Phi loop match. + if (rewrite_both_) { + if (phi_arc_.ilabel == phi_label_) phi_arc_.ilabel = phi_match_; + if (phi_arc_.olabel == phi_label_) phi_arc_.olabel = phi_match_; + } else if (match_type_ == MATCH_INPUT) { + phi_arc_.ilabel = phi_match_; + } else { + phi_arc_.olabel = phi_match_; + } + } + return phi_arc_; + } + } + + void Next() final { matcher_->Next(); } + + Weight Final(StateId s) const final { + auto weight = matcher_->Final(s); + if (phi_label_ == kNoLabel || weight != Weight::Zero()) { + return weight; + } + weight = Weight::One(); + matcher_->SetState(s); + while (matcher_->Final(s) == Weight::Zero()) { + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) break; + weight = Times(weight, matcher_->Value().weight); + if (s == matcher_->Value().nextstate) { + return Weight::Zero(); // Does not follow phi self-loops. + } + s = matcher_->Value().nextstate; + matcher_->SetState(s); + } + weight = Times(weight, matcher_->Final(s)); + return weight; + } + + ssize_t Priority(StateId s) final { + if (phi_label_ != kNoLabel) { + matcher_->SetState(s); + const bool has_phi = matcher_->Find(phi_label_ == 0 ? -1 : phi_label_); + return has_phi ? kRequirePriority : matcher_->Priority(s); + } else { + return matcher_->Priority(s); + } + } + + const FST &GetFst() const override { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const override; + + uint32 Flags() const override { + if (phi_label_ == kNoLabel || match_type_ == MATCH_NONE) { + return matcher_->Flags(); + } + return matcher_->Flags() | kRequireMatch; + } + + Label PhiLabel() const { return phi_label_; } + + private: + mutable std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + Label phi_label_; // Label that represents the phi transition. + bool rewrite_both_; // Rewrite both sides when both are phi_label_? + bool has_phi_; // Are there possibly phis at the current state? + Label phi_match_; // Current label that matches phi loop. + mutable Arc phi_arc_; // Arc to return. + StateId state_; // Matcher state. + Weight phi_weight_; // Product of the weights of phi transitions taken. + bool phi_loop_; // When true, phi self-loop are allowed and treated + // as rho (required for Aho-Corasick). + bool error_; // Error encountered? + + PhiMatcher &operator=(const PhiMatcher &) = delete; +}; + +template +inline bool PhiMatcher::Find(Label label) { + if (label == phi_label_ && phi_label_ != kNoLabel && phi_label_ != 0) { + FSTERROR() << "PhiMatcher::Find: bad label (phi): " << phi_label_; + error_ = true; + return false; + } + matcher_->SetState(state_); + phi_match_ = kNoLabel; + phi_weight_ = Weight::One(); + // If phi_label_ == 0, there are no more true epsilon arcs. + if (phi_label_ == 0) { + if (label == kNoLabel) { + return false; + } + if (label == 0) { // but a virtual epsilon loop needs to be returned. + if (!matcher_->Find(kNoLabel)) { + return matcher_->Find(0); + } else { + phi_match_ = 0; + return true; + } + } + } + if (!has_phi_ || label == 0 || label == kNoLabel) { + return matcher_->Find(label); + } + auto s = state_; + while (!matcher_->Find(label)) { + // Look for phi transition (if phi_label_ == 0, we need to look + // for -1 to avoid getting the virtual self-loop) + if (!matcher_->Find(phi_label_ == 0 ? -1 : phi_label_)) return false; + if (phi_loop_ && matcher_->Value().nextstate == s) { + phi_match_ = label; + return true; + } + phi_weight_ = Times(phi_weight_, matcher_->Value().weight); + s = matcher_->Value().nextstate; + matcher_->Next(); + if (!matcher_->Done()) { + FSTERROR() << "PhiMatcher: Phi non-determinism not supported"; + error_ = true; + } + matcher_->SetState(s); + } + return true; +} + +template +inline uint64 PhiMatcher::Properties(uint64 inprops) const { + auto outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoIEpsilons; + } + if (rewrite_both_) { + return outprops & + ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kODeterministic | kAcceptor | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (phi_label_ == 0) { + outprops &= ~kEpsilons | ~kIEpsilons | ~kOEpsilons; + outprops |= kNoEpsilons | kNoOEpsilons; + } + if (rewrite_both_) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kIDeterministic | kAcceptor | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "PhiMatcher: Bad match type: " << match_type_; + return 0; + } +} + +// For any requested label that doesn't match at a state, this matcher +// considers all transitions that match the label 'rho_label' (rho = +// 'rest'). Each such rho transition found is returned with the +// rho_label rewritten as the requested label (both sides if an +// acceptor, or if 'rewrite_both' is true and both input and output +// labels of the found transition are 'rho_label'). If 'rho_label' is +// kNoLabel, this special matching is not done. RhoMatcher is +// templated itself on a matcher, which is used to perform the +// underlying matching. By default, the underlying matcher is +// constructed by RhoMatcher. The user can instead pass in this +// object; in that case, RhoMatcher takes its ownership. +// No non-consuming symbols other than epsilon supported with +// the underlying template argument matcher. +template +class RhoMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + RhoMatcher(const FST &fst, MatchType match_type, Label rho_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + rho_label_(rho_label), + error_(false), + state_(kNoStateId), + has_rho_(false) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "RhoMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (rho_label == 0) { + FSTERROR() << "RhoMatcher: 0 cannot be used as rho_label"; + rho_label_ = kNoLabel; + error_ = true; + } + if (rewrite_mode == MATCHER_REWRITE_AUTO) { + rewrite_both_ = fst.Properties(kAcceptor, true); + } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) { + rewrite_both_ = true; + } else { + rewrite_both_ = false; + } + } + + // This doesn't copy the FST. + RhoMatcher(const FST *fst, MatchType match_type, Label rho_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : RhoMatcher(*fst, match_type, rho_label, rewrite_mode, + matcher ? matcher : new M(fst, match_type)) { } + + // This makes a copy of the FST. + RhoMatcher(const RhoMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + rho_label_(matcher.rho_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_), + state_(kNoStateId), + has_rho_(false) {} + + RhoMatcher *Copy(bool safe = false) const override { + return new RhoMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { + if (state_ == s) return; + state_ = s; + matcher_->SetState(s); + has_rho_ = rho_label_ != kNoLabel; + } + + bool Find(Label label) final { + if (label == rho_label_ && rho_label_ != kNoLabel) { + FSTERROR() << "RhoMatcher::Find: bad label (rho)"; + error_ = true; + return false; + } + if (matcher_->Find(label)) { + rho_match_ = kNoLabel; + return true; + } else if (has_rho_ && label != 0 && label != kNoLabel && + (has_rho_ = matcher_->Find(rho_label_))) { + rho_match_ = label; + return true; + } else { + return false; + } + } + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { + if (rho_match_ == kNoLabel) { + return matcher_->Value(); + } else { + rho_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (rho_arc_.ilabel == rho_label_) rho_arc_.ilabel = rho_match_; + if (rho_arc_.olabel == rho_label_) rho_arc_.olabel = rho_match_; + } else if (match_type_ == MATCH_INPUT) { + rho_arc_.ilabel = rho_match_; + } else { + rho_arc_.olabel = rho_match_; + } + return rho_arc_; + } + } + + void Next() final { matcher_->Next(); } + + Weight Final(StateId s) const final { return matcher_->Final(s); } + + ssize_t Priority(StateId s) final { + state_ = s; + matcher_->SetState(s); + has_rho_ = matcher_->Find(rho_label_); + if (has_rho_) { + return kRequirePriority; + } else { + return matcher_->Priority(s); + } + } + + const FST &GetFst() const override { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const override; + + uint32 Flags() const override { + if (rho_label_ == kNoLabel || match_type_ == MATCH_NONE) { + return matcher_->Flags(); + } + return matcher_->Flags() | kRequireMatch; + } + + Label RhoLabel() const { return rho_label_; } + + private: + std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + Label rho_label_; // Label that represents the rho transition + bool rewrite_both_; // Rewrite both sides when both are rho_label_? + Label rho_match_; // Current label that matches rho transition. + mutable Arc rho_arc_; // Arc to return when rho match. + bool error_; // Error encountered? + StateId state_; // Matcher state. + bool has_rho_; // Are there possibly rhos at the current state? +}; + +template +inline uint64 RhoMatcher::Properties(uint64 inprops) const { + auto outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (match_type_ == MATCH_INPUT) { + if (rewrite_both_) { + return outprops & + ~(kODeterministic | kNonODeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kODeterministic | kAcceptor | kString | kILabelSorted | + kNotILabelSorted); + } + } else if (match_type_ == MATCH_OUTPUT) { + if (rewrite_both_) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kString | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted); + } else { + return outprops & + ~(kIDeterministic | kAcceptor | kString | kOLabelSorted | + kNotOLabelSorted); + } + } else { + // Shouldn't ever get here. + FSTERROR() << "RhoMatcher: Bad match type: " << match_type_; + return 0; + } +} + +// For any requested label, this matcher considers all transitions +// that match the label 'sigma_label' (sigma = "any"), and this in +// additions to transitions with the requested label. Each such sigma +// transition found is returned with the sigma_label rewritten as the +// requested label (both sides if an acceptor, or if 'rewrite_both' is +// true and both input and output labels of the found transition are +// 'sigma_label'). If 'sigma_label' is kNoLabel, this special +// matching is not done. SigmaMatcher is templated itself on a +// matcher, which is used to perform the underlying matching. By +// default, the underlying matcher is constructed by SigmaMatcher. +// The user can instead pass in this object; in that case, +// SigmaMatcher takes its ownership. No non-consuming symbols other +// than epsilon supported with the underlying template argument matcher. +template +class SigmaMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + SigmaMatcher(const FST &fst, MatchType match_type, + Label sigma_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + sigma_label_(sigma_label), + error_(false), + state_(kNoStateId) { + if (match_type == MATCH_BOTH) { + FSTERROR() << "SigmaMatcher: Bad match type"; + match_type_ = MATCH_NONE; + error_ = true; + } + if (sigma_label == 0) { + FSTERROR() << "SigmaMatcher: 0 cannot be used as sigma_label"; + sigma_label_ = kNoLabel; + error_ = true; + } + if (rewrite_mode == MATCHER_REWRITE_AUTO) { + rewrite_both_ = fst.Properties(kAcceptor, true); + } else if (rewrite_mode == MATCHER_REWRITE_ALWAYS) { + rewrite_both_ = true; + } else { + rewrite_both_ = false; + } + } + + // This doesn't copy the FST. + SigmaMatcher(const FST *fst, MatchType match_type, + Label sigma_label = kNoLabel, + MatcherRewriteMode rewrite_mode = MATCHER_REWRITE_AUTO, + M *matcher = nullptr) + : SigmaMatcher(*fst, match_type, sigma_label, rewrite_mode, + matcher ? matcher : new M(fst, match_type)) { } + + // This makes a copy of the FST. + SigmaMatcher(const SigmaMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + sigma_label_(matcher.sigma_label_), + rewrite_both_(matcher.rewrite_both_), + error_(matcher.error_), + state_(kNoStateId) {} + + SigmaMatcher *Copy(bool safe = false) const override { + return new SigmaMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { + if (state_ == s) return; + state_ = s; + matcher_->SetState(s); + has_sigma_ = + (sigma_label_ != kNoLabel) ? matcher_->Find(sigma_label_) : false; + } + + bool Find(Label match_label) final { + match_label_ = match_label; + if (match_label == sigma_label_ && sigma_label_ != kNoLabel) { + FSTERROR() << "SigmaMatcher::Find: bad label (sigma)"; + error_ = true; + return false; + } + if (matcher_->Find(match_label)) { + sigma_match_ = kNoLabel; + return true; + } else if (has_sigma_ && match_label != 0 && match_label != kNoLabel && + matcher_->Find(sigma_label_)) { + sigma_match_ = match_label; + return true; + } else { + return false; + } + } + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { + if (sigma_match_ == kNoLabel) { + return matcher_->Value(); + } else { + sigma_arc_ = matcher_->Value(); + if (rewrite_both_) { + if (sigma_arc_.ilabel == sigma_label_) sigma_arc_.ilabel = sigma_match_; + if (sigma_arc_.olabel == sigma_label_) sigma_arc_.olabel = sigma_match_; + } else if (match_type_ == MATCH_INPUT) { + sigma_arc_.ilabel = sigma_match_; + } else { + sigma_arc_.olabel = sigma_match_; + } + return sigma_arc_; + } + } + + void Next() final { + matcher_->Next(); + if (matcher_->Done() && has_sigma_ && (sigma_match_ == kNoLabel) && + (match_label_ > 0)) { + matcher_->Find(sigma_label_); + sigma_match_ = match_label_; + } + } + + Weight Final(StateId s) const final { return matcher_->Final(s); } + + ssize_t Priority(StateId s) final { + if (sigma_label_ != kNoLabel) { + SetState(s); + return has_sigma_ ? kRequirePriority : matcher_->Priority(s); + } else { + return matcher_->Priority(s); + } + } + + const FST &GetFst() const override { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const override; + + uint32 Flags() const override { + if (sigma_label_ == kNoLabel || match_type_ == MATCH_NONE) { + return matcher_->Flags(); + } + return matcher_->Flags() | kRequireMatch; + } + + Label SigmaLabel() const { return sigma_label_; } + + private: + std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + Label sigma_label_; // Label that represents the sigma transition. + bool rewrite_both_; // Rewrite both sides when both are sigma_label_? + bool has_sigma_; // Are there sigmas at the current state? + Label sigma_match_; // Current label that matches sigma transition. + mutable Arc sigma_arc_; // Arc to return when sigma match. + Label match_label_; // Label being matched. + bool error_; // Error encountered? + StateId state_; // Matcher state. +}; + +template +inline uint64 SigmaMatcher::Properties(uint64 inprops) const { + auto outprops = matcher_->Properties(inprops); + if (error_) outprops |= kError; + if (match_type_ == MATCH_NONE) { + return outprops; + } else if (rewrite_both_) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kILabelSorted | kNotILabelSorted | + kOLabelSorted | kNotOLabelSorted | kString); + } else if (match_type_ == MATCH_INPUT) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kILabelSorted | kNotILabelSorted | kString | + kAcceptor); + } else if (match_type_ == MATCH_OUTPUT) { + return outprops & + ~(kIDeterministic | kNonIDeterministic | kODeterministic | + kNonODeterministic | kOLabelSorted | kNotOLabelSorted | kString | + kAcceptor); + } else { + // Shouldn't ever get here. + FSTERROR() << "SigmaMatcher: Bad match type: " << match_type_; + return 0; + } +} + +// Flags for MultiEpsMatcher. + +// Return multi-epsilon arcs for Find(kNoLabel). +const uint32 kMultiEpsList = 0x00000001; + +// Return a kNolabel loop for Find(multi_eps). +const uint32 kMultiEpsLoop = 0x00000002; + +// MultiEpsMatcher: allows treating multiple non-0 labels as +// non-consuming labels in addition to 0 that is always +// non-consuming. Precise behavior controlled by 'flags' argument. By +// default, the underlying matcher is constructed by +// MultiEpsMatcher. The user can instead pass in this object; in that +// case, MultiEpsMatcher takes its ownership iff 'own_matcher' is +// true. +template +class MultiEpsMatcher { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST (w/o 'matcher' arg). + MultiEpsMatcher(const FST &fst, MatchType match_type, + uint32 flags = (kMultiEpsLoop | kMultiEpsList), + M *matcher = nullptr, bool own_matcher = true) + : matcher_(matcher ? matcher : new M(fst, match_type)), + flags_(flags), + own_matcher_(matcher ? own_matcher : true) { + Init(match_type); + } + + // This doesn't copy the FST. + MultiEpsMatcher(const FST *fst, MatchType match_type, + uint32 flags = (kMultiEpsLoop | kMultiEpsList), + M *matcher = nullptr, bool own_matcher = true) + : matcher_(matcher ? matcher : new M(fst, match_type)), + flags_(flags), + own_matcher_(matcher ? own_matcher : true) { + Init(match_type); + } + + // This makes a copy of the FST. + MultiEpsMatcher(const MultiEpsMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + flags_(matcher.flags_), + own_matcher_(true), + multi_eps_labels_(matcher.multi_eps_labels_), + loop_(matcher.loop_) { + loop_.nextstate = kNoStateId; + } + + ~MultiEpsMatcher() { + if (own_matcher_) delete matcher_; + } + + MultiEpsMatcher *Copy(bool safe = false) const { + return new MultiEpsMatcher(*this, safe); + } + + MatchType Type(bool test) const { return matcher_->Type(test); } + + void SetState(StateId state) { + matcher_->SetState(state); + loop_.nextstate = state; + } + + bool Find(Label label); + + bool Done() const { return done_; } + + const Arc &Value() const { return current_loop_ ? loop_ : matcher_->Value(); } + + void Next() { + if (!current_loop_) { + matcher_->Next(); + done_ = matcher_->Done(); + if (done_ && multi_eps_iter_ != multi_eps_labels_.End()) { + ++multi_eps_iter_; + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) { + ++multi_eps_iter_; + } + if (multi_eps_iter_ != multi_eps_labels_.End()) { + done_ = false; + } else { + done_ = !matcher_->Find(kNoLabel); + } + } + } else { + done_ = true; + } + } + + const FST &GetFst() const { return matcher_->GetFst(); } + + uint64 Properties(uint64 props) const { return matcher_->Properties(props); } + + const M *GetMatcher() const { return matcher_; } + + Weight Final(StateId s) const { return matcher_->Final(s); } + + uint32 Flags() const { return matcher_->Flags(); } + + ssize_t Priority(StateId s) { return matcher_->Priority(s); } + + void AddMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Insert(label); + } + } + + void RemoveMultiEpsLabel(Label label) { + if (label == 0) { + FSTERROR() << "MultiEpsMatcher: Bad multi-eps label: 0"; + } else { + multi_eps_labels_.Erase(label); + } + } + + void ClearMultiEpsLabels() { multi_eps_labels_.Clear(); } + + private: + void Init(MatchType match_type) { + if (match_type == MATCH_INPUT) { + loop_.ilabel = kNoLabel; + loop_.olabel = 0; + } else { + loop_.ilabel = 0; + loop_.olabel = kNoLabel; + } + loop_.weight = Weight::One(); + loop_.nextstate = kNoStateId; + } + + M *matcher_; + uint32 flags_; + bool own_matcher_; // Does this class delete the matcher? + + // Multi-eps label set. + CompactSet multi_eps_labels_; + typename CompactSet::const_iterator multi_eps_iter_; + + bool current_loop_; // Current arc is the implicit loop? + mutable Arc loop_; // For non-consuming symbols. + bool done_; // Matching done? + + MultiEpsMatcher &operator=(const MultiEpsMatcher &) = delete; +}; + +template +inline bool MultiEpsMatcher::Find(Label label) { + multi_eps_iter_ = multi_eps_labels_.End(); + current_loop_ = false; + bool ret; + if (label == 0) { + ret = matcher_->Find(0); + } else if (label == kNoLabel) { + if (flags_ & kMultiEpsList) { + // Returns all non-consuming arcs (including epsilon). + multi_eps_iter_ = multi_eps_labels_.Begin(); + while ((multi_eps_iter_ != multi_eps_labels_.End()) && + !matcher_->Find(*multi_eps_iter_)) { + ++multi_eps_iter_; + } + if (multi_eps_iter_ != multi_eps_labels_.End()) { + ret = true; + } else { + ret = matcher_->Find(kNoLabel); + } + } else { + // Returns all epsilon arcs. + ret = matcher_->Find(kNoLabel); + } + } else if ((flags_ & kMultiEpsLoop) && + multi_eps_labels_.Find(label) != multi_eps_labels_.End()) { + // Returns implicit loop. + current_loop_ = true; + ret = true; + } else { + ret = matcher_->Find(label); + } + done_ = !ret; + return ret; +} + +// This class discards any implicit matches (e.g., the implicit epsilon +// self-loops in the SortedMatcher). Matchers are most often used in +// composition/intersection where the implicit matches are needed +// e.g. for epsilon processing. However, if a matcher is simply being +// used to look-up explicit label matches, this class saves the user +// from having to check for and discard the unwanted implicit matches +// themselves. +template +class ExplicitMatcher : public MatcherBase { + public: + using FST = typename M::FST; + using Arc = typename FST::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST. + ExplicitMatcher(const FST &fst, MatchType match_type, M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + error_(false) {} + + // This doesn't copy the FST. + ExplicitMatcher(const FST *fst, MatchType match_type, M *matcher = nullptr) + : matcher_(matcher ? matcher : new M(fst, match_type)), + match_type_(match_type), + error_(false) {} + + // This makes a copy of the FST. + ExplicitMatcher(const ExplicitMatcher &matcher, bool safe = false) + : matcher_(new M(*matcher.matcher_, safe)), + match_type_(matcher.match_type_), + error_(matcher.error_) {} + + ExplicitMatcher *Copy(bool safe = false) const override { + return new ExplicitMatcher(*this, safe); + } + + MatchType Type(bool test) const override { return matcher_->Type(test); } + + void SetState(StateId s) final { matcher_->SetState(s); } + + bool Find(Label label) final { + matcher_->Find(label); + CheckArc(); + return !Done(); + } + + bool Done() const final { return matcher_->Done(); } + + const Arc &Value() const final { return matcher_->Value(); } + + void Next() final { + matcher_->Next(); + CheckArc(); + } + + Weight Final(StateId s) const final { return matcher_->Final(s); } + + ssize_t Priority(StateId s) final { return matcher_->Priority(s); } + + const FST &GetFst() const final { return matcher_->GetFst(); } + + uint64 Properties(uint64 inprops) const override { + return matcher_->Properties(inprops); + } + + const M *GetMatcher() const { return matcher_.get(); } + + uint32 Flags() const override { return matcher_->Flags(); } + + private: + // Checks current arc if available and explicit. If not available, stops. If + // not explicit, checks next ones. + void CheckArc() { + for (; !matcher_->Done(); matcher_->Next()) { + const auto label = match_type_ == MATCH_INPUT ? matcher_->Value().ilabel + : matcher_->Value().olabel; + if (label != kNoLabel) return; + } + } + + std::unique_ptr matcher_; + MatchType match_type_; // Type of match requested. + bool error_; // Error encountered? +}; + +// Generic matcher, templated on the FST definition. +// +// Here is a typical use: +// +// Matcher matcher(fst, MATCH_INPUT); +// matcher.SetState(state); +// if (matcher.Find(label)) +// for (; !matcher.Done(); matcher.Next()) { +// auto &arc = matcher.Value(); +// ... +// } +template +class Matcher { + public: + using FST = F; + using Arc = typename F::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // This makes a copy of the FST. + Matcher(const FST &fst, MatchType match_type) + : owned_fst_(fst.Copy()), + base_(owned_fst_->InitMatcher(match_type)) { + if (!base_) base_.reset(new SortedMatcher(owned_fst_.get(), + match_type)); + } + + // This doesn't copy the FST. + Matcher(const FST *fst, MatchType match_type) + : base_(fst->InitMatcher(match_type)) { + if (!base_) base_.reset(new SortedMatcher(fst, match_type)); + } + + // This makes a copy of the FST. + Matcher(const Matcher &matcher, bool safe = false) + : base_(matcher.base_->Copy(safe)) { } + + // Takes ownership of the provided matcher. + explicit Matcher(MatcherBase *base_matcher) + : base_(base_matcher) { } + + Matcher *Copy(bool safe = false) const { + return new Matcher(*this, safe); + } + + MatchType Type(bool test) const { return base_->Type(test); } + + void SetState(StateId s) { base_->SetState(s); } + + bool Find(Label label) { return base_->Find(label); } + + bool Done() const { return base_->Done(); } + + const Arc &Value() const { return base_->Value(); } + + void Next() { base_->Next(); } + + const FST &GetFst() const { + return static_cast(base_->GetFst()); + } + + uint64 Properties(uint64 props) const { return base_->Properties(props); } + + Weight Final(StateId s) const { return base_->Final(s); } + + uint32 Flags() const { return base_->Flags() & kMatcherFlags; } + + ssize_t Priority(StateId s) { return base_->Priority(s); } + + private: + std::unique_ptr owned_fst_; + std::unique_ptr> base_; +}; + +} // namespace fst + +#endif // FST_MATCHER_H_ diff --git a/projects/llm_framework/include/fst/memory.h b/projects/llm_framework/include/fst/memory.h new file mode 100644 index 00000000..c1f0bddc --- /dev/null +++ b/projects/llm_framework/include/fst/memory.h @@ -0,0 +1,443 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST memory utilities. + +#ifndef FST_MEMORY_H_ +#define FST_MEMORY_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace fst { + +// Default block allocation size. +constexpr int kAllocSize = 64; + +// Minimum number of allocations per block. +constexpr int kAllocFit = 4; + +// Base class for MemoryArena that allows (e.g.) MemoryArenaCollection to +// easily manipulate collections of variously sized arenas. +class MemoryArenaBase { + public: + virtual ~MemoryArenaBase() {} + virtual size_t Size() const = 0; +}; + +namespace internal { + +// Allocates 'size' unintialized memory chunks of size object_size from +// underlying blocks of (at least) size 'block_size * object_size'. +// All blocks are freed when this class is deleted. Result of allocate() will +// be aligned to object_size. +template +class MemoryArenaImpl : public MemoryArenaBase { + public: + enum { kObjectSize = object_size }; + + explicit MemoryArenaImpl(size_t block_size = kAllocSize) + : block_size_(block_size * kObjectSize), block_pos_(0) { + blocks_.emplace_front(new char[block_size_]); + } + + void *Allocate(size_t size) { + const auto byte_size = size * kObjectSize; + if (byte_size * kAllocFit > block_size_) { + // Large block; adds new large block. + auto *ptr = new char[byte_size]; + blocks_.emplace_back(ptr); + return ptr; + } + if (block_pos_ + byte_size > block_size_) { + // Doesn't fit; adds new standard block. + auto *ptr = new char[block_size_]; + block_pos_ = 0; + blocks_.emplace_front(ptr); + } + // Fits; uses current block. + auto *ptr = blocks_.front().get() + block_pos_; + block_pos_ += byte_size; + return ptr; + } + + size_t Size() const override { return kObjectSize; } + + private: + const size_t block_size_; // Default block size in bytes. + size_t block_pos_; // Current position in block in bytes. + std::list> blocks_; // List of allocated blocks. +}; + +} // namespace internal + +template +using MemoryArena = internal::MemoryArenaImpl; + +// Base class for MemoryPool that allows (e.g.) MemoryPoolCollection to easily +// manipulate collections of variously sized pools. +class MemoryPoolBase { + public: + virtual ~MemoryPoolBase() {} + virtual size_t Size() const = 0; +}; + +namespace internal { + +// Allocates and frees initially uninitialized memory chunks of size +// object_size. Keeps an internal list of freed chunks that are reused (as is) +// on the next allocation if available. Chunks are constructed in blocks of size +// 'pool_size'. +template +class MemoryPoolImpl : public MemoryPoolBase { + public: + enum { kObjectSize = object_size }; + + struct Link { + char buf[kObjectSize]; + Link *next; + }; + + explicit MemoryPoolImpl(size_t pool_size) + : mem_arena_(pool_size), free_list_(nullptr) {} + + void *Allocate() { + if (free_list_ == nullptr) { + auto *link = static_cast(mem_arena_.Allocate(1)); + link->next = nullptr; + return link; + } else { + auto *link = free_list_; + free_list_ = link->next; + return link; + } + } + + void Free(void *ptr) { + if (ptr) { + auto *link = static_cast(ptr); + link->next = free_list_; + free_list_ = link; + } + } + + size_t Size() const override { return kObjectSize; } + + private: + MemoryArena mem_arena_; + Link *free_list_; + + MemoryPoolImpl(const MemoryPoolImpl &) = delete; + MemoryPoolImpl &operator=(const MemoryPoolImpl &) = delete; +}; + +} // namespace internal + +// Allocates and frees initially uninitialized memory chunks of size sizeof(T). +// All memory is freed when the class is deleted. The result of Allocate() will +// be suitably memory-aligned. Combined with placement operator new and destroy +// functions for the T class, this can be used to improve allocation efficiency. +// See nlp/fst/lib/visit.h (global new) and nlp/fst/lib/dfs-visit.h (class new) +// for examples. +template +class MemoryPool : public internal::MemoryPoolImpl { + public: + // 'pool_size' specifies the size of the initial pool and how it is extended. + MemoryPool(size_t pool_size = kAllocSize) + : internal::MemoryPoolImpl(pool_size) {} +}; + +// Stores a collection of memory arenas. +class MemoryArenaCollection { + public: + // 'block_size' specifies the block size of the arenas. + explicit MemoryArenaCollection(size_t block_size = kAllocSize) + : block_size_(block_size), ref_count_(1) {} + + template + MemoryArena *Arena() { + if (sizeof(T) >= arenas_.size()) arenas_.resize(sizeof(T) + 1); + MemoryArenaBase *arena = arenas_[sizeof(T)].get(); + if (arena == nullptr) { + arena = new MemoryArena(block_size_); + arenas_[sizeof(T)].reset(arena); + } + return static_cast *>(arena); + } + + size_t BlockSize() const { return block_size_; } + + size_t RefCount() const { return ref_count_; } + + size_t IncrRefCount() { return ++ref_count_; } + + size_t DecrRefCount() { return --ref_count_; } + + private: + size_t block_size_; + size_t ref_count_; + std::vector> arenas_; +}; + +// Stores a collection of memory pools +class MemoryPoolCollection { + public: + // 'pool_size' specifies the size of initial pool and how it is extended. + explicit MemoryPoolCollection(size_t pool_size = kAllocSize) + : pool_size_(pool_size), ref_count_(1) {} + + template + MemoryPool *Pool() { + if (sizeof(T) >= pools_.size()) pools_.resize(sizeof(T) + 1); + MemoryPoolBase *pool = pools_[sizeof(T)].get(); + if (pool == nullptr) { + pool = new MemoryPool(pool_size_); + pools_[sizeof(T)].reset(pool); + } + return static_cast *>(pool); + } + + size_t PoolSize() const { return pool_size_; } + + size_t RefCount() const { return ref_count_; } + + size_t IncrRefCount() { return ++ref_count_; } + + size_t DecrRefCount() { return --ref_count_; } + + private: + size_t pool_size_; + size_t ref_count_; + std::vector> pools_; +}; + +// STL allocator using memory arenas. Memory is allocated from underlying +// blocks of size 'block_size * sizeof(T)'. Memory is freed only when all +// objects using this allocator are destroyed and there is otherwise no reuse +// (unlike PoolAllocator). +// +// This allocator has object-local state so it should not be used with splicing +// or swapping operations between objects created with different allocators nor +// should it be used if copies must be thread-safe. The result of allocate() +// will be suitably memory-aligned. +template +class BlockAllocator { + public: + using Allocator = std::allocator; + using size_type = typename Allocator::size_type; + using difference_type = typename Allocator::difference_type; + using pointer = typename Allocator::pointer; + using const_pointer = typename Allocator::const_pointer; + using reference = typename Allocator::reference; + using const_reference = typename Allocator::const_reference; + using value_type = typename Allocator::value_type; + + template + struct rebind { + using other = BlockAllocator; + }; + + explicit BlockAllocator(size_t block_size = kAllocSize) + : arenas_(new MemoryArenaCollection(block_size)) {} + + BlockAllocator(const BlockAllocator &arena_alloc) + : arenas_(arena_alloc.Arenas()) { + Arenas()->IncrRefCount(); + } + + template + explicit BlockAllocator(const BlockAllocator &arena_alloc) + : arenas_(arena_alloc.Arenas()) { + Arenas()->IncrRefCount(); + } + + ~BlockAllocator() { + if (Arenas()->DecrRefCount() == 0) delete Arenas(); + } + + pointer address(reference ref) const { return Allocator().address(ref); } + + const_pointer address(const_reference ref) const { + return Allocator().address(ref); + } + + size_type max_size() const { return Allocator().max_size(); } + + template + void construct(U *p, Args &&... args) { + Allocator().construct(p, std::forward(args)...); + } + + void destroy(pointer p) { Allocator().destroy(p); } + + pointer allocate(size_type n, const void *hint = nullptr) { + if (n * kAllocFit <= kAllocSize) { + return static_cast(Arena()->Allocate(n)); + } else { + return Allocator().allocate(n, hint); + } + } + + void deallocate(pointer p, size_type n) { + if (n * kAllocFit > kAllocSize) Allocator().deallocate(p, n); + } + + MemoryArenaCollection *Arenas() const { return arenas_; } + + private: + MemoryArena *Arena() { return arenas_->Arena(); } + + MemoryArenaCollection *arenas_; + + BlockAllocator operator=(const BlockAllocator &); +}; + +template +bool operator==(const BlockAllocator &alloc1, + const BlockAllocator &alloc2) { + return false; +} + +template +bool operator!=(const BlockAllocator &alloc1, + const BlockAllocator &alloc2) { + return true; +} + +// STL allocator using memory pools. Memory is allocated from underlying +// blocks of size 'block_size * sizeof(T)'. Keeps an internal list of freed +// chunks thare are reused on the next allocation. +// +// This allocator has object-local state so it should not be used with splicing +// or swapping operations between objects created with different allocators nor +// should it be used if copies must be thread-safe. The result of allocate() +// will be suitably memory-aligned. +template +class PoolAllocator { + public: + using Allocator = std::allocator; + using size_type = typename Allocator::size_type; + using difference_type = typename Allocator::difference_type; + using pointer = typename Allocator::pointer; + using const_pointer = typename Allocator::const_pointer; + using reference = typename Allocator::reference; + using const_reference = typename Allocator::const_reference; + using value_type = typename Allocator::value_type; + + template + struct rebind { + using other = PoolAllocator; + }; + + explicit PoolAllocator(size_t pool_size = kAllocSize) + : pools_(new MemoryPoolCollection(pool_size)) {} + + PoolAllocator(const PoolAllocator &pool_alloc) + : pools_(pool_alloc.Pools()) { + Pools()->IncrRefCount(); + } + + template + explicit PoolAllocator(const PoolAllocator &pool_alloc) + : pools_(pool_alloc.Pools()) { + Pools()->IncrRefCount(); + } + + ~PoolAllocator() { + if (Pools()->DecrRefCount() == 0) delete Pools(); + } + + pointer address(reference ref) const { return Allocator().address(ref); } + + const_pointer address(const_reference ref) const { + return Allocator().address(ref); + } + + size_type max_size() const { return Allocator().max_size(); } + + template + void construct(U *p, Args &&... args) { + Allocator().construct(p, std::forward(args)...); + } + + void destroy(pointer p) { Allocator().destroy(p); } + + pointer allocate(size_type n, const void *hint = nullptr) { + if (n == 1) { + return static_cast(Pool<1>()->Allocate()); + } else if (n == 2) { + return static_cast(Pool<2>()->Allocate()); + } else if (n <= 4) { + return static_cast(Pool<4>()->Allocate()); + } else if (n <= 8) { + return static_cast(Pool<8>()->Allocate()); + } else if (n <= 16) { + return static_cast(Pool<16>()->Allocate()); + } else if (n <= 32) { + return static_cast(Pool<32>()->Allocate()); + } else if (n <= 64) { + return static_cast(Pool<64>()->Allocate()); + } else { + return Allocator().allocate(n, hint); + } + } + + void deallocate(pointer p, size_type n) { + if (n == 1) { + Pool<1>()->Free(p); + } else if (n == 2) { + Pool<2>()->Free(p); + } else if (n <= 4) { + Pool<4>()->Free(p); + } else if (n <= 8) { + Pool<8>()->Free(p); + } else if (n <= 16) { + Pool<16>()->Free(p); + } else if (n <= 32) { + Pool<32>()->Free(p); + } else if (n <= 64) { + Pool<64>()->Free(p); + } else { + Allocator().deallocate(p, n); + } + } + + MemoryPoolCollection *Pools() const { return pools_; } + + private: + template + struct TN { + T buf[n]; + }; + + template + MemoryPool> *Pool() { + return pools_->Pool>(); + } + + MemoryPoolCollection *pools_; + + PoolAllocator operator=(const PoolAllocator &); +}; + +template +bool operator==(const PoolAllocator &alloc1, + const PoolAllocator &alloc2) { + return false; +} + +template +bool operator!=(const PoolAllocator &alloc1, + const PoolAllocator &alloc2) { + return true; +} + +} // namespace fst + +#endif // FST_MEMORY_H_ diff --git a/projects/llm_framework/include/fst/minimize.h b/projects/llm_framework/include/fst/minimize.h new file mode 100644 index 00000000..9f17a22a --- /dev/null +++ b/projects/llm_framework/include/fst/minimize.h @@ -0,0 +1,568 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to minimize an FST. + +#ifndef FST_MINIMIZE_H_ +#define FST_MINIMIZE_H_ + +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { +namespace internal { + +// Comparator for creating partition. +template +class StateComparator { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + StateComparator(const Fst &fst, const Partition &partition) + : fst_(fst), partition_(partition) {} + + // Compares state x with state y based on sort criteria. + bool operator()(const StateId x, const StateId y) const { + // Checks for final state equivalence. + const auto xfinal = fst_.Final(x).Hash(); + const auto yfinal = fst_.Final(y).Hash(); + if (xfinal < yfinal) { + return true; + } else if (xfinal > yfinal) { + return false; + } + // Checks for number of arcs. + if (fst_.NumArcs(x) < fst_.NumArcs(y)) return true; + if (fst_.NumArcs(x) > fst_.NumArcs(y)) return false; + // If the number of arcs are equal, checks for arc match. + for (ArcIterator> aiter1(fst_, x), aiter2(fst_, y); + !aiter1.Done() && !aiter2.Done(); aiter1.Next(), aiter2.Next()) { + const auto &arc1 = aiter1.Value(); + const auto &arc2 = aiter2.Value(); + if (arc1.ilabel < arc2.ilabel) return true; + if (arc1.ilabel > arc2.ilabel) return false; + if (partition_.ClassId(arc1.nextstate) < + partition_.ClassId(arc2.nextstate)) + return true; + if (partition_.ClassId(arc1.nextstate) > + partition_.ClassId(arc2.nextstate)) + return false; + } + return false; + } + + private: + const Fst &fst_; + const Partition &partition_; +}; + +// Computes equivalence classes for cyclic unweighted acceptors. For cyclic +// minimization we use the classic Hopcroft minimization algorithm, which has +// complexity O(E log V) where E is the number of arcs and V is the number of +// states. +// +// For more information, see: +// +// Hopcroft, J. 1971. An n Log n algorithm for minimizing states in a finite +// automaton. Ms, Stanford University. +// +// Note: the original presentation of the paper was for a finite automaton (== +// deterministic, unweighted acceptor), but we also apply it to the +// nondeterministic case, where it is also applicable as long as the semiring is +// idempotent (if the semiring is not idempotent, there are some complexities +// in keeping track of the weight when there are multiple arcs to states that +// will be merged, and we don't deal with this). +template +class CyclicMinimizer { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using ClassId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using RevArc = ReverseArc; + + explicit CyclicMinimizer(const ExpandedFst &fst) { + Initialize(fst); + Compute(fst); + } + + const Partition &GetPartition() const { return P_; } + + private: + // StateILabelHasher is a hashing object that computes a hash-function + // of an FST state that depends only on the set of ilabels on arcs leaving + // the state [note: it assumes that the arcs are ilabel-sorted]. + // In order to work correctly for non-deterministic automata, multiple + // instances of the same ilabel count the same as a single instance. + class StateILabelHasher { + public: + explicit StateILabelHasher(const Fst &fst) : fst_(fst) {} + + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + + size_t operator()(const StateId s) { + const size_t p1 = 7603; + const size_t p2 = 433024223; + size_t result = p2; + size_t current_ilabel = kNoLabel; + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + Label this_ilabel = aiter.Value().ilabel; + if (this_ilabel != current_ilabel) { // Ignores repeats. + result = p1 * result + this_ilabel; + current_ilabel = this_ilabel; + } + } + return result; + } + + private: + const Fst &fst_; + }; + + class ArcIterCompare { + public: + explicit ArcIterCompare(const Partition &partition) + : partition_(partition) {} + + ArcIterCompare(const ArcIterCompare &comp) : partition_(comp.partition_) {} + + // Compares two iterators based on their input labels. + bool operator()(const ArcIterator> *x, + const ArcIterator> *y) const { + const auto &xarc = x->Value(); + const auto &yarc = y->Value(); + return xarc.ilabel > yarc.ilabel; + } + + private: + const Partition &partition_; + }; + + using ArcIterQueue = + std::priority_queue> *, + std::vector> *>, + ArcIterCompare>; + + private: + // Prepartitions the space into equivalence classes. We ensure that final and + // non-final states always go into different equivalence classes, and we use + // class StateILabelHasher to make sure that most of the time, states with + // different sets of ilabels on arcs leaving them, go to different partitions. + // Note: for the O(n) guarantees we don't rely on the goodness of this + // hashing function---it just provides a bonus speedup. + void PrePartition(const ExpandedFst &fst) { + VLOG(5) << "PrePartition"; + StateId next_class = 0; + auto num_states = fst.NumStates(); + // Allocates a temporary vector to store the initial class mappings, so that + // we can allocate the classes all at once. + std::vector state_to_initial_class(num_states); + { + // We maintain two maps from hash-value to class---one for final states + // (final-prob == One()) and one for non-final states + // (final-prob == Zero()). We are processing unweighted acceptors, so the + // are the only two possible values. + using HashToClassMap = std::unordered_map; + HashToClassMap hash_to_class_nonfinal; + HashToClassMap hash_to_class_final; + StateILabelHasher hasher(fst); + for (StateId s = 0; s < num_states; ++s) { + size_t hash = hasher(s); + HashToClassMap &this_map = + (fst.Final(s) != Weight::Zero() ? hash_to_class_final + : hash_to_class_nonfinal); + // Avoids two map lookups by using 'insert' instead of 'find'. + auto p = this_map.insert(std::make_pair(hash, next_class)); + state_to_initial_class[s] = p.second ? next_class++ : p.first->second; + } + // Lets the unordered_maps go out of scope before we allocate the classes, + // to reduce the maximum amount of memory used. + } + P_.AllocateClasses(next_class); + for (StateId s = 0; s < num_states; ++s) { + P_.Add(s, state_to_initial_class[s]); + } + for (StateId c = 0; c < next_class; ++c) L_.Enqueue(c); + VLOG(5) << "Initial Partition: " << P_.NumClasses(); + } + + // Creates inverse transition Tr_ = rev(fst), loops over states in FST and + // splits on final, creating two blocks in the partition corresponding to + // final, non-final. + void Initialize(const ExpandedFst &fst) { + // Constructs Tr. + Reverse(fst, &Tr_); + ILabelCompare ilabel_comp; + ArcSort(&Tr_, ilabel_comp); + // Tells the partition how many elements to allocate. The first state in + // Tr_ is super-final state. + P_.Initialize(Tr_.NumStates() - 1); + // Prepares initial partition. + PrePartition(fst); + // Allocates arc iterator queue. + ArcIterCompare comp(P_); + aiter_queue_.reset(new ArcIterQueue(comp)); + } + // Partitions all classes with destination C. + void Split(ClassId C) { + // Prepares priority queue: opens arc iterator for each state in C, and + // inserts into priority queue. + for (PartitionIterator siter(P_, C); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + if (Tr_.NumArcs(s + 1)) { + aiter_queue_->push(new ArcIterator>(Tr_, s + 1)); + } + } + // Now pops arc iterator from queue, splits entering equivalence class, and + // re-inserts updated iterator into queue. + Label prev_label = -1; + while (!aiter_queue_->empty()) { + std::unique_ptr>> aiter(aiter_queue_->top()); + aiter_queue_->pop(); + if (aiter->Done()) continue; + const auto &arc = aiter->Value(); + auto from_state = aiter->Value().nextstate - 1; + auto from_label = arc.ilabel; + if (prev_label != from_label) P_.FinalizeSplit(&L_); + auto from_class = P_.ClassId(from_state); + if (P_.ClassSize(from_class) > 1) P_.SplitOn(from_state); + prev_label = from_label; + aiter->Next(); + if (!aiter->Done()) aiter_queue_->push(aiter.release()); + } + P_.FinalizeSplit(&L_); + } + + // Main loop for Hopcroft minimization. + void Compute(const Fst &fst) { + // Processes active classes (FIFO, or FILO). + while (!L_.Empty()) { + const auto C = L_.Head(); + L_.Dequeue(); + Split(C); // Splits on C, all labels in C. + } + } + + private: + // Partioning of states into equivalence classes. + Partition P_; + // Set of active classes to be processed in partition P. + Queue L_; + // Reverses transition function. + VectorFst Tr_; + // Priority queue of open arc iterators for all states in the splitter + // equivalence class. + std::unique_ptr aiter_queue_; +}; + +// Computes equivalence classes for acyclic FST. +// +// Complexity: +// +// O(E) +// +// where E is the number of arcs. +// +// For more information, see: +// +// Revuz, D. 1992. Minimization of acyclic deterministic automata in linear +// time. Theoretical Computer Science 92(1): 181-189. +template +class AcyclicMinimizer { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using ClassId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit AcyclicMinimizer(const ExpandedFst &fst) { + Initialize(fst); + Refine(fst); + } + + const Partition &GetPartition() { return partition_; } + + private: + // DFS visitor to compute the height (distance) to final state. + class HeightVisitor { + public: + HeightVisitor() : max_height_(0), num_states_(0) {} + + // Invoked before DFS visit. + void InitVisit(const Fst &fst) {} + + // Invoked when state is discovered (2nd arg is DFS tree root). + bool InitState(StateId s, StateId root) { + // Extends height array and initialize height (distance) to 0. + for (StateId i = height_.size(); i <= s; ++i) height_.push_back(-1); + if (s >= num_states_) num_states_ = s + 1; + return true; + } + + // Invoked when tree arc examined (to undiscovered state). + bool TreeArc(StateId s, const Arc &arc) { return true; } + + // Invoked when back arc examined (to unfinished state). + bool BackArc(StateId s, const Arc &arc) { return true; } + + // Invoked when forward or cross arc examined (to finished state). + bool ForwardOrCrossArc(StateId s, const Arc &arc) { + if (height_[arc.nextstate] + 1 > height_[s]) { + height_[s] = height_[arc.nextstate] + 1; + } + return true; + } + + // Invoked when state finished (parent is kNoStateId for tree root). + void FinishState(StateId s, StateId parent, const Arc *parent_arc) { + if (height_[s] == -1) height_[s] = 0; + const auto h = height_[s] + 1; + if (parent >= 0) { + if (h > height_[parent]) height_[parent] = h; + if (h > max_height_) max_height_ = h; + } + } + + // Invoked after DFS visit. + void FinishVisit() {} + + size_t max_height() const { return max_height_; } + + const std::vector &height() const { return height_; } + + size_t num_states() const { return num_states_; } + + private: + std::vector height_; + size_t max_height_; + size_t num_states_; + }; + + private: + // Cluster states according to height (distance to final state) + void Initialize(const Fst &fst) { + // Computes height (distance to final state). + HeightVisitor hvisitor; + DfsVisit(fst, &hvisitor); + // Creates initial partition based on height. + partition_.Initialize(hvisitor.num_states()); + partition_.AllocateClasses(hvisitor.max_height() + 1); + const auto &hstates = hvisitor.height(); + for (StateId s = 0; s < hstates.size(); ++s) partition_.Add(s, hstates[s]); + } + + // Refines states based on arc sort (out degree, arc equivalence). + void Refine(const Fst &fst) { + using EquivalenceMap = std::map>; + StateComparator comp(fst, partition_); + // Starts with tail (height = 0). + auto height = partition_.NumClasses(); + for (StateId h = 0; h < height; ++h) { + EquivalenceMap equiv_classes(comp); + // Sorts states within equivalence class. + PartitionIterator siter(partition_, h); + equiv_classes[siter.Value()] = h; + for (siter.Next(); !siter.Done(); siter.Next()) { + auto insert_result = + equiv_classes.insert(std::make_pair(siter.Value(), kNoStateId)); + if (insert_result.second) { + insert_result.first->second = partition_.AddClass(); + } + } + // Creates refined partition. + for (siter.Reset(); !siter.Done();) { + const auto s = siter.Value(); + const auto old_class = partition_.ClassId(s); + const auto new_class = equiv_classes[s]; + // A move operation can invalidate the iterator, so we first update + // the iterator to the next element before we move the current element + // out of the list. + siter.Next(); + if (old_class != new_class) partition_.Move(s, new_class); + } + } + } + + private: + Partition partition_; +}; + +// Given a partition and a Mutable FST, merges states of Fst in place (i.e., +// destructively). Merging works by taking the first state in a class of the +// partition to be the representative state for the class. Each arc is then +// reconnected to this state. All states in the class are merged by adding +// their arcs to the representative state. +template +void MergeStates(const Partition &partition, + MutableFst *fst) { + using StateId = typename Arc::StateId; + std::vector state_map(partition.NumClasses()); + for (StateId i = 0; i < partition.NumClasses(); ++i) { + PartitionIterator siter(partition, i); + state_map[i] = siter.Value(); // First state in partition. + } + // Relabels destination states. + for (StateId c = 0; c < partition.NumClasses(); ++c) { + for (PartitionIterator siter(partition, c); !siter.Done(); + siter.Next()) { + const auto s = siter.Value(); + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + arc.nextstate = state_map[partition.ClassId(arc.nextstate)]; + if (s == state_map[c]) { // For the first state, just sets destination. + aiter.SetValue(arc); + } else { + fst->AddArc(state_map[c], std::move(arc)); + } + } + } + } + fst->SetStart(state_map[partition.ClassId(fst->Start())]); + Connect(fst); +} + +template +void AcceptorMinimize(MutableFst *fst, + bool allow_acyclic_minimization = true) { + if (!(fst->Properties(kAcceptor | kUnweighted, true) == + (kAcceptor | kUnweighted))) { + FSTERROR() << "FST is not an unweighted acceptor"; + fst->SetProperties(kError, kError); + return; + } + // Connects FST before minimization, handles disconnected states. + Connect(fst); + if (fst->NumStates() == 0) return; + if (allow_acyclic_minimization && fst->Properties(kAcyclic, true)) { + // Acyclic minimization (Revuz). + VLOG(2) << "Acyclic minimization"; + ArcSort(fst, ILabelCompare()); + AcyclicMinimizer minimizer(*fst); + MergeStates(minimizer.GetPartition(), fst); + } else { + // Either the FST has cycles, or it's generated from non-deterministic input + // (which the Revuz algorithm can't handle), so use the cyclic minimization + // algorithm of Hopcroft. + VLOG(2) << "Cyclic minimization"; + CyclicMinimizer> minimizer(*fst); + MergeStates(minimizer.GetPartition(), fst); + } + // Merges in appropriate semiring + ArcUniqueMapper mapper(*fst); + StateMap(fst, mapper); +} + +} // namespace internal + +// In place minimization of deterministic weighted automata and transducers, +// and also non-deterministic ones if they use an idempotent semiring. +// For transducers, if the 'sfst' argument is not null, the algorithm +// produces a compact factorization of the minimal transducer. +// +// In the acyclic deterministic case, we use an algorithm from Revuz that is +// linear in the number of arcs (edges) in the machine. +// +// In the cyclic or non-deterministic case, we use the classical Hopcroft +// minimization (which was presented for the deterministic case but which +// also works for non-deterministic FSTs); this has complexity O(e log v). +// +template +void Minimize(MutableFst *fst, MutableFst *sfst = nullptr, + float delta = kShortestDelta, bool allow_nondet = false) { + using Weight = typename Arc::Weight; + const auto props = fst->Properties( + kAcceptor | kIDeterministic | kWeighted | kUnweighted, true); + bool allow_acyclic_minimization; + if (props & kIDeterministic) { + allow_acyclic_minimization = true; + } else { + // Our approach to minimization of non-deterministic FSTs will only work in + // idempotent semirings---for non-deterministic inputs, a state could have + // multiple transitions to states that will get merged, and we'd have to + // sum their weights. The algorithm doesn't handle that. + if (!(Weight::Properties() & kIdempotent)) { + fst->SetProperties(kError, kError); + FSTERROR() << "Cannot minimize a non-deterministic FST over a " + "non-idempotent semiring"; + return; + } else if (!allow_nondet) { + fst->SetProperties(kError, kError); + FSTERROR() << "Refusing to minimize a non-deterministic FST with " + << "allow_nondet = false"; + return; + } + // The Revuz algorithm won't work for nondeterministic inputs, so if the + // input is nondeterministic, we'll have to pass a bool saying not to use + // that algorithm. We check at this level rather than in AcceptorMinimize(), + // because it's possible that the FST at this level could be deterministic, + // but a harmless type of non-determinism could be introduced by Encode() + // (thanks to kEncodeWeights, if the FST has epsilons and has a final + // weight with weights equal to some epsilon arc.) + allow_acyclic_minimization = false; + } + if (!(props & kAcceptor)) { // Weighted transducer. + VectorFst> gfst; + ArcMap(*fst, &gfst, ToGallicMapper()); + fst->DeleteStates(); + gfst.SetProperties(kAcceptor, kAcceptor); + Push(&gfst, REWEIGHT_TO_INITIAL, delta); + ArcMap(&gfst, QuantizeMapper>(delta)); + EncodeMapper> encoder( + kEncodeLabels | kEncodeWeights, ENCODE); + Encode(&gfst, &encoder); + internal::AcceptorMinimize(&gfst, allow_acyclic_minimization); + Decode(&gfst, encoder); + if (!sfst) { + FactorWeightFst, + GallicFactor> + fwfst(gfst); + std::unique_ptr osyms( + fst->OutputSymbols() ? fst->OutputSymbols()->Copy() : nullptr); + ArcMap(fwfst, fst, FromGallicMapper()); + fst->SetOutputSymbols(osyms.get()); + } else { + sfst->SetOutputSymbols(fst->OutputSymbols()); + GallicToNewSymbolsMapper mapper(sfst); + ArcMap(gfst, fst, &mapper); + fst->SetOutputSymbols(sfst->InputSymbols()); + } + } else if (props & kWeighted) { // Weighted acceptor. + Push(fst, REWEIGHT_TO_INITIAL, delta); + ArcMap(fst, QuantizeMapper(delta)); + EncodeMapper encoder(kEncodeLabels | kEncodeWeights, ENCODE); + Encode(fst, &encoder); + internal::AcceptorMinimize(fst, allow_acyclic_minimization); + Decode(fst, encoder); + } else { // Unweighted acceptor. + internal::AcceptorMinimize(fst, allow_acyclic_minimization); + } +} + +} // namespace fst + +#endif // FST_MINIMIZE_H_ diff --git a/projects/llm_framework/include/fst/mutable-fst.h b/projects/llm_framework/include/fst/mutable-fst.h new file mode 100644 index 00000000..9031770d --- /dev/null +++ b/projects/llm_framework/include/fst/mutable-fst.h @@ -0,0 +1,398 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Expanded FST augmented with mutators; interface class definition and +// mutable arc iterator interface. + +#ifndef FST_MUTABLE_FST_H_ +#define FST_MUTABLE_FST_H_ + +#include +#include + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +template +struct MutableArcIteratorData; + +// Abstract interface for an expanded FST which also supports mutation +// operations. To modify arcs, use MutableArcIterator. +template +class MutableFst : public ExpandedFst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + virtual MutableFst &operator=(const Fst &fst) = 0; + + MutableFst &operator=(const MutableFst &fst) { + return operator=(static_cast &>(fst)); + } + + // Sets the initial state. + virtual void SetStart(StateId) = 0; + + // Sets a state's final weight. + virtual void SetFinal(StateId, Weight) = 0; + + // Sets property bits w.r.t. mask. + virtual void SetProperties(uint64 props, uint64 mask) = 0; + + // Adds a state and returns its ID. + virtual StateId AddState() = 0; + + // Adds an arc to state. + virtual void AddArc(StateId, const Arc &arc) = 0; + + // Adds an arc (passed by rvalue reference) to state. Allows subclasses + // to optionally implement move semantics. Defaults to lvalue overload. + virtual void AddArc(StateId state, Arc &&arc) { AddArc(state, arc); } + + // Deletes some states, preserving original StateId ordering. + virtual void DeleteStates(const std::vector &) = 0; + + // Delete all states. + virtual void DeleteStates() = 0; + + // Delete some arcs at a given state. + virtual void DeleteArcs(StateId, size_t n) = 0; + + // Delete all arcs at a given state. + virtual void DeleteArcs(StateId) = 0; + + // Optional, best effort only. + virtual void ReserveStates(StateId n) {} + + // Optional, best effort only. + virtual void ReserveArcs(StateId s, size_t n) {} + + // Returns input label symbol table or nullptr if not specified. + const SymbolTable *InputSymbols() const override = 0; + + // Returns output label symbol table or nullptr if not specified. + const SymbolTable *OutputSymbols() const override = 0; + + // Returns input label symbol table or nullptr if not specified. + virtual SymbolTable *MutableInputSymbols() = 0; + + // Returns output label symbol table or nullptr if not specified. + virtual SymbolTable *MutableOutputSymbols() = 0; + + // Sets input label symbol table; pass nullptr to delete table. + virtual void SetInputSymbols(const SymbolTable *isyms) = 0; + + // Sets output label symbol table; pass nullptr to delete table. + virtual void SetOutputSymbols(const SymbolTable *osyms) = 0; + + // Gets a copy of this MutableFst. See Fst<>::Copy() for further doc. + MutableFst *Copy(bool safe = false) const override = 0; + + // Reads a MutableFst from an input stream, returning nullptr on error. + static MutableFst *Read(std::istream &strm, const FstReadOptions &opts) { + FstReadOptions ropts(opts); + FstHeader hdr; + if (ropts.header) { + hdr = *opts.header; + } else { + if (!hdr.Read(strm, opts.source)) return nullptr; + ropts.header = &hdr; + } + if (!(hdr.Properties() & kMutable)) { + LOG(ERROR) << "MutableFst::Read: Not a MutableFst: " << ropts.source; + return nullptr; + } + const auto &fst_type = hdr.FstType(); + const auto reader = FstRegister::GetRegister()->GetReader(fst_type); + if (!reader) { + LOG(ERROR) << "MutableFst::Read: Unknown FST type \"" << fst_type + << "\" (arc type = \"" << A::Type() << "\"): " << ropts.source; + return nullptr; + } + auto *fst = reader(strm, ropts); + if (!fst) return nullptr; + return static_cast *>(fst); + } + + // Reads a MutableFst from a file; returns nullptr on error. An empty + // filename results in reading from standard input. If convert is true, + // convert to a mutable FST subclass (given by convert_type) in the case + // that the input FST is non-mutable. + static MutableFst *Read(const string &filename, bool convert = false, + const string &convert_type = "vector") { + if (convert == false) { + if (!filename.empty()) { + std::ifstream strm(filename, + std::ios_base::in | std::ios_base::binary); + if (!strm) { + LOG(ERROR) << "MutableFst::Read: Can't open file: " << filename; + return nullptr; + } + return Read(strm, FstReadOptions(filename)); + } else { + return Read(std::cin, FstReadOptions("standard input")); + } + } else { // Converts to 'convert_type' if not mutable. + std::unique_ptr> ifst(Fst::Read(filename)); + if (!ifst) return nullptr; + if (ifst->Properties(kMutable, false)) { + return static_cast *>(ifst.release()); + } else { + std::unique_ptr> ofst(Convert(*ifst, convert_type)); + ifst.reset(); + if (!ofst) return nullptr; + if (!ofst->Properties(kMutable, false)) { + LOG(ERROR) << "MutableFst: Bad convert type: " << convert_type; + } + return static_cast *>(ofst.release()); + } + } + } + + // For generic mutuble arc iterator construction; not normally called + // directly by users. + virtual void InitMutableArcIterator(StateId s, + MutableArcIteratorData *data) = 0; +}; + +// Mutable arc iterator interface, templated on the Arc definition. This is +// used by mutable arc iterator specializations that are returned by the +// InitMutableArcIterator MutableFst method. +template +class MutableArcIteratorBase : public ArcIteratorBase { + public: + // Sets current arc. + virtual void SetValue(const Arc &) = 0; +}; + +template +struct MutableArcIteratorData { + MutableArcIteratorBase *base; // Specific iterator. +}; + +// Generic mutable arc iterator, templated on the FST definition; a wrapper +// around a pointer to a more specific one. +// +// Here is a typical use: +// +// for (MutableArcIterator aiter(&fst, s); +// !aiter.Done(); +// aiter.Next()) { +// StdArc arc = aiter.Value(); +// arc.ilabel = 7; +// aiter.SetValue(arc); +// ... +// } +// +// This version requires function calls. +template +class MutableArcIterator { + public: + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + + MutableArcIterator(FST *fst, StateId s) { + fst->InitMutableArcIterator(s, &data_); + } + + ~MutableArcIterator() { delete data_.base; } + + bool Done() const { return data_.base->Done(); } + + const Arc &Value() const { return data_.base->Value(); } + + void Next() { data_.base->Next(); } + + size_t Position() const { return data_.base->Position(); } + + void Reset() { data_.base->Reset(); } + + void Seek(size_t a) { data_.base->Seek(a); } + + void SetValue(const Arc &arc) { data_.base->SetValue(arc); } + + uint32 Flags() const { return data_.base->Flags(); } + + void SetFlags(uint32 flags, uint32 mask) { + return data_.base->SetFlags(flags, mask); + } + + private: + MutableArcIteratorData data_; + + MutableArcIterator(const MutableArcIterator &) = delete; + MutableArcIterator &operator=(const MutableArcIterator &) = delete; +}; + +namespace internal { + +// MutableFst case: abstract methods. +template +inline typename Arc::Weight Final(const MutableFst &fst, + typename Arc::StateId s) { + return fst.Final(s); +} + +template +inline ssize_t NumArcs(const MutableFst &fst, typename Arc::StateId s) { + return fst.NumArcs(s); +} + +template +inline ssize_t NumInputEpsilons(const MutableFst &fst, + typename Arc::StateId s) { + return fst.NumInputEpsilons(s); +} + +template +inline ssize_t NumOutputEpsilons(const MutableFst &fst, + typename Arc::StateId s) { + return fst.NumOutputEpsilons(s); +} + +} // namespace internal + +// A useful alias when using StdArc. +using StdMutableFst = MutableFst; + +// This is a helper class template useful for attaching a MutableFst interface +// to its implementation, handling reference counting and COW semantics. +template > +class ImplToMutableFst : public ImplToExpandedFst { + public: + using Arc = typename Impl::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using ImplToExpandedFst::operator=; + + void SetStart(StateId s) override { + MutateCheck(); + GetMutableImpl()->SetStart(s); + } + + void SetFinal(StateId s, Weight weight) override { + MutateCheck(); + GetMutableImpl()->SetFinal(s, std::move(weight)); + } + + void SetProperties(uint64 props, uint64 mask) override { + // Can skip mutate check if extrinsic properties don't change, + // since it is then safe to update all (shallow) copies + const auto exprops = kExtrinsicProperties & mask; + if (GetImpl()->Properties(exprops) != (props & exprops)) MutateCheck(); + GetMutableImpl()->SetProperties(props, mask); + } + + StateId AddState() override { + MutateCheck(); + return GetMutableImpl()->AddState(); + } + + void AddArc(StateId s, const Arc &arc) override { + MutateCheck(); + GetMutableImpl()->AddArc(s, arc); + } + + void AddArc(StateId s, Arc &&arc) override { + MutateCheck(); + GetMutableImpl()->AddArc(s, std::move(arc)); + } + + void DeleteStates(const std::vector &dstates) override { + MutateCheck(); + GetMutableImpl()->DeleteStates(dstates); + } + + void DeleteStates() override { + if (!Unique()) { + const auto *isymbols = GetImpl()->InputSymbols(); + const auto *osymbols = GetImpl()->OutputSymbols(); + SetImpl(std::make_shared()); + GetMutableImpl()->SetInputSymbols(isymbols); + GetMutableImpl()->SetOutputSymbols(osymbols); + } else { + GetMutableImpl()->DeleteStates(); + } + } + + void DeleteArcs(StateId s, size_t n) override { + MutateCheck(); + GetMutableImpl()->DeleteArcs(s, n); + } + + void DeleteArcs(StateId s) override { + MutateCheck(); + GetMutableImpl()->DeleteArcs(s); + } + + void ReserveStates(StateId s) override { + MutateCheck(); + GetMutableImpl()->ReserveStates(s); + } + + void ReserveArcs(StateId s, size_t n) override { + MutateCheck(); + GetMutableImpl()->ReserveArcs(s, n); + } + + const SymbolTable *InputSymbols() const override { + return GetImpl()->InputSymbols(); + } + + const SymbolTable *OutputSymbols() const override { + return GetImpl()->OutputSymbols(); + } + + SymbolTable *MutableInputSymbols() override { + MutateCheck(); + return GetMutableImpl()->InputSymbols(); + } + + SymbolTable *MutableOutputSymbols() override { + MutateCheck(); + return GetMutableImpl()->OutputSymbols(); + } + + void SetInputSymbols(const SymbolTable *isyms) override { + MutateCheck(); + GetMutableImpl()->SetInputSymbols(isyms); + } + + void SetOutputSymbols(const SymbolTable *osyms) override { + MutateCheck(); + GetMutableImpl()->SetOutputSymbols(osyms); + } + + protected: + using ImplToExpandedFst::GetImpl; + using ImplToExpandedFst::GetMutableImpl; + using ImplToExpandedFst::Unique; + using ImplToExpandedFst::SetImpl; + using ImplToExpandedFst::InputSymbols; + + explicit ImplToMutableFst(std::shared_ptr impl) + : ImplToExpandedFst(impl) {} + + ImplToMutableFst(const ImplToMutableFst &fst, bool safe) + : ImplToExpandedFst(fst, safe) {} + + void MutateCheck() { + if (!Unique()) SetImpl(std::make_shared(*this)); + } +}; + +} // namespace fst + +#endif // FST_MUTABLE_FST_H_ diff --git a/projects/llm_framework/include/fst/pair-weight.h b/projects/llm_framework/include/fst/pair-weight.h new file mode 100644 index 00000000..1f2e963c --- /dev/null +++ b/projects/llm_framework/include/fst/pair-weight.h @@ -0,0 +1,155 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Pair weight templated base class for weight classes that contain two weights +// (e.g. Product, Lexicographic). + +#ifndef FST_PAIR_WEIGHT_H_ +#define FST_PAIR_WEIGHT_H_ + +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +template +class PairWeight { + public: + using ReverseWeight = + PairWeight; + + PairWeight() {} + + PairWeight(W1 w1, W2 w2) : value1_(std::move(w1)), value2_(std::move(w2)) {} + + static const PairWeight &Zero() { + static const PairWeight zero(W1::Zero(), W2::Zero()); + return zero; + } + + static const PairWeight &One() { + static const PairWeight one(W1::One(), W2::One()); + return one; + } + + static const PairWeight &NoWeight() { + static const PairWeight no_weight(W1::NoWeight(), W2::NoWeight()); + return no_weight; + } + + std::istream &Read(std::istream &strm) { + value1_.Read(strm); + return value2_.Read(strm); + } + + std::ostream &Write(std::ostream &strm) const { + value1_.Write(strm); + return value2_.Write(strm); + } + + bool Member() const { return value1_.Member() && value2_.Member(); } + + size_t Hash() const { + const auto h1 = value1_.Hash(); + const auto h2 = value2_.Hash(); + static constexpr int lshift = 5; + static constexpr int rshift = CHAR_BIT * sizeof(size_t) - 5; + return h1 << lshift ^ h1 >> rshift ^ h2; + } + + PairWeight Quantize(float delta = kDelta) const { + return PairWeight(value1_.Quantize(delta), value2_.Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(value1_.Reverse(), value2_.Reverse()); + } + + const W1 &Value1() const { return value1_; } + + const W2 &Value2() const { return value2_; } + + void SetValue1(const W1 &weight) { value1_ = weight; } + + void SetValue2(const W2 &weight) { value2_ = weight; } + + private: + W1 value1_; + W2 value2_; +}; + +template +inline bool operator==(const PairWeight &w1, + const PairWeight &w2) { + return w1.Value1() == w2.Value1() && w1.Value2() == w2.Value2(); +} + +template +inline bool operator!=(const PairWeight &w1, + const PairWeight &w2) { + return w1.Value1() != w2.Value1() || w1.Value2() != w2.Value2(); +} + +template +inline bool ApproxEqual(const PairWeight &w1, + const PairWeight &w2, float delta = kDelta) { + return ApproxEqual(w1.Value1(), w2.Value1(), delta) && + ApproxEqual(w1.Value2(), w2.Value2(), delta); +} + +template +inline std::ostream &operator<<(std::ostream &strm, + const PairWeight &weight) { + CompositeWeightWriter writer(strm); + writer.WriteBegin(); + writer.WriteElement(weight.Value1()); + writer.WriteElement(weight.Value2()); + writer.WriteEnd(); + return strm; +} + +template +inline std::istream &operator>>(std::istream &strm, + PairWeight &weight) { + CompositeWeightReader reader(strm); + reader.ReadBegin(); + W1 w1; + reader.ReadElement(&w1); + weight.SetValue1(w1); + W2 w2; + reader.ReadElement(&w2, true); + weight.SetValue2(w2); + reader.ReadEnd(); + return strm; +} + +// This function object returns weights by calling the underlying generators +// and forming a pair. This is intended primarily for testing. +template +class WeightGenerate> { + public: + using Weight = PairWeight; + using Generate1 = WeightGenerate; + using Generate2 = WeightGenerate; + + explicit WeightGenerate(bool allow_zero = true) + : generate1_(allow_zero), generate2_(allow_zero) {} + + Weight operator()() const { return Weight(generate1_(), generate2_()); } + + private: + Generate1 generate1_; + Generate2 generate2_; +}; + +} // namespace fst + +#endif // FST_PAIR_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/partition.h b/projects/llm_framework/include/fst/partition.h new file mode 100644 index 00000000..5dbbe46a --- /dev/null +++ b/projects/llm_framework/include/fst/partition.h @@ -0,0 +1,305 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to create a partition of states. + +#ifndef FST_PARTITION_H_ +#define FST_PARTITION_H_ + +#include +#include + + +#include + + +namespace fst { +namespace internal { + +template +class PartitionIterator; + +// Defines a partitioning of elements, used to represent equivalence classes +// for FST operations like minimization. T must be a signed integer type. +// +// The elements are numbered from 0 to num_elements - 1. +// Initialize(num_elements) sets up the class for a given number of elements. +// We maintain a partition of these elements into classes. The classes are also +// numbered from zero; you can add a class with AddClass(), or add them in bulk +// with AllocateClasses(num_classes). Initially the elements are not assigned +// to any class; you set up the initial mapping from elements to classes by +// calling Add(element_id, class_id). You can also move an element to a +// different class by calling Move(element_id, class_id). +// +// We also support a rather specialized interface that allows you to efficiently +// split classes in the Hopcroft minimization algorithm. This maintains a +// binary partition of each class. Let's call these, rather arbitrarily, the +// 'yes' subset and the 'no' subset of each class, and assume that by default, +// each element of a class is in its 'no' subset. When one calls +// SplitOn(element_id), element_id is moved to the 'yes' subset of its class. +// (If it was already in the 'yes' set, it just stays there). The aim is to +// enable (later) splitting the class in two in time no greater than the time +// already spent calling SplitOn() for that class. We keep a list of the classes +// which have nonempty 'yes' sets, as visited_classes_. When one calls +// FinalizeSplit(Queue *l), for each class in visited_classes_ whose 'yes' +// and 'no' sets are both nonempty, it will create a new class consisting of +// the smaller of the two subsets (and this class will be added to the queue), +// and the old class will now be the larger of the two subsets. This call also +// resets all the yes/no partitions so that everything is in the 'no' subsets. +// +// One cannot use the Move() function if SplitOn() has been called without +// a subsequent call to FinalizeSplit() +template +class Partition { + public: + Partition() {} + + explicit Partition(T num_elements) { Initialize(num_elements); } + + // Creates an empty partition for num_elements. This means that the elements + // are not assigned to a class (i.e class_index = -1); you should set up the + // number of classes using AllocateClasses() or AddClass(), and allocate each + // element to a class by calling Add(element, class_id). + void Initialize(size_t num_elements) { + elements_.resize(num_elements); + classes_.reserve(num_elements); + classes_.clear(); + yes_counter_ = 1; + } + + // Adds a class; returns new number of classes. + T AddClass() { + auto num_classes = classes_.size(); + classes_.resize(num_classes + 1); + return num_classes; + } + + // Adds 'num_classes' new (empty) classes. + void AllocateClasses(T num_classes) { + classes_.resize(classes_.size() + num_classes); + } + + // Adds element_id to class_id. element_id should already have been allocated + // by calling Initialize(num_elements)---or the constructor taking + // num_elements---with num_elements > element_id. element_id must not + // currently be a member of any class; once elements have been added to a + // class, use the Move() method to move them from one class to another. + void Add(T element_id, T class_id) { + auto &this_element = elements_[element_id]; + auto &this_class = classes_[class_id]; + ++this_class.size; + // Adds the element to the 'no' subset of the class. + auto no_head = this_class.no_head; + if (no_head >= 0) elements_[no_head].prev_element = element_id; + this_class.no_head = element_id; + this_element.class_id = class_id; + // Adds to the 'no' subset of the class. + this_element.yes = 0; + this_element.next_element = no_head; + this_element.prev_element = -1; + } + + // Moves element_id from 'no' subset of its current class to 'no' subset of + // class class_id. This may not work correctly if you have called SplitOn() + // [for any element] and haven't subsequently called FinalizeSplit(). + void Move(T element_id, T class_id) { + auto elements = &(elements_[0]); + auto &element = elements[element_id]; + auto &old_class = classes_[element.class_id]; + --old_class.size; + // Excises the element from the 'no' list of its old class, where it is + // assumed to be. + if (element.prev_element >= 0) { + elements[element.prev_element].next_element = element.next_element; + } else { + old_class.no_head = element.next_element; + } + if (element.next_element >= 0) { + elements[element.next_element].prev_element = element.prev_element; + } + // Adds to new class. + Add(element_id, class_id); + } + + // Moves element_id to the 'yes' subset of its class if it was in the 'no' + // subset, and marks the class as having been visited. + void SplitOn(T element_id) { + auto elements = &(elements_[0]); + auto &element = elements[element_id]; + if (element.yes == yes_counter_) { + return; // Already in the 'yes' set; nothing to do. + } + auto class_id = element.class_id; + auto &this_class = classes_[class_id]; + // Excises the element from the 'no' list of its class. + if (element.prev_element >= 0) { + elements[element.prev_element].next_element = element.next_element; + } else { + this_class.no_head = element.next_element; + } + if (element.next_element >= 0) { + elements[element.next_element].prev_element = element.prev_element; + } + // Adds the element to the 'yes' list. + if (this_class.yes_head >= 0) { + elements[this_class.yes_head].prev_element = element_id; + } else { + visited_classes_.push_back(class_id); + } + element.yes = yes_counter_; + element.next_element = this_class.yes_head; + element.prev_element = -1; + this_class.yes_head = element_id; + this_class.yes_size++; + } + + // This should be called after one has possibly called SplitOn for one or more + // elements, thus moving those elements to the 'yes' subset for their class. + // For each class that has a nontrivial split (i.e., it's not the case that + // all members are in the 'yes' or 'no' subset), this function creates a new + // class containing the smaller of the two subsets of elements, leaving the + // larger group of elements in the old class. The identifier of the new class + // will be added to the queue provided as the pointer L. This method then + // moves all elements to the 'no' subset of their class. + template + void FinalizeSplit(Queue *queue) { + for (const auto &visited_class : visited_classes_) { + const auto new_class = SplitRefine(visited_class); + if (new_class != -1 && queue) queue->Enqueue(new_class); + } + visited_classes_.clear(); + // Incrementation sets all the 'yes' members of the elements to false. + ++yes_counter_; + } + + const T ClassId(T element_id) const { return elements_[element_id].class_id; } + + const size_t ClassSize(T class_id) const { return classes_[class_id].size; } + + const T NumClasses() const { return classes_.size(); } + + private: + friend class PartitionIterator; + + // Information about a given element. + struct Element { + T class_id; // Class ID of this element. + T yes; // This is to be interpreted as a bool, true if it's in the + // 'yes' set of this class. The interpretation as bool is + // (yes == yes_counter_ ? true : false). + T next_element; // Next element in the 'no' list or 'yes' list of this + // class, whichever of the two we belong to (think of + // this as the 'next' in a doubly-linked list, although + // it is an index into the elements array). Negative + // values corresponds to null. + T prev_element; // Previous element in the 'no' or 'yes' doubly linked + // list. Negative values corresponds to null. + }; + + // Information about a given class. + struct Class { + Class() : size(0), yes_size(0), no_head(-1), yes_head(-1) {} + T size; // Total number of elements in this class ('no' plus 'yes' + // subsets). + T yes_size; // Total number of elements of 'yes' subset of this class. + T no_head; // Index of head element of doubly-linked list in 'no' subset. + // Everything is in the 'no' subset until you call SplitOn(). + // -1 means no element. + T yes_head; // Index of head element of doubly-linked list in 'yes' subset. + // -1 means no element. + }; + + // This method, called from FinalizeSplit(), checks whether a class has to + // be split (a class will be split only if its 'yes' and 'no' subsets are + // both nonempty, but one can assume that since this function was called, the + // 'yes' subset is nonempty). It splits by taking the smaller subset and + // making it a new class, and leaving the larger subset of elements in the + // 'no' subset of the old class. It returns the new class if created, or -1 + // if none was created. + T SplitRefine(T class_id) { + auto yes_size = classes_[class_id].yes_size; + auto size = classes_[class_id].size; + auto no_size = size - yes_size; + if (no_size == 0) { + // All members are in the 'yes' subset, so we don't have to create a new + // class, just move them all to the 'no' subset. + classes_[class_id].no_head = classes_[class_id].yes_head; + classes_[class_id].yes_head = -1; + classes_[class_id].yes_size = 0; + return -1; + } else { + auto new_class_id = classes_.size(); + classes_.resize(classes_.size() + 1); + auto &old_class = classes_[class_id]; + auto &new_class = classes_[new_class_id]; + // The new_class will have the values from the constructor. + if (no_size < yes_size) { + // Moves the 'no' subset to new class ('no' subset). + new_class.no_head = old_class.no_head; + new_class.size = no_size; + // And makes the 'yes' subset of the old class ('no' subset). + old_class.no_head = old_class.yes_head; + old_class.yes_head = -1; + old_class.size = yes_size; + old_class.yes_size = 0; + } else { + // Moves the 'yes' subset to the new class (to the 'no' subset) + new_class.size = yes_size; + new_class.no_head = old_class.yes_head; + // Retains only the 'no' subset in the old class. + old_class.size = no_size; + old_class.yes_size = 0; + old_class.yes_head = -1; + } + auto elements = &(elements_[0]); + // Updates the 'class_id' of all the elements we moved. + for (auto e = new_class.no_head; e >= 0; e = elements[e].next_element) { + elements[e].class_id = new_class_id; + } + return new_class_id; + } + } + + // elements_[i] contains all info about the i'th element. + std::vector elements_; + // classes_[i] contains all info about the i'th class. + std::vector classes_; + // Set of visited classes to be used in split refine. + std::vector visited_classes_; + // yes_counter_ is used in interpreting the 'yes' members of class Element. + // If element.yes == yes_counter_, we interpret that element as being in the + // 'yes' subset of its class. This allows us to, in effect, set all those + // bools to false at a stroke by incrementing yes_counter_. + T yes_counter_; +}; + +// Iterates over members of the 'no' subset of a class in a partition. (When +// this is used, everything is in the 'no' subset). +template +class PartitionIterator { + public: + using Element = typename Partition::Element; + + PartitionIterator(const Partition &partition, T class_id) + : partition_(partition), + element_id_(partition_.classes_[class_id].no_head), + class_id_(class_id) {} + + bool Done() { return element_id_ < 0; } + + const T Value() { return element_id_; } + + void Next() { element_id_ = partition_.elements_[element_id_].next_element; } + + void Reset() { element_id_ = partition_.classes_[class_id_].no_head; } + + private: + const Partition &partition_; + T element_id_; + T class_id_; +}; + +} // namespace internal +} // namespace fst + +#endif // FST_PARTITION_H_ diff --git a/projects/llm_framework/include/fst/power-weight.h b/projects/llm_framework/include/fst/power-weight.h new file mode 100644 index 00000000..f2f3cbdb --- /dev/null +++ b/projects/llm_framework/include/fst/power-weight.h @@ -0,0 +1,168 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Cartesian power weight semiring operation definitions. + +#ifndef FST_POWER_WEIGHT_H_ +#define FST_POWER_WEIGHT_H_ + +#include + +#include +#include + + +namespace fst { + +// Cartesian power semiring: W ^ n +// +// Forms: +// - a left semimodule when W is a left semiring, +// - a right semimodule when W is a right semiring, +// - a bisemimodule when W is a semiring, +// the free semimodule of rank n over W +// The Times operation is overloaded to provide the left and right scalar +// products. +template +class PowerWeight : public TupleWeight { + public: + using ReverseWeight = PowerWeight; + + PowerWeight() {} + + explicit PowerWeight(const TupleWeight &weight) + : TupleWeight(weight) {} + + template + PowerWeight(Iterator begin, Iterator end) : TupleWeight(begin, end) {} + + // Initialize component `index` to `weight`; initialize all other components + // to `default_weight` + PowerWeight(size_t index, const W &weight, + const W &default_weight = W::Zero()) + : TupleWeight(index, weight, default_weight) {} + + static const PowerWeight &Zero() { + static const PowerWeight zero(TupleWeight::Zero()); + return zero; + } + + static const PowerWeight &One() { + static const PowerWeight one(TupleWeight::One()); + return one; + } + + static const PowerWeight &NoWeight() { + static const PowerWeight no_weight(TupleWeight::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = + new string(W::Type() + "_^" + std::to_string(n)); + return *type; + } + + static constexpr uint64 Properties() { + return W::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } + + PowerWeight Quantize(float delta = kDelta) const { + return PowerWeight(TupleWeight::Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(TupleWeight::Reverse()); + } +}; + +// Semiring plus operation. +template +inline PowerWeight Plus(const PowerWeight &w1, + const PowerWeight &w2) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Plus(w1.Value(i), w2.Value(i))); + } + return result; +} + +// Semiring times operation. +template +inline PowerWeight Times(const PowerWeight &w1, + const PowerWeight &w2) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Times(w1.Value(i), w2.Value(i))); + } + return result; +} + +// Semiring divide operation. +template +inline PowerWeight Divide(const PowerWeight &w1, + const PowerWeight &w2, + DivideType type = DIVIDE_ANY) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Divide(w1.Value(i), w2.Value(i), type)); + } + return result; +} + +// Semimodule left scalar product. +template +inline PowerWeight Times(const W &scalar, + const PowerWeight &weight) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Times(scalar, weight.Value(i))); + } + return result; +} + +// Semimodule right scalar product. +template +inline PowerWeight Times(const PowerWeight &weight, + const W &scalar) { + PowerWeight result; + for (size_t i = 0; i < n; ++i) { + result.SetValue(i, Times(weight.Value(i), scalar)); + } + return result; +} + +// Semimodule dot product. +template +inline W DotProduct(const PowerWeight &w1, const PowerWeight &w2) { + W result(W::Zero()); + for (size_t i = 0; i < n; ++i) { + result = Plus(result, Times(w1.Value(i), w2.Value(i))); + } + return result; +} + +// This function object generates weights over the Cartesian power of rank +// n over the underlying weight. This is intended primarily for testing. +template +class WeightGenerate> { + public: + using Weight = PowerWeight; + using Generate = WeightGenerate; + + explicit WeightGenerate(bool allow_zero = true) : generate_(allow_zero) {} + + Weight operator()() const { + Weight result; + for (size_t i = 0; i < n; ++i) result.SetValue(i, generate_()); + return result; + } + + private: + Generate generate_; +}; + +} // namespace fst + +#endif // FST_POWER_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/product-weight.h b/projects/llm_framework/include/fst/product-weight.h new file mode 100644 index 00000000..56a18be1 --- /dev/null +++ b/projects/llm_framework/include/fst/product-weight.h @@ -0,0 +1,107 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Product weight set and associated semiring operation definitions. + +#ifndef FST_PRODUCT_WEIGHT_H_ +#define FST_PRODUCT_WEIGHT_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Product semiring: W1 * W2. +template +class ProductWeight : public PairWeight { + public: + using ReverseWeight = + ProductWeight; + + ProductWeight() {} + + explicit ProductWeight(const PairWeight &weight) + : PairWeight(weight) {} + + ProductWeight(W1 w1, W2 w2) + : PairWeight(std::move(w1), std::move(w2)) {} + + static const ProductWeight &Zero() { + static const ProductWeight zero(PairWeight::Zero()); + return zero; + } + + static const ProductWeight &One() { + static const ProductWeight one(PairWeight::One()); + return one; + } + + static const ProductWeight &NoWeight() { + static const ProductWeight no_weight(PairWeight::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = + new string(W1::Type() + "_X_" + W2::Type()); + return *type; + } + + static constexpr uint64 Properties() { + return W1::Properties() & W2::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } + + ProductWeight Quantize(float delta = kDelta) const { + return ProductWeight(PairWeight::Quantize(delta)); + } + + ReverseWeight Reverse() const { + return ReverseWeight(PairWeight::Reverse()); + } +}; + +template +inline ProductWeight Plus(const ProductWeight &w1, + const ProductWeight &w2) { + return ProductWeight(Plus(w1.Value1(), w2.Value1()), + Plus(w1.Value2(), w2.Value2())); +} + +template +inline ProductWeight Times(const ProductWeight &w1, + const ProductWeight &w2) { + return ProductWeight(Times(w1.Value1(), w2.Value1()), + Times(w1.Value2(), w2.Value2())); +} + +template +inline ProductWeight Divide(const ProductWeight &w1, + const ProductWeight &w2, + DivideType typ = DIVIDE_ANY) { + return ProductWeight(Divide(w1.Value1(), w2.Value1(), typ), + Divide(w1.Value2(), w2.Value2(), typ)); +} + +// This function object generates weights by calling the underlying generators +// for the template weight types, like all other pair weight types. This is +// intended primarily for testing. +template +class WeightGenerate> : + public WeightGenerate> { + public: + using Weight = ProductWeight; + using Generate = WeightGenerate>; + + explicit WeightGenerate(bool allow_zero = true) : Generate(allow_zero) {} + + Weight operator()() const { return Weight(Generate::operator()()); } +}; + +} // namespace fst + +#endif // FST_PRODUCT_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/project.h b/projects/llm_framework/include/fst/project.h new file mode 100644 index 00000000..5a82cf14 --- /dev/null +++ b/projects/llm_framework/include/fst/project.h @@ -0,0 +1,159 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to project an FST on to its domain or range. + +#ifndef FST_PROJECT_H_ +#define FST_PROJECT_H_ + +#include +#include + + +namespace fst { + +// This specifies whether to project on input or output. +enum ProjectType { PROJECT_INPUT = 1, PROJECT_OUTPUT = 2 }; + +// Mapper to implement projection per arc. +template +class ProjectMapper { + public: + using FromArc = A; + using ToArc = A; + + constexpr explicit ProjectMapper(ProjectType project_type) + : project_type_(project_type) {} + + ToArc operator()(const FromArc &arc) const { + const auto label = project_type_ == PROJECT_INPUT ? arc.ilabel : arc.olabel; + return ToArc(label, label, arc.weight, arc.nextstate); + } + + constexpr MapFinalAction FinalAction() const { + return MAP_NO_SUPERFINAL; + } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return project_type_ == PROJECT_INPUT ? MAP_COPY_SYMBOLS + : MAP_CLEAR_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return project_type_ == PROJECT_OUTPUT ? MAP_COPY_SYMBOLS + : MAP_CLEAR_SYMBOLS; + } + + constexpr uint64 Properties(uint64 props) const { + return ProjectProperties(props, project_type_ == PROJECT_INPUT); + } + + private: + const ProjectType project_type_; +}; + +// Projects an FST onto its domain or range by either copying each arcs' input +// label to the output label or vice versa. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(1) +// +// where V is the number of states and E is the number of arcs. +template +inline void Project(const Fst &ifst, MutableFst *ofst, + ProjectType project_type) { + ArcMap(ifst, ofst, ProjectMapper(project_type)); + switch (project_type) { + case PROJECT_INPUT: + ofst->SetOutputSymbols(ifst.InputSymbols()); + return; + case PROJECT_OUTPUT: + ofst->SetInputSymbols(ifst.OutputSymbols()); + return; + } +} + +// Destructive variant of the above. +template +inline void Project(MutableFst *fst, ProjectType project_type) { + ArcMap(fst, ProjectMapper(project_type)); + switch (project_type) { + case PROJECT_INPUT: + fst->SetOutputSymbols(fst->InputSymbols()); + return; + case PROJECT_OUTPUT: + fst->SetInputSymbols(fst->OutputSymbols()); + return; + } +} + +// Projects an FST onto its domain or range by either copying each arc's input +// label to the output label or vice versa. This version is a delayed FST. +// +// Complexity: +// +// Time: O(v + e) +// Space: O(1) +// +// where v is the number of states visited and e is the number of arcs visited. +// Constant time and to visit an input state or arc is assumed and exclusive of +// caching. +template +class ProjectFst : public ArcMapFst> { + public: + using FromArc = A; + using ToArc = A; + + using Impl = internal::ArcMapFstImpl>; + + ProjectFst(const Fst &fst, ProjectType project_type) + : ArcMapFst>(fst, ProjectMapper(project_type)) { + if (project_type == PROJECT_INPUT) { + GetMutableImpl()->SetOutputSymbols(fst.InputSymbols()); + } + if (project_type == PROJECT_OUTPUT) { + GetMutableImpl()->SetInputSymbols(fst.OutputSymbols()); + } + } + + // See Fst<>::Copy() for doc. + ProjectFst(const ProjectFst &fst, bool safe = false) + : ArcMapFst>(fst, safe) {} + + // Gets a copy of this ProjectFst. See Fst<>::Copy() for further doc. + ProjectFst *Copy(bool safe = false) const override { + return new ProjectFst(*this, safe); + } + + private: + using ImplToFst::GetMutableImpl; +}; + +// Specialization for ProjectFst. +template +class StateIterator> + : public StateIterator>> { + public: + explicit StateIterator(const ProjectFst &fst) + : StateIterator>>(fst) {} +}; + +// Specialization for ProjectFst. +template +class ArcIterator> + : public ArcIterator>> { + public: + using StateId = typename A::StateId; + + ArcIterator(const ProjectFst &fst, StateId s) + : ArcIterator>>(fst, s) {} +}; + +// Useful alias when using StdArc. +using StdProjectFst = ProjectFst; + +} // namespace fst + +#endif // FST_PROJECT_H_ diff --git a/projects/llm_framework/include/fst/properties.h b/projects/llm_framework/include/fst/properties.h new file mode 100644 index 00000000..157247a6 --- /dev/null +++ b/projects/llm_framework/include/fst/properties.h @@ -0,0 +1,468 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST property bits. + +#ifndef FST_PROPERTIES_H_ +#define FST_PROPERTIES_H_ + +#include +#include + +#include + +namespace fst { + +// The property bits here assert facts about an FST. If individual bits are +// added, then the composite properties below, the property functions and +// property names in properties.cc, and TestProperties() in test-properties.h +// should be updated. + +// BINARY PROPERTIES +// +// For each property below, there is a single bit. If it is set, the property is +// true. If it is not set, the property is false. + +// The Fst is an ExpandedFst. +constexpr uint64 kExpanded = 0x0000000000000001ULL; + +// The Fst is a MutableFst. +constexpr uint64 kMutable = 0x0000000000000002ULL; + +// An error was detected while constructing/using the FST. +constexpr uint64 kError = 0x0000000000000004ULL; + +// TRINARY PROPERTIES +// +// For each of these properties below there is a pair of property bits, one +// positive and one negative. If the positive bit is set, the property is true. +// If the negative bit is set, the property is false. If neither is set, the +// property has unknown value. Both should never be simultaneously set. The +// individual positive and negative bit pairs should be adjacent with the +// positive bit at an odd and lower position. + +// ilabel == olabel for each arc. +constexpr uint64 kAcceptor = 0x0000000000010000ULL; +// ilabel != olabel for some arc. +constexpr uint64 kNotAcceptor = 0x0000000000020000ULL; + +// ilabels unique leaving each state. +constexpr uint64 kIDeterministic = 0x0000000000040000ULL; +// ilabels not unique leaving some state. +constexpr uint64 kNonIDeterministic = 0x0000000000080000ULL; + +// olabels unique leaving each state. +constexpr uint64 kODeterministic = 0x0000000000100000ULL; +// olabels not unique leaving some state. +constexpr uint64 kNonODeterministic = 0x0000000000200000ULL; + +// FST has input/output epsilons. +constexpr uint64 kEpsilons = 0x0000000000400000ULL; +// FST has no input/output epsilons. +constexpr uint64 kNoEpsilons = 0x0000000000800000ULL; + +// FST has input epsilons. +constexpr uint64 kIEpsilons = 0x0000000001000000ULL; +// FST has no input epsilons. +constexpr uint64 kNoIEpsilons = 0x0000000002000000ULL; + +// FST has output epsilons. +constexpr uint64 kOEpsilons = 0x0000000004000000ULL; +// FST has no output epsilons. +constexpr uint64 kNoOEpsilons = 0x0000000008000000ULL; + +// ilabels sorted wrt < for each state. +constexpr uint64 kILabelSorted = 0x0000000010000000ULL; +// ilabels not sorted wrt < for some state. +constexpr uint64 kNotILabelSorted = 0x0000000020000000ULL; + +// olabels sorted wrt < for each state. +constexpr uint64 kOLabelSorted = 0x0000000040000000ULL; +// olabels not sorted wrt < for some state. +constexpr uint64 kNotOLabelSorted = 0x0000000080000000ULL; + +// Non-trivial arc or final weights. +constexpr uint64 kWeighted = 0x0000000100000000ULL; +// Only trivial arc and final weights. +constexpr uint64 kUnweighted = 0x0000000200000000ULL; + +// FST has cycles. +constexpr uint64 kCyclic = 0x0000000400000000ULL; +// FST has no cycles. +constexpr uint64 kAcyclic = 0x0000000800000000ULL; + +// FST has cycles containing the initial state. +constexpr uint64 kInitialCyclic = 0x0000001000000000ULL; +// FST has no cycles containing the initial state. +constexpr uint64 kInitialAcyclic = 0x0000002000000000ULL; + +// FST is topologically sorted. +constexpr uint64 kTopSorted = 0x0000004000000000ULL; +// FST is not topologically sorted. +constexpr uint64 kNotTopSorted = 0x0000008000000000ULL; + +// All states reachable from the initial state. +constexpr uint64 kAccessible = 0x0000010000000000ULL; +// Not all states reachable from the initial state. +constexpr uint64 kNotAccessible = 0x0000020000000000ULL; + +// All states can reach a final state. +constexpr uint64 kCoAccessible = 0x0000040000000000ULL; +// Not all states can reach a final state. +constexpr uint64 kNotCoAccessible = 0x0000080000000000ULL; + +// If NumStates() > 0, then state 0 is initial, state NumStates() - 1 is final, +// there is a transition from each non-final state i to state i + 1, and there +// are no other transitions. +constexpr uint64 kString = 0x0000100000000000ULL; + +// Not a string FST. +constexpr uint64 kNotString = 0x0000200000000000ULL; + +// FST has least one weighted cycle. +constexpr uint64 kWeightedCycles = 0x0000400000000000ULL; + +// Only unweighted cycles. +constexpr uint64 kUnweightedCycles = 0x0000800000000000ULL; + +// COMPOSITE PROPERTIES + +// Properties of an empty machine. +constexpr uint64 kNullProperties = + kAcceptor | kIDeterministic | kODeterministic | kNoEpsilons | kNoIEpsilons | + kNoOEpsilons | kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | + kInitialAcyclic | kTopSorted | kAccessible | kCoAccessible | kString | + kUnweightedCycles; + +// Properties that are preserved when an FST is copied. +constexpr uint64 kCopyProperties = + kError | kAcceptor | kNotAcceptor | kIDeterministic | kNonIDeterministic | + kODeterministic | kNonODeterministic | kEpsilons | kNoEpsilons | + kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | kILabelSorted | + kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString | kWeightedCycles | + kUnweightedCycles; + +// Properties that are intrinsic to the FST. +constexpr uint64 kIntrinsicProperties = + kExpanded | kMutable | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kAccessible | + kNotAccessible | kCoAccessible | kNotCoAccessible | kString | kNotString | + kWeightedCycles | kUnweightedCycles; + +// Properties that are (potentially) extrinsic to the FST. +constexpr uint64 kExtrinsicProperties = kError; + +// Properties that are preserved when an FST start state is set. +constexpr uint64 kSetStartProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kTopSorted | kNotTopSorted | + kCoAccessible | kNotCoAccessible | kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST final weight is set. +constexpr uint64 kSetFinalProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | + kNotTopSorted | kAccessible | kNotAccessible | kWeightedCycles | + kUnweightedCycles; + +// Properties that are preserved when an FST state is added. +constexpr uint64 kAddStateProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kNotAccessible | + kNotCoAccessible | kNotString | kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST arc is added. +constexpr uint64 kAddArcProperties = + kExpanded | kMutable | kError | kNotAcceptor | kNonIDeterministic | + kNonODeterministic | kEpsilons | kIEpsilons | kOEpsilons | + kNotILabelSorted | kNotOLabelSorted | kWeighted | kCyclic | kInitialCyclic | + kNotTopSorted | kAccessible | kCoAccessible | kWeightedCycles; + +// Properties that are preserved when an FST arc is set. +constexpr uint64 kSetArcProperties = kExpanded | kMutable | kError; + +// Properties that are preserved when FST states are deleted. +constexpr uint64 kDeleteStatesProperties = + kExpanded | kMutable | kError | kAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | kInitialAcyclic | + kTopSorted | kUnweightedCycles; + +// Properties that are preserved when FST arcs are deleted. +constexpr uint64 kDeleteArcsProperties = + kExpanded | kMutable | kError | kAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kAcyclic | kInitialAcyclic | + kTopSorted | kNotAccessible | kNotCoAccessible | kUnweightedCycles; + +// Properties that are preserved when an FST's states are reordered. +constexpr uint64 kStateSortProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST's arcs are reordered. +constexpr uint64 kArcSortProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kWeighted | kUnweighted | kCyclic | kAcyclic | kInitialCyclic | + kInitialAcyclic | kTopSorted | kNotTopSorted | kAccessible | + kNotAccessible | kCoAccessible | kNotCoAccessible | kString | kNotString | + kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when an FST's input labels are changed. +constexpr uint64 kILabelInvariantProperties = + kExpanded | kMutable | kError | kODeterministic | kNonODeterministic | + kOEpsilons | kNoOEpsilons | kOLabelSorted | kNotOLabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString | kWeightedCycles | + kUnweightedCycles; + +// Properties that are preserved when an FST's output labels are changed. +constexpr uint64 kOLabelInvariantProperties = + kExpanded | kMutable | kError | kIDeterministic | kNonIDeterministic | + kIEpsilons | kNoIEpsilons | kILabelSorted | kNotILabelSorted | kWeighted | + kUnweighted | kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kTopSorted | kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString | kWeightedCycles | + kUnweightedCycles; + +// Properties that are preserved when an FST's weights are changed. This +// assumes that the set of states that are non-final is not changed. +constexpr uint64 kWeightInvariantProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kNonIDeterministic | kODeterministic | kNonODeterministic | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | kNoOEpsilons | + kILabelSorted | kNotILabelSorted | kOLabelSorted | kNotOLabelSorted | + kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | + kNotTopSorted | kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible | kString | kNotString; + +// Properties that are preserved when a superfinal state is added and an FST's +// final weights are directed to it via new transitions. +constexpr uint64 kAddSuperFinalProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | + kNonIDeterministic | kNonODeterministic | kEpsilons | kIEpsilons | + kOEpsilons | kNotILabelSorted | kNotOLabelSorted | kWeighted | kUnweighted | + kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | kNotTopSorted | + kNotAccessible | kCoAccessible | kNotCoAccessible | kNotString | + kWeightedCycles | kUnweightedCycles; + +// Properties that are preserved when a superfinal state is removed and the +// epsilon transitions directed to it are made final weights. +constexpr uint64 kRmSuperFinalProperties = + kExpanded | kMutable | kError | kAcceptor | kNotAcceptor | kIDeterministic | + kODeterministic | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kWeighted | kUnweighted | kCyclic | + kAcyclic | kInitialCyclic | kInitialAcyclic | kTopSorted | kAccessible | + kCoAccessible | kNotCoAccessible | kString | kWeightedCycles | + kUnweightedCycles; + +// All binary properties. +constexpr uint64 kBinaryProperties = 0x0000000000000007ULL; + +// All trinary properties. +constexpr uint64 kTrinaryProperties = 0x0000ffffffff0000ULL; + +// COMPUTED PROPERTIES + +// 1st bit of trinary properties. +constexpr uint64 kPosTrinaryProperties = kTrinaryProperties & + 0x5555555555555555ULL; + +// 2nd bit of trinary properties. +constexpr uint64 kNegTrinaryProperties = kTrinaryProperties & + 0xaaaaaaaaaaaaaaaaULL; + +// All properties. +constexpr uint64 kFstProperties = kBinaryProperties | kTrinaryProperties; + +// PROPERTY FUNCTIONS and STRING NAMES (defined in properties.cc). + +// Below are functions for getting property bit vectors when executing +// mutation operations. +inline uint64 SetStartProperties(uint64 inprops); + +template +uint64 SetFinalProperties(uint64 inprops, const Weight &old_weight, + const Weight &new_weight); + +inline uint64 AddStateProperties(uint64 inprops); + +template +uint64 AddArcProperties(uint64 inprops, typename A::StateId s, const A &arc, + const A *prev_arc); + +inline uint64 DeleteStatesProperties(uint64 inprops); + +inline uint64 DeleteAllStatesProperties(uint64 inprops, uint64 staticProps); + +inline uint64 DeleteArcsProperties(uint64 inprops); + +uint64 ClosureProperties(uint64 inprops, bool star, bool delayed = false); + +uint64 ComplementProperties(uint64 inprops); + +uint64 ComposeProperties(uint64 inprops1, uint64 inprops2); + +uint64 ConcatProperties(uint64 inprops1, uint64 inprops2, bool delayed = false); + +uint64 DeterminizeProperties(uint64 inprops, bool has_subsequential_label, + bool distinct_psubsequential_labels); + +uint64 FactorWeightProperties(uint64 inprops); + +uint64 InvertProperties(uint64 inprops); + +uint64 ProjectProperties(uint64 inprops, bool project_input); + +uint64 RandGenProperties(uint64 inprops, bool weighted); + +uint64 RelabelProperties(uint64 inprops); + +uint64 ReplaceProperties(const std::vector &inprops, size_t root, + bool epsilon_on_call, bool epsilon_on_return, + bool out_epsilon_on_call, bool out_epsilon_on_return, + bool replace_transducer, bool no_empty_fst, + bool all_ilabel_sorted, bool all_olabel_sorted, + bool all_negative_or_dense); + +uint64 ReverseProperties(uint64 inprops, bool has_superinitial); + +uint64 ReweightProperties(uint64 inprops); + +uint64 RmEpsilonProperties(uint64 inprops, bool delayed = false); + +uint64 ShortestPathProperties(uint64 props, bool tree = false); + +uint64 SynchronizeProperties(uint64 inprops); + +uint64 UnionProperties(uint64 inprops1, uint64 inprops2, bool delayed = false); + +// Definitions of inlined functions. + +uint64 SetStartProperties(uint64 inprops) { + auto outprops = inprops & kSetStartProperties; + if (inprops & kAcyclic) { + outprops |= kInitialAcyclic; + } + return outprops; +} + +uint64 AddStateProperties(uint64 inprops) { + return inprops & kAddStateProperties; +} + +uint64 DeleteStatesProperties(uint64 inprops) { + return inprops & kDeleteStatesProperties; +} + +uint64 DeleteAllStatesProperties(uint64 inprops, uint64 staticprops) { + const auto outprops = inprops & kError; + return outprops | kNullProperties | staticprops; +} + +uint64 DeleteArcsProperties(uint64 inprops) { + return inprops & kDeleteArcsProperties; +} + +// Definitions of template functions. + +template +uint64 SetFinalProperties(uint64 inprops, const Weight &old_weight, + const Weight &new_weight) { + auto outprops = inprops; + if (old_weight != Weight::Zero() && old_weight != Weight::One()) { + outprops &= ~kWeighted; + } + if (new_weight != Weight::Zero() && new_weight != Weight::One()) { + outprops |= kWeighted; + outprops &= ~kUnweighted; + } + outprops &= kSetFinalProperties | kWeighted | kUnweighted; + return outprops; +} + +/// Gets the properties for the MutableFst::AddArc method. +/// +/// \param inprops the current properties of the FST +/// \param s the ID of the state to which an arc is being added. +/// \param arc the arc being added to the state with the specified ID +/// \param prev_arc the previously-added (or "last") arc of state s, or nullptr +// if s currently has no arcs. +template +uint64 AddArcProperties(uint64 inprops, typename Arc::StateId s, + const Arc &arc, const Arc *prev_arc) { + using Weight = typename Arc::Weight; + auto outprops = inprops; + if (arc.ilabel != arc.olabel) { + outprops |= kNotAcceptor; + outprops &= ~kAcceptor; + } + if (arc.ilabel == 0) { + outprops |= kIEpsilons; + outprops &= ~kNoIEpsilons; + if (arc.olabel == 0) { + outprops |= kEpsilons; + outprops &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + outprops |= kOEpsilons; + outprops &= ~kNoOEpsilons; + } + if (prev_arc) { + if (prev_arc->ilabel > arc.ilabel) { + outprops |= kNotILabelSorted; + outprops &= ~kILabelSorted; + } + if (prev_arc->olabel > arc.olabel) { + outprops |= kNotOLabelSorted; + outprops &= ~kOLabelSorted; + } + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + outprops |= kWeighted; + outprops &= ~kUnweighted; + } + if (arc.nextstate <= s) { + outprops |= kNotTopSorted; + outprops &= ~kTopSorted; + } + outprops &= kAddArcProperties | kAcceptor | kNoEpsilons | kNoIEpsilons | + kNoOEpsilons | kILabelSorted | kOLabelSorted | kUnweighted | + kTopSorted; + if (outprops & kTopSorted) { + outprops |= kAcyclic | kInitialAcyclic; + } + return outprops; +} + +extern const char *PropertyNames[]; + +} // namespace fst + +#endif // FST_PROPERTIES_H_ diff --git a/projects/llm_framework/include/fst/prune.h b/projects/llm_framework/include/fst/prune.h new file mode 100644 index 00000000..9e7c04bd --- /dev/null +++ b/projects/llm_framework/include/fst/prune.h @@ -0,0 +1,341 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions implementing pruning. + +#ifndef FST_PRUNE_H_ +#define FST_PRUNE_H_ + +#include +#include +#include + +#include + +#include +#include +#include + + +namespace fst { +namespace internal { + +template +class PruneCompare { + public: + PruneCompare(const std::vector &idistance, + const std::vector &fdistance) + : idistance_(idistance), fdistance_(fdistance) {} + + bool operator()(const StateId x, const StateId y) const { + const auto wx = Times(IDistance(x), FDistance(x)); + const auto wy = Times(IDistance(y), FDistance(y)); + return less_(wx, wy); + } + + private: + Weight IDistance(const StateId s) const { + return s < idistance_.size() ? idistance_[s] : Weight::Zero(); + } + + Weight FDistance(const StateId s) const { + return s < fdistance_.size() ? fdistance_[s] : Weight::Zero(); + } + + const std::vector &idistance_; + const std::vector &fdistance_; + NaturalLess less_; +}; + +} // namespace internal + +template +struct PruneOptions { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit PruneOptions(const Weight &weight_threshold = Weight::Zero(), + StateId state_threshold = kNoStateId, + ArcFilter filter = ArcFilter(), + std::vector *distance = nullptr, + float delta = kDelta, bool threshold_initial = false) + : weight_threshold(std::move(weight_threshold)), + state_threshold(state_threshold), + filter(std::move(filter)), + distance(distance), + delta(delta), + threshold_initial(threshold_initial) {} + + // Pruning weight threshold. + Weight weight_threshold; + // Pruning state threshold. + StateId state_threshold; + // Arc filter. + ArcFilter filter; + // If non-zero, passes in pre-computed shortest distance to final states. + const std::vector *distance; + // Determines the degree of convergence required when computing shortest + // distances. + float delta; + // Determines if the shortest path weight is left (true) or right + // (false) multiplied by the threshold to get the limit for + // keeping a state or arc (matters if the semiring is not + // commutative). + bool threshold_initial; +}; + +// Pruning algorithm: this version modifies its input and it takes an options +// class as an argument. After pruning the FST contains states and arcs that +// belong to a successful path in the FST whose weight is no more than the +// weight of the shortest path Times() the provided weight threshold. When the +// state threshold is not kNoStateId, the output FST is further restricted to +// have no more than the number of states in opts.state_threshold. Weights must +// have the path property. The weight of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) == Weight::One() +template ::value>::type * = + nullptr> +void Prune(MutableFst *fst, const PruneOptions &opts = + PruneOptions()) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using StateHeap = Heap>; + auto ns = fst->NumStates(); + if (ns < 1) return; + std::vector idistance(ns, Weight::Zero()); + std::vector tmp; + if (!opts.distance) { + tmp.reserve(ns); + ShortestDistance(*fst, &tmp, true, opts.delta); + } + const auto *fdistance = opts.distance ? opts.distance : &tmp; + if ((opts.state_threshold == 0) || (fdistance->size() <= fst->Start()) || + ((*fdistance)[fst->Start()] == Weight::Zero())) { + fst->DeleteStates(); + return; + } + internal::PruneCompare compare(idistance, *fdistance); + StateHeap heap(compare); + std::vector visited(ns, false); + std::vector enqueued(ns, StateHeap::kNoKey); + std::vector dead; + dead.push_back(fst->AddState()); + NaturalLess less; + auto s = fst->Start(); + const auto limit = opts.threshold_initial ? + Times(opts.weight_threshold, (*fdistance)[s]) : + Times((*fdistance)[s], opts.weight_threshold); + StateId num_visited = 0; + + if (!less(limit, (*fdistance)[s])) { + idistance[s] = Weight::One(); + enqueued[s] = heap.Insert(s); + ++num_visited; + } + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = StateHeap::kNoKey; + visited[s] = true; + if (less(limit, Times(idistance[s], fst->Final(s)))) { + fst->SetFinal(s, Weight::Zero()); + } + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); // Copy intended. + if (!opts.filter(arc)) continue; + const auto weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() ? + (*fdistance)[arc.nextstate] : Weight::Zero()); + if (less(limit, weight)) { + arc.nextstate = dead[0]; + aiter.SetValue(arc); + continue; + } + if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) { + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + } + if (visited[arc.nextstate]) continue; + if ((opts.state_threshold != kNoStateId) && + (num_visited >= opts.state_threshold)) { + continue; + } + if (enqueued[arc.nextstate] == StateHeap::kNoKey) { + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + ++num_visited; + } else { + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } + } + for (StateId i = 0; i < visited.size(); ++i) { + if (!visited[i]) dead.push_back(i); + } + fst->DeleteStates(dead); +} + +template ::value>::type + * = nullptr> +void Prune(MutableFst *fst, const PruneOptions &opts = + PruneOptions()) { + FSTERROR() << "Prune: Weight needs to have the path property: " + << Arc::Weight::Type(); + fst->SetProperties(kError, kError); +} + +// Pruning algorithm: this version modifies its input and takes the +// pruning threshold as an argument. It deletes states and arcs in the +// FST that do not belong to a successful path whose weight is more +// than the weight of the shortest path Times() the provided weight +// threshold. When the state threshold is not kNoStateId, the output +// FST is further restricted to have no more than the number of states +// in opts.state_threshold. Weights must have the path property. The +// weight of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) == Weight::One() +template +void Prune(MutableFst *fst, typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + const PruneOptions> opts( + weight_threshold, state_threshold, AnyArcFilter(), nullptr, delta); + Prune(fst, opts); +} + +// Pruning algorithm: this version writes the pruned input FST to an +// output MutableFst and it takes an options class as an argument. The +// output FST contains states and arcs that belong to a successful +// path in the input FST whose weight is more than the weight of the +// shortest path Times() the provided weight threshold. When the state +// threshold is not kNoStateId, the output FST is further restricted +// to have no more than the number of states in +// opts.state_threshold. Weights have the path property. The weight +// of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) == Weight::One() +template ::value>::type * = + nullptr> +void Prune( + const Fst &ifst, MutableFst *ofst, + const PruneOptions &opts = PruneOptions()) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using StateHeap = Heap>; + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Start() == kNoStateId) return; + NaturalLess less; + if (less(opts.weight_threshold, Weight::One()) || + (opts.state_threshold == 0)) { + return; + } + std::vector idistance; + std::vector tmp; + if (!opts.distance) ShortestDistance(ifst, &tmp, true, opts.delta); + const auto *fdistance = opts.distance ? opts.distance : &tmp; + if ((fdistance->size() <= ifst.Start()) || + ((*fdistance)[ifst.Start()] == Weight::Zero())) { + return; + } + internal::PruneCompare compare(idistance, *fdistance); + StateHeap heap(compare); + std::vector copy; + std::vector enqueued; + std::vector visited; + auto s = ifst.Start(); + const auto limit = opts.threshold_initial ? + Times(opts.weight_threshold, (*fdistance)[s]) : + Times((*fdistance)[s], opts.weight_threshold); + while (copy.size() <= s) copy.push_back(kNoStateId); + copy[s] = ofst->AddState(); + ofst->SetStart(copy[s]); + while (idistance.size() <= s) idistance.push_back(Weight::Zero()); + idistance[s] = Weight::One(); + while (enqueued.size() <= s) { + enqueued.push_back(StateHeap::kNoKey); + visited.push_back(false); + } + enqueued[s] = heap.Insert(s); + while (!heap.Empty()) { + s = heap.Top(); + heap.Pop(); + enqueued[s] = StateHeap::kNoKey; + visited[s] = true; + if (!less(limit, Times(idistance[s], ifst.Final(s)))) { + ofst->SetFinal(copy[s], ifst.Final(s)); + } + for (ArcIterator> aiter(ifst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (!opts.filter(arc)) continue; + const auto weight = Times(Times(idistance[s], arc.weight), + arc.nextstate < fdistance->size() ? + (*fdistance)[arc.nextstate] : Weight::Zero()); + if (less(limit, weight)) continue; + if ((opts.state_threshold != kNoStateId) && + (ofst->NumStates() >= opts.state_threshold)) { + continue; + } + while (idistance.size() <= arc.nextstate) { + idistance.push_back(Weight::Zero()); + } + if (less(Times(idistance[s], arc.weight), idistance[arc.nextstate])) { + idistance[arc.nextstate] = Times(idistance[s], arc.weight); + } + while (copy.size() <= arc.nextstate) copy.push_back(kNoStateId); + if (copy[arc.nextstate] == kNoStateId) { + copy[arc.nextstate] = ofst->AddState(); + } + ofst->AddArc(copy[s], Arc(arc.ilabel, arc.olabel, arc.weight, + copy[arc.nextstate])); + while (enqueued.size() <= arc.nextstate) { + enqueued.push_back(StateHeap::kNoKey); + visited.push_back(false); + } + if (visited[arc.nextstate]) continue; + if (enqueued[arc.nextstate] == StateHeap::kNoKey) { + enqueued[arc.nextstate] = heap.Insert(arc.nextstate); + } else { + heap.Update(enqueued[arc.nextstate], arc.nextstate); + } + } + } +} + +template ::value>::type + * = nullptr> +void Prune(const Fst &, MutableFst *ofst, + const PruneOptions &) { + FSTERROR() << "Prune: Weight needs to have the path property: " + << Arc::Weight::Type(); + ofst->SetProperties(kError, kError); +} + +// Pruning algorithm: this version writes the pruned input FST to an +// output MutableFst and simply takes the pruning threshold as an +// argument. The output FST contains states and arcs that belong to a +// successful path in the input FST whose weight is no more than the +// weight of the shortest path Times() the provided weight +// threshold. When the state threshold is not kNoStateId, the output +// FST is further restricted to have no more than the number of states +// in opts.state_threshold. Weights must have the path property. The +// weight of any cycle needs to be bounded; i.e., +// +// Plus(weight, Weight::One()) = Weight::One(); +template +void Prune(const Fst &ifst, MutableFst *ofst, + typename Arc::Weight weight_threshold, + typename Arc::StateId state_threshold = kNoStateId, + float delta = kDelta) { + const PruneOptions> opts( + weight_threshold, state_threshold, AnyArcFilter(), nullptr, delta); + Prune(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_PRUNE_H_ diff --git a/projects/llm_framework/include/fst/push.h b/projects/llm_framework/include/fst/push.h new file mode 100644 index 00000000..1f772739 --- /dev/null +++ b/projects/llm_framework/include/fst/push.h @@ -0,0 +1,155 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to reweight/push an FST, and utility functions to weigh and reweight +// an FST. + +#ifndef FST_PUSH_H_ +#define FST_PUSH_H_ + +#include + +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +// Computes the total weight (sum of the weights of all accepting paths) from +// the output of ShortestDistance, using the shortest distance from the final +// state when reverse is true and from the initial state otherwise. +template +typename Arc::Weight ComputeTotalWeight( + const Fst &fst, const std::vector &distance, + bool reverse) { + if (reverse) { + return fst.Start() < distance.size() ? distance[fst.Start()] + : Arc::Weight::Zero(); + } + auto sum = Arc::Weight::Zero(); + for (typename Arc::StateId s = 0; s < distance.size(); ++s) { + sum = Plus(sum, Times(distance[s], fst.Final(s))); + } + return sum; +} + +// Divides the weight of every accepting path by a fixed weight. This weight +// is also divided at the final state if at_final is true and at the initial +// state otherwise. +template +void RemoveWeight(MutableFst *fst, const typename Arc::Weight &weight, + bool at_final) { + using Weight = typename Arc::Weight; + if ((weight == Weight::One()) || (weight == Weight::Zero())) return; + if (at_final) { + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + fst->SetFinal(siter.Value(), + Divide(fst->Final(siter.Value()), weight, DIVIDE_RIGHT)); + } + } else { + const auto start = fst->Start(); + for (MutableArcIterator> aiter(fst, start); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + arc.weight = Divide(arc.weight, weight, DIVIDE_LEFT); + aiter.SetValue(arc); + } + fst->SetFinal(start, Divide(fst->Final(start), weight, DIVIDE_LEFT)); + } +} + +// Pushes the weights in FST in the direction defined by TYPE. If +// pushing towards the initial state, the sum of the weight of the +// outgoing transitions and final weight at a non-initial state is +// equal to One() in the resulting machine. If pushing towards the +// final state, the same property holds on the reverse machine. +// +// Weight needs to be left distributive when pushing towards the +// initial state and right distributive when pushing towards the final +// states. +template +void Push(MutableFst *fst, ReweightType type, float delta = kDelta, + bool remove_total_weight = false) { + using Weight = typename Arc::Weight; + std::vector distance; + ShortestDistance(*fst, &distance, type == REWEIGHT_TO_INITIAL, delta); + auto total_weight = Weight::One(); + if (remove_total_weight) { + total_weight = + ComputeTotalWeight(*fst, distance, type == REWEIGHT_TO_INITIAL); + } + Reweight(fst, distance, type); + if (remove_total_weight) { + RemoveWeight(fst, total_weight, type == REWEIGHT_TO_FINAL); + } +} + +constexpr uint32 kPushWeights = 0x0001; +constexpr uint32 kPushLabels = 0x0002; +constexpr uint32 kPushRemoveTotalWeight = 0x0004; +constexpr uint32 kPushRemoveCommonAffix = 0x0008; + +// Pushes the weights and/or labels of the input FST into the output +// mutable FST by pushing weights and/or labels (as determined by the +// ptype argument) towards the initial state or final states (as +// determined by the rtype template parameter). The weight type must +// be left distributive when pushing weights towards the initial state, and +// right distribution when pushing weights towards the final states. +template +void Push(const Fst &ifst, MutableFst *ofst, uint32 ptype, + float delta = kDelta) { + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + if ((ptype & (kPushWeights | kPushLabels)) == kPushWeights) { + *ofst = ifst; + Push(ofst, rtype, delta, ptype & kPushRemoveTotalWeight); + } else if (ptype & kPushLabels) { + const auto gtype = + rtype == REWEIGHT_TO_INITIAL ? GALLIC_LEFT : GALLIC_RIGHT; + using GallicWeight = typename GallicArc::Weight; + std::vector gdistance; + VectorFst> gfst; + ArcMap(ifst, &gfst, ToGallicMapper()); + if (ptype & kPushWeights) { + ShortestDistance(gfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } else { + ArcMapFst> uwfst(ifst, + RmWeightMapper()); + ArcMapFst, ToGallicMapper> guwfst( + uwfst, ToGallicMapper()); + ShortestDistance(guwfst, &gdistance, rtype == REWEIGHT_TO_INITIAL, delta); + } + auto total_weight = GallicWeight::One(); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { + total_weight = + ComputeTotalWeight(gfst, gdistance, rtype == REWEIGHT_TO_INITIAL); + total_weight = GallicWeight( + ptype & kPushRemoveCommonAffix + ? total_weight.Value1() + : StringWeight::One(), + ptype & kPushRemoveTotalWeight ? total_weight.Value2() + : Weight::One()); + } + Reweight(&gfst, gdistance, rtype); + if (ptype & (kPushRemoveTotalWeight | kPushRemoveCommonAffix)) { + RemoveWeight(&gfst, total_weight, rtype == REWEIGHT_TO_FINAL); + } + FactorWeightFst, GallicFactor> + fwfst(gfst); + ArcMap(fwfst, ofst, FromGallicMapper()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else { + LOG(WARNING) << "Push: pushing type is set to 0, so not pushing"; + *ofst = ifst; + } +} + +} // namespace fst + +#endif // FST_PUSH_H_ diff --git a/projects/llm_framework/include/fst/queue.h b/projects/llm_framework/include/fst/queue.h new file mode 100644 index 00000000..f57d176e --- /dev/null +++ b/projects/llm_framework/include/fst/queue.h @@ -0,0 +1,948 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes for various FST state queues with a unified interface. + +#ifndef FST_QUEUE_H_ +#define FST_QUEUE_H_ + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +// The Queue interface is: +// +// template +// class Queue { +// public: +// using StateId = S; +// +// // Constructor: may need args (e.g., FST, comparator) for some queues. +// Queue(...) override; +// +// // Returns the head of the queue. +// StateId Head() const override; +// +// // Inserts a state. +// void Enqueue(StateId s) override; +// +// // Removes the head of the queue. +// void Dequeue() override; +// +// // Updates ordering of state s when weight changes, if necessary. +// void Update(StateId s) override; +// +// // Is the queue empty? +// bool Empty() const override; +// +// // Removes all states from the queue. +// void Clear() override; +// }; + +// State queue types. +enum QueueType { + TRIVIAL_QUEUE = 0, // Single state queue. + FIFO_QUEUE = 1, // First-in, first-out queue. + LIFO_QUEUE = 2, // Last-in, first-out queue. + SHORTEST_FIRST_QUEUE = 3, // Shortest-first queue. + TOP_ORDER_QUEUE = 4, // Topologically-ordered queue. + STATE_ORDER_QUEUE = 5, // State ID-ordered queue. + SCC_QUEUE = 6, // Component graph top-ordered meta-queue. + AUTO_QUEUE = 7, // Auto-selected queue. + OTHER_QUEUE = 8 +}; + +// QueueBase, templated on the StateId, is a virtual base class shared by all +// queues considered by AutoQueue. +template +class QueueBase { + public: + using StateId = S; + + virtual ~QueueBase() {} + + // Concrete implementation. + + explicit QueueBase(QueueType type) : queue_type_(type), error_(false) {} + + void SetError(bool error) { error_ = error; } + + bool Error() const { return error_; } + + QueueType Type() const { return queue_type_; } + + // Virtual interface. + + virtual StateId Head() const = 0; + virtual void Enqueue(StateId) = 0; + virtual void Dequeue() = 0; + virtual void Update(StateId) = 0; + virtual bool Empty() const = 0; + virtual void Clear() = 0; + + private: + QueueType queue_type_; + bool error_; +}; + +// Trivial queue discipline; one may enqueue at most one state at a time. It +// can be used for strongly connected components with only one state and no +// self-loops. +template +class TrivialQueue : public QueueBase { + public: + using StateId = S; + + TrivialQueue() : QueueBase(TRIVIAL_QUEUE), front_(kNoStateId) {} + + virtual ~TrivialQueue() = default; + + StateId Head() const final { return front_; } + + void Enqueue(StateId s) final { front_ = s; } + + void Dequeue() final { front_ = kNoStateId; } + + void Update(StateId) final {} + + bool Empty() const final { return front_ == kNoStateId; } + + void Clear() final { front_ = kNoStateId; } + + private: + StateId front_; +}; + +// First-in, first-out queue discipline. +// +// This is not a final class. +template +class FifoQueue : public QueueBase { + public: + using StateId = S; + + FifoQueue() : QueueBase(FIFO_QUEUE) {} + + virtual ~FifoQueue() = default; + + StateId Head() const override { return queue_.back(); } + + void Enqueue(StateId s) override { queue_.push_front(s); } + + void Dequeue() override { queue_.pop_back(); } + + void Update(StateId) override {} + + bool Empty() const override { return queue_.empty(); } + + void Clear() override { queue_.clear(); } + + private: + std::deque queue_; +}; + +// Last-in, first-out queue discipline. +template +class LifoQueue : public QueueBase { + public: + using StateId = S; + + LifoQueue() : QueueBase(LIFO_QUEUE) {} + + virtual ~LifoQueue() = default; + + StateId Head() const final { return queue_.front(); } + + void Enqueue(StateId s) final { queue_.push_front(s); } + + void Dequeue() final { queue_.pop_front(); } + + void Update(StateId) final {} + + bool Empty() const final { return queue_.empty(); } + + void Clear() final { queue_.clear(); } + + private: + std::deque queue_; +}; + +// Shortest-first queue discipline, templated on the StateId and as well as a +// comparison functor used to compare two StateIds. If a (single) state's order +// changes, it can be reordered in the queue with a call to Update(). If update +// is false, call to Update() does not reorder the queue. +// +// This is not a final class. +template +class ShortestFirstQueue : public QueueBase { + public: + using StateId = S; + + explicit ShortestFirstQueue(Compare comp) + : QueueBase(SHORTEST_FIRST_QUEUE), heap_(comp) {} + + virtual ~ShortestFirstQueue() = default; + + StateId Head() const override { return heap_.Top(); } + + void Enqueue(StateId s) override { + if (update) { + for (StateId i = key_.size(); i <= s; ++i) key_.push_back(kNoStateId); + key_[s] = heap_.Insert(s); + } else { + heap_.Insert(s); + } + } + + void Dequeue() override { + if (update) { + key_[heap_.Pop()] = kNoStateId; + } else { + heap_.Pop(); + } + } + + void Update(StateId s) override { + if (!update) return; + if (s >= key_.size() || key_[s] == kNoStateId) { + Enqueue(s); + } else { + heap_.Update(key_[s], s); + } + } + + bool Empty() const override { return heap_.Empty(); } + + void Clear() override { + heap_.Clear(); + if (update) key_.clear(); + } + + const Compare &GetCompare() const { return heap_.GetCompare(); } + + private: + Heap heap_; + std::vector key_; +}; + +namespace internal { + +// Given a vector that maps from states to weights, and a comparison functor +// for weights, this class defines a comparison function object between states. +template +class StateWeightCompare { + public: + using Weight = typename Less::Weight; + + StateWeightCompare(const std::vector &weights, const Less &less) + : weights_(weights), less_(less) {} + + bool operator()(const StateId s1, const StateId s2) const { + return less_(weights_[s1], weights_[s2]); + } + + private: + // Borrowed references. + const std::vector &weights_; + const Less &less_; +}; + +} // namespace internal + +// Shortest-first queue discipline, templated on the StateId and Weight, is +// specialized to use the weight's natural order for the comparison function. +template +class NaturalShortestFirstQueue + : public ShortestFirstQueue< + S, internal::StateWeightCompare>> { + public: + using StateId = S; + using Compare = internal::StateWeightCompare>; + + explicit NaturalShortestFirstQueue(const std::vector &distance) + : ShortestFirstQueue(Compare(distance, less_)) {} + + virtual ~NaturalShortestFirstQueue() = default; + + private: + // This is non-static because the constructor for non-idempotent weights will + // result in an error. + const NaturalLess less_{}; +}; + +// In a shortest path computation on a lattice-like FST, we may keep many old +// nonviable paths as a part of the search. Since the search process always +// expands the lowest cost path next, that lowest cost path may be a very old +// nonviable path instead of one we expect to lead to a shortest path. +// +// For instance, suppose that the current best path in an alignment has +// traversed 500 arcs with a cost of 10. We may also have a bad path in +// the queue that has traversed only 40 arcs but also has a cost of 10. +// This path is very unlikely to lead to a reasonable alignment, so this queue +// can prune it from the search space. +// +// This queue relies on the caller using a shortest-first exploration order +// like this: +// while (true) { +// StateId head = queue.Head(); +// queue.Dequeue(); +// for (const auto& arc : GetArcs(fst, head)) { +// queue.Enqueue(arc.nextstate); +// } +// } +// We use this assumption to guess that there is an arc between Head and the +// Enqueued state; this is how the number of path steps is measured. +template +class PruneNaturalShortestFirstQueue + : public NaturalShortestFirstQueue { + public: + using StateId = S; + using Base = NaturalShortestFirstQueue; + + explicit PruneNaturalShortestFirstQueue(const std::vector &distance, + int threshold) + : Base(distance), + threshold_(threshold), + head_steps_(0), + max_head_steps_(0) {} + + ~PruneNaturalShortestFirstQueue() override = default; + + StateId Head() const override { + const auto head = Base::Head(); + // Stores the number of steps from the start of the graph to this state + // along the shortest-weight path. + if (head < steps_.size()) { + max_head_steps_ = std::max(steps_[head], max_head_steps_); + head_steps_ = steps_[head]; + } + return head; + } + + void Enqueue(StateId s) override { + // We assume that there is an arc between the Head() state and this + // Enqueued state. + const ssize_t state_steps = head_steps_ + 1; + if (s >= steps_.size()) { + steps_.resize(s + 1, state_steps); + } + // This is the number of arcs in the minimum cost path from Start to s. + steps_[s] = state_steps; + if (state_steps > (max_head_steps_ - threshold_) || threshold_ < 0) { + Base::Enqueue(s); + } + } + + private: + // A dense map from StateId to the number of arcs in the minimum weight + // path from Start to this state. + std::vector steps_; + // We only keep paths that are within this number of arcs (not weight!) + // of the longest path. + const ssize_t threshold_; + + // The following are mutable because Head() is const. + // The number of arcs traversed in the minimum cost path from the start + // state to the current Head() state. + mutable ssize_t head_steps_; + // The maximum number of arcs traversed by any low-cost path so far. + mutable ssize_t max_head_steps_; +}; + +// Topological-order queue discipline, templated on the StateId. States are +// ordered in the queue topologically. The FST must be acyclic. +template +class TopOrderQueue : public QueueBase { + public: + using StateId = S; + + // This constructor computes the topological order. It accepts an arc filter + // to limit the transitions considered in that computation (e.g., only the + // epsilon graph). + template + TopOrderQueue(const Fst &fst, ArcFilter filter) + : QueueBase(TOP_ORDER_QUEUE), + front_(0), + back_(kNoStateId), + order_(0), + state_(0) { + bool acyclic; + TopOrderVisitor top_order_visitor(&order_, &acyclic); + DfsVisit(fst, &top_order_visitor, filter); + if (!acyclic) { + FSTERROR() << "TopOrderQueue: FST is not acyclic"; + QueueBase::SetError(true); + } + state_.resize(order_.size(), kNoStateId); + } + + // This constructor is passed the pre-computed topological order. + explicit TopOrderQueue(const std::vector &order) + : QueueBase(TOP_ORDER_QUEUE), + front_(0), + back_(kNoStateId), + order_(order), + state_(order.size(), kNoStateId) {} + + virtual ~TopOrderQueue() = default; + + StateId Head() const final { return state_[front_]; } + + void Enqueue(StateId s) final { + if (front_ > back_) { + front_ = back_ = order_[s]; + } else if (order_[s] > back_) { + back_ = order_[s]; + } else if (order_[s] < front_) { + front_ = order_[s]; + } + state_[order_[s]] = s; + } + + void Dequeue() final { + state_[front_] = kNoStateId; + while ((front_ <= back_) && (state_[front_] == kNoStateId)) ++front_; + } + + void Update(StateId) final {} + + bool Empty() const final { return front_ > back_; } + + void Clear() final { + for (StateId s = front_; s <= back_; ++s) state_[s] = kNoStateId; + back_ = kNoStateId; + front_ = 0; + } + + private: + StateId front_; + StateId back_; + std::vector order_; + std::vector state_; +}; + +// State order queue discipline, templated on the StateId. States are ordered in +// the queue by state ID. +template +class StateOrderQueue : public QueueBase { + public: + using StateId = S; + + StateOrderQueue() + : QueueBase(STATE_ORDER_QUEUE), front_(0), back_(kNoStateId) {} + + virtual ~StateOrderQueue() = default; + + StateId Head() const final { return front_; } + + void Enqueue(StateId s) final { + if (front_ > back_) { + front_ = back_ = s; + } else if (s > back_) { + back_ = s; + } else if (s < front_) { + front_ = s; + } + while (enqueued_.size() <= s) enqueued_.push_back(false); + enqueued_[s] = true; + } + + void Dequeue() final { + enqueued_[front_] = false; + while ((front_ <= back_) && (enqueued_[front_] == false)) ++front_; + } + + void Update(StateId) final {} + + bool Empty() const final { return front_ > back_; } + + void Clear() final { + for (StateId i = front_; i <= back_; ++i) enqueued_[i] = false; + front_ = 0; + back_ = kNoStateId; + } + + private: + StateId front_; + StateId back_; + std::vector enqueued_; +}; + +// SCC topological-order meta-queue discipline, templated on the StateId and a +// queue used inside each SCC. It visits the SCCs of an FST in topological +// order. Its constructor is passed the queues to to use within an SCC. +template +class SccQueue : public QueueBase { + public: + using StateId = S; + + // Constructor takes a vector specifying the SCC number per state and a + // vector giving the queue to use per SCC number. + SccQueue(const std::vector &scc, + std::vector> *queue) + : QueueBase(SCC_QUEUE), + queue_(queue), + scc_(scc), + front_(0), + back_(kNoStateId) {} + + virtual ~SccQueue() = default; + + StateId Head() const final { + while ((front_ <= back_) && + (((*queue_)[front_] && (*queue_)[front_]->Empty()) || + (((*queue_)[front_] == nullptr) && + ((front_ >= trivial_queue_.size()) || + (trivial_queue_[front_] == kNoStateId))))) { + ++front_; + } + if ((*queue_)[front_]) { + return (*queue_)[front_]->Head(); + } else { + return trivial_queue_[front_]; + } + } + + void Enqueue(StateId s) final { + if (front_ > back_) { + front_ = back_ = scc_[s]; + } else if (scc_[s] > back_) { + back_ = scc_[s]; + } else if (scc_[s] < front_) { + front_ = scc_[s]; + } + if ((*queue_)[scc_[s]]) { + (*queue_)[scc_[s]]->Enqueue(s); + } else { + while (trivial_queue_.size() <= scc_[s]) { + trivial_queue_.push_back(kNoStateId); + } + trivial_queue_[scc_[s]] = s; + } + } + + void Dequeue() final { + if ((*queue_)[front_]) { + (*queue_)[front_]->Dequeue(); + } else if (front_ < trivial_queue_.size()) { + trivial_queue_[front_] = kNoStateId; + } + } + + void Update(StateId s) final { + if ((*queue_)[scc_[s]]) (*queue_)[scc_[s]]->Update(s); + } + + bool Empty() const final { + // Queues SCC number back_ is not empty unless back_ == front_. + if (front_ < back_) { + return false; + } else if (front_ > back_) { + return true; + } else if ((*queue_)[front_]) { + return (*queue_)[front_]->Empty(); + } else { + return (front_ >= trivial_queue_.size()) || + (trivial_queue_[front_] == kNoStateId); + } + } + + void Clear() final { + for (StateId i = front_; i <= back_; ++i) { + if ((*queue_)[i]) { + (*queue_)[i]->Clear(); + } else if (i < trivial_queue_.size()) { + trivial_queue_[i] = kNoStateId; + } + } + front_ = 0; + back_ = kNoStateId; + } + + private: + std::vector> *queue_; + const std::vector &scc_; + mutable StateId front_; + StateId back_; + std::vector trivial_queue_; +}; + +// Automatic queue discipline. It selects a queue discipline for a given FST +// based on its properties. +template +class AutoQueue : public QueueBase { + public: + using StateId = S; + + // This constructor takes a state distance vector that, if non-null and if + // the Weight type has the path property, will entertain the shortest-first + // queue using the natural order w.r.t to the distance. + template + AutoQueue(const Fst &fst, + const std::vector *distance, ArcFilter filter) + : QueueBase(AUTO_QUEUE) { + using Weight = typename Arc::Weight; + using Less = NaturalLess; + using Compare = internal::StateWeightCompare; + // First checks if the FST is known to have these properties. + const auto props = + fst.Properties(kAcyclic | kCyclic | kTopSorted | kUnweighted, false); + if ((props & kTopSorted) || fst.Start() == kNoStateId) { + queue_.reset(new StateOrderQueue()); + VLOG(2) << "AutoQueue: using state-order discipline"; + } else if (props & kAcyclic) { + queue_.reset(new TopOrderQueue(fst, filter)); + VLOG(2) << "AutoQueue: using top-order discipline"; + } else if ((props & kUnweighted) && (Weight::Properties() & kIdempotent)) { + queue_.reset(new LifoQueue()); + VLOG(2) << "AutoQueue: using LIFO discipline"; + } else { + uint64 properties; + // Decomposes into strongly-connected components. + SccVisitor scc_visitor(&scc_, nullptr, nullptr, &properties); + DfsVisit(fst, &scc_visitor, filter); + auto nscc = *std::max_element(scc_.begin(), scc_.end()) + 1; + std::vector queue_types(nscc); + std::unique_ptr less; + std::unique_ptr comp; + if (distance && (Weight::Properties() & kPath) == kPath) { + less.reset(new Less); + comp.reset(new Compare(*distance, *less)); + } + // Finds the queue type to use per SCC. + bool unweighted; + bool all_trivial; + SccQueueType(fst, scc_, &queue_types, filter, less.get(), &all_trivial, + &unweighted); + // If unweighted and semiring is idempotent, uses LIFO queue. + if (unweighted) { + queue_.reset(new LifoQueue()); + VLOG(2) << "AutoQueue: using LIFO discipline"; + return; + } + // If all the SCC are trivial, the FST is acyclic and the scc number gives + // the topological order. + if (all_trivial) { + queue_.reset(new TopOrderQueue(scc_)); + VLOG(2) << "AutoQueue: using top-order discipline"; + return; + } + VLOG(2) << "AutoQueue: using SCC meta-discipline"; + queues_.resize(nscc); + for (StateId i = 0; i < nscc; ++i) { + switch (queue_types[i]) { + case TRIVIAL_QUEUE: + queues_[i].reset(); + VLOG(3) << "AutoQueue: SCC #" << i << ": using trivial discipline"; + break; + case SHORTEST_FIRST_QUEUE: + queues_[i].reset( + new ShortestFirstQueue(*comp)); + VLOG(3) << "AutoQueue: SCC #" << i + << ": using shortest-first discipline"; + break; + case LIFO_QUEUE: + queues_[i].reset(new LifoQueue()); + VLOG(3) << "AutoQueue: SCC #" << i << ": using LIFO discipline"; + break; + case FIFO_QUEUE: + default: + queues_[i].reset(new FifoQueue()); + VLOG(3) << "AutoQueue: SCC #" << i << ": using FIFO discipine"; + break; + } + } + queue_.reset(new SccQueue>(scc_, &queues_)); + } + } + + virtual ~AutoQueue() = default; + + StateId Head() const final { return queue_->Head(); } + + void Enqueue(StateId s) final { queue_->Enqueue(s); } + + void Dequeue() final { queue_->Dequeue(); } + + void Update(StateId s) final { queue_->Update(s); } + + bool Empty() const final { return queue_->Empty(); } + + void Clear() final { queue_->Clear(); } + + private: + template + static void SccQueueType(const Fst &fst, const std::vector &scc, + std::vector *queue_types, + ArcFilter filter, Less *less, bool *all_trivial, + bool *unweighted); + + std::unique_ptr> queue_; + std::vector>> queues_; + std::vector scc_; +}; + +// Examines the states in an FST's strongly connected components and determines +// which type of queue to use per SCC. Stores result as a vector of QueueTypes +// which is assumed to have length equal to the number of SCCs. An arc filter +// is used to limit the transitions considered (e.g., only the epsilon graph). +// The argument all_trivial is set to true if every queue is the trivial queue. +// The argument unweighted is set to true if the semiring is idempotent and all +// the arc weights are equal to Zero() or One(). +template +template +void AutoQueue::SccQueueType(const Fst &fst, + const std::vector &scc, + std::vector *queue_type, + ArcFilter filter, Less *less, + bool *all_trivial, bool *unweighted) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + *all_trivial = true; + *unweighted = true; + for (StateId i = 0; i < queue_type->size(); ++i) { + (*queue_type)[i] = TRIVIAL_QUEUE; + } + for (StateIterator> sit(fst); !sit.Done(); sit.Next()) { + const auto state = sit.Value(); + for (ArcIterator> ait(fst, state); !ait.Done(); ait.Next()) { + const auto &arc = ait.Value(); + if (!filter(arc)) continue; + if (scc[state] == scc[arc.nextstate]) { + auto &type = (*queue_type)[scc[state]]; + if (!less || ((*less)(arc.weight, Weight::One()))) { + type = FIFO_QUEUE; + } else if ((type == TRIVIAL_QUEUE) || (type == LIFO_QUEUE)) { + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { + type = SHORTEST_FIRST_QUEUE; + } else { + type = LIFO_QUEUE; + } + } + if (type != TRIVIAL_QUEUE) *all_trivial = false; + } + if (!(Weight::Properties() & kIdempotent) || + (arc.weight != Weight::Zero() && arc.weight != Weight::One())) { + *unweighted = false; + } + } + } +} + +// An A* estimate is a function object that maps from a state ID to an +// estimate of the shortest distance to the final states. + +// A trivial A* estimate, yielding a queue which behaves the same in Dijkstra's +// algorithm. +template +struct TrivialAStarEstimate { + constexpr Weight operator()(StateId) const { return Weight::One(); } +}; + +// A non-trivial A* estimate using a vector of the estimated future costs. +template +class NaturalAStarEstimate { + public: + NaturalAStarEstimate(const std::vector &beta) : beta_(beta) {} + + const Weight &operator()(StateId s) const { + return (s < beta_.size()) ? beta_[s] : kZero; + } + + private: + static constexpr Weight kZero = Weight::Zero(); + + const std::vector &beta_; +}; + +template +constexpr Weight NaturalAStarEstimate::kZero; + +// Given a vector that maps from states to weights representing the shortest +// distance from the initial state, a comparison function object between +// weights, and an estimate of the shortest distance to the final states, this +// class defines a comparison function object between states. +template +class AStarWeightCompare { + public: + using StateId = S; + using Weight = typename Less::Weight; + + AStarWeightCompare(const std::vector &weights, const Less &less, + const Estimate &estimate) + : weights_(weights), less_(less), estimate_(estimate) {} + + bool operator()(StateId s1, StateId s2) const { + const auto w1 = Times(weights_[s1], estimate_(s1)); + const auto w2 = Times(weights_[s2], estimate_(s2)); + return less_(w1, w2); + } + + const Estimate &GetEstimate() const { return estimate_; } + + private: + const std::vector &weights_; + const Less &less_; + const Estimate &estimate_; +}; + +// A* queue discipline templated on StateId, Weight, and Estimate. +template +class NaturalAStarQueue : public ShortestFirstQueue< + S, AStarWeightCompare, Estimate>> { + public: + using StateId = S; + using Compare = AStarWeightCompare, Estimate>; + + NaturalAStarQueue(const std::vector &distance, + const Estimate &estimate) + : ShortestFirstQueue( + Compare(distance, less_, estimate)) {} + + ~NaturalAStarQueue() = default; + + private: + // This is non-static because the constructor for non-idempotent weights will + // result in an error. + const NaturalLess less_{}; +}; + +// A state equivalence class is a function object that maps from a state ID to +// an equivalence class (state) ID. The trivial equivalence class maps a state +// ID to itself. +template +struct TrivialStateEquivClass { + StateId operator()(StateId s) const { return s; } +}; + +// Distance-based pruning queue discipline: Enqueues a state only when its +// shortest distance (so far), as specified by distance, is less than (as +// specified by comp) the shortest distance Times() the threshold to any state +// in the same equivalence class, as specified by the functor class_func. The +// underlying queue discipline is specified by queue. The ownership of queue is +// given to this class. +// +// This is not a final class. +template +class PruneQueue : public QueueBase { + public: + using StateId = typename Queue::StateId; + using Weight = typename Less::Weight; + + PruneQueue(const std::vector &distance, Queue *queue, + const Less &less, const ClassFnc &class_fnc, Weight threshold) + : QueueBase(OTHER_QUEUE), + distance_(distance), + queue_(queue), + less_(less), + class_fnc_(class_fnc), + threshold_(std::move(threshold)) {} + + virtual ~PruneQueue() = default; + + StateId Head() const override { return queue_->Head(); } + + void Enqueue(StateId s) override { + const auto c = class_fnc_(s); + if (c >= class_distance_.size()) { + class_distance_.resize(c + 1, Weight::Zero()); + } + if (less_(distance_[s], class_distance_[c])) { + class_distance_[c] = distance_[s]; + } + // Enqueues only if below threshold limit. + const auto limit = Times(class_distance_[c], threshold_); + if (less_(distance_[s], limit)) queue_->Enqueue(s); + } + + void Dequeue() override { queue_->Dequeue(); } + + void Update(StateId s) override { + const auto c = class_fnc_(s); + if (less_(distance_[s], class_distance_[c])) { + class_distance_[c] = distance_[s]; + } + queue_->Update(s); + } + + bool Empty() const override { return queue_->Empty(); } + + void Clear() override { queue_->Clear(); } + + private: + const std::vector &distance_; // Shortest distance to state. + std::unique_ptr queue_; + const Less &less_; // Borrowed reference. + const ClassFnc &class_fnc_; // Equivalence class functor. + Weight threshold_; // Pruning weight threshold. + std::vector class_distance_; // Shortest distance to class. +}; + +// Pruning queue discipline (see above) using the weight's natural order for the +// comparison function. The ownership of the queue argument is given to this +// class. +template +class NaturalPruneQueue final + : public PruneQueue, ClassFnc> { + public: + using StateId = typename Queue::StateId; + + NaturalPruneQueue(const std::vector &distance, Queue *queue, + const ClassFnc &class_fnc, Weight threshold) + : PruneQueue, ClassFnc>( + distance, queue, NaturalLess(), class_fnc, threshold) {} + + virtual ~NaturalPruneQueue() = default; +}; + +// Filter-based pruning queue discipline: enqueues a state only if allowed by +// the filter, specified by the state filter functor argument. The underlying +// queue discipline is specified by the queue argument. The ownership of the +// queue is given to this class. +template +class FilterQueue : public QueueBase { + public: + using StateId = typename Queue::StateId; + + FilterQueue(Queue *queue, const Filter &filter) + : QueueBase(OTHER_QUEUE), queue_(queue), filter_(filter) {} + + virtual ~FilterQueue() = default; + + StateId Head() const final { return queue_->Head(); } + + // Enqueues only if allowed by state filter. + void Enqueue(StateId s) final { + if (filter_(s)) queue_->Enqueue(s); + } + + void Dequeue() final { queue_->Dequeue(); } + + void Update(StateId s) final {} + + bool Empty() const final { return queue_->Empty(); } + + void Clear() final { queue_->Clear(); } + + private: + std::unique_ptr queue_; + const Filter &filter_; +}; + +} // namespace fst + +#endif // FST_QUEUE_H_ diff --git a/projects/llm_framework/include/fst/randequivalent.h b/projects/llm_framework/include/fst/randequivalent.h new file mode 100644 index 00000000..73108b46 --- /dev/null +++ b/projects/llm_framework/include/fst/randequivalent.h @@ -0,0 +1,114 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Tests if two FSTS are equivalent by checking if random strings from one FST +// are transduced the same by both FSTs. + +#ifndef FST_RANDEQUIVALENT_H_ +#define FST_RANDEQUIVALENT_H_ + +#include + +#include +#include +#include +#include +#include +#include + + +namespace fst { + +// Test if two FSTs are stochastically equivalent by randomly generating +// random paths through the FSTs. +// +// For each randomly generated path, the algorithm computes for each +// of the two FSTs the sum of the weights of all the successful paths +// sharing the same input and output labels as the considered randomly +// generated path and checks that these two values are within a user-specified +// delta. Returns optional error value (when FLAGS_error_fatal = false). +template +bool RandEquivalent(const Fst &fst1, const Fst &fst2, + int32 num_paths, float delta, + const RandGenOptions &opts, + bool *error = nullptr) { + using Weight = typename Arc::Weight; + if (error) *error = false; + // Checks that the symbol table are compatible. + if (!CompatSymbols(fst1.InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1.OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "RandEquivalent: Input/output symbol tables of 1st " + << "argument do not match input/output symbol tables of 2nd " + << "argument"; + if (error) *error = true; + return false; + } + static const ILabelCompare icomp; + static const OLabelCompare ocomp; + VectorFst sfst1(fst1); + VectorFst sfst2(fst2); + Connect(&sfst1); + Connect(&sfst2); + ArcSort(&sfst1, icomp); + ArcSort(&sfst2, icomp); + bool result = true; + for (int32 n = 0; n < num_paths; ++n) { + VectorFst path; + const auto &fst = rand() % 2 ? sfst1 : sfst2; // NOLINT + RandGen(fst, &path, opts); + VectorFst ipath(path); + VectorFst opath(path); + Project(&ipath, PROJECT_INPUT); + Project(&opath, PROJECT_OUTPUT); + VectorFst cfst1, pfst1; + Compose(ipath, sfst1, &cfst1); + ArcSort(&cfst1, ocomp); + Compose(cfst1, opath, &pfst1); + // Gives up if there are epsilon cycles in a non-idempotent semiring. + if (!(Weight::Properties() & kIdempotent) && + pfst1.Properties(kCyclic, true)) { + continue; + } + const auto sum1 = ShortestDistance(pfst1); + VectorFst cfst2; + Compose(ipath, sfst2, &cfst2); + ArcSort(&cfst2, ocomp); + VectorFst pfst2; + Compose(cfst2, opath, &pfst2); + // Gives up if there are epsilon cycles in a non-idempotent semiring. + if (!(Weight::Properties() & kIdempotent) && + pfst2.Properties(kCyclic, true)) { + continue; + } + const auto sum2 = ShortestDistance(pfst2); + if (!ApproxEqual(sum1, sum2, delta)) { + VLOG(1) << "Sum1 = " << sum1; + VLOG(1) << "Sum2 = " << sum2; + result = false; + break; + } + } + if (fst1.Properties(kError, false) || fst2.Properties(kError, false)) { + if (error) *error = true; + return false; + } + return result; +} + +// Tests if two FSTs are equivalent by randomly generating a nnum_paths paths +// (no longer than the path_length) using a user-specified seed, optionally +// indicating an error setting an optional error argument to true. +template +bool RandEquivalent(const Fst &fst1, const Fst &fst2, int32 num_paths, + float delta = kDelta, time_t seed = time(nullptr), + int32 max_length = std::numeric_limits::max(), + bool *error = nullptr) { + const UniformArcSelector uniform_selector(seed); + const RandGenOptions> opts(uniform_selector, + max_length); + return RandEquivalent(fst1, fst2, num_paths, delta, opts, error); +} + +} // namespace fst + +#endif // FST_RANDEQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/randgen.h b/projects/llm_framework/include/fst/randgen.h new file mode 100644 index 00000000..5bcd9fd0 --- /dev/null +++ b/projects/llm_framework/include/fst/randgen.h @@ -0,0 +1,756 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes and functions to generate random paths through an FST. + +#ifndef FST_RANDGEN_H_ +#define FST_RANDGEN_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// The RandGenFst class is roughly similar to ArcMapFst in that it takes two +// template parameters denoting the input and output arc types. However, it also +// takes an additional template parameter which specifies a sampler object which +// samples (with replacement) arcs from an FST state. The sampler in turn takes +// a template parameter for a selector object which actually chooses the arc. +// +// Arc selector functors are used to select a random transition given an FST +// state s, returning a number N such that 0 <= N <= NumArcs(s). If N is +// NumArcs(s), then the final weight is selected; otherwise the N-th arc is +// selected. It is assumed these are not applied to any state which is neither +// final nor has any arcs leaving it. + +// Randomly selects a transition using the uniform distribution. This class is +// not thread-safe. +template +class UniformArcSelector { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // Constructs a selector with a non-deterministic seed. + UniformArcSelector() : rand_(std::random_device()()) {} + // Constructs a selector with a given seed. + explicit UniformArcSelector(uint64 seed) : rand_(seed) {} + + size_t operator()(const Fst &fst, StateId s) const { + const auto n = fst.NumArcs(s) + (fst.Final(s) != Weight::Zero()); + return static_cast( + std::uniform_int_distribution<>(0, n - 1)(rand_)); + } + + private: + mutable std::mt19937_64 rand_; +}; + +// Randomly selects a transition w.r.t. the weights treated as negative log +// probabilities after normalizing for the total weight leaving the state. Zero +// transitions are disregarded. It assumed that Arc::Weight::Value() accesses +// the floating point representation of the weight. This class is not +// thread-safe. +template +class LogProbArcSelector { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // Constructs a selector with a non-deterministic seed. + LogProbArcSelector() : seed_(std::random_device()()), rand_(seed_) {} + // Constructs a selector with a given seed. + explicit LogProbArcSelector(uint64 seed) : seed_(seed), rand_(seed) {} + + size_t operator()(const Fst &fst, StateId s) const { + // Finds total weight leaving state. + auto sum = Log64Weight::Zero(); + ArcIterator> aiter(fst, s); + for (; !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + sum = Plus(sum, to_log_weight_(arc.weight)); + } + sum = Plus(sum, to_log_weight_(fst.Final(s))); + const double threshold = + std::uniform_real_distribution<>(0, exp(-sum.Value()))(rand_); + auto p = Log64Weight::Zero(); + size_t n = 0; + for (aiter.Reset(); !aiter.Done(); aiter.Next(), ++n) { + p = Plus(p, to_log_weight_(aiter.Value().weight)); + if (exp(-p.Value()) > threshold) return n; + } + return n; + } + + uint64 Seed() const { return seed_; } + + protected: + Log64Weight ToLogWeight(const Weight &weight) const { + return to_log_weight_(weight); + } + + std::mt19937_64 &MutableRand() const { return rand_; } + + private: + const uint64 seed_; + mutable std::mt19937_64 rand_; + const WeightConvert to_log_weight_{}; +}; + +// Useful alias when using StdArc. +using StdArcSelector = LogProbArcSelector; + +// Same as LogProbArcSelector but use CacheLogAccumulator to cache the weight +// accumulation computations. This class is not thread-safe. +template +class FastLogProbArcSelector : public LogProbArcSelector { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using LogProbArcSelector::MutableRand; + using LogProbArcSelector::ToLogWeight; + using LogProbArcSelector::operator(); + + // Constructs a selector with a non-deterministic seed. + FastLogProbArcSelector() : LogProbArcSelector() {} + // Constructs a selector with a given seed. + explicit FastLogProbArcSelector(uint64 seed) : LogProbArcSelector( + seed) {} + + size_t operator()(const Fst &fst, StateId s, + CacheLogAccumulator *accumulator) const { + accumulator->SetState(s); + ArcIterator> aiter(fst, s); + // Finds total weight leaving state. + const double sum = + ToLogWeight(accumulator->Sum(fst.Final(s), &aiter, 0, fst.NumArcs(s))) + .Value(); + const double r = -log(std::uniform_real_distribution<>(0, 1)( + MutableRand())); + Weight w = from_log_weight_(r + sum); + aiter.Reset(); + return accumulator->LowerBound(w, &aiter); + } + + private: + const WeightConvert from_log_weight_{}; +}; + +// Random path state info maintained by RandGenFst and passed to samplers. +template +struct RandState { + using StateId = typename Arc::StateId; + + StateId state_id; // Current input FST state. + size_t nsamples; // Number of samples to be sampled at this state. + size_t length; // Length of path to this random state. + size_t select; // Previous sample arc selection. + const RandState *parent; // Previous random state on this path. + + explicit RandState(StateId state_id, size_t nsamples = 0, size_t length = 0, + size_t select = 0, const RandState *parent = nullptr) + : state_id(state_id), + nsamples(nsamples), + length(length), + select(select), + parent(parent) {} + + RandState() : RandState(kNoStateId) {} +}; + +// This class, given an arc selector, samples, with replacement, multiple random +// transitions from an FST's state. This is a generic version with a +// straightforward use of the arc selector. Specializations may be defined for +// arc selectors for greater efficiency or special behavior. +template +class ArcSampler { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // The max_length argument may be interpreted (or ignored) by a selector as + // it chooses. This generic version interprets this literally. + ArcSampler(const Fst &fst, const Selector &selector, + int32 max_length = std::numeric_limits::max()) + : fst_(fst), selector_(selector), max_length_(max_length) {} + + // Allow updating FST argument; pass only if changed. + ArcSampler(const ArcSampler &sampler, + const Fst *fst = nullptr) + : fst_(fst ? *fst : sampler.fst_), + selector_(sampler.selector_), + max_length_(sampler.max_length_) { + Reset(); + } + + // Samples a fixed number of samples from the given state. The length argument + // specifies the length of the path to the state. Returns true if the samples + // were collected. No samples may be collected if either there are no + // transitions leaving the state and the state is non-final, or if the path + // length has been exceeded. Iterator members are provided to read the samples + // in the order in which they were collected. + bool Sample(const RandState &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + for (size_t i = 0; i < rstate.nsamples; ++i) { + ++sample_map_[selector_(fst_, rstate.state_id)]; + } + Reset(); + return true; + } + + // More samples? + bool Done() const { return sample_iter_ == sample_map_.end(); } + + // Gets the next sample. + void Next() { ++sample_iter_; } + + std::pair Value() const { return *sample_iter_; } + + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return false; } + + private: + const Fst &fst_; + const Selector &selector_; + const int32 max_length_; + + // Stores (N, K) as described for Value(). + std::map sample_map_; + std::map::const_iterator sample_iter_; + + ArcSampler &operator=(const ArcSampler &) = delete; +}; + +// Samples one sample of num_to_sample dimensions from a multinomial +// distribution parameterized by a vector of probabilities. The result +// container should be pre-initialized (e.g., an empty map or a zeroed vector +// sized the same as the vector of probabilities. +// probs.size()). +template +void OneMultinomialSample(const std::vector &probs, + size_t num_to_sample, Result *result, RNG *rng) { + // Left-over probability mass. + double norm = 0; + for (double p : probs) norm += p; + // Left-over number of samples needed. + for (size_t i = 0; i < probs.size(); ++i) { + size_t num_sampled = 0; + if (probs[i] > 0) { + std::binomial_distribution<> d(num_to_sample, probs[i] / norm); + num_sampled = d(*rng); + } + if (num_sampled != 0) (*result)[i] = num_sampled; + norm -= probs[i]; + num_to_sample -= num_sampled; + } +} + +// Specialization for FastLogProbArcSelector. +template +class ArcSampler> { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Accumulator = CacheLogAccumulator; + using Selector = FastLogProbArcSelector; + + ArcSampler(const Fst &fst, const Selector &selector, + int32 max_length = std::numeric_limits::max()) + : fst_(fst), + selector_(selector), + max_length_(max_length), + accumulator_(new Accumulator()) { + accumulator_->Init(fst); + rng_.seed(selector_.Seed()); + } + + ArcSampler(const ArcSampler &sampler, + const Fst *fst = nullptr) + : fst_(fst ? *fst : sampler.fst_), + selector_(sampler.selector_), + max_length_(sampler.max_length_) { + if (fst) { + accumulator_.reset(new Accumulator()); + accumulator_->Init(*fst); + } else { // Shallow copy. + accumulator_.reset(new Accumulator(*sampler.accumulator_)); + } + } + + bool Sample(const RandState &rstate) { + sample_map_.clear(); + if ((fst_.NumArcs(rstate.state_id) == 0 && + fst_.Final(rstate.state_id) == Weight::Zero()) || + rstate.length == max_length_) { + Reset(); + return false; + } + if (fst_.NumArcs(rstate.state_id) + 1 < rstate.nsamples) { + MultinomialSample(rstate); + Reset(); + return true; + } + for (size_t i = 0; i < rstate.nsamples; ++i) { + ++sample_map_[selector_(fst_, rstate.state_id, accumulator_.get())]; + } + Reset(); + return true; + } + + bool Done() const { return sample_iter_ == sample_map_.end(); } + + void Next() { ++sample_iter_; } + + std::pair Value() const { return *sample_iter_; } + + void Reset() { sample_iter_ = sample_map_.begin(); } + + bool Error() const { return accumulator_->Error(); } + + private: + using RNG = std::mt19937; + + // Sample according to the multinomial distribution of rstate.nsamples draws + // from p_. + void MultinomialSample(const RandState &rstate) { + p_.clear(); + for (ArcIterator> aiter(fst_, rstate.state_id); !aiter.Done(); + aiter.Next()) { + p_.push_back(exp(-to_log_weight_(aiter.Value().weight).Value())); + } + if (fst_.Final(rstate.state_id) != Weight::Zero()) { + p_.push_back(exp(-to_log_weight_(fst_.Final(rstate.state_id)).Value())); + } + if (rstate.nsamples < std::numeric_limits::max()) { + OneMultinomialSample(p_, rstate.nsamples, &sample_map_, &rng_); + } else { + for (size_t i = 0; i < p_.size(); ++i) { + sample_map_[i] = ceil(p_[i] * rstate.nsamples); + } + } + } + + const Fst &fst_; + const Selector &selector_; + const int32 max_length_; + + // Stores (N, K) for Value(). + std::map sample_map_; + std::map::const_iterator sample_iter_; + + std::unique_ptr accumulator_; + RNG rng_; // Random number generator. + std::vector p_; // Multinomial parameters. + const WeightConvert to_log_weight_{}; +}; + +// Options for random path generation with RandGenFst. The template argument is +// a sampler, typically the class ArcSampler. Ownership of the sampler is taken +// by RandGenFst. +template +struct RandGenFstOptions : public CacheOptions { + Sampler *sampler; // How to sample transitions at a state. + int32 npath; // Number of paths to generate. + bool weighted; // Is the output tree weighted by path count, or + // is it just an unweighted DAG? + bool remove_total_weight; // Remove total weight when output is weighted. + + RandGenFstOptions(const CacheOptions &opts, Sampler *sampler, int32 npath = 1, + bool weighted = true, bool remove_total_weight = false) + : CacheOptions(opts), + sampler(sampler), + npath(npath), + weighted(weighted), + remove_total_weight(remove_total_weight) {} +}; + +namespace internal { + +// Implementation of RandGenFst. +template +class RandGenFstImpl : public CacheImpl { + public: + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::EmplaceArc; + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + using Label = typename FromArc::Label; + using StateId = typename FromArc::StateId; + using FromWeight = typename FromArc::Weight; + + using ToWeight = typename ToArc::Weight; + + RandGenFstImpl(const Fst &fst, + const RandGenFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + sampler_(opts.sampler), + npath_(opts.npath), + weighted_(opts.weighted), + remove_total_weight_(opts.remove_total_weight), + superfinal_(kNoLabel) { + SetType("randgen"); + SetProperties( + RandGenProperties(fst.Properties(kFstProperties, false), weighted_), + kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RandGenFstImpl(const RandGenFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + sampler_(new Sampler(*impl.sampler_, fst_.get())), + npath_(impl.npath_), + weighted_(impl.weighted_), + superfinal_(kNoLabel) { + SetType("randgen"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) { + const auto s = fst_->Start(); + if (s == kNoStateId) return kNoStateId; + SetStart(state_table_.size()); + state_table_.emplace_back( + new RandState(s, npath_, 0, 0, nullptr)); + } + return CacheImpl::Start(); + } + + ToWeight Final(StateId s) { + if (!HasFinal(s)) Expand(s); + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && + (fst_->Properties(kError, false) || sampler_->Error())) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + // Computes the outgoing transitions from a state, creating new destination + // states as needed. + void Expand(StateId s) { + if (s == superfinal_) { + SetFinal(s, ToWeight::One()); + SetArcs(s); + return; + } + SetFinal(s, ToWeight::Zero()); + const auto &rstate = *state_table_[s]; + sampler_->Sample(rstate); + ArcIterator> aiter(*fst_, rstate.state_id); + const auto narcs = fst_->NumArcs(rstate.state_id); + for (; !sampler_->Done(); sampler_->Next()) { + const auto &sample_pair = sampler_->Value(); + const auto pos = sample_pair.first; + const auto count = sample_pair.second; + double prob = static_cast(count) / rstate.nsamples; + if (pos < narcs) { // Regular transition. + aiter.Seek(sample_pair.first); + const auto &aarc = aiter.Value(); + auto weight = + weighted_ ? to_weight_(Log64Weight(-log(prob))) : ToWeight::One(); + EmplaceArc(s, aarc.ilabel, aarc.olabel, std::move(weight), + state_table_.size()); + auto *nrstate = new RandState(aarc.nextstate, count, + rstate.length + 1, pos, &rstate); + state_table_.emplace_back(nrstate); + } else { // Super-final transition. + if (weighted_) { + const auto weight = + remove_total_weight_ + ? to_weight_(Log64Weight(-log(prob))) + : to_weight_(Log64Weight(-log(prob * npath_))); + SetFinal(s, weight); + } else { + if (superfinal_ == kNoLabel) { + superfinal_ = state_table_.size(); + state_table_.emplace_back( + new RandState(kNoStateId, 0, 0, 0, nullptr)); + } + for (size_t n = 0; n < count; ++n) { + EmplaceArc(s, 0, 0, ToWeight::One(), superfinal_); + } + } + } + } + SetArcs(s); + } + + private: + const std::unique_ptr> fst_; + std::unique_ptr sampler_; + const int32 npath_; + std::vector>> state_table_; + const bool weighted_; + bool remove_total_weight_; + StateId superfinal_; + const WeightConvert to_weight_{}; +}; + +} // namespace internal + +// FST class to randomly generate paths through an FST, with details controlled +// by RandGenOptionsFst. Output format is a tree weighted by the path count. +template +class RandGenFst + : public ImplToFst> { + public: + using Label = typename FromArc::Label; + using StateId = typename FromArc::StateId; + using Weight = typename FromArc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using Impl = internal::RandGenFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + RandGenFst(const Fst &fst, const RandGenFstOptions &opts) + : ImplToFst(std::make_shared(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RandGenFst(const RandGenFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this RandGenFst. See Fst<>::Copy() for further doc. + RandGenFst *Copy(bool safe = false) const override { + return new RandGenFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + RandGenFst &operator=(const RandGenFst &) = delete; +}; + +// Specialization for RandGenFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const RandGenFst &fst) + : CacheStateIterator>( + fst, fst.GetMutableImpl()) {} +}; + +// Specialization for RandGenFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename FromArc::StateId; + + ArcIterator(const RandGenFst &fst, StateId s) + : CacheArcIterator>( + fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void RandGenFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Options for random path generation. +template +struct RandGenOptions { + const Selector &selector; // How an arc is selected at a state. + int32 max_length; // Maximum path length. + int32 npath; // Number of paths to generate. + bool weighted; // Is the output tree weighted by path count, or + // is it just an unweighted DAG? + bool remove_total_weight; // Remove total weight when output is weighted? + + explicit RandGenOptions(const Selector &selector, + int32 max_length = std::numeric_limits::max(), + int32 npath = 1, bool weighted = false, + bool remove_total_weight = false) + : selector(selector), + max_length(max_length), + npath(npath), + weighted(weighted), + remove_total_weight(remove_total_weight) {} +}; + +namespace internal { + +template +class RandGenVisitor { + public: + using StateId = typename FromArc::StateId; + using Weight = typename FromArc::Weight; + + explicit RandGenVisitor(MutableFst *ofst) : ofst_(ofst) {} + + void InitVisit(const Fst &ifst) { + ifst_ = &ifst; + ofst_->DeleteStates(); + ofst_->SetInputSymbols(ifst.InputSymbols()); + ofst_->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kError, false)) ofst_->SetProperties(kError, kError); + path_.clear(); + } + + constexpr bool InitState(StateId, StateId) const { return true; } + + bool TreeArc(StateId, const ToArc &arc) { + if (ifst_->Final(arc.nextstate) == Weight::Zero()) { + path_.push_back(arc); + } else { + OutputPath(); + } + return true; + } + + bool BackArc(StateId, const FromArc &) { + FSTERROR() << "RandGenVisitor: cyclic input"; + ofst_->SetProperties(kError, kError); + return false; + } + + bool ForwardOrCrossArc(StateId, const FromArc &) { + OutputPath(); + return true; + } + + void FinishState(StateId s, StateId p, const FromArc *) { + if (p != kNoStateId && ifst_->Final(s) == Weight::Zero()) path_.pop_back(); + } + + void FinishVisit() {} + + private: + void OutputPath() { + if (ofst_->Start() == kNoStateId) { + const auto start = ofst_->AddState(); + ofst_->SetStart(start); + } + auto src = ofst_->Start(); + for (size_t i = 0; i < path_.size(); ++i) { + const auto dest = ofst_->AddState(); + const ToArc arc(path_[i].ilabel, path_[i].olabel, Weight::One(), dest); + ofst_->AddArc(src, arc); + src = dest; + } + ofst_->SetFinal(src, Weight::One()); + } + + const Fst *ifst_; + MutableFst *ofst_; + std::vector path_; + + RandGenVisitor(const RandGenVisitor &) = delete; + RandGenVisitor &operator=(const RandGenVisitor &) = delete; +}; + +} // namespace internal + +// Randomly generate paths through an FST; details controlled by +// RandGenOptions. +template +void RandGen(const Fst &ifst, MutableFst *ofst, + const RandGenOptions &opts) { + using Sampler = ArcSampler; + auto *sampler = new Sampler(ifst, opts.selector, opts.max_length); + RandGenFstOptions fopts(CacheOptions(true, 0), sampler, opts.npath, + opts.weighted, opts.remove_total_weight); + RandGenFst rfst(ifst, fopts); + if (opts.weighted) { + *ofst = rfst; + } else { + internal::RandGenVisitor rand_visitor(ofst); + DfsVisit(rfst, &rand_visitor); + } +} + +// Randomly generate a path through an FST with the uniform distribution +// over the transitions. +template +void RandGen(const Fst &ifst, MutableFst *ofst) { + const UniformArcSelector uniform_selector; + RandGenOptions> opts(uniform_selector); + RandGen(ifst, ofst, opts); +} + +} // namespace fst + +#endif // FST_RANDGEN_H_ diff --git a/projects/llm_framework/include/fst/rational.h b/projects/llm_framework/include/fst/rational.h new file mode 100644 index 00000000..184ebf3f --- /dev/null +++ b/projects/llm_framework/include/fst/rational.h @@ -0,0 +1,307 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// An FST implementation and base interface for delayed unions, concatenations, +// and closures. + +#ifndef FST_RATIONAL_H_ +#define FST_RATIONAL_H_ + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +using RationalFstOptions = CacheOptions; + +// This specifies whether to add the empty string. +enum ClosureType { + CLOSURE_STAR = 0, // Add the empty string. + CLOSURE_PLUS = 1 // Don't add the empty string. +}; + +template +class RationalFst; + +template +void Union(RationalFst *fst1, const Fst &fst2); + +template +void Concat(RationalFst *fst1, const Fst &fst2); + +template +void Concat(const Fst &fst1, RationalFst *fst2); + +template +void Closure(RationalFst *fst, ClosureType closure_type); + +namespace internal { + +// Implementation class for delayed unions, concatenations and closures. +template +class RationalFstImpl : public FstImpl { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + explicit RationalFstImpl(const RationalFstOptions &opts) + : nonterminals_(0), replace_options_(opts, 0) { + SetType("rational"); + fst_tuples_.emplace_back(0, nullptr); + } + + RationalFstImpl(const RationalFstImpl &impl) + : rfst_(impl.rfst_), + nonterminals_(impl.nonterminals_), + replace_(impl.replace_ ? impl.replace_->Copy(true) : nullptr), + replace_options_(impl.replace_options_) { + SetType("rational"); + fst_tuples_.reserve(impl.fst_tuples_.size()); + for (const auto &pair : impl.fst_tuples_) { + fst_tuples_.emplace_back(pair.first, + pair.second ? pair.second->Copy(true) : nullptr); + } + } + + ~RationalFstImpl() override { + for (auto &tuple : fst_tuples_) delete tuple.second; + } + + StateId Start() { return Replace()->Start(); } + + Weight Final(StateId s) { return Replace()->Final(s); } + + size_t NumArcs(StateId s) { return Replace()->NumArcs(s); } + + size_t NumInputEpsilons(StateId s) { return Replace()->NumInputEpsilons(s); } + + size_t NumOutputEpsilons(StateId s) { + return Replace()->NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && Replace()->Properties(kError, false)) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + // Implementation of UnionFst(fst1, fst2). + void InitUnion(const Fst &fst1, const Fst &fst2) { + replace_.reset(); + const auto props1 = fst1.Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst1.OutputSymbols()); + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(1, Weight::One()); + rfst_.SetInputSymbols(fst1.InputSymbols()); + rfst_.SetOutputSymbols(fst1.OutputSymbols()); + nonterminals_ = 2; + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); + rfst_.EmplaceArc(0, 0, -2, Weight::One(), 1); + fst_tuples_.emplace_back(-1, fst1.Copy()); + fst_tuples_.emplace_back(-2, fst2.Copy()); + SetProperties(UnionProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of ConcatFst(fst1, fst2). + void InitConcat(const Fst &fst1, const Fst &fst2) { + replace_.reset(); + const auto props1 = fst1.Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + SetInputSymbols(fst1.InputSymbols()); + SetOutputSymbols(fst1.OutputSymbols()); + rfst_.AddState(); + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(2, Weight::One()); + rfst_.SetInputSymbols(fst1.InputSymbols()); + rfst_.SetOutputSymbols(fst1.OutputSymbols()); + nonterminals_ = 2; + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); + rfst_.EmplaceArc(1, 0, -2, Weight::One(), 2); + fst_tuples_.emplace_back(-1, fst1.Copy()); + fst_tuples_.emplace_back(-2, fst2.Copy()); + SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of ClosureFst(fst, closure_type). + void InitClosure(const Fst &fst, ClosureType closure_type) { + replace_.reset(); + const auto props = fst.Properties(kFstProperties, false); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + if (closure_type == CLOSURE_STAR) { + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(0, Weight::One()); + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 0); + } else { + rfst_.AddState(); + rfst_.AddState(); + rfst_.SetStart(0); + rfst_.SetFinal(1, Weight::One()); + rfst_.EmplaceArc(0, 0, -1, Weight::One(), 1); + rfst_.EmplaceArc(1, 0, 0, Weight::One(), 0); + } + rfst_.SetInputSymbols(fst.InputSymbols()); + rfst_.SetOutputSymbols(fst.OutputSymbols()); + fst_tuples_.emplace_back(-1, fst.Copy()); + nonterminals_ = 1; + SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), + kCopyProperties); + } + + // Implementation of Union(Fst &, RationalFst *). + void AddUnion(const Fst &fst) { + replace_.reset(); + const auto props1 = FstImpl::Properties(); + const auto props2 = fst.Properties(kFstProperties, false); + VectorFst afst; + afst.AddState(); + afst.AddState(); + afst.SetStart(0); + afst.SetFinal(1, Weight::One()); + ++nonterminals_; + afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1); + Union(&rfst_, afst); + fst_tuples_.emplace_back(-nonterminals_, fst.Copy()); + SetProperties(UnionProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of Concat(Fst &, RationalFst *). + void AddConcat(const Fst &fst, bool append) { + replace_.reset(); + const auto props1 = FstImpl::Properties(); + const auto props2 = fst.Properties(kFstProperties, false); + VectorFst afst; + afst.AddState(); + afst.AddState(); + afst.SetStart(0); + afst.SetFinal(1, Weight::One()); + ++nonterminals_; + afst.EmplaceArc(0, 0, -nonterminals_, Weight::One(), 1); + if (append) { + Concat(&rfst_, afst); + } else { + Concat(afst, &rfst_); + } + fst_tuples_.emplace_back(-nonterminals_, fst.Copy()); + SetProperties(ConcatProperties(props1, props2, true), kCopyProperties); + } + + // Implementation of Closure(RationalFst *, closure_type). + void AddClosure(ClosureType closure_type) { + replace_.reset(); + const auto props = FstImpl::Properties(); + Closure(&rfst_, closure_type); + SetProperties(ClosureProperties(props, closure_type == CLOSURE_STAR, true), + kCopyProperties); + } + + // Returns the underlying ReplaceFst, preserving ownership of the underlying + // object. + ReplaceFst *Replace() const { + if (!replace_) { + fst_tuples_[0].second = rfst_.Copy(); + replace_.reset(new ReplaceFst(fst_tuples_, replace_options_)); + } + return replace_.get(); + } + + private: + // Rational topology machine, using negative non-terminals. + VectorFst rfst_; + // Number of nonterminals used. + Label nonterminals_; + // Contains the nonterminals and their corresponding FSTs. + mutable std::vector *>> fst_tuples_; + // Underlying ReplaceFst. + mutable std::unique_ptr> replace_; + const ReplaceFstOptions replace_options_; +}; + +} // namespace internal + +// Parent class for the delayed rational operations (union, concatenation, and +// closure). This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RationalFst : public ImplToFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Impl = internal::RationalFstImpl; + + friend class StateIterator>; + friend class ArcIterator>; + friend void Union<>(RationalFst *fst1, const Fst &fst2); + friend void Concat<>(RationalFst *fst1, const Fst &fst2); + friend void Concat<>(const Fst &fst1, RationalFst *fst2); + friend void Closure<>(RationalFst *fst, ClosureType closure_type); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->Replace()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->Replace()->InitArcIterator(s, data); + } + + protected: + using ImplToFst::GetImpl; + + explicit RationalFst(const RationalFstOptions &opts = RationalFstOptions()) + : ImplToFst(std::make_shared(opts)) {} + + // See Fst<>::Copy() for doc. + RationalFst(const RationalFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + private: + RationalFst &operator=(const RationalFst &) = delete; +}; + +// Specialization for RationalFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const RationalFst &fst) + : StateIterator>(*(fst.GetImpl()->Replace())) {} +}; + +// Specialization for RationalFst. +template +class ArcIterator> : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const RationalFst &fst, StateId s) + : ArcIterator>(*(fst.GetImpl()->Replace()), s) {} +}; + +} // namespace fst + +#endif // FST_RATIONAL_H_ diff --git a/projects/llm_framework/include/fst/register.h b/projects/llm_framework/include/fst/register.h new file mode 100644 index 00000000..2d1a6ea7 --- /dev/null +++ b/projects/llm_framework/include/fst/register.h @@ -0,0 +1,115 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for registering derived FST for generic reading. + +#ifndef FST_REGISTER_H_ +#define FST_REGISTER_H_ + +#include +#include + + +#include +#include +#include + + +#include +#include + +namespace fst { + +template +class Fst; + +struct FstReadOptions; + +// This class represents a single entry in a FstRegister +template +struct FstRegisterEntry { + using Reader = Fst *(*)(std::istream &istrm, const FstReadOptions &opts); + using Converter = Fst *(*)(const Fst &fst); + + Reader reader; + Converter converter; + + explicit FstRegisterEntry(Reader reader = nullptr, + Converter converter = nullptr) + : reader(reader), converter(converter) {} +}; + +// This class maintains the correspondence between a string describing +// an FST type, and its reader and converter. +template +class FstRegister + : public GenericRegister, FstRegister> { + public: + using Reader = typename FstRegisterEntry::Reader; + using Converter = typename FstRegisterEntry::Converter; + + const Reader GetReader(const string &type) const { + return this->GetEntry(type).reader; + } + + const Converter GetConverter(const string &type) const { + return this->GetEntry(type).converter; + } + + protected: + string ConvertKeyToSoFilename(const string &key) const override { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + return legal_type + "-fst.so"; + } +}; + +// This class registers an FST type for generic reading and creating. +// The type must have a default constructor and a copy constructor from +// Fst. +template +class FstRegisterer : public GenericRegisterer> { + public: + using Arc = typename FST::Arc; + using Entry = typename FstRegister::Entry; + using Reader = typename FstRegister::Reader; + + FstRegisterer() + : GenericRegisterer>(FST().Type(), + BuildEntry()) {} + + private: + static Fst *ReadGeneric( + std::istream &strm, const FstReadOptions &opts) { + static_assert(std::is_base_of, FST>::value, + "FST class does not inherit from Fst"); + return FST::Read(strm, opts); + } + + static Entry BuildEntry() { + return Entry(&ReadGeneric, &FstRegisterer::Convert); + } + + static Fst *Convert(const Fst &fst) { return new FST(fst); } +}; + +// Convenience macro to generate static FstRegisterer instance. +#define REGISTER_FST(FST, Arc) \ + static fst::FstRegisterer> FST##_##Arc##_registerer + +// Converts an FST to the specified type. +template +Fst *Convert(const Fst &fst, const string &fst_type) { + auto *reg = FstRegister::GetRegister(); + const auto converter = reg->GetConverter(fst_type); + if (!converter) { + FSTERROR() << "Fst::Convert: Unknown FST type " << fst_type << " (arc type " + << Arc::Type() << ")"; + return nullptr; + } + return converter(fst); +} + +} // namespace fst + +#endif // FST_REGISTER_H_ diff --git a/projects/llm_framework/include/fst/relabel.h b/projects/llm_framework/include/fst/relabel.h new file mode 100644 index 00000000..0979b077 --- /dev/null +++ b/projects/llm_framework/include/fst/relabel.h @@ -0,0 +1,472 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to relabel an FST (either on input or output). + +#ifndef FST_RELABEL_H_ +#define FST_RELABEL_H_ + +#include +#include +#include +#include + +#include + +#include +#include + + +#include + +namespace fst { + +// Relabels either the input labels or output labels. The old to +// new labels are specified using a vector of std::pair. +// Any label associations not specified are assumed to be identity +// mapping. The destination labels must be valid labels (e.g., not kNoLabel). +template +void Relabel( + MutableFst *fst, + const std::vector> + &ipairs, + const std::vector> + &opairs) { + using Label = typename Arc::Label; + const auto props = fst->Properties(kFstProperties, false); + // Constructs label-to-label maps. + const std::unordered_map input_map( + ipairs.begin(), ipairs.end()); + const std::unordered_map output_map( + opairs.begin(), opairs.end()); + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + for (MutableArcIterator> aiter(fst, siter.Value()); + !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + // Relabels input. + auto it = input_map.find(arc.ilabel); + if (it != input_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Input symbol ID " << arc.ilabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.ilabel = it->second; + } + // Relabels output. + it = output_map.find(arc.olabel); + if (it != output_map.end()) { + if (it->second == kNoLabel) { + FSTERROR() << "Output symbol id " << arc.olabel + << " missing from target vocabulary"; + fst->SetProperties(kError, kError); + return; + } + arc.olabel = it->second; + } + aiter.SetValue(arc); + } + } + fst->SetProperties(RelabelProperties(props), kFstProperties); +} + +// Relabels either the input labels or output labels. The old to +// new labels are specified using pairs of old and new symbol tables. +// The tables must contain (at least) all labels on the appropriate side of the +// FST. If the 'unknown_i(o)symbol' is non-empty, it is used to label any +// missing symbol in new_i(o)symbols table. +template +void Relabel(MutableFst *fst, + const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, + const string &unknown_isymbol, bool attach_new_isymbols, + const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, + const string &unknown_osymbol, bool attach_new_osymbols) { + using Label = typename Arc::Label; + // Constructs vectors of input-side label pairs. + std::vector> ipairs; + if (old_isymbols && new_isymbols) { + size_t num_missing_syms = 0; + Label unknown_ilabel = kNoLabel; + if (!unknown_isymbol.empty()) { + unknown_ilabel = new_isymbols->Find(unknown_isymbol); + if (unknown_ilabel == kNoLabel) { + VLOG(1) << "Input symbol '" << unknown_isymbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + + for (SymbolTableIterator siter(*old_isymbols); !siter.Done(); + siter.Next()) { + const auto old_index = siter.Value(); + const auto symbol = siter.Symbol(); + auto new_index = new_isymbols->Find(siter.Symbol()); + if (new_index == kNoLabel) { + if (unknown_ilabel != kNoLabel) { + new_index = unknown_ilabel; + } else { + VLOG(1) << "Input symbol ID " << old_index << " symbol '" << symbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + ipairs.push_back(std::make_pair(old_index, new_index)); + } + if (num_missing_syms > 0) { + LOG(WARNING) << "Target symbol table missing: " << num_missing_syms + << " input symbols"; + } + if (attach_new_isymbols) fst->SetInputSymbols(new_isymbols); + } + // Constructs vectors of output-side label pairs. + std::vector> opairs; + if (old_osymbols && new_osymbols) { + size_t num_missing_syms = 0; + Label unknown_olabel = kNoLabel; + if (!unknown_osymbol.empty()) { + unknown_olabel = new_osymbols->Find(unknown_osymbol); + if (unknown_olabel == kNoLabel) { + VLOG(1) << "Output symbol '" << unknown_osymbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + + for (SymbolTableIterator siter(*old_osymbols); !siter.Done(); + siter.Next()) { + const auto old_index = siter.Value(); + const auto symbol = siter.Symbol(); + auto new_index = new_osymbols->Find(siter.Symbol()); + if (new_index == kNoLabel) { + if (unknown_olabel != kNoLabel) { + new_index = unknown_olabel; + } else { + VLOG(1) << "Output symbol ID " << old_index << " symbol '" << symbol + << "' missing from target symbol table"; + ++num_missing_syms; + } + } + opairs.push_back(std::make_pair(old_index, new_index)); + } + if (num_missing_syms > 0) { + LOG(WARNING) << "Target symbol table missing: " << num_missing_syms + << " output symbols"; + } + if (attach_new_osymbols) fst->SetOutputSymbols(new_osymbols); + } + // Calls relabel using vector of relabel pairs. + Relabel(fst, ipairs, opairs); +} + +// Same as previous but no special allowance for unknown symbols. Kept +// for backward compat. +template +void Relabel(MutableFst *fst, const SymbolTable *old_isymbols, + const SymbolTable *new_isymbols, bool attach_new_isymbols, + const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, + bool attach_new_osymbols) { + Relabel(fst, + old_isymbols, new_isymbols, "" /* no unknown isymbol */, + attach_new_isymbols, + old_osymbols, new_osymbols, "" /* no unknown ioymbol */, + attach_new_osymbols); +} + + +// Relabels either the input labels or output labels. The old to +// new labels are specified using symbol tables. Any label associations not +// specified are assumed to be identity mapping. +template +void Relabel(MutableFst *fst, const SymbolTable *new_isymbols, + const SymbolTable *new_osymbols) { + Relabel(fst, fst->InputSymbols(), new_isymbols, true, fst->OutputSymbols(), + new_osymbols, true); +} + +using RelabelFstOptions = CacheOptions; + +template +class RelabelFst; + +namespace internal { + +// Relabels an FST from one symbol set to another. Relabeling can either be on +// input or output space. RelabelFst implements a delayed version of the +// relabel. Arcs are relabeled on the fly and not cached; i.e., each request is +// recomputed. +template +class RelabelFstImpl : public CacheImpl { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::WriteHeader; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::PushArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + friend class StateIterator>; + + RelabelFstImpl(const Fst &fst, + const std::vector> &ipairs, + const std::vector> &opairs, + const RelabelFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + input_map_(ipairs.begin(), ipairs.end()), + output_map_(opairs.begin(), opairs.end()), + relabel_input_(!ipairs.empty()), + relabel_output_(!opairs.empty()) { + SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false))); + SetType("relabel"); + } + + RelabelFstImpl(const Fst &fst, + const SymbolTable *old_isymbols, + const SymbolTable *new_isymbols, + const SymbolTable *old_osymbols, + const SymbolTable *new_osymbols, + const RelabelFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + relabel_input_(false), + relabel_output_(false) { + SetType("relabel"); + SetProperties(RelabelProperties(fst.Properties(kCopyProperties, false))); + SetInputSymbols(old_isymbols); + SetOutputSymbols(old_osymbols); + if (old_isymbols && new_isymbols && + old_isymbols->LabeledCheckSum() != new_isymbols->LabeledCheckSum()) { + for (SymbolTableIterator siter(*old_isymbols); !siter.Done(); + siter.Next()) { + input_map_[siter.Value()] = new_isymbols->Find(siter.Symbol()); + } + SetInputSymbols(new_isymbols); + relabel_input_ = true; + } + if (old_osymbols && new_osymbols && + old_osymbols->LabeledCheckSum() != new_osymbols->LabeledCheckSum()) { + for (SymbolTableIterator siter(*old_osymbols); !siter.Done(); + siter.Next()) { + output_map_[siter.Value()] = new_osymbols->Find(siter.Symbol()); + } + SetOutputSymbols(new_osymbols); + relabel_output_ = true; + } + } + + RelabelFstImpl(const RelabelFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + input_map_(impl.input_map_), + output_map_(impl.output_map_), + relabel_input_(impl.relabel_input_), + relabel_output_(impl.relabel_output_) { + SetType("relabel"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) SetStart(fst_->Start()); + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) SetFinal(s, fst_->Final(s)); + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found, and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && fst_->Properties(kError, false)) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + for (ArcIterator> aiter(*fst_, s); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); + if (relabel_input_) { + auto it = input_map_.find(arc.ilabel); + if (it != input_map_.end()) arc.ilabel = it->second; + } + if (relabel_output_) { + auto it = output_map_.find(arc.olabel); + if (it != output_map_.end()) { + arc.olabel = it->second; + } + } + PushArc(s, std::move(arc)); + } + SetArcs(s); + } + + private: + std::unique_ptr> fst_; + + std::unordered_map input_map_; + std::unordered_map output_map_; + bool relabel_input_; + bool relabel_output_; +}; + +} // namespace internal + +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RelabelFst : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::RelabelFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + RelabelFst(const Fst &fst, + const std::vector> &ipairs, + const std::vector> &opairs, + const RelabelFstOptions &opts = RelabelFstOptions()) + : ImplToFst(std::make_shared(fst, ipairs, opairs, opts)) {} + + RelabelFst(const Fst &fst, const SymbolTable *new_isymbols, + const SymbolTable *new_osymbols, + const RelabelFstOptions &opts = RelabelFstOptions()) + : ImplToFst( + std::make_shared(fst, fst.InputSymbols(), new_isymbols, + fst.OutputSymbols(), new_osymbols, opts)) {} + + RelabelFst(const Fst &fst, const SymbolTable *old_isymbols, + const SymbolTable *new_isymbols, const SymbolTable *old_osymbols, + const SymbolTable *new_osymbols, + const RelabelFstOptions &opts = RelabelFstOptions()) + : ImplToFst(std::make_shared(fst, old_isymbols, new_isymbols, + old_osymbols, new_osymbols, + opts)) {} + + // See Fst<>::Copy() for doc. + RelabelFst(const RelabelFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this RelabelFst. See Fst<>::Copy() for further doc. + RelabelFst *Copy(bool safe = false) const override { + return new RelabelFst(*this, safe); + } + + void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + return GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + RelabelFst &operator=(const RelabelFst &) = delete; +}; + +// Specialization for RelabelFst. +template +class StateIterator> : public StateIteratorBase { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const RelabelFst &fst) + : impl_(fst.GetImpl()), siter_(*impl_->fst_), s_(0) {} + + bool Done() const final { return siter_.Done(); } + + StateId Value() const final { return s_; } + + void Next() final { + if (!siter_.Done()) { + ++s_; + siter_.Next(); + } + } + + void Reset() final { + s_ = 0; + siter_.Reset(); + } + + private: + const internal::RelabelFstImpl* impl_; + StateIterator> siter_; + StateId s_; + + StateIterator(const StateIterator &) = delete; + StateIterator &operator=(const StateIterator &) = delete; +}; + +// Specialization for RelabelFst. +template +class ArcIterator> : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const RelabelFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void RelabelFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Useful alias when using StdArc. +using StdRelabelFst = RelabelFst; + +} // namespace fst + +#endif // FST_RELABEL_H_ diff --git a/projects/llm_framework/include/fst/replace-util.h b/projects/llm_framework/include/fst/replace-util.h new file mode 100644 index 00000000..42c69824 --- /dev/null +++ b/projects/llm_framework/include/fst/replace-util.h @@ -0,0 +1,629 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Utility classes for the recursive replacement of FSTs (RTNs). + +#ifndef FST_REPLACE_UTIL_H_ +#define FST_REPLACE_UTIL_H_ + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + + +namespace fst { + +// This specifies what labels to output on the call or return arc. Note that +// REPLACE_LABEL_INPUT and REPLACE_LABEL_OUTPUT will produce transducers when +// applied to acceptors. +enum ReplaceLabelType { + // Epsilon labels on both input and output. + REPLACE_LABEL_NEITHER = 1, + // Non-epsilon labels on input and epsilon on output. + REPLACE_LABEL_INPUT = 2, + // Epsilon on input and non-epsilon on output. + REPLACE_LABEL_OUTPUT = 3, + // Non-epsilon labels on both input and output. + REPLACE_LABEL_BOTH = 4 +}; + +// By default ReplaceUtil will copy the input label of the replace arc. +// The call_label_type and return_label_type options specify how to manage +// the labels of the call arc and the return arc of the replace FST +struct ReplaceUtilOptions { + int64 root; // Root rule for expansion. + ReplaceLabelType call_label_type; // How to label call arc. + ReplaceLabelType return_label_type; // How to label return arc. + int64 return_label; // Label to put on return arc. + + explicit ReplaceUtilOptions( + int64 root = kNoLabel, + ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT, + ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER, + int64 return_label = 0) + : root(root), + call_label_type(call_label_type), + return_label_type(return_label_type), + return_label(return_label) {} + + // For backwards compatibility. + ReplaceUtilOptions(int64 root, bool epsilon_replace_arc) + : ReplaceUtilOptions(root, + epsilon_replace_arc ? REPLACE_LABEL_NEITHER + : REPLACE_LABEL_INPUT) {} +}; + +// Every non-terminal on a path appears as the first label on that path in every +// FST associated with a given SCC of the replace dependency graph. This would +// be true if the SCC were formed from left-linear grammar rules. +constexpr uint8 kReplaceSCCLeftLinear = 0x01; +// Every non-terminal on a path appears as the final label on that path in every +// FST associated with a given SCC of the replace dependency graph. This would +// be true if the SCC were formed from right-linear grammar rules. +constexpr uint8 kReplaceSCCRightLinear = 0x02; +// The SCC in the replace dependency graph has more than one state or a +// self-loop. +constexpr uint8 kReplaceSCCNonTrivial = 0x04; + +// Defined in replace.h. +template +void Replace( + const std::vector *>> &, + MutableFst *, const ReplaceUtilOptions &); + +// Utility class for the recursive replacement of FSTs (RTNs). The user provides +// a set of label/FST pairs at construction. These are used by methods for +// testing cyclic dependencies and connectedness and doing RTN connection and +// specific FST replacement by label or for various optimization properties. The +// modified results can be obtained with the GetFstPairs() or +// GetMutableFstPairs() methods. +template +class ReplaceUtil { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstPair = std::pair *>; + using MutableFstPair = std::pair *>; + using NonTerminalHash = std::unordered_map; + + // Constructs from mutable FSTs; FST ownership is given to ReplaceUtil. + ReplaceUtil(const std::vector &fst_pairs, + const ReplaceUtilOptions &opts); + + // Constructs from FSTs; FST ownership is retained by caller. + ReplaceUtil(const std::vector &fst_pairs, + const ReplaceUtilOptions &opts); + + // Constructs from ReplaceFst internals; FST ownership is retained by caller. + ReplaceUtil(const std::vector>> &fst_array, + const NonTerminalHash &nonterminal_hash, + const ReplaceUtilOptions &opts); + + ~ReplaceUtil() { + for (Label i = 0; i < fst_array_.size(); ++i) delete fst_array_[i]; + } + + // True if the non-terminal dependencies are cyclic. Cyclic dependencies will + // result in an unexpandable FST. + bool CyclicDependencies() const { + GetDependencies(false); + return depprops_ & kCyclic; + } + + // Returns the strongly-connected component ID in the dependency graph of the + // replace FSTS. + StateId SCC(Label label) const { + GetDependencies(false); + const auto it = nonterminal_hash_.find(label); + if (it == nonterminal_hash_.end()) return kNoStateId; + return depscc_[it->second]; + } + + // Returns properties for the strongly-connected component in the dependency + // graph of the replace FSTs. If the SCC is kReplaceSCCLeftLinear or + // kReplaceSCCRightLinear, that SCC can be represented as finite-state despite + // any cyclic dependencies, but not by the usual replacement operation (see + // fst/extensions/pdt/replace.h). + uint8 SCCProperties(StateId scc_id) { + GetSCCProperties(); + return depsccprops_[scc_id]; + } + + // Returns true if no useless FSTs, states or transitions are present in the + // RTN. + bool Connected() const { + GetDependencies(false); + uint64 props = kAccessible | kCoAccessible; + for (Label i = 0; i < fst_array_.size(); ++i) { + if (!fst_array_[i]) continue; + if (fst_array_[i]->Properties(props, true) != props || !depaccess_[i]) { + return false; + } + } + return true; + } + + // Removes useless FSTs, states and transitions from the RTN. + void Connect(); + + // Replaces FSTs specified by labels, unless there are cyclic dependencies. + void ReplaceLabels(const std::vector */> +class ReplaceFst + : public ImplToFst> { + public: + using Arc = A; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using StateTable = T; + using Store = CacheStore; + using State = typename CacheStore::State; + using Impl = internal::ReplaceFstImpl; + using CacheImpl = internal::CacheBaseImpl; + + using ImplToFst::Properties; + + friend class ArcIterator>; + friend class StateIterator>; + friend class ReplaceFstMatcher; + + ReplaceFst(const std::vector *>> &fst_array, + Label root) + : ImplToFst(std::make_shared( + fst_array, ReplaceFstOptions(root))) {} + + ReplaceFst(const std::vector *>> &fst_array, + const ReplaceFstOptions &opts) + : ImplToFst(std::make_shared(fst_array, opts)) {} + + // See Fst<>::Copy() for doc. + ReplaceFst(const ReplaceFst &fst, + bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this ReplaceFst. See Fst<>::Copy() for further doc. + ReplaceFst *Copy( + bool safe = false) const override { + return new ReplaceFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + MatcherBase *InitMatcher(MatchType match_type) const override { + if ((GetImpl()->ArcIteratorFlags() & kArcNoCache) && + ((match_type == MATCH_INPUT && Properties(kILabelSorted, false)) || + (match_type == MATCH_OUTPUT && Properties(kOLabelSorted, false)))) { + return new ReplaceFstMatcher + (this, match_type); + } else { + VLOG(2) << "Not using replace matcher"; + return nullptr; + } + } + + bool CyclicDependencies() const { return GetImpl()->CyclicDependencies(); } + + const StateTable &GetStateTable() const { + return *GetImpl()->GetStateTable(); + } + + const Fst &GetFst(Label nonterminal) const { + return *GetImpl()->GetFst(GetImpl()->GetFstId(nonterminal)); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + ReplaceFst &operator=(const ReplaceFst &) = delete; +}; + +// Specialization for ReplaceFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const ReplaceFst &fst) + : CacheStateIterator>( + fst, fst.GetMutableImpl()) {} +}; + +// Specialization for ReplaceFst, implementing optional caching. It is be used +// as follows: +// +// ReplaceFst replace; +// ArcIterator> aiter(replace, s); +// // Note: ArcIterator< Fst> is always a caching arc iterator. +// aiter.SetFlags(kArcNoCache, kArcNoCache); +// // Uses the arc iterator, no arc will be cached, no state will be expanded. +// // Arc flags can be used to decide which component of the arc need to be +// computed. +// aiter.SetFlags(kArcILabelValue, kArcValueFlags); +// // Wants the ilabel for this arc. +// aiter.Value(); // Does not compute the destination state. +// aiter.Next(); +// aiter.SetFlags(kArcNextStateValue, kArcNextStateValue); +// // Wants the ilabel and next state for this arc. +// aiter.Value(); // Does compute the destination state and inserts it +// // in the replace state table. +// // No additional arcs have been cached at this point. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + + using StateTuple = typename StateTable::StateTuple; + + ArcIterator(const ReplaceFst &fst, StateId s) + : fst_(fst), + s_(s), + pos_(0), + offset_(0), + flags_(kArcValueFlags), + arcs_(nullptr), + data_flags_(0), + final_flags_(0) { + cache_data_.ref_count = nullptr; + local_data_.ref_count = nullptr; + // If FST does not support optional caching, forces caching. + if (!(fst_.GetImpl()->ArcIteratorFlags() & kArcNoCache) && + !(fst_.GetImpl()->HasArcs(s_))) { + fst_.GetMutableImpl()->Expand(s_); + } + // If state is already cached, use cached arcs array. + if (fst_.GetImpl()->HasArcs(s_)) { + (fst_.GetImpl()) + ->internal::template CacheBaseImpl< + typename CacheStore::State, + CacheStore>::InitArcIterator(s_, &cache_data_); + num_arcs_ = cache_data_.narcs; + arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs. + data_flags_ = kArcValueFlags; // All the arc member values are valid. + } else { // Otherwise delay decision until Value() is called. + tuple_ = fst_.GetImpl()->GetStateTable()->Tuple(s_); + if (tuple_.fst_state == kNoStateId) { + num_arcs_ = 0; + } else { + // The decision to cache or not to cache has been defered until Value() + // or + // SetFlags() is called. However, the arc iterator is set up now to be + // ready for non-caching in order to keep the Value() method simple and + // efficient. + const auto *rfst = fst_.GetImpl()->GetFst(tuple_.fst_id); + rfst->InitArcIterator(tuple_.fst_state, &local_data_); + // arcs_ is a pointer to the arcs in the underlying machine. + arcs_ = local_data_.arcs; + // Computes the final arc (but not its destination state) if a final arc + // is required. + bool has_final_arc = fst_.GetMutableImpl()->ComputeFinalArc( + tuple_, &final_arc_, kArcValueFlags & ~kArcNextStateValue); + // Sets the arc value flags that hold for final_arc_. + final_flags_ = kArcValueFlags & ~kArcNextStateValue; + // Computes the number of arcs. + num_arcs_ = local_data_.narcs; + if (has_final_arc) ++num_arcs_; + // Sets the offset between the underlying arc positions and the + // positions + // in the arc iterator. + offset_ = num_arcs_ - local_data_.narcs; + // Defers the decision to cache or not until Value() or SetFlags() is + // called. + data_flags_ = 0; + } + } + } + + ~ArcIterator() { + if (cache_data_.ref_count) --(*cache_data_.ref_count); + if (local_data_.ref_count) --(*local_data_.ref_count); + } + + void ExpandAndCache() const { + // TODO(allauzen): revisit this. + // fst_.GetImpl()->Expand(s_, tuple_, local_data_); + // (fst_.GetImpl())->CacheImpl*>::InitArcIterator(s_, + // &cache_data_); + // + fst_.InitArcIterator(s_, &cache_data_); // Expand and cache state. + arcs_ = cache_data_.arcs; // arcs_ is a pointer to the cached arcs. + data_flags_ = kArcValueFlags; // All the arc member values are valid. + offset_ = 0; // No offset. + } + + void Init() { + if (flags_ & kArcNoCache) { // If caching is disabled + // arcs_ is a pointer to the arcs in the underlying machine. + arcs_ = local_data_.arcs; + // Sets the arcs value flags that hold for arcs_. + data_flags_ = kArcWeightValue; + if (!fst_.GetMutableImpl()->EpsilonOnCallInput()) { + data_flags_ |= kArcILabelValue; + } + // Sets the offset between the underlying arc positions and the positions + // in the arc iterator. + offset_ = num_arcs_ - local_data_.narcs; + } else { + ExpandAndCache(); + } + } + + bool Done() const { return pos_ >= num_arcs_; } + + const Arc &Value() const { + // If data_flags_ is 0, non-caching was not requested. + if (!data_flags_) { + // TODO(allauzen): Revisit this. + if (flags_ & kArcNoCache) { + // Should never happen. + FSTERROR() << "ReplaceFst: Inconsistent arc iterator flags"; + } + ExpandAndCache(); + } + if (pos_ - offset_ >= 0) { // The requested arc is not the final arc. + const auto &arc = arcs_[pos_ - offset_]; + if ((data_flags_ & flags_) == (flags_ & kArcValueFlags)) { + // If the value flags match the recquired value flags then returns the + // arc. + return arc; + } else { + // Otherwise, compute the corresponding arc on-the-fly. + fst_.GetMutableImpl()->ComputeArc(tuple_, arc, &arc_, + flags_ & kArcValueFlags); + return arc_; + } + } else { // The requested arc is the final arc. + if ((final_flags_ & flags_) != (flags_ & kArcValueFlags)) { + // If the arc value flags that hold for the final arc do not match the + // requested value flags, then + // final_arc_ needs to be updated. + fst_.GetMutableImpl()->ComputeFinalArc(tuple_, &final_arc_, + flags_ & kArcValueFlags); + final_flags_ = flags_ & kArcValueFlags; + } + return final_arc_; + } + } + + void Next() { ++pos_; } + + size_t Position() const { return pos_; } + + void Reset() { pos_ = 0; } + + void Seek(size_t pos) { pos_ = pos; } + + uint32 Flags() const { return flags_; } + + void SetFlags(uint32 flags, uint32 mask) { + // Updates the flags taking into account what flags are supported + // by the FST. + flags_ &= ~mask; + flags_ |= (flags & fst_.GetImpl()->ArcIteratorFlags()); + // If non-caching is not requested (and caching has not already been + // performed), then flush data_flags_ to request caching during the next + // call to Value(). + if (!(flags_ & kArcNoCache) && data_flags_ != kArcValueFlags) { + if (!fst_.GetImpl()->HasArcs(s_)) data_flags_ = 0; + } + // If data_flags_ has been flushed but non-caching is requested before + // calling Value(), then set up the iterator for non-caching. + if ((flags & kArcNoCache) && (!data_flags_)) Init(); + } + + private: + const ReplaceFst &fst_; // Reference to the FST. + StateId s_; // State in the FST. + mutable StateTuple tuple_; // Tuple corresponding to state_. + + ssize_t pos_; // Current position. + mutable ssize_t offset_; // Offset between position in iterator and in arcs_. + ssize_t num_arcs_; // Number of arcs at state_. + uint32 flags_; // Behavorial flags for the arc iterator + mutable Arc arc_; // Memory to temporarily store computed arcs. + + mutable ArcIteratorData cache_data_; // Arc iterator data in cache. + mutable ArcIteratorData local_data_; // Arc iterator data in local FST. + + mutable const Arc *arcs_; // Array of arcs. + mutable uint32 data_flags_; // Arc value flags valid for data in arcs_. + mutable Arc final_arc_; // Final arc (when required). + mutable uint32 final_flags_; // Arc value flags valid for final_arc_. + + ArcIterator(const ArcIterator &) = delete; + ArcIterator &operator=(const ArcIterator &) = delete; +}; + +template +class ReplaceFstMatcher : public MatcherBase { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FST = ReplaceFst; + using LocalMatcher = MultiEpsMatcher>>; + + using StateTuple = typename StateTable::StateTuple; + + // This makes a copy of the FST. + ReplaceFstMatcher(const ReplaceFst &fst, + MatchType match_type) + : owned_fst_(fst.Copy()), + fst_(*owned_fst_), + impl_(fst_.GetMutableImpl()), + s_(fst::kNoStateId), + match_type_(match_type), + current_loop_(false), + final_arc_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + InitMatchers(); + } + + // This doesn't copy the FST. + ReplaceFstMatcher(const ReplaceFst *fst, + MatchType match_type) + : fst_(*fst), + impl_(fst_.GetMutableImpl()), + s_(fst::kNoStateId), + match_type_(match_type), + current_loop_(false), + final_arc_(false), + loop_(kNoLabel, 0, Weight::One(), kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + InitMatchers(); + } + + // This makes a copy of the FST. + ReplaceFstMatcher( + const ReplaceFstMatcher &matcher, + bool safe = false) + : owned_fst_(matcher.fst_.Copy(safe)), + fst_(*owned_fst_), + impl_(fst_.GetMutableImpl()), + s_(fst::kNoStateId), + match_type_(matcher.match_type_), + current_loop_(false), + final_arc_(false), + loop_(fst::kNoLabel, 0, Weight::One(), fst::kNoStateId) { + if (match_type_ == fst::MATCH_OUTPUT) { + std::swap(loop_.ilabel, loop_.olabel); + } + InitMatchers(); + } + + // Creates a local matcher for each component FST in the RTN. LocalMatcher is + // a multi-epsilon wrapper matcher. MultiEpsilonMatcher is used to match each + // non-terminal arc, since these non-terminal + // turn into epsilons on recursion. + void InitMatchers() { + const auto &fst_array = impl_->fst_array_; + matcher_.resize(fst_array.size()); + for (Label i = 0; i < fst_array.size(); ++i) { + if (fst_array[i]) { + matcher_[i].reset( + new LocalMatcher(*fst_array[i], match_type_, kMultiEpsList)); + auto it = impl_->nonterminal_set_.begin(); + for (; it != impl_->nonterminal_set_.end(); ++it) { + matcher_[i]->AddMultiEpsLabel(*it); + } + } + } + } + + ReplaceFstMatcher *Copy( + bool safe = false) const override { + return new ReplaceFstMatcher(*this, safe); + } + + MatchType Type(bool test) const override { + if (match_type_ == MATCH_NONE) return match_type_; + const auto true_prop = + match_type_ == MATCH_INPUT ? kILabelSorted : kOLabelSorted; + const auto false_prop = + match_type_ == MATCH_INPUT ? kNotILabelSorted : kNotOLabelSorted; + const auto props = fst_.Properties(true_prop | false_prop, test); + if (props & true_prop) { + return match_type_; + } else if (props & false_prop) { + return MATCH_NONE; + } else { + return MATCH_UNKNOWN; + } + } + + const Fst &GetFst() const override { return fst_; } + + uint64 Properties(uint64 props) const override { return props; } + + // Sets the state from which our matching happens. + void SetState(StateId s) final { + if (s_ == s) return; + s_ = s; + tuple_ = impl_->GetStateTable()->Tuple(s_); + if (tuple_.fst_state == kNoStateId) { + done_ = true; + return; + } + // Gets current matcher, used for non-epsilon matching. + current_matcher_ = matcher_[tuple_.fst_id].get(); + current_matcher_->SetState(tuple_.fst_state); + loop_.nextstate = s_; + final_arc_ = false; + } + + // Searches for label from previous set state. If label == 0, first + // hallucinate an epsilon loop; otherwise use the underlying matcher to + // search for the label or epsilons. Note since the ReplaceFst recursion + // on non-terminal arcs causes epsilon transitions to be created we use + // MultiEpsilonMatcher to search for possible matches of non-terminals. If the + // component FST + // reaches a final state we also need to add the exiting final arc. + bool Find(Label label) final { + bool found = false; + label_ = label; + if (label_ == 0 || label_ == kNoLabel) { + // Computes loop directly, avoiding Replace::ComputeArc. + if (label_ == 0) { + current_loop_ = true; + found = true; + } + // Searches for matching multi-epsilons. + final_arc_ = impl_->ComputeFinalArc(tuple_, nullptr); + found = current_matcher_->Find(kNoLabel) || final_arc_ || found; + } else { + // Searches on a sub machine directly using sub machine matcher. + found = current_matcher_->Find(label_); + } + return found; + } + + bool Done() const final { + return !current_loop_ && !final_arc_ && current_matcher_->Done(); + } + + const Arc &Value() const final { + if (current_loop_) return loop_; + if (final_arc_) { + impl_->ComputeFinalArc(tuple_, &arc_); + return arc_; + } + const auto &component_arc = current_matcher_->Value(); + impl_->ComputeArc(tuple_, component_arc, &arc_); + return arc_; + } + + void Next() final { + if (current_loop_) { + current_loop_ = false; + return; + } + if (final_arc_) { + final_arc_ = false; + return; + } + current_matcher_->Next(); + } + + ssize_t Priority(StateId s) final { return fst_.NumArcs(s); } + + private: + std::unique_ptr> owned_fst_; + const ReplaceFst &fst_; + internal::ReplaceFstImpl *impl_; + LocalMatcher *current_matcher_; + std::vector> matcher_; + StateId s_; // Current state. + Label label_; // Current label. + MatchType match_type_; // Supplied by caller. + mutable bool done_; + mutable bool current_loop_; // Current arc is the implicit loop. + mutable bool final_arc_; // Current arc for exiting recursion. + mutable StateTuple tuple_; // Tuple corresponding to state_. + mutable Arc arc_; + Arc loop_; + + ReplaceFstMatcher &operator=(const ReplaceFstMatcher &) = delete; +}; + +template +inline void ReplaceFst::InitStateIterator( + StateIteratorData *data) const { + data->base = + new StateIterator>(*this); +} + +using StdReplaceFst = ReplaceFst; + +// Recursively replaces arcs in the root FSTs with other FSTs. +// This version writes the result of replacement to an output MutableFst. +// +// Replace supports replacement of arcs in one Fst with another FST. This +// replacement is recursive. Replace takes an array of FST(s). One FST +// represents the root (or topology) machine. The root FST refers to other FSTs +// by recursively replacing arcs labeled as non-terminals with the matching +// non-terminal FST. Currently Replace uses the output symbols of the arcs to +// determine whether the arc is a non-terminal arc or not. A non-terminal can be +// any label that is not a non-zero terminal label in the output alphabet. +// +// Note that input argument is a vector of pairs. These correspond to the tuple +// of non-terminal Label and corresponding FST. +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, + ReplaceFstOptions opts = ReplaceFstOptions()) { + opts.gc = true; + opts.gc_limit = 0; // Caches only the last state for fastest copy. + *ofst = ReplaceFst(ifst_array, opts); +} + +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, const ReplaceUtilOptions &opts) { + Replace(ifst_array, ofst, ReplaceFstOptions(opts)); +} + +// For backwards compatibility. +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, typename Arc::Label root, + bool epsilon_on_replace) { + Replace(ifst_array, ofst, ReplaceFstOptions(root, epsilon_on_replace)); +} + +template +void Replace(const std::vector *>> + &ifst_array, + MutableFst *ofst, typename Arc::Label root) { + Replace(ifst_array, ofst, ReplaceFstOptions(root)); +} + +} // namespace fst + +#endif // FST_REPLACE_H_ diff --git a/projects/llm_framework/include/fst/reverse.h b/projects/llm_framework/include/fst/reverse.h new file mode 100644 index 00000000..7c7c89db --- /dev/null +++ b/projects/llm_framework/include/fst/reverse.h @@ -0,0 +1,116 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to sort arcs in an FST. + +#ifndef FST_REVERSE_H_ +#define FST_REVERSE_H_ + +#include +#include + +#include + + +namespace fst { + +// Reverses an FST. The reversed result is written to an output mutable FST. +// If A transduces string x to y with weight a, then the reverse of A +// transduces the reverse of x to the reverse of y with weight a.Reverse(). +// +// Typically, a = a.Reverse() and an arc is its own reverse (e.g., for +// TropicalWeight or LogWeight). In general, e.g., when the weights only form a +// left or right semiring, the output arc type must match the input arc type +// except having the reversed Weight type. +// +// When require_superinitial is false, a superinitial state is not created in +// the reversed FST iff the input FST has exactly one final state (which becomes +// the initial state of the reversed FST) with a final weight of semiring One, +// or if it does not belong to any cycle. When require_superinitial is true, a +// superinitial state is always created. +template +void Reverse(const Fst &ifst, MutableFst *ofst, + bool require_superinitial = true) { + using StateId = typename FromArc::StateId; + using FromWeight = typename FromArc::Weight; + using ToWeight = typename ToArc::Weight; + ofst->DeleteStates(); + ofst->SetInputSymbols(ifst.InputSymbols()); + ofst->SetOutputSymbols(ifst.OutputSymbols()); + if (ifst.Properties(kExpanded, false)) { + ofst->ReserveStates(CountStates(ifst) + 1); + } + StateId istart = ifst.Start(); + StateId ostart = kNoStateId; + StateId offset = 0; + uint64 dfs_iprops = 0; + uint64 dfs_oprops = 0; + if (!require_superinitial) { + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (ifst.Final(s) == FromWeight::Zero()) continue; + if (ostart != kNoStateId) { + ostart = kNoStateId; + break; + } else { + ostart = s; + } + } + if (ostart != kNoStateId && ifst.Final(ostart) != FromWeight::One()) { + std::vector scc; + SccVisitor scc_visitor(&scc, nullptr, nullptr, &dfs_iprops); + DfsVisit(ifst, &scc_visitor); + if (count(scc.begin(), scc.end(), scc[ostart]) > 1) { + ostart = kNoStateId; + } else { + for (ArcIterator> aiter(ifst, ostart); !aiter.Done(); + aiter.Next()) { + if (aiter.Value().nextstate == ostart) { + ostart = kNoStateId; + break; + } + } + } + if (ostart != kNoStateId) dfs_oprops = kInitialAcyclic; + } + } + if (ostart == kNoStateId) { // Super-initial requested or needed. + ostart = ofst->AddState(); + offset = 1; + } + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + const auto is = siter.Value(); + const auto os = is + offset; + while (ofst->NumStates() <= os) ofst->AddState(); + if (is == istart) ofst->SetFinal(os, ToWeight::One()); + const auto weight = ifst.Final(is); + if ((weight != FromWeight::Zero()) && (offset == 1)) { + const ToArc oarc(0, 0, weight.Reverse(), os); + ofst->AddArc(0, oarc); + } + for (ArcIterator> aiter(ifst, is); !aiter.Done(); + aiter.Next()) { + const auto &iarc = aiter.Value(); + const auto nos = iarc.nextstate + offset; + auto weight = iarc.weight.Reverse(); + if (!offset && (nos == ostart)) { + weight = Times(ifst.Final(ostart).Reverse(), weight); + } + const ToArc oarc(iarc.ilabel, iarc.olabel, weight, os); + while (ofst->NumStates() <= nos) ofst->AddState(); + ofst->AddArc(nos, oarc); + } + } + ofst->SetStart(ostart); + if (offset == 0 && ostart == istart) { + ofst->SetFinal(ostart, ifst.Final(ostart).Reverse()); + } + const auto iprops = ifst.Properties(kCopyProperties, false) | dfs_iprops; + const auto oprops = ofst->Properties(kFstProperties, false) | dfs_oprops; + ofst->SetProperties(ReverseProperties(iprops, offset == 1) | oprops, + kFstProperties); +} + +} // namespace fst + +#endif // FST_REVERSE_H_ diff --git a/projects/llm_framework/include/fst/reweight.h b/projects/llm_framework/include/fst/reweight.h new file mode 100644 index 00000000..64e68cb7 --- /dev/null +++ b/projects/llm_framework/include/fst/reweight.h @@ -0,0 +1,127 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to reweight an FST. + +#ifndef FST_REWEIGHT_H_ +#define FST_REWEIGHT_H_ + +#include +#include + +#include + + +namespace fst { + +enum ReweightType { REWEIGHT_TO_INITIAL, REWEIGHT_TO_FINAL }; + +// Reweights an FST according to a vector of potentials in a given direction. +// The weight must be left distributive when reweighting towards the initial +// state and right distributive when reweighting towards the final states. +// +// An arc of weight w, with an origin state of potential p and destination state +// of potential q, is reweighted by p^-1 \otimes (w \otimes q) when reweighting +// torwards the initial state, and by (p \otimes w) \otimes q^-1 when +// reweighting towards the final states. +template +void Reweight(MutableFst *fst, + const std::vector &potential, + ReweightType type) { + using Weight = typename Arc::Weight; + if (fst->NumStates() == 0) return; + // TODO(kbg): Make this a compile-time static_assert once we have a pleasant + // way to "deregister" this operation for non-distributive semirings so an + // informative error message is produced. + if (type == REWEIGHT_TO_FINAL && !(Weight::Properties() & kRightSemiring)) { + FSTERROR() << "Reweight: Reweighting to the final states requires " + << "Weight to be right distributive: " << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + // TODO(kbg): Make this a compile-time static_assert once we have a pleasant + // way to "deregister" this operation for non-distributive semirings so an + // informative error message is produced. + if (type == REWEIGHT_TO_INITIAL && !(Weight::Properties() & kLeftSemiring)) { + FSTERROR() << "Reweight: Reweighting to the initial state requires " + << "Weight to be left distributive: " << Weight::Type(); + fst->SetProperties(kError, kError); + return; + } + StateIterator> siter(*fst); + for (; !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (s == potential.size()) break; + const auto &weight = potential[s]; + if (weight != Weight::Zero()) { + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + if (arc.nextstate >= potential.size()) continue; + const auto &nextweight = potential[arc.nextstate]; + if (nextweight == Weight::Zero()) continue; + if (type == REWEIGHT_TO_INITIAL) { + arc.weight = + Divide(Times(arc.weight, nextweight), weight, DIVIDE_LEFT); + } + if (type == REWEIGHT_TO_FINAL) { + arc.weight = + Divide(Times(weight, arc.weight), nextweight, DIVIDE_RIGHT); + } + aiter.SetValue(arc); + } + if (type == REWEIGHT_TO_INITIAL) { + fst->SetFinal(s, Divide(fst->Final(s), weight, DIVIDE_LEFT)); + } + } + if (type == REWEIGHT_TO_FINAL) { + fst->SetFinal(s, Times(weight, fst->Final(s))); + } + } + // This handles elements past the end of the potentials array. + for (; !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (type == REWEIGHT_TO_FINAL) { + fst->SetFinal(s, Times(Weight::Zero(), fst->Final(s))); + } + } + const auto startweight = fst->Start() < potential.size() + ? potential[fst->Start()] + : Weight::Zero(); + if ((startweight != Weight::One()) && (startweight != Weight::Zero())) { + if (fst->Properties(kInitialAcyclic, true) & kInitialAcyclic) { + const auto s = fst->Start(); + for (MutableArcIterator> aiter(fst, s); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + if (type == REWEIGHT_TO_INITIAL) { + arc.weight = Times(startweight, arc.weight); + } else { + arc.weight = Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT), + arc.weight); + } + aiter.SetValue(arc); + } + if (type == REWEIGHT_TO_INITIAL) { + fst->SetFinal(s, Times(startweight, fst->Final(s))); + } else { + fst->SetFinal(s, Times(Divide(Weight::One(), startweight, DIVIDE_RIGHT), + fst->Final(s))); + } + } else { + const auto s = fst->AddState(); + const auto weight = + (type == REWEIGHT_TO_INITIAL) + ? startweight + : Divide(Weight::One(), startweight, DIVIDE_RIGHT); + fst->AddArc(s, Arc(0, 0, weight, fst->Start())); + fst->SetStart(s); + } + } + fst->SetProperties(ReweightProperties(fst->Properties(kFstProperties, false)), + kFstProperties); +} + +} // namespace fst + +#endif // FST_REWEIGHT_H_ diff --git a/projects/llm_framework/include/fst/rmepsilon.h b/projects/llm_framework/include/fst/rmepsilon.h new file mode 100644 index 00000000..5135bf2d --- /dev/null +++ b/projects/llm_framework/include/fst/rmepsilon.h @@ -0,0 +1,548 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes that implemement epsilon-removal. + +#ifndef FST_RMEPSILON_H_ +#define FST_RMEPSILON_H_ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace fst { + +template +class RmEpsilonOptions + : public ShortestDistanceOptions> { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + bool connect; // Connect output + Weight weight_threshold; // Pruning weight threshold. + StateId state_threshold; // Pruning state threshold. + + explicit RmEpsilonOptions(Queue *queue, float delta = kShortestDelta, + bool connect = true, + Weight weight_threshold = Weight::Zero(), + StateId state_threshold = kNoStateId) + : ShortestDistanceOptions>( + queue, EpsilonArcFilter(), kNoStateId, delta), + connect(connect), + weight_threshold(std::move(weight_threshold)), + state_threshold(state_threshold) {} +}; + +namespace internal { + +// Computation state of the epsilon-removal algorithm. +template +class RmEpsilonState { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + RmEpsilonState(const Fst &fst, std::vector *distance, + const RmEpsilonOptions &opts) + : fst_(fst), + distance_(distance), + sd_state_(fst_, distance, opts, true), + expand_id_(0) {} + + void Expand(StateId s); + + std::vector &Arcs() { return arcs_; } + + const Weight &Final() const { return final_; } + + bool Error() const { return sd_state_.Error(); } + + private: + struct Element { + Label ilabel; + Label olabel; + StateId nextstate; + + Element() {} + + Element(Label ilabel, Label olabel, StateId nexstate) + : ilabel(ilabel), olabel(olabel), nextstate(nexstate) {} + }; + + struct ElementHash { + public: + size_t operator()(const Element &element) const { + static constexpr size_t prime0 = 7853; + static constexpr size_t prime1 = 7867; + return static_cast(element.nextstate) + + static_cast(element.ilabel) * prime0 + + static_cast(element.olabel) * prime1; + } + }; + + class ElementEqual { + public: + bool operator()(const Element &e1, const Element &e2) const { + return (e1.ilabel == e2.ilabel) && (e1.olabel == e2.olabel) && + (e1.nextstate == e2.nextstate); + } + }; + + using ElementMap = std::unordered_map, + ElementHash, ElementEqual>; + + const Fst &fst_; + // Distance from state being expanded in epsilon-closure. + std::vector *distance_; + // Shortest distance algorithm computation state. + internal::ShortestDistanceState> sd_state_; + // Maps an element to a pair corresponding to a position in the arcs vector + // of the state being expanded. The element corresopnds to the position in + // the arcs_ vector if p.first is equal to the state being expanded. + ElementMap element_map_; + EpsilonArcFilter eps_filter_; + std::stack eps_queue_; // Queue used to visit the epsilon-closure. + std::vector visited_; // True if the state has been visited. + std::forward_list visited_states_; // List of visited states. + std::vector arcs_; // Arcs of state being expanded. + Weight final_; // Final weight of state being expanded. + StateId expand_id_; // Unique ID for each call to Expand + + RmEpsilonState(const RmEpsilonState &) = delete; + RmEpsilonState &operator=(const RmEpsilonState &) = delete; +}; + +template +void RmEpsilonState::Expand(typename Arc::StateId source) { + final_ = Weight::Zero(); + arcs_.clear(); + sd_state_.ShortestDistance(source); + if (sd_state_.Error()) return; + eps_queue_.push(source); + while (!eps_queue_.empty()) { + const auto state = eps_queue_.top(); + eps_queue_.pop(); + while (visited_.size() <= state) visited_.push_back(false); + if (visited_[state]) continue; + visited_[state] = true; + visited_states_.push_front(state); + for (ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + auto arc = aiter.Value(); + arc.weight = Times((*distance_)[state], arc.weight); + if (eps_filter_(arc)) { + while (visited_.size() <= arc.nextstate) visited_.push_back(false); + if (!visited_[arc.nextstate]) eps_queue_.push(arc.nextstate); + } else { + const Element element(arc.ilabel, arc.olabel, arc.nextstate); + auto insert_result = element_map_.insert( + std::make_pair(element, std::make_pair(expand_id_, arcs_.size()))); + if (insert_result.second) { + arcs_.push_back(std::move(arc)); + } else { + if (insert_result.first->second.first == expand_id_) { + auto &weight = arcs_[insert_result.first->second.second].weight; + weight = Plus(weight, arc.weight); + } else { + insert_result.first->second.first = expand_id_; + insert_result.first->second.second = arcs_.size(); + arcs_.push_back(std::move(arc)); + } + } + } + } + final_ = Plus(final_, Times((*distance_)[state], fst_.Final(state))); + } + while (!visited_states_.empty()) { + visited_[visited_states_.front()] = false; + visited_states_.pop_front(); + } + ++expand_id_; +} + +} // namespace internal + +// Removes epsilon-transitions (when both the input and output label are an +// epsilon) from a transducer. The result will be an equivalent FST that has no +// such epsilon transitions. This version modifies its input. It allows fine +// control via the options argument; see below for a simpler interface. +// +// The distance vector will be used to hold the shortest distances during the +// epsilon-closure computation. The state queue discipline and convergence delta +// are taken in the options argument. +template +void RmEpsilon(MutableFst *fst, + std::vector *distance, + const RmEpsilonOptions &opts) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + if (fst->Start() == kNoStateId) return; + // noneps_in[s] will be set to true iff s admits a non-epsilon incoming + // transition or is the start state. + std::vector noneps_in(fst->NumStates(), false); + noneps_in[fst->Start()] = true; + for (size_t i = 0; i < fst->NumStates(); ++i) { + for (ArcIterator> aiter(*fst, i); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (arc.ilabel != 0 || arc.olabel != 0) { + noneps_in[arc.nextstate] = true; + } + } + } + // States sorted in topological order when (acyclic) or generic topological + // order (cyclic). + std::vector states; + states.reserve(fst->NumStates()); + if (fst->Properties(kTopSorted, false) & kTopSorted) { + for (size_t i = 0; i < fst->NumStates(); i++) states.push_back(i); + } else if (fst->Properties(kAcyclic, false) & kAcyclic) { + std::vector order; + bool acyclic; + TopOrderVisitor top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor, EpsilonArcFilter()); + // Sanity check: should be acyclic if property bit is set. + if (!acyclic) { + FSTERROR() << "RmEpsilon: Inconsistent acyclic property bit"; + fst->SetProperties(kError, kError); + return; + } + states.resize(order.size()); + for (StateId i = 0; i < order.size(); i++) states[order[i]] = i; + } else { + uint64 props; + std::vector scc; + SccVisitor scc_visitor(&scc, nullptr, nullptr, &props); + DfsVisit(*fst, &scc_visitor, EpsilonArcFilter()); + std::vector first(scc.size(), kNoStateId); + std::vector next(scc.size(), kNoStateId); + for (StateId i = 0; i < scc.size(); i++) { + if (first[scc[i]] != kNoStateId) next[i] = first[scc[i]]; + first[scc[i]] = i; + } + for (StateId i = 0; i < first.size(); i++) { + for (auto j = first[i]; j != kNoStateId; j = next[j]) { + states.push_back(j); + } + } + } + internal::RmEpsilonState rmeps_state(*fst, distance, opts); + while (!states.empty()) { + const auto state = states.back(); + states.pop_back(); + if (!noneps_in[state] && + (opts.connect || opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId)) { + continue; + } + rmeps_state.Expand(state); + fst->SetFinal(state, rmeps_state.Final()); + fst->DeleteArcs(state); + auto &arcs = rmeps_state.Arcs(); + fst->ReserveArcs(state, arcs.size()); + while (!arcs.empty()) { + fst->AddArc(state, arcs.back()); + arcs.pop_back(); + } + } + if (opts.connect || opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + for (size_t s = 0; s < fst->NumStates(); ++s) { + if (!noneps_in[s]) fst->DeleteArcs(s); + } + } + if (rmeps_state.Error()) fst->SetProperties(kError, kError); + fst->SetProperties( + RmEpsilonProperties(fst->Properties(kFstProperties, false)), + kFstProperties); + if (opts.weight_threshold != Weight::Zero() || + opts.state_threshold != kNoStateId) { + Prune(fst, opts.weight_threshold, opts.state_threshold); + } + if (opts.connect && opts.weight_threshold == Weight::Zero() && + opts.state_threshold == kNoStateId) { + Connect(fst); + } +} + +// Removes epsilon-transitions (when both the input and output label +// are an epsilon) from a transducer. The result will be an equivalent +// FST that has no such epsilon transitions. This version modifies its +// input. It has a simplified interface; see above for a version that +// allows finer control. +// +// Complexity: +// +// - Time: +// +// Unweighted: O(v^2 + ve). +// Acyclic: O(v^2 + V e). +// Tropical semiring: O(v^2 log V + ve). +// General: exponential. +// +// - Space: O(vE) +// +// where v is the number of states visited and e is the number of arcs visited. +// +// For more information, see: +// +// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers. International Journal of Computer +// Science 13(1): 129-143. +template +void RmEpsilon(MutableFst *fst, bool connect = true, + typename Arc::Weight weight_threshold = Arc::Weight::Zero(), + typename Arc::StateId state_threshold = kNoStateId, + float delta = kShortestDelta) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + std::vector distance; + AutoQueue state_queue(*fst, &distance, EpsilonArcFilter()); + RmEpsilonOptions> opts( + &state_queue, delta, connect, weight_threshold, state_threshold); + RmEpsilon(fst, &distance, opts); +} + +struct RmEpsilonFstOptions : CacheOptions { + float delta; + + explicit RmEpsilonFstOptions(const CacheOptions &opts, + float delta = kShortestDelta) + : CacheOptions(opts), delta(delta) {} + + explicit RmEpsilonFstOptions(float delta = kShortestDelta) : delta(delta) {} +}; + +namespace internal { + +// Implementation of delayed RmEpsilonFst. +template +class RmEpsilonFstImpl : public CacheImpl { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Store = DefaultCacheStore; + using State = typename Store::State; + + using FstImpl::Properties; + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheBaseImpl>::HasArcs; + using CacheBaseImpl>::HasFinal; + using CacheBaseImpl>::HasStart; + using CacheBaseImpl>::PushArc; + using CacheBaseImpl>::SetArcs; + using CacheBaseImpl>::SetFinal; + using CacheBaseImpl>::SetStart; + + RmEpsilonFstImpl(const Fst &fst, const RmEpsilonFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + delta_(opts.delta), + rmeps_state_( + *fst_, &distance_, + RmEpsilonOptions>(&queue_, delta_, false)) { + SetType("rmepsilon"); + SetProperties( + RmEpsilonProperties(fst.Properties(kFstProperties, false), true), + kCopyProperties); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + } + + RmEpsilonFstImpl(const RmEpsilonFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + delta_(impl.delta_), + rmeps_state_( + *fst_, &distance_, + RmEpsilonOptions>(&queue_, delta_, false)) { + SetType("rmepsilon"); + SetProperties(impl.Properties(), kCopyProperties); + SetInputSymbols(impl.InputSymbols()); + SetOutputSymbols(impl.OutputSymbols()); + } + + StateId Start() { + if (!HasStart()) SetStart(fst_->Start()); + return CacheImpl::Start(); + } + + Weight Final(StateId s) { + if (!HasFinal(s)) Expand(s); + return CacheImpl::Final(s); + } + + size_t NumArcs(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumArcs(s); + } + + size_t NumInputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(StateId s) { + if (!HasArcs(s)) Expand(s); + return CacheImpl::NumOutputEpsilons(s); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + // Sets error if found and returns other FST impl properties. + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && + (fst_->Properties(kError, false) || rmeps_state_.Error())) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) { + if (!HasArcs(s)) Expand(s); + CacheImpl::InitArcIterator(s, data); + } + + void Expand(StateId s) { + rmeps_state_.Expand(s); + SetFinal(s, rmeps_state_.Final()); + auto &arcs = rmeps_state_.Arcs(); + while (!arcs.empty()) { + PushArc(s, std::move(arcs.back())); + arcs.pop_back(); + } + SetArcs(s); + } + + private: + std::unique_ptr> fst_; + float delta_; + std::vector distance_; + FifoQueue queue_; + internal::RmEpsilonState> rmeps_state_; +}; + +} // namespace internal + +// Removes epsilon-transitions (when both the input and output label are an +// epsilon) from a transducer. The result will be an equivalent FST that has no +// such epsilon transitions. This version is a +// delayed FST. +// +// Complexity: +// +// - Time: +// Unweighted: O(v^2 + ve). +// General: exponential. +// +// - Space: O(vE) +// +// where v is the number of states visited and e is the number of arcs visited. +// Constant time to visit an input state or arc is assumed and exclusive of +// caching. +// +// For more information, see: +// +// Mohri, M. 2002. Generic epsilon-removal and input epsilon-normalization +// algorithms for weighted transducers. International Journal of Computer +// Science 13(1): 129-143. +// +// This class attaches interface to implementation and handles +// reference counting, delegating most methods to ImplToFst. +template +class RmEpsilonFst : public ImplToFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::RmEpsilonFstImpl; + + friend class ArcIterator>; + friend class StateIterator>; + + explicit RmEpsilonFst(const Fst &fst) + : ImplToFst(std::make_shared(fst, RmEpsilonFstOptions())) {} + + RmEpsilonFst(const Fst &fst, const RmEpsilonFstOptions &opts) + : ImplToFst(std::make_shared(fst, opts)) {} + + // See Fst<>::Copy() for doc. + RmEpsilonFst(const RmEpsilonFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this RmEpsilonFst. See Fst<>::Copy() for further doc. + RmEpsilonFst *Copy(bool safe = false) const override { + return new RmEpsilonFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + RmEpsilonFst &operator=(const RmEpsilonFst &) = delete; +}; + +// Specialization for RmEpsilonFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const RmEpsilonFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for RmEpsilonFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const RmEpsilonFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void RmEpsilonFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Useful alias when using StdArc. +using StdRmEpsilonFst = RmEpsilonFst; + +} // namespace fst + +#endif // FST_RMEPSILON_H_ diff --git a/projects/llm_framework/include/fst/rmfinalepsilon.h b/projects/llm_framework/include/fst/rmfinalepsilon.h new file mode 100644 index 00000000..87e3a714 --- /dev/null +++ b/projects/llm_framework/include/fst/rmfinalepsilon.h @@ -0,0 +1,80 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to remove of final states that have epsilon-only input arcs. + +#ifndef FST_RMFINALEPSILON_H_ +#define FST_RMFINALEPSILON_H_ + +#include +#include + +#include +#include + + +namespace fst { + +// Removes final states that have epsilon-only input arcs. +template +void RmFinalEpsilon(MutableFst *fst) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + // Determines the coaccesibility of states. + std::vector access; + std::vector coaccess; + uint64 props = 0; + SccVisitor scc_visitor(nullptr, &access, &coaccess, &props); + DfsVisit(*fst, &scc_visitor); + // Finds potential list of removable final states. These are final states that + // have no outgoing transitions or final states that have a non-coaccessible + // future. + std::unordered_set finals; + for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (fst->Final(s) != Weight::Zero()) { + bool future_coaccess = false; + for (ArcIterator> aiter(*fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (coaccess[arc.nextstate]) { + future_coaccess = true; + break; + } + } + if (!future_coaccess) finals.insert(s); + } + } + // Moves the final weight. + std::vector arcs; + for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + auto weight = fst->Final(s); + arcs.clear(); + for (ArcIterator> aiter(*fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + // Next state is in the list of finals. + if (finals.find(arc.nextstate) != finals.end()) { + // Sums up all epsilon arcs. + if (arc.ilabel == 0 && arc.olabel == 0) { + weight = Plus(Times(fst->Final(arc.nextstate), arc.weight), weight); + } else { + arcs.push_back(arc); + } + } else { + arcs.push_back(arc); + } + } + // If some arcs (epsilon arcs) were deleted, delete all arcs and add back + // only the non-epsilon arcs. + if (arcs.size() < fst->NumArcs(s)) { + fst->DeleteArcs(s); + fst->SetFinal(s, weight); + for (const auto &arc : arcs) fst->AddArc(s, arc); + } + } + Connect(fst); +} + +} // namespace fst + +#endif // FST_RMFINALEPSILON_H_ diff --git a/projects/llm_framework/include/fst/script/arc-class.h b/projects/llm_framework/include/fst/script/arc-class.h new file mode 100644 index 00000000..551266d7 --- /dev/null +++ b/projects/llm_framework/include/fst/script/arc-class.h @@ -0,0 +1,40 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ARC_CLASS_H_ +#define FST_SCRIPT_ARC_CLASS_H_ + +#include + +namespace fst { +namespace script { + +// A struct representing an arc while ignoring arc type. It is passed as an +// argument to AddArc. + +struct ArcClass { + template + explicit ArcClass(const Arc &arc) + : ilabel(arc.ilabel), olabel(arc.olabel), weight(arc.weight), + nextstate(arc.nextstate) {} + + ArcClass(int64 ilabel, int64 olabel, const WeightClass &weight, + int64 nextstate) + : ilabel(ilabel), olabel(olabel), weight(weight), nextstate(nextstate) {} + + template + Arc GetArc() const { + return Arc(ilabel, olabel, *(weight.GetWeight()), + nextstate); + } + + int64 ilabel; + int64 olabel; + WeightClass weight; + int64 nextstate; +}; + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARC_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/arciterator-class.h b/projects/llm_framework/include/fst/script/arciterator-class.h new file mode 100644 index 00000000..8e4ca4f8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/arciterator-class.h @@ -0,0 +1,212 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ARCITERATOR_CLASS_H_ +#define FST_SCRIPT_ARCITERATOR_CLASS_H_ + +#include +#include + +#include +#include + +// Scripting API support for ArcIterator. +// +// A call to Value() causes the underlying arc to be used to construct the +// associated ArcClass. + +namespace fst { +namespace script { + +// Non-mutable arc iterators. + +// Virtual interface implemented by each concrete ArcIteratorImpl. +class ArcIteratorImplBase { + public: + virtual bool Done() const = 0; + virtual uint32 Flags() const = 0; + virtual void Next() = 0; + virtual size_t Position() const = 0; + virtual void Reset() = 0; + virtual void Seek(size_t a) = 0; + virtual void SetFlags(uint32 flags, uint32 mask) = 0; + virtual ArcClass Value() const = 0; + virtual ~ArcIteratorImplBase() {} +}; + +// Templated implementation. +template +class ArcIteratorClassImpl : public ArcIteratorImplBase { + public: + explicit ArcIteratorClassImpl(const Fst &fst, int64 s) + : aiter_(fst, s) {} + + bool Done() const final { return aiter_.Done(); } + + uint32 Flags() const final { return aiter_.Flags(); } + + void Next() final { aiter_.Next(); } + + size_t Position() const final { return aiter_.Position(); } + + void Reset() final { aiter_.Reset(); } + + void Seek(size_t a) final { aiter_.Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) final { + aiter_.SetFlags(flags, mask); + } + + // This is returned by value because it has not yet been constructed, and + // is likely to participate in return-value optimization. + ArcClass Value() const final { return ArcClass(aiter_.Value()); } + + ~ArcIteratorClassImpl() final {} + + private: + ArcIterator> aiter_; +}; + +class ArcIteratorClass; + +using InitArcIteratorClassArgs = + std::tuple; + +// Untemplated user-facing class holding a templated pimpl. +class ArcIteratorClass { + public: + ArcIteratorClass(const FstClass &fst, int64 s); + + template + ArcIteratorClass(const Fst &fst, int64 s) + : impl_(new ArcIteratorClassImpl(fst, s)) {} + + bool Done() const { return impl_->Done(); } + + uint32 Flags() const { return impl_->Flags(); } + + void Next() { impl_->Next(); } + + size_t Position() const { return impl_->Position(); } + + void Reset() { impl_->Reset(); } + + void Seek(size_t a) { impl_->Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) { impl_->SetFlags(flags, mask); } + + ArcClass Value() const { return impl_->Value(); } + + template + friend void InitArcIteratorClass(InitArcIteratorClassArgs *args); + + private: + std::unique_ptr impl_; +}; + +template +void InitArcIteratorClass(InitArcIteratorClassArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + std::get<2>(*args)->impl_.reset( + new ArcIteratorClassImpl(fst, std::get<1>(*args))); +} + +// Mutable arc iterators. + +// Virtual interface implemented by each concrete MutableArcIteratorImpl. +class MutableArcIteratorImplBase : public ArcIteratorImplBase { + public: + virtual void SetValue(const ArcClass &) = 0; + + ~MutableArcIteratorImplBase() override {} +}; + +// Templated implementation. +template +class MutableArcIteratorClassImpl + : public MutableArcIteratorImplBase { + public: + explicit MutableArcIteratorClassImpl(MutableFst *fst, int64 s) + : aiter_(fst, s) {} + + bool Done() const final { return aiter_.Done(); } + + uint32 Flags() const final { return aiter_.Flags(); } + + void Next() final { aiter_.Next(); } + + size_t Position() const final { return aiter_.Position(); } + + void Reset() final { aiter_.Reset(); } + + void Seek(size_t a) final { aiter_.Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) final { + aiter_.SetFlags(flags, mask); + } + + void SetValue(const Arc &arc) { aiter_.SetValue(arc); } + + void SetValue(const ArcClass &ac) final { aiter_.SetValue(ac.GetArc()); } + + // This is returned by value because it has not yet been constructed, and + // is likely to participate in return-value optimization. + ArcClass Value() const final { return ArcClass(aiter_.Value()); } + + ~MutableArcIteratorClassImpl() override {} + + private: + MutableArcIterator> aiter_; +}; + +class MutableArcIteratorClass; + +using InitMutableArcIteratorClassArgs = + std::tuple; + +// Untemplated user-facing class holding a templated pimpl. +class MutableArcIteratorClass { + public: + MutableArcIteratorClass(MutableFstClass *fst, int64 s); + + template + MutableArcIteratorClass(MutableFst *fst, int64 s) + : impl_(new MutableArcIteratorClassImpl(fst, s)) {} + + bool Done() const { return impl_->Done(); } + + uint32 Flags() const { return impl_->Flags(); } + + void Next() { impl_->Next(); } + + size_t Position() const { return impl_->Position(); } + + void Reset() { impl_->Reset(); } + + void Seek(size_t a) { impl_->Seek(a); } + + void SetFlags(uint32 flags, uint32 mask) { impl_->SetFlags(flags, mask); } + + void SetValue(const ArcClass &ac) { impl_->SetValue(ac); } + + ArcClass Value() const { return impl_->Value(); } + + template + friend void InitMutableArcIteratorClass( + InitMutableArcIteratorClassArgs *args); + + private: + std::unique_ptr impl_; +}; + +template +void InitMutableArcIteratorClass(InitMutableArcIteratorClassArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + std::get<2>(*args)->impl_.reset( + new MutableArcIteratorClassImpl(fst, std::get<1>(*args))); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARCITERATOR_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/arcsort.h b/projects/llm_framework/include/fst/script/arcsort.h new file mode 100644 index 00000000..3e56fe5c --- /dev/null +++ b/projects/llm_framework/include/fst/script/arcsort.h @@ -0,0 +1,44 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ARCSORT_H_ +#define FST_SCRIPT_ARCSORT_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +enum ArcSortType { + ILABEL_SORT, + OLABEL_SORT +}; + +using ArcSortArgs = std::pair; + +template +void ArcSort(ArcSortArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + switch (std::get<1>(*args)) { + case ILABEL_SORT: { + const ILabelCompare icomp; + ArcSort(fst, icomp); + return; + } + case OLABEL_SORT: { + const OLabelCompare ocomp; + ArcSort(fst, ocomp); + return; + } + } +} + +void ArcSort(MutableFstClass *ofst, ArcSortType); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARCSORT_H_ diff --git a/projects/llm_framework/include/fst/script/arg-packs.h b/projects/llm_framework/include/fst/script/arg-packs.h new file mode 100644 index 00000000..93cd4a05 --- /dev/null +++ b/projects/llm_framework/include/fst/script/arg-packs.h @@ -0,0 +1,37 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// std::pair and std::tuple are used for the arguments of FstClass operations. +// +// If a function with a return value is required, use the WithReturnValue +// template as follows: +// +// WithReturnValue> + +#ifndef FST_SCRIPT_ARG_PACKS_H_ +#define FST_SCRIPT_ARG_PACKS_H_ + +#include + +namespace fst { +namespace script { + +// Tack this on to an existing type to add a return value. The syntax for +// accessing the args is then slightly more stilted, as you must do an extra +// member access (since the args are stored as a member of this class). + +template +struct WithReturnValue { + // Avoid reference-to-reference if ArgTuple is a reference. + using Args = typename std::remove_reference::type; + + Retval retval; + const Args &args; + + explicit WithReturnValue(const Args &args) : args(args) {} +}; + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ARG_PACKS_H_ diff --git a/projects/llm_framework/include/fst/script/closure.h b/projects/llm_framework/include/fst/script/closure.h new file mode 100644 index 00000000..7c68604a --- /dev/null +++ b/projects/llm_framework/include/fst/script/closure.h @@ -0,0 +1,28 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CLOSURE_H_ +#define FST_SCRIPT_CLOSURE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ClosureArgs = std::pair; + +template +void Closure(ClosureArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + Closure(fst, std::get<1>(*args)); +} + +void Closure(MutableFstClass *ofst, ClosureType closure_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CLOSURE_H_ diff --git a/projects/llm_framework/include/fst/script/compile-impl.h b/projects/llm_framework/include/fst/script/compile-impl.h new file mode 100644 index 00000000..943a0b72 --- /dev/null +++ b/projects/llm_framework/include/fst/script/compile-impl.h @@ -0,0 +1,217 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to to compile a binary FST from textual input. + +#ifndef FST_SCRIPT_COMPILE_IMPL_H_ +#define FST_SCRIPT_COMPILE_IMPL_H_ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Compile a binary Fst from textual input, helper class for fstcompile.cc +// WARNING: Stand-alone use of this class not recommended, most code should +// read/write using the binary format which is much more efficient. +template +class FstCompiler { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // WARNING: use of negative labels not recommended as it may cause conflicts. + // If add_symbols_ is true, then the symbols will be dynamically added to the + // symbol tables. This is only useful if you set the (i/o)keep flag to attach + // the final symbol table, or use the accessors. (The input symbol tables are + // const and therefore not changed.) + FstCompiler(std::istream &istrm, const string &source, // NOLINT + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, bool allow_negative_labels = false) { + std::unique_ptr misyms(isyms ? isyms->Copy() : nullptr); + std::unique_ptr mosyms(osyms ? osyms->Copy() : nullptr); + std::unique_ptr mssyms(ssyms ? ssyms->Copy() : nullptr); + Init(istrm, source, misyms.get(), mosyms.get(), mssyms.get(), accep, + ikeep, okeep, nkeep, allow_negative_labels, false); + } + + FstCompiler(std::istream &istrm, const string &source, // NOLINT + SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms, + bool accep, bool ikeep, bool okeep, bool nkeep, + bool allow_negative_labels, bool add_symbols) { + Init(istrm, source, isyms, osyms, ssyms, accep, ikeep, okeep, nkeep, + allow_negative_labels, add_symbols); + } + + void Init(std::istream &istrm, const string &source, // NOLINT + SymbolTable *isyms, SymbolTable *osyms, SymbolTable *ssyms, + bool accep, bool ikeep, bool okeep, bool nkeep, + bool allow_negative_labels, bool add_symbols) { + nline_ = 0; + source_ = source; + isyms_ = isyms; + osyms_ = osyms; + ssyms_ = ssyms; + nstates_ = 0; + keep_state_numbering_ = nkeep; + allow_negative_labels_ = allow_negative_labels; + add_symbols_ = add_symbols; + bool start_state_populated = false; + char line[kLineLen]; + const string separator = FLAGS_fst_field_separator + "\n"; + while (istrm.getline(line, kLineLen)) { + ++nline_; + std::vector col; + SplitString(line, separator.c_str(), &col, true); + if (col.empty() || col[0][0] == '\0') + continue; + if (col.size() > 5 || (col.size() > 4 && accep) || + (col.size() == 3 && !accep)) { + FSTERROR() << "FstCompiler: Bad number of columns, source = " << source_ + << ", line = " << nline_; + fst_.SetProperties(kError, kError); + return; + } + StateId s = StrToStateId(col[0]); + while (s >= fst_.NumStates()) fst_.AddState(); + if (!start_state_populated) { + fst_.SetStart(s); + start_state_populated = true; + } + + Arc arc; + StateId d = s; + switch (col.size()) { + case 1: + fst_.SetFinal(s, Weight::One()); + break; + case 2: + fst_.SetFinal(s, StrToWeight(col[1], true)); + break; + case 3: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + arc.olabel = arc.ilabel; + arc.weight = Weight::One(); + fst_.AddArc(s, arc); + break; + case 4: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + if (accep) { + arc.olabel = arc.ilabel; + arc.weight = StrToWeight(col[3], true); + } else { + arc.olabel = StrToOLabel(col[3]); + arc.weight = Weight::One(); + } + fst_.AddArc(s, arc); + break; + case 5: + arc.nextstate = d = StrToStateId(col[1]); + arc.ilabel = StrToILabel(col[2]); + arc.olabel = StrToOLabel(col[3]); + arc.weight = StrToWeight(col[4], true); + fst_.AddArc(s, arc); + } + while (d >= fst_.NumStates()) fst_.AddState(); + } + if (ikeep) fst_.SetInputSymbols(isyms); + if (okeep) fst_.SetOutputSymbols(osyms); + } + + const VectorFst &Fst() const { return fst_; } + + private: + // Maximum line length in text file. + static constexpr int kLineLen = 8096; + + StateId StrToId(const char *s, SymbolTable *syms, const char *name, + bool allow_negative = false) const { + StateId n = 0; + if (syms) { + n = (add_symbols_) ? syms->AddSymbol(s) : syms->Find(s); + if (n == -1 || (!allow_negative && n < 0)) { + FSTERROR() << "FstCompiler: Symbol \"" << s + << "\" is not mapped to any integer " << name + << ", symbol table = " << syms->Name() + << ", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + } + } else { + char *p; + n = strtoll(s, &p, 10); + if (p < s + strlen(s) || (!allow_negative && n < 0)) { + FSTERROR() << "FstCompiler: Bad " << name << " integer = \"" << s + << "\", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + } + } + return n; + } + + StateId StrToStateId(const char *s) { + StateId n = StrToId(s, ssyms_, "state ID"); + if (keep_state_numbering_) return n; + // Remaps state IDs to make dense set. + const auto it = states_.find(n); + if (it == states_.end()) { + states_[n] = nstates_; + return nstates_++; + } else { + return it->second; + } + } + + StateId StrToILabel(const char *s) const { + return StrToId(s, isyms_, "arc ilabel", allow_negative_labels_); + } + + StateId StrToOLabel(const char *s) const { + return StrToId(s, osyms_, "arc olabel", allow_negative_labels_); + } + + Weight StrToWeight(const char *s, bool allow_zero) const { + Weight w; + std::istringstream strm(s); + strm >> w; + if (!strm || (!allow_zero && w == Weight::Zero())) { + FSTERROR() << "FstCompiler: Bad weight = \"" << s + << "\", source = " << source_ << ", line = " << nline_; + fst_.SetProperties(kError, kError); + w = Weight::NoWeight(); + } + return w; + } + + mutable VectorFst fst_; + size_t nline_; + string source_; // Text FST source name. + SymbolTable *isyms_; // ilabel symbol table (not owned). + SymbolTable *osyms_; // olabel symbol table (not owned). + SymbolTable *ssyms_; // slabel symbol table (not owned). + std::unordered_map states_; // State ID map. + StateId nstates_; // Number of seen states. + bool keep_state_numbering_; + bool allow_negative_labels_; // Not recommended; may cause conflicts. + bool add_symbols_; // Add to symbol tables on-the fly. + + FstCompiler(const FstCompiler &) = delete; + FstCompiler &operator=(const FstCompiler &) = delete; +}; + +} // namespace fst + +#endif // FST_SCRIPT_COMPILE_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/compile.h b/projects/llm_framework/include/fst/script/compile.h new file mode 100644 index 00000000..c82ed477 --- /dev/null +++ b/projects/llm_framework/include/fst/script/compile.h @@ -0,0 +1,98 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_COMPILE_H_ +#define FST_SCRIPT_COMPILE_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +// This operation exists in two forms. 1 is a void operation which writes the +// compiled machine to disk; 2 returns an FstClass. I/O should normally be done +// using the binary format for efficiency, so users are STRONGLY ENCOURAGED to +// use 1 or to construct FSTs using the C++ FST mutation operations. + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct CompileFstInnerArgs { + std::istream &istrm; + const string &source; + const string &fst_type; + const fst::SymbolTable *isyms; + const fst::SymbolTable *osyms; + const fst::SymbolTable *ssyms; + const bool accep; + const bool ikeep; + const bool okeep; + const bool nkeep; + const bool allow_negative_labels; + + CompileFstInnerArgs(std::istream &istrm, const string &source, + const string &fst_type, const fst::SymbolTable *isyms, + const fst::SymbolTable *osyms, + const fst::SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, + bool allow_negative_labels = false) + : istrm(istrm), + source(source), + fst_type(fst_type), + isyms(isyms), + osyms(osyms), + ssyms(ssyms), + accep(accep), + ikeep(ikeep), + okeep(okeep), + nkeep(nkeep), + allow_negative_labels(allow_negative_labels) {} +}; + +using CompileFstArgs = WithReturnValue; + +template +void CompileFstInternal(CompileFstArgs *args) { + using fst::Convert; + using fst::Fst; + using fst::FstCompiler; + FstCompiler fstcompiler( + args->args.istrm, args->args.source, args->args.isyms, args->args.osyms, + args->args.ssyms, args->args.accep, args->args.ikeep, args->args.okeep, + args->args.nkeep, args->args.allow_negative_labels); + const Fst *fst = &fstcompiler.Fst(); + std::unique_ptr> owned_fst; + if (args->args.fst_type != "vector") { + owned_fst.reset(Convert(*fst, args->args.fst_type)); + if (!owned_fst) { + FSTERROR() << "Failed to convert FST to desired type: " + << args->args.fst_type; + } + fst = owned_fst.get(); + } + args->retval = fst ? new FstClass(*fst) : nullptr; +} + +void CompileFst(std::istream &istrm, const string &source, const string &dest, + const string &fst_type, const string &arc_type, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, bool okeep, + bool nkeep, bool allow_negative_labels); + +FstClass *CompileFstInternal(std::istream &istrm, const string &source, + const string &fst_type, const string &arc_type, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accep, bool ikeep, + bool okeep, bool nkeep, + bool allow_negative_labels); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_COMPILE_H_ diff --git a/projects/llm_framework/include/fst/script/compose.h b/projects/llm_framework/include/fst/script/compose.h new file mode 100644 index 00000000..a1735803 --- /dev/null +++ b/projects/llm_framework/include/fst/script/compose.h @@ -0,0 +1,34 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_COMPOSE_H_ +#define FST_SCRIPT_COMPOSE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ComposeArgs = std::tuple; + +template +void Compose(ComposeArgs *args) { + const Fst &ifst1 = *(std::get<0>(*args).GetFst()); + const Fst &ifst2 = *(std::get<1>(*args).GetFst()); + MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); + const auto &opts = std::get<3>(*args); + Compose(ifst1, ifst2, ofst, opts); +} + +void Compose(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = ComposeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_COMPOSE_H_ diff --git a/projects/llm_framework/include/fst/script/concat.h b/projects/llm_framework/include/fst/script/concat.h new file mode 100644 index 00000000..4bf8dc61 --- /dev/null +++ b/projects/llm_framework/include/fst/script/concat.h @@ -0,0 +1,40 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CONCAT_H_ +#define FST_SCRIPT_CONCAT_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ConcatArgs1 = std::pair; + +template +void Concat(ConcatArgs1 *args) { + MutableFst *ofst = std::get<0>(*args)->GetMutableFst(); + const Fst &ifst = *(std::get<1>(*args).GetFst()); + Concat(ofst, ifst); +} + +using ConcatArgs2 = std::pair; + +template +void Concat(ConcatArgs2 *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + Concat(ifst, ofst); +} + +void Concat(MutableFstClass *ofst, const FstClass &ifst); + +void Concat(const FstClass &ifst, MutableFstClass *ofst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONCAT_H_ diff --git a/projects/llm_framework/include/fst/script/connect.h b/projects/llm_framework/include/fst/script/connect.h new file mode 100644 index 00000000..030102ac --- /dev/null +++ b/projects/llm_framework/include/fst/script/connect.h @@ -0,0 +1,23 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CONNECT_H_ +#define FST_SCRIPT_CONNECT_H_ + +#include +#include + +namespace fst { +namespace script { + +template +void Connect(MutableFstClass *fst) { + Connect(fst->GetMutableFst()); +} + +void Connect(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONNECT_H_ diff --git a/projects/llm_framework/include/fst/script/convert.h b/projects/llm_framework/include/fst/script/convert.h new file mode 100644 index 00000000..1a6eeaa3 --- /dev/null +++ b/projects/llm_framework/include/fst/script/convert.h @@ -0,0 +1,35 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_CONVERT_H_ +#define FST_SCRIPT_CONVERT_H_ + +#include +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using ConvertInnerArgs = std::pair; + +using ConvertArgs = WithReturnValue; + +template +void Convert(ConvertArgs *args) { + const Fst &fst = *(std::get<0>(args->args).GetFst()); + const string &new_type = std::get<1>(args->args); + std::unique_ptr> result(Convert(fst, new_type)); + args->retval = result ? new FstClass(*result) : nullptr; +} + +FstClass *Convert(const FstClass &fst, const string &new_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_CONVERT_H_ diff --git a/projects/llm_framework/include/fst/script/decode.h b/projects/llm_framework/include/fst/script/decode.h new file mode 100644 index 00000000..09f25391 --- /dev/null +++ b/projects/llm_framework/include/fst/script/decode.h @@ -0,0 +1,49 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DECODE_H_ +#define FST_SCRIPT_DECODE_H_ + +#include +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using DecodeArgs1 = std::pair; + +template +void Decode(DecodeArgs1 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + std::unique_ptr> decoder( + EncodeMapper::Read(std::get<1>(*args), DECODE)); + if (!decoder) { + fst->SetProperties(kError, kError); + return; + } + Decode(fst, *decoder); +} + +using DecodeArgs2 = std::pair; + +template +void Decode(DecodeArgs2 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const EncodeMapper &encoder = + *(std::get<1>(*args).GetEncodeMapper()); + Decode(fst, encoder); +} + +void Decode(MutableFstClass *fst, const string &coder_fname); + +void Decode(MutableFstClass *fst, const EncodeMapperClass &encoder); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DECODE_H_ diff --git a/projects/llm_framework/include/fst/script/determinize.h b/projects/llm_framework/include/fst/script/determinize.h new file mode 100644 index 00000000..383a8fe4 --- /dev/null +++ b/projects/llm_framework/include/fst/script/determinize.h @@ -0,0 +1,59 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DETERMINIZE_H_ +#define FST_SCRIPT_DETERMINIZE_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +struct DeterminizeOptions { + const float delta; + const WeightClass &weight_threshold; + const int64 state_threshold; + const int64 subsequential_label; + const DeterminizeType det_type; + const bool increment_subsequential_label; + + DeterminizeOptions(float delta, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + int64 subsequential_label = 0, + DeterminizeType det_type = DETERMINIZE_FUNCTIONAL, + bool increment_subsequential_label = false) + : delta(delta), + weight_threshold(weight_threshold), + state_threshold(state_threshold), + subsequential_label(subsequential_label), + det_type(det_type), + increment_subsequential_label(increment_subsequential_label) {} +}; + +using DeterminizeArgs = std::tuple; + +template +void Determinize(DeterminizeArgs *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto &opts = std::get<2>(*args); + const auto weight_threshold = *(opts.weight_threshold.GetWeight()); + const fst::DeterminizeOptions detargs(opts.delta, weight_threshold, + opts.state_threshold, opts.subsequential_label, opts.det_type, + opts.increment_subsequential_label); + Determinize(ifst, ofst, detargs); +} + +void Determinize(const FstClass &ifst, MutableFstClass *ofst, + const DeterminizeOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DETERMINIZE_H_ diff --git a/projects/llm_framework/include/fst/script/difference.h b/projects/llm_framework/include/fst/script/difference.h new file mode 100644 index 00000000..7af6200a --- /dev/null +++ b/projects/llm_framework/include/fst/script/difference.h @@ -0,0 +1,35 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DIFFERENCE_H_ +#define FST_SCRIPT_DIFFERENCE_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using DifferenceArgs = std::tuple; + +template +void Difference(DifferenceArgs *args) { + const Fst &ifst1 = *(std::get<0>(*args).GetFst()); + const Fst &ifst2 = *(std::get<1>(*args).GetFst()); + MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); + const auto &opts = std::get<3>(*args); + Difference(ifst1, ifst2, ofst, opts); +} + +void Difference(const FstClass &ifst1, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = ComposeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DIFFERENCE_H_ diff --git a/projects/llm_framework/include/fst/script/disambiguate.h b/projects/llm_framework/include/fst/script/disambiguate.h new file mode 100644 index 00000000..acc1fba2 --- /dev/null +++ b/projects/llm_framework/include/fst/script/disambiguate.h @@ -0,0 +1,54 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DISAMBIGUATE_H_ +#define FST_SCRIPT_DISAMBIGUATE_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +struct DisambiguateOptions { + const float delta; + const WeightClass &weight_threshold; + const int64 state_threshold; + const int64 subsequential_label; + + DisambiguateOptions(float delta, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + int64 subsequential_label = 0) + : delta(delta), + weight_threshold(weight_threshold), + state_threshold(state_threshold), + subsequential_label(subsequential_label) {} +}; + +using DisambiguateArgs = std::tuple; + +template +void Disambiguate(DisambiguateArgs *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto &opts = std::get<2>(*args); + const auto weight_threshold = *(opts.weight_threshold.GetWeight()); + const fst::DisambiguateOptions disargs(opts.delta, weight_threshold, + opts.state_threshold, + opts.subsequential_label); + Disambiguate(ifst, ofst, disargs); +} + +void Disambiguate(const FstClass &ifst, MutableFstClass *ofst, + const DisambiguateOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DISAMBIGUATE_H_ diff --git a/projects/llm_framework/include/fst/script/draw-impl.h b/projects/llm_framework/include/fst/script/draw-impl.h new file mode 100644 index 00000000..f204b2e6 --- /dev/null +++ b/projects/llm_framework/include/fst/script/draw-impl.h @@ -0,0 +1,227 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to draw a binary FST by producing a text file in dot format, a helper +// class to fstdraw.cc. + +#ifndef FST_SCRIPT_DRAW_IMPL_H_ +#define FST_SCRIPT_DRAW_IMPL_H_ + +#include +#include +#include + +#include +#include +#include + +namespace fst { + +// Print a binary FST in GraphViz textual format (helper class for fstdraw.cc). +// WARNING: Stand-alone use not recommend. +template +class FstDrawer { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + FstDrawer(const Fst &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + const string &title, float width, float height, bool portrait, + bool vertical, float ranksep, float nodesep, int fontsize, + int precision, const string &float_format, bool show_weight_one) + : fst_(fst), + isyms_(isyms), + osyms_(osyms), + ssyms_(ssyms), + accep_(accep && fst.Properties(kAcceptor, true)), + ostrm_(nullptr), + title_(title), + width_(width), + height_(height), + portrait_(portrait), + vertical_(vertical), + ranksep_(ranksep), + nodesep_(nodesep), + fontsize_(fontsize), + precision_(precision), + float_format_(float_format), + show_weight_one_(show_weight_one) {} + + // Draws FST to an output buffer. + void Draw(std::ostream *strm, const string &dest) { + ostrm_ = strm; + SetStreamState(ostrm_); + dest_ = dest; + StateId start = fst_.Start(); + if (start == kNoStateId) return; + PrintString("digraph FST {\n"); + if (vertical_) { + PrintString("rankdir = BT;\n"); + } else { + PrintString("rankdir = LR;\n"); + } + PrintString("size = \""); + Print(width_); + PrintString(","); + Print(height_); + PrintString("\";\n"); + if (!dest_.empty()) PrintString("label = \"" + title_ + "\";\n"); + PrintString("center = 1;\n"); + if (portrait_) { + PrintString("orientation = Portrait;\n"); + } else { + PrintString("orientation = Landscape;\n"); + } + PrintString("ranksep = \""); + Print(ranksep_); + PrintString("\";\n"); + PrintString("nodesep = \""); + Print(nodesep_); + PrintString("\";\n"); + // Initial state first. + DrawState(start); + for (StateIterator> siter(fst_); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (s != start) DrawState(s); + } + PrintString("}\n"); + } + + private: + void SetStreamState(std::ostream* strm) const { + strm->precision(precision_); + if (float_format_ == "e") + strm->setf(std::ios_base::scientific, std::ios_base::floatfield); + if (float_format_ == "f") + strm->setf(std::ios_base::fixed, std::ios_base::floatfield); + // O.w. defaults to "g" per standard lib. + } + + void PrintString(const string &str) const { *ostrm_ << str; } + + // Escapes backslash and double quote if these occur in the string. Dot will + // not deal gracefully with these if they are not escaped. + static string Escape(const string &str) { + string ns; + for (char c : str) { + if (c == '\\' || c == '"') ns.push_back('\\'); + ns.push_back(c); + } + return ns; + } + + void PrintId(StateId id, const SymbolTable *syms, const char *name) const { + if (syms) { + auto symbol = syms->Find(id); + if (symbol.empty()) { + FSTERROR() << "FstDrawer: Integer " << id + << " is not mapped to any textual symbol" + << ", symbol table = " << syms->Name() + << ", destination = " << dest_; + symbol = "?"; + } + PrintString(Escape(symbol)); + } else { + PrintString(std::to_string(id)); + } + } + + void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); } + + void PrintILabel(Label label) const { + PrintId(label, isyms_, "arc input label"); + } + + void PrintOLabel(Label label) const { + PrintId(label, osyms_, "arc output label"); + } + + void PrintWeight(Weight w) const { + // Weight may have double quote characters in it, so escape it. + PrintString(Escape(ToString(w))); + } + + template + void Print(T t) const { *ostrm_ << t; } + + template + string ToString(T t) const { + std::stringstream ss; + SetStreamState(&ss); + ss << t; + return ss.str(); + } + + void DrawState(StateId s) const { + Print(s); + PrintString(" [label = \""); + PrintStateId(s); + const auto weight = fst_.Final(s); + if (weight != Weight::Zero()) { + if (show_weight_one_ || (weight != Weight::One())) { + PrintString("/"); + PrintWeight(weight); + } + PrintString("\", shape = doublecircle,"); + } else { + PrintString("\", shape = circle,"); + } + if (s == fst_.Start()) { + PrintString(" style = bold,"); + } else { + PrintString(" style = solid,"); + } + PrintString(" fontsize = "); + Print(fontsize_); + PrintString("]\n"); + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + PrintString("\t"); + Print(s); + PrintString(" -> "); + Print(arc.nextstate); + PrintString(" [label = \""); + PrintILabel(arc.ilabel); + if (!accep_) { + PrintString(":"); + PrintOLabel(arc.olabel); + } + if (show_weight_one_ || (arc.weight != Weight::One())) { + PrintString("/"); + PrintWeight(arc.weight); + } + PrintString("\", fontsize = "); + Print(fontsize_); + PrintString("];\n"); + } + } + + const Fst &fst_; + const SymbolTable *isyms_; // ilabel symbol table. + const SymbolTable *osyms_; // olabel symbol table. + const SymbolTable *ssyms_; // slabel symbol table. + bool accep_; // Print as acceptor when possible. + std::ostream *ostrm_; // Drawn FST destination. + string dest_; // Drawn FST destination name. + + string title_; + float width_; + float height_; + bool portrait_; + bool vertical_; + float ranksep_; + float nodesep_; + int fontsize_; + int precision_; + string float_format_; + bool show_weight_one_; + + FstDrawer(const FstDrawer &) = delete; + FstDrawer &operator=(const FstDrawer &) = delete; +}; + +} // namespace fst + +#endif // FST_SCRIPT_DRAW_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/draw.h b/projects/llm_framework/include/fst/script/draw.h new file mode 100644 index 00000000..cb37df1e --- /dev/null +++ b/projects/llm_framework/include/fst/script/draw.h @@ -0,0 +1,85 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_DRAW_H_ +#define FST_SCRIPT_DRAW_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstDrawerArgs { + const FstClass &fst; + const SymbolTable *isyms; + const SymbolTable *osyms; + const SymbolTable *ssyms; + const bool accep; + const string &title; + const float width; + const float height; + const bool portrait; + const bool vertical; + const float ranksep; + const float nodesep; + const int fontsize; + const int precision; + const string &float_format; // NOLINT + const bool show_weight_one; + std::ostream *ostrm; + const string &dest; + + FstDrawerArgs(const FstClass &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + const string &title, float width, float height, bool portrait, + bool vertical, float ranksep, float nodesep, int fontsize, + int precision, const string &float_format, + bool show_weight_one, std::ostream *ostrm, const string &dest) + : fst(fst), + isyms(isyms), + osyms(osyms), + ssyms(ssyms), + accep(accep), + title(title), + width(width), + height(height), + portrait(portrait), + vertical(vertical), + ranksep(ranksep), + nodesep(nodesep), + fontsize(fontsize), + precision(precision), + float_format(float_format), + show_weight_one(show_weight_one), + ostrm(ostrm), + dest(dest) {} +}; + +template +void DrawFst(FstDrawerArgs *args) { + const Fst &fst = *(args->fst.GetFst()); + FstDrawer fstdrawer(fst, args->isyms, args->osyms, args->ssyms, + args->accep, args->title, args->width, args->height, args->portrait, + args->vertical, args->ranksep, args->nodesep, args->fontsize, + args->precision, args->float_format, args->show_weight_one); + fstdrawer.Draw(args->ostrm, args->dest); +} + +void DrawFst(const FstClass &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + const string &title, float width, float height, bool portrait, + bool vertical, float ranksep, float nodesep, int fontsize, + int precision, const string &float_format, bool show_weight_one, + std::ostream *ostrm, const string &dest); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_DRAW_H_ diff --git a/projects/llm_framework/include/fst/script/encode.h b/projects/llm_framework/include/fst/script/encode.h new file mode 100644 index 00000000..6a869680 --- /dev/null +++ b/projects/llm_framework/include/fst/script/encode.h @@ -0,0 +1,51 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ENCODE_H_ +#define FST_SCRIPT_ENCODE_H_ + +#include +#include +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using EncodeArgs1 = std::tuple; + +template +void Encode(EncodeArgs1 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const string &coder_fname = std::get<3>(*args); + // If true, reuse encode from disk. If false, make a new encoder and just use + // the filename argument as the destination state. + std::unique_ptr> encoder( + std::get<2>(*args) ? EncodeMapper::Read(coder_fname, ENCODE) + : new EncodeMapper(std::get<1>(*args), ENCODE)); + Encode(fst, encoder.get()); + if (!std::get<2>(*args)) encoder->Write(coder_fname); +} + +using EncodeArgs2 = std::pair; + +template +void Encode(EncodeArgs2 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + EncodeMapper *encoder = std::get<1>(*args)->GetEncodeMapper(); + Encode(fst, encoder); +} + +void Encode(MutableFstClass *fst, uint32 flags, bool reuse_encoder, + const string &coder_fname); + +void Encode(MutableFstClass *fst, EncodeMapperClass *encoder); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ENCODE_H_ diff --git a/projects/llm_framework/include/fst/script/encodemapper-class.h b/projects/llm_framework/include/fst/script/encodemapper-class.h new file mode 100644 index 00000000..b62824f1 --- /dev/null +++ b/projects/llm_framework/include/fst/script/encodemapper-class.h @@ -0,0 +1,169 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ENCODEMAPPER_CLASS_H_ +#define FST_SCRIPT_ENCODEMAPPER_CLASS_H_ + +#include +#include +#include + +#include +#include +#include + +// Scripting API support for EncodeMapper. + +namespace fst { +namespace script { + +// Virtual interface implemented by each concrete EncodeMapperClassImpl. +class EncodeMapperImplBase { + public: + // Returns an encoded ArcClass. + virtual ArcClass operator()(const ArcClass &a) = 0; + virtual const string &ArcType() const = 0; + virtual uint32 Flags() const = 0; + virtual uint64 Properties(uint64 inprops) = 0; + virtual EncodeType Type() const = 0; + virtual const SymbolTable *InputSymbols() const = 0; + virtual const SymbolTable *OutputSymbols() const = 0; + virtual void SetInputSymbols(const SymbolTable *syms) = 0; + virtual void SetOutputSymbols(const SymbolTable *syms) = 0; + virtual const string &WeightType() const = 0; + virtual ~EncodeMapperImplBase() {} +}; + +// Templated implementation. +template +class EncodeMapperClassImpl : public EncodeMapperImplBase { + public: + EncodeMapperClassImpl(uint32 flags, EncodeType type) + : encoder_(flags, type) {} + + ArcClass operator()(const ArcClass &a) final; + + const string &ArcType() const final { return Arc::Type(); } + + uint32 Flags() const final { return encoder_.Flags(); } + + uint64 Properties(uint64 inprops) final { + return encoder_.Properties(inprops); + } + + EncodeType Type() const final { return encoder_.Type(); } + + const SymbolTable *InputSymbols() const final { + return encoder_.InputSymbols(); + } + + const SymbolTable *OutputSymbols() const final { + return encoder_.OutputSymbols(); + } + + void SetInputSymbols(const SymbolTable *syms) final { + encoder_.SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable *syms) final { + encoder_.SetOutputSymbols(syms); + } + + const string &WeightType() const final { return Arc::Weight::Type(); } + + ~EncodeMapperClassImpl() override {} + + EncodeMapper *GetImpl() const { return &encoder_; } + + EncodeMapper *GetImpl() { return &encoder_; } + + private: + EncodeMapper encoder_; +}; + +// This is returned by value because it is very likely to undergo return-value +// optimization. +template +inline ArcClass EncodeMapperClassImpl::operator()(const ArcClass &a) { + Arc arc(a.ilabel, a.olabel, *(a.weight.GetWeight()), + a.nextstate); + return ArcClass(encoder_(arc)); +} + +class EncodeMapperClass; + +using InitEncodeMapperClassArgs = + std::tuple; + +class EncodeMapperClass { + public: + EncodeMapperClass(const string &arc_type, uint32 flags, EncodeType type); + + template + EncodeMapperClass(uint32 flags, EncodeType type) + : impl_(new EncodeMapperClassImpl(flags, type)) {} + + ArcClass operator()(const ArcClass &arc) { return (*impl_)(arc); } + + const string &ArcType() const { return impl_->ArcType(); } + + uint32 Flags() const { return impl_->Flags(); } + + uint64 Properties(uint64 inprops) { return impl_->Properties(inprops); } + + EncodeType Type() const { return impl_->Type(); } + + const SymbolTable *InputSymbols() const { return impl_->InputSymbols(); } + + const SymbolTable *OutputSymbols() const { return impl_->OutputSymbols(); } + + void SetInputSymbols(const SymbolTable *syms) { + impl_->SetInputSymbols(syms); + } + + void SetOutputSymbols(const SymbolTable *syms) { + impl_->SetOutputSymbols(syms); + } + + const string &WeightType() const { return impl_->WeightType(); } + + template + friend void InitEncodeMapperClass(InitEncodeMapperClassArgs *args); + + // Naturally, this exists in non-const and const forms. Encoding arcs or FSTs + // mutates the underlying encoder; decoding them does not. + + template + EncodeMapper *GetEncodeMapper() { + if (Arc::Type() != ArcType()) { + return nullptr; + } else { + auto *typed_impl = static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + template + const EncodeMapper *GetEncodeMapper() const { + if (Arc::Type() != ArcType()) { + return nullptr; + } else { + auto *typed_impl = static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + private: + std::unique_ptr impl_; +}; + +template +void InitEncodeMapperClass(InitEncodeMapperClassArgs *args) { + std::get<2>(*args)->impl_.reset( + new EncodeMapperClassImpl(std::get<0>(*args), std::get<1>(*args))); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ENCODEMAPPER_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/epsnormalize.h b/projects/llm_framework/include/fst/script/epsnormalize.h new file mode 100644 index 00000000..b55fefae --- /dev/null +++ b/projects/llm_framework/include/fst/script/epsnormalize.h @@ -0,0 +1,31 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_EPSNORMALIZE_H_ +#define FST_SCRIPT_EPSNORMALIZE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using EpsNormalizeArgs = std::tuple; + +template +void EpsNormalize(EpsNormalizeArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + EpsNormalize(ifst, ofst, std::get<2>(*args)); +} + +void EpsNormalize(const FstClass &ifst, MutableFstClass *ofst, + EpsNormalizeType norm_type = EPS_NORM_INPUT); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EPSNORMALIZE_H_ diff --git a/projects/llm_framework/include/fst/script/equal.h b/projects/llm_framework/include/fst/script/equal.h new file mode 100644 index 00000000..79ea9aa4 --- /dev/null +++ b/projects/llm_framework/include/fst/script/equal.h @@ -0,0 +1,32 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_EQUAL_H_ +#define FST_SCRIPT_EQUAL_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using EqualInnerArgs = std::tuple; + +using EqualArgs = WithReturnValue; + +template +void Equal(EqualArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + args->retval = Equal(fst1, fst2, std::get<2>(args->args)); +} + +bool Equal(const FstClass &fst1, const FstClass &fst2, float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EQUAL_H_ diff --git a/projects/llm_framework/include/fst/script/equivalent.h b/projects/llm_framework/include/fst/script/equivalent.h new file mode 100644 index 00000000..7cdff45e --- /dev/null +++ b/projects/llm_framework/include/fst/script/equivalent.h @@ -0,0 +1,34 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_EQUIVALENT_H_ +#define FST_SCRIPT_EQUIVALENT_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using EquivalentInnerArgs = std::tuple; + +using EquivalentArgs = WithReturnValue; + +template +void Equivalent(EquivalentArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + args->retval = Equivalent(fst1, fst2, std::get<2>(args->args)); +} + +bool Equivalent(const FstClass &fst1, const FstClass &fst2, + float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_EQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/script/fst-class.h b/projects/llm_framework/include/fst/script/fst-class.h new file mode 100644 index 00000000..07319fc7 --- /dev/null +++ b/projects/llm_framework/include/fst/script/fst-class.h @@ -0,0 +1,530 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_FST_CLASS_H_ +#define FST_SCRIPT_FST_CLASS_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +// Classes to support "boxing" all existing types of FST arcs in a single +// FstClass which hides the arc types. This allows clients to load +// and work with FSTs without knowing the arc type. These classes are only +// recommended for use in high-level scripting applications. Most users should +// use the lower-level templated versions corresponding to these classes. + +namespace fst { +namespace script { + +// Abstract base class defining the set of functionalities implemented in all +// impls and passed through by all bases. Below FstClassBase the class +// hierarchy bifurcates; FstClassImplBase serves as the base class for all +// implementations (of which FstClassImpl is currently the only one) and +// FstClass serves as the base class for all interfaces. + +class FstClassBase { + public: + virtual const string &ArcType() const = 0; + virtual WeightClass Final(int64) const = 0; + virtual const string &FstType() const = 0; + virtual const SymbolTable *InputSymbols() const = 0; + virtual size_t NumArcs(int64) const = 0; + virtual size_t NumInputEpsilons(int64) const = 0; + virtual size_t NumOutputEpsilons(int64) const = 0; + virtual const SymbolTable *OutputSymbols() const = 0; + virtual uint64 Properties(uint64, bool) const = 0; + virtual int64 Start() const = 0; + virtual const string &WeightType() const = 0; + virtual bool ValidStateId(int64) const = 0; + virtual bool Write(const string &) const = 0; + virtual bool Write(std::ostream &, const string &) const = 0; + virtual ~FstClassBase() {} +}; + +// Adds all the MutableFst methods. +class FstClassImplBase : public FstClassBase { + public: + virtual bool AddArc(int64, const ArcClass &) = 0; + virtual int64 AddState() = 0; + virtual FstClassImplBase *Copy() = 0; + virtual bool DeleteArcs(int64, size_t) = 0; + virtual bool DeleteArcs(int64) = 0; + virtual bool DeleteStates(const std::vector &) = 0; + virtual void DeleteStates() = 0; + virtual SymbolTable *MutableInputSymbols() = 0; + virtual SymbolTable *MutableOutputSymbols() = 0; + virtual int64 NumStates() const = 0; + virtual bool ReserveArcs(int64, size_t) = 0; + virtual void ReserveStates(int64) = 0; + virtual void SetInputSymbols(SymbolTable *) = 0; + virtual bool SetFinal(int64, const WeightClass &) = 0; + virtual void SetOutputSymbols(SymbolTable *) = 0; + virtual void SetProperties(uint64, uint64) = 0; + virtual bool SetStart(int64) = 0; + ~FstClassImplBase() override {} +}; + +// Containiner class wrapping an Fst, hiding its arc type. Whether this +// Fst pointer refers to a special kind of FST (e.g. a MutableFst) is +// known by the type of interface class that owns the pointer to this +// container. + +template +class FstClassImpl : public FstClassImplBase { + public: + explicit FstClassImpl(Fst *impl, bool should_own = false) + : impl_(should_own ? impl : impl->Copy()) {} + + explicit FstClassImpl(const Fst &impl) : impl_(impl.Copy()) {} + + // Warning: calling this method casts the FST to a mutable FST. + bool AddArc(int64 s, const ArcClass &ac) final { + if (!ValidStateId(s)) return false; + // Note that we do not check that the destination state is valid, so users + // can add arcs before they add the corresponding states. Verify can be + // used to determine whether any arc has a nonexisting destination. + Arc arc(ac.ilabel, ac.olabel, *ac.weight.GetWeight(), + ac.nextstate); + static_cast *>(impl_.get())->AddArc(s, arc); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + int64 AddState() final { + return static_cast *>(impl_.get())->AddState(); + } + + const string &ArcType() const final { return Arc::Type(); } + + FstClassImpl *Copy() final { return new FstClassImpl(impl_.get()); } + + // Warning: calling this method casts the FST to a mutable FST. + bool DeleteArcs(int64 s, size_t n) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->DeleteArcs(s, n); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + bool DeleteArcs(int64 s) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->DeleteArcs(s); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + bool DeleteStates(const std::vector &dstates) final { + for (const auto &state : dstates) + if (!ValidStateId(state)) return false; + // Warning: calling this method with any integers beyond the precision of + // the underlying FST will result in truncation. + std::vector typed_dstates(dstates.size()); + std::copy(dstates.begin(), dstates.end(), typed_dstates.begin()); + static_cast *>(impl_.get())->DeleteStates(typed_dstates); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + void DeleteStates() final { + static_cast *>(impl_.get())->DeleteStates(); + } + + WeightClass Final(int64 s) const final { + if (!ValidStateId(s)) return WeightClass::NoWeight(WeightType()); + WeightClass w(impl_->Final(s)); + return w; + } + + const string &FstType() const final { return impl_->Type(); } + + const SymbolTable *InputSymbols() const final { + return impl_->InputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + SymbolTable *MutableInputSymbols() final { + return static_cast *>(impl_.get())->MutableInputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + SymbolTable *MutableOutputSymbols() final { + return static_cast *>(impl_.get())->MutableOutputSymbols(); + } + + // Signals failure by returning size_t max. + size_t NumArcs(int64 s) const final { + return ValidStateId(s) ? impl_->NumArcs(s) + : std::numeric_limits::max(); + } + + // Signals failure by returning size_t max. + size_t NumInputEpsilons(int64 s) const final { + return ValidStateId(s) ? impl_->NumInputEpsilons(s) + : std::numeric_limits::max(); + } + + // Signals failure by returning size_t max. + size_t NumOutputEpsilons(int64 s) const final { + return ValidStateId(s) ? impl_->NumOutputEpsilons(s) + : std::numeric_limits::max(); + } + + // Warning: calling this method casts the FST to a mutable FST. + int64 NumStates() const final { + return static_cast *>(impl_.get())->NumStates(); + } + + uint64 Properties(uint64 mask, bool test) const final { + return impl_->Properties(mask, test); + } + + // Warning: calling this method casts the FST to a mutable FST. + bool ReserveArcs(int64 s, size_t n) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->ReserveArcs(s, n); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + void ReserveStates(int64 s) final { + static_cast *>(impl_.get())->ReserveStates(s); + } + + const SymbolTable *OutputSymbols() const final { + return impl_->OutputSymbols(); + } + + // Warning: calling this method casts the FST to a mutable FST. + void SetInputSymbols(SymbolTable *isyms) final { + static_cast *>(impl_.get())->SetInputSymbols(isyms); + } + + // Warning: calling this method casts the FST to a mutable FST. + bool SetFinal(int64 s, const WeightClass &weight) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get()) + ->SetFinal(s, *weight.GetWeight()); + return true; + } + + // Warning: calling this method casts the FST to a mutable FST. + void SetOutputSymbols(SymbolTable *osyms) final { + static_cast *>(impl_.get())->SetOutputSymbols(osyms); + } + + // Warning: calling this method casts the FST to a mutable FST. + void SetProperties(uint64 props, uint64 mask) final { + static_cast *>(impl_.get())->SetProperties(props, mask); + } + + // Warning: calling this method casts the FST to a mutable FST. + bool SetStart(int64 s) final { + if (!ValidStateId(s)) return false; + static_cast *>(impl_.get())->SetStart(s); + return true; + } + + int64 Start() const final { return impl_->Start(); } + + bool ValidStateId(int64 s) const final { + // This cowardly refuses to count states if the FST is not yet expanded. + if (!Properties(kExpanded, true)) { + FSTERROR() << "Cannot get number of states for unexpanded FST"; + return false; + } + // If the FST is already expanded, CountStates calls NumStates. + if (s < 0 || s >= CountStates(*impl_)) { + FSTERROR() << "State ID " << s << " not valid"; + return false; + } + return true; + } + + const string &WeightType() const final { return Arc::Weight::Type(); } + + bool Write(const string &fname) const final { return impl_->Write(fname); } + + bool Write(std::ostream &ostr, const string &fname) const final { + const FstWriteOptions opts(fname); + return impl_->Write(ostr, opts); + } + + ~FstClassImpl() override {} + + Fst *GetImpl() const { return impl_.get(); } + + private: + std::unique_ptr> impl_; +}; + +// BASE CLASS DEFINITIONS + +class MutableFstClass; + +class FstClass : public FstClassBase { + public: + FstClass() : impl_(nullptr) {} + + template + explicit FstClass(const Fst &fst) : impl_(new FstClassImpl(fst)) {} + + FstClass(const FstClass &other) + : impl_(other.impl_ == nullptr ? nullptr : other.impl_->Copy()) {} + + FstClass &operator=(const FstClass &other) { + impl_.reset(other.impl_ == nullptr ? nullptr : other.impl_->Copy()); + return *this; + } + + WeightClass Final(int64 s) const final { return impl_->Final(s); } + + const string &ArcType() const final { return impl_->ArcType(); } + + const string &FstType() const final { return impl_->FstType(); } + + const SymbolTable *InputSymbols() const final { + return impl_->InputSymbols(); + } + + size_t NumArcs(int64 s) const final { return impl_->NumArcs(s); } + + size_t NumInputEpsilons(int64 s) const final { + return impl_->NumInputEpsilons(s); + } + + size_t NumOutputEpsilons(int64 s) const final { + return impl_->NumOutputEpsilons(s); + } + + const SymbolTable *OutputSymbols() const final { + return impl_->OutputSymbols(); + } + + uint64 Properties(uint64 mask, bool test) const final { + // Special handling for FSTs with a null impl. + if (!impl_) return kError & mask; + return impl_->Properties(mask, test); + } + + static FstClass *Read(const string &fname); + + static FstClass *Read(std::istream &istrm, const string &source); + + int64 Start() const final { return impl_->Start(); } + + bool ValidStateId(int64 s) const final { return impl_->ValidStateId(s); } + + const string &WeightType() const final { return impl_->WeightType(); } + + // Helper that logs an ERROR if the weight type of an FST and a WeightClass + // don't match. + + bool WeightTypesMatch(const WeightClass &weight, const string &op_name) const; + + bool Write(const string &fname) const final { return impl_->Write(fname); } + + bool Write(std::ostream &ostr, const string &fname) const final { + return impl_->Write(ostr, fname); + } + + ~FstClass() override {} + + // These methods are required by IO registration. + + template + static FstClassImplBase *Convert(const FstClass &other) { + FSTERROR() << "Doesn't make sense to convert any class to type FstClass"; + return nullptr; + } + + template + static FstClassImplBase *Create() { + FSTERROR() << "Doesn't make sense to create an FstClass with a " + << "particular arc type"; + return nullptr; + } + + template + const Fst *GetFst() const { + if (Arc::Type() != ArcType()) { + return nullptr; + } else { + FstClassImpl *typed_impl = + static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + template + static FstClass *Read(std::istream &stream, const FstReadOptions &opts) { + if (!opts.header) { + LOG(ERROR) << "FstClass::Read: Options header not specified"; + return nullptr; + } + const FstHeader &hdr = *opts.header; + if (hdr.Properties() & kMutable) { + return ReadTypedFst>(stream, opts); + } else { + return ReadTypedFst>(stream, opts); + } + } + + protected: + explicit FstClass(FstClassImplBase *impl) : impl_(impl) {} + + const FstClassImplBase *GetImpl() const { return impl_.get(); } + + FstClassImplBase *GetImpl() { return impl_.get(); } + + // Generic template method for reading an arc-templated FST of type + // UnderlyingT, and returning it wrapped as FstClassT, with appropriat + // error checking. Called from arc-templated Read() static methods. + template + static FstClassT *ReadTypedFst(std::istream &stream, + const FstReadOptions &opts) { + std::unique_ptr u(UnderlyingT::Read(stream, opts)); + return u ? new FstClassT(*u) : nullptr; + } + + private: + std::unique_ptr impl_; +}; + +// Specific types of FstClass with special properties + +class MutableFstClass : public FstClass { + public: + bool AddArc(int64 s, const ArcClass &ac) { + if (!WeightTypesMatch(ac.weight, "AddArc")) return false; + return GetImpl()->AddArc(s, ac); + } + + int64 AddState() { return GetImpl()->AddState(); } + + bool DeleteArcs(int64 s, size_t n) { return GetImpl()->DeleteArcs(s, n); } + + bool DeleteArcs(int64 s) { return GetImpl()->DeleteArcs(s); } + + bool DeleteStates(const std::vector &dstates) { + return GetImpl()->DeleteStates(dstates); + } + + void DeleteStates() { GetImpl()->DeleteStates(); } + + SymbolTable *MutableInputSymbols() { + return GetImpl()->MutableInputSymbols(); + } + + SymbolTable *MutableOutputSymbols() { + return GetImpl()->MutableOutputSymbols(); + } + + int64 NumStates() const { return GetImpl()->NumStates(); } + + bool ReserveArcs(int64 s, size_t n) { return GetImpl()->ReserveArcs(s, n); } + + void ReserveStates(int64 s) { GetImpl()->ReserveStates(s); } + + static MutableFstClass *Read(const string &fname, bool convert = false); + + void SetInputSymbols(SymbolTable *isyms) { + GetImpl()->SetInputSymbols(isyms); + } + + bool SetFinal(int64 s, const WeightClass &weight) { + if (!WeightTypesMatch(weight, "SetFinal")) return false; + return GetImpl()->SetFinal(s, weight); + } + + void SetOutputSymbols(SymbolTable *osyms) { + GetImpl()->SetOutputSymbols(osyms); + } + + void SetProperties(uint64 props, uint64 mask) { + GetImpl()->SetProperties(props, mask); + } + + bool SetStart(int64 s) { return GetImpl()->SetStart(s); } + + template + explicit MutableFstClass(const MutableFst &fst) : FstClass(fst) {} + + // These methods are required by IO registration. + + template + static FstClassImplBase *Convert(const FstClass &other) { + FSTERROR() << "Doesn't make sense to convert any class to type " + << "MutableFstClass"; + return nullptr; + } + + template + static FstClassImplBase *Create() { + FSTERROR() << "Doesn't make sense to create a MutableFstClass with a " + << "particular arc type"; + return nullptr; + } + + template + MutableFst *GetMutableFst() { + Fst *fst = const_cast *>(this->GetFst()); + MutableFst *mfst = static_cast *>(fst); + return mfst; + } + + template + static MutableFstClass *Read(std::istream &stream, + const FstReadOptions &opts) { + std::unique_ptr> mfst(MutableFst::Read(stream, opts)); + return mfst ? new MutableFstClass(*mfst) : nullptr; + } + + protected: + explicit MutableFstClass(FstClassImplBase *impl) : FstClass(impl) {} +}; + +class VectorFstClass : public MutableFstClass { + public: + explicit VectorFstClass(FstClassImplBase *impl) : MutableFstClass(impl) {} + + explicit VectorFstClass(const FstClass &other); + + explicit VectorFstClass(const string &arc_type); + + static VectorFstClass *Read(const string &fname); + + template + static VectorFstClass *Read(std::istream &stream, + const FstReadOptions &opts) { + std::unique_ptr> mfst(VectorFst::Read(stream, opts)); + return mfst ? new VectorFstClass(*mfst) : nullptr; + } + + template + explicit VectorFstClass(const VectorFst &fst) : MutableFstClass(fst) {} + + template + static FstClassImplBase *Convert(const FstClass &other) { + return new FstClassImpl(new VectorFst(*other.GetFst()), + true); + } + + template + static FstClassImplBase *Create() { + return new FstClassImpl(new VectorFst(), true); + } +}; + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_FST_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/fstscript-decl.h b/projects/llm_framework/include/fst/script/fstscript-decl.h new file mode 100644 index 00000000..294d0159 --- /dev/null +++ b/projects/llm_framework/include/fst/script/fstscript-decl.h @@ -0,0 +1,32 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Forward declarations for the FST and FST script classes. + +#ifndef FST_SCRIPT_FSTSCRIPT_DECL_H_ +#define FST_SCRIPT_FSTSCRIPT_DECL_H_ + +#include + +namespace fst { +namespace script { + +class ArcClass; + +class ArcIteratorClass; +class MutableArcIteratorClass; + +class EncodeMapperClass; + +class FstClass; +class MutableFstClass; +class VectorFstClass; + +class StateIteratorClass; + +class WeightClass; + +} // namespace script +} // namespace fst; + +#endif // FST_SCRIPT_FSTSCRIPT_DECL_H_ diff --git a/projects/llm_framework/include/fst/script/fstscript.h b/projects/llm_framework/include/fst/script/fstscript.h new file mode 100644 index 00000000..45f16175 --- /dev/null +++ b/projects/llm_framework/include/fst/script/fstscript.h @@ -0,0 +1,155 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// The FST script interface permits users to interact with FSTs without knowing +// their arc type. It does this by mapping compile-time polymorphism (in the +// form of a arc-templated FST types) onto a shared virtual interface. It also +// supports arc extension via a DSO interface. Due to the overhead of virtual +// dispatch and registered function lookups, the script API is somewhat slower +// then library API provided by types like StdVectorFst, but has the advantage +// that it is designed not to crash (and to provide useful debugging +// information) upon common user errors like passing invalid indices or +// attempting comparison of incompatible FSTs. It is used both by the FST +// binaries and the Python extension. +// +// This header includes all of the FST script functionality. + +#ifndef FST_SCRIPT_FSTSCRIPT_H_ +#define FST_SCRIPT_FSTSCRIPT_H_ + +// Major classes +#include +#include +#include +#include +#include +#include + +// Flag-to-enum parsers. +#include +// Templates like Operation<> and Apply<>. +#include + +// Operations. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// This class is necessary because registering each of the operations +// separately overfills the stack, as there's so many of them. +namespace fst { +namespace script { + +template +class AllFstOperationsRegisterer { + public: + AllFstOperationsRegisterer() { + RegisterBatch1(); + RegisterBatch2(); + } + + private: + void RegisterBatch1() { + REGISTER_FST_OPERATION(ArcSort, Arc, ArcSortArgs); + REGISTER_FST_OPERATION(Closure, Arc, ClosureArgs); + REGISTER_FST_OPERATION(CompileFstInternal, Arc, CompileFstArgs); + REGISTER_FST_OPERATION(Compose, Arc, ComposeArgs); + REGISTER_FST_OPERATION(Concat, Arc, ConcatArgs1); + REGISTER_FST_OPERATION(Concat, Arc, ConcatArgs2); + REGISTER_FST_OPERATION(Connect, Arc, MutableFstClass); + REGISTER_FST_OPERATION(Convert, Arc, ConvertArgs); + REGISTER_FST_OPERATION(Decode, Arc, DecodeArgs1); + REGISTER_FST_OPERATION(Decode, Arc, DecodeArgs2); + REGISTER_FST_OPERATION(Determinize, Arc, DeterminizeArgs); + REGISTER_FST_OPERATION(Difference, Arc, DifferenceArgs); + REGISTER_FST_OPERATION(Disambiguate, Arc, DisambiguateArgs); + REGISTER_FST_OPERATION(DrawFst, Arc, FstDrawerArgs); + REGISTER_FST_OPERATION(Encode, Arc, EncodeArgs1); + REGISTER_FST_OPERATION(Encode, Arc, EncodeArgs2); + REGISTER_FST_OPERATION(EpsNormalize, Arc, EpsNormalizeArgs); + REGISTER_FST_OPERATION(Equal, Arc, EqualArgs); + REGISTER_FST_OPERATION(Equivalent, Arc, EquivalentArgs); + REGISTER_FST_OPERATION(PrintFstInfo, Arc, InfoArgs); + REGISTER_FST_OPERATION(GetFstInfo, Arc, GetInfoArgs); + REGISTER_FST_OPERATION(InitArcIteratorClass, Arc, + InitArcIteratorClassArgs); + REGISTER_FST_OPERATION(InitEncodeMapperClass, Arc, + InitEncodeMapperClassArgs); + REGISTER_FST_OPERATION(InitMutableArcIteratorClass, Arc, + InitMutableArcIteratorClassArgs); + REGISTER_FST_OPERATION(InitStateIteratorClass, Arc, + InitStateIteratorClassArgs); + } + + void RegisterBatch2() { + REGISTER_FST_OPERATION(Intersect, Arc, IntersectArgs); + REGISTER_FST_OPERATION(Invert, Arc, MutableFstClass); + REGISTER_FST_OPERATION(Map, Arc, MapArgs); + REGISTER_FST_OPERATION(Minimize, Arc, MinimizeArgs); + REGISTER_FST_OPERATION(PrintFst, Arc, FstPrinterArgs); + REGISTER_FST_OPERATION(Project, Arc, ProjectArgs); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs1); + REGISTER_FST_OPERATION(Prune, Arc, PruneArgs2); + REGISTER_FST_OPERATION(Push, Arc, PushArgs1); + REGISTER_FST_OPERATION(Push, Arc, PushArgs2); + REGISTER_FST_OPERATION(RandEquivalent, Arc, RandEquivalentArgs); + REGISTER_FST_OPERATION(RandGen, Arc, RandGenArgs); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs1); + REGISTER_FST_OPERATION(Relabel, Arc, RelabelArgs2); + REGISTER_FST_OPERATION(Replace, Arc, ReplaceArgs); + REGISTER_FST_OPERATION(Reverse, Arc, ReverseArgs); + REGISTER_FST_OPERATION(Reweight, Arc, ReweightArgs); + REGISTER_FST_OPERATION(RmEpsilon, Arc, RmEpsilonArgs); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs1); + REGISTER_FST_OPERATION(ShortestDistance, Arc, ShortestDistanceArgs2); + REGISTER_FST_OPERATION(ShortestPath, Arc, ShortestPathArgs); + REGISTER_FST_OPERATION(Synchronize, Arc, SynchronizeArgs); + REGISTER_FST_OPERATION(TopSort, Arc, TopSortArgs); + REGISTER_FST_OPERATION(Union, Arc, UnionArgs); + REGISTER_FST_OPERATION(Verify, Arc, VerifyArgs); + } +}; + +} // namespace script +} // namespace fst + +#define REGISTER_FST_OPERATIONS(Arc) \ + AllFstOperationsRegisterer register_all_fst_operations##Arc; + +#endif // FST_SCRIPT_FSTSCRIPT_H_ diff --git a/projects/llm_framework/include/fst/script/getters.h b/projects/llm_framework/include/fst/script/getters.h new file mode 100644 index 00000000..5cd727e8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/getters.h @@ -0,0 +1,76 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Getters for converting command-line arguments into the appropriate enums +// or bitmasks, with the simplest ones defined as inline. + +#ifndef FST_SCRIPT_GETTERS_H_ +#define FST_SCRIPT_GETTERS_H_ + +#include + +#include // For ComposeFilter. +#include // For DeterminizeType. +#include // For kEncodeLabels (etc.). +#include // For EpsNormalizeType. +#include // For ProjectType. +#include // For kPushWeights (etc.). +#include // For QueueType. +#include // For ClosureType. +#include // For ArcSortType. +#include // For MapType. +#include // For RandArcSelection. + +#include + +namespace fst { +namespace script { + +bool GetArcSortType(const string &str, ArcSortType *sort_type); + +inline ClosureType GetClosureType(bool closure_plus) { + return closure_plus ? CLOSURE_PLUS : CLOSURE_STAR; +} + +bool GetComposeFilter(const string &str, ComposeFilter *compose_filter); + +bool GetDeterminizeType(const string &str, DeterminizeType *det_type); + +inline uint32 GetEncodeFlags(bool encode_labels, bool encode_weights) { + return (encode_labels ? kEncodeLabels : 0) | + (encode_weights ? kEncodeWeights : 0); +} + +inline EpsNormalizeType GetEpsNormalizeType(bool eps_norm_output) { + return eps_norm_output ? EPS_NORM_OUTPUT : EPS_NORM_INPUT; +} + +bool GetMapType(const string &str, MapType *map_type); + +inline ProjectType GetProjectType(bool project_output) { + return project_output ? PROJECT_OUTPUT : PROJECT_INPUT; +} + +inline uint32 GetPushFlags(bool push_weights, bool push_labels, + bool remove_total_weight, bool remove_common_affix) { + return ((push_weights ? kPushWeights : 0) | + (push_labels ? kPushLabels : 0) | + (remove_total_weight ? kPushRemoveTotalWeight : 0) | + (remove_common_affix ? kPushRemoveCommonAffix : 0)); +} + +bool GetQueueType(const string &str, QueueType *queue_type); + +bool GetRandArcSelection(const string &str, RandArcSelection *ras); + +bool GetReplaceLabelType(const string &str, bool epsilon_on_replace, + ReplaceLabelType *rlt); + +inline ReweightType GetReweightType(bool to_final) { + return to_final ? REWEIGHT_TO_FINAL : REWEIGHT_TO_INITIAL; +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_GETTERS_H_ diff --git a/projects/llm_framework/include/fst/script/info-impl.h b/projects/llm_framework/include/fst/script/info-impl.h new file mode 100644 index 00000000..e8956498 --- /dev/null +++ b/projects/llm_framework/include/fst/script/info-impl.h @@ -0,0 +1,314 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to compute various information about FSTs, a helper class for +// fstinfo.cc. + +#ifndef FST_SCRIPT_INFO_IMPL_H_ +#define FST_SCRIPT_INFO_IMPL_H_ + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace fst { + +// Compute various information about FSTs, helper class for fstinfo.cc. +// WARNING: Stand-alone use of this class is not recommended, most code +// should call directly the relevant library functions: Fst::NumStates, +// Fst::NumArcs, TestProperties, etc. +class FstInfo { + public: + FstInfo() {} + + // When info_type is "short" (or "auto" and not an ExpandedFst) then only + // minimal info is computed and can be requested. + template + FstInfo(const Fst &fst, bool test_properties, + const string &arc_filter_type = "any", + const string &info_type = "auto", bool verify = true) + : fst_type_(fst.Type()), + input_symbols_(fst.InputSymbols() ? fst.InputSymbols()->Name() + : "none"), + output_symbols_(fst.OutputSymbols() ? fst.OutputSymbols()->Name() + : "none"), + nstates_(0), + narcs_(0), + start_(kNoStateId), + nfinal_(0), + nepsilons_(0), + niepsilons_(0), + noepsilons_(0), + ilabel_mult_(0.0), + olabel_mult_(0.0), + naccess_(0), + ncoaccess_(0), + nconnect_(0), + ncc_(0), + nscc_(0), + input_match_type_(MATCH_NONE), + output_match_type_(MATCH_NONE), + input_lookahead_(false), + output_lookahead_(false), + properties_(0), + arc_filter_type_(arc_filter_type), + long_info_(true), + arc_type_(Arc::Type()) { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + if (info_type == "long") { + long_info_ = true; + } else if (info_type == "short") { + long_info_ = false; + } else if (info_type == "auto") { + long_info_ = fst.Properties(kExpanded, false); + } else { + FSTERROR() << "Bad info type: " << info_type; + return; + } + if (!long_info_) return; + // If the FST is not sane, we return. + if (verify && !Verify(fst)) { + FSTERROR() << "FstInfo: Verify: FST not well-formed"; + return; + } + start_ = fst.Start(); + properties_ = fst.Properties(kFstProperties, test_properties); + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + ++nstates_; + const auto s = siter.Value(); + if (fst.Final(s) != Weight::Zero()) ++nfinal_; + std::map ilabel_count; + std::map olabel_count; + for (ArcIterator> aiter(fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + ++narcs_; + if (arc.ilabel == 0 && arc.olabel == 0) ++nepsilons_; + if (arc.ilabel == 0) ++niepsilons_; + if (arc.olabel == 0) ++noepsilons_; + ++ilabel_count[arc.ilabel]; + ++olabel_count[arc.olabel]; + } + for (auto it = ilabel_count.begin(); it != ilabel_count.end(); ++it) { + ilabel_mult_ += it->second * it->second; + } + for (auto it = olabel_count.begin(); it != olabel_count.end(); ++it) { + olabel_mult_ += it->second * it->second; + } + } + if (narcs_ > 0) { + ilabel_mult_ /= narcs_; + olabel_mult_ /= narcs_; + } + { + std::vector cc; + CcVisitor cc_visitor(&cc); + FifoQueue fifo_queue; + if (arc_filter_type == "any") { + Visit(fst, &cc_visitor, &fifo_queue); + } else if (arc_filter_type == "epsilon") { + Visit(fst, &cc_visitor, &fifo_queue, EpsilonArcFilter()); + } else if (arc_filter_type == "iepsilon") { + Visit(fst, &cc_visitor, &fifo_queue, InputEpsilonArcFilter()); + } else if (arc_filter_type == "oepsilon") { + Visit(fst, &cc_visitor, &fifo_queue, OutputEpsilonArcFilter()); + } else { + FSTERROR() << "Bad arc filter type: " << arc_filter_type; + return; + } + for (StateId s = 0; s < cc.size(); ++s) { + if (cc[s] >= ncc_) ncc_ = cc[s] + 1; + } + } + { + std::vector scc; + std::vector access, coaccess; + uint64 props = 0; + SccVisitor scc_visitor(&scc, &access, &coaccess, &props); + if (arc_filter_type == "any") { + DfsVisit(fst, &scc_visitor); + } else if (arc_filter_type == "epsilon") { + DfsVisit(fst, &scc_visitor, EpsilonArcFilter()); + } else if (arc_filter_type == "iepsilon") { + DfsVisit(fst, &scc_visitor, InputEpsilonArcFilter()); + } else if (arc_filter_type == "oepsilon") { + DfsVisit(fst, &scc_visitor, OutputEpsilonArcFilter()); + } else { + FSTERROR() << "Bad arc filter type: " << arc_filter_type; + return; + } + for (StateId s = 0; s < scc.size(); ++s) { + if (access[s]) ++naccess_; + if (coaccess[s]) ++ncoaccess_; + if (access[s] && coaccess[s]) ++nconnect_; + if (scc[s] >= nscc_) nscc_ = scc[s] + 1; + } + } + LookAheadMatcher> imatcher(fst, MATCH_INPUT); + input_match_type_ = imatcher.Type(test_properties); + input_lookahead_ = imatcher.Flags() & kInputLookAheadMatcher; + LookAheadMatcher> omatcher(fst, MATCH_OUTPUT); + output_match_type_ = omatcher.Type(test_properties); + output_lookahead_ = omatcher.Flags() & kOutputLookAheadMatcher; + } + + // Short info. + + const string &FstType() const { return fst_type_; } + + const string &ArcType() const { return arc_type_; } + + const string &InputSymbols() const { return input_symbols_; } + + const string &OutputSymbols() const { return output_symbols_; } + + bool LongInfo() const { return long_info_; } + + const string &ArcFilterType() const { return arc_filter_type_; } + + // Long info. + + MatchType InputMatchType() const { + CheckLong(); + return input_match_type_; + } + + MatchType OutputMatchType() const { + CheckLong(); + return output_match_type_; + } + + bool InputLookAhead() const { + CheckLong(); + return input_lookahead_; + } + + bool OutputLookAhead() const { + CheckLong(); + return output_lookahead_; + } + + int64 NumStates() const { + CheckLong(); + return nstates_; + } + + size_t NumArcs() const { + CheckLong(); + return narcs_; + } + + int64 Start() const { + CheckLong(); + return start_; + } + + size_t NumFinal() const { + CheckLong(); + return nfinal_; + } + + size_t NumEpsilons() const { + CheckLong(); + return nepsilons_; + } + + size_t NumInputEpsilons() const { + CheckLong(); + return niepsilons_; + } + + size_t NumOutputEpsilons() const { + CheckLong(); + return noepsilons_; + } + + double InputLabelMultiplicity() const { + CheckLong(); + return ilabel_mult_; + } + + double OutputLabelMultiplicity() const { + CheckLong(); + return olabel_mult_; + } + + size_t NumAccessible() const { + CheckLong(); + return naccess_; + } + + size_t NumCoAccessible() const { + CheckLong(); + return ncoaccess_; + } + + size_t NumConnected() const { + CheckLong(); + return nconnect_; + } + + size_t NumCc() const { + CheckLong(); + return ncc_; + } + + size_t NumScc() const { + CheckLong(); + return nscc_; + } + + uint64 Properties() const { + CheckLong(); + return properties_; + } + + private: + void CheckLong() const { + if (!long_info_) + FSTERROR() << "FstInfo: Method only available with long info signature"; + } + + string fst_type_; + string input_symbols_; + string output_symbols_; + int64 nstates_; + size_t narcs_; + int64 start_; + size_t nfinal_; + size_t nepsilons_; + size_t niepsilons_; + size_t noepsilons_; + double ilabel_mult_; + double olabel_mult_; + size_t naccess_; + size_t ncoaccess_; + size_t nconnect_; + size_t ncc_; + size_t nscc_; + MatchType input_match_type_; + MatchType output_match_type_; + bool input_lookahead_; + bool output_lookahead_; + uint64 properties_; + string arc_filter_type_; + bool long_info_; + string arc_type_; +}; + +void PrintFstInfoImpl(const FstInfo &fstinfo, bool pipe = false); + +} // namespace fst + +#endif // FST_SCRIPT_INFO_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/info.h b/projects/llm_framework/include/fst/script/info.h new file mode 100644 index 00000000..039d06d8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/info.h @@ -0,0 +1,50 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_INFO_H_ +#define FST_SCRIPT_INFO_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using InfoArgs = std::tuple; + +template +void PrintFstInfo(InfoArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + const FstInfo fstinfo(fst, std::get<1>(*args), std::get<2>(*args), + std::get<3>(*args), std::get<4>(*args)); + PrintFstInfoImpl(fstinfo, std::get<5>(*args)); + if (std::get<5>(*args)) fst.Write(""); +} + +void PrintFstInfo(const FstClass &f, bool test_properties, + const string &arc_filter, const string &info_type, bool pipe, + bool verify); + +using GetInfoArgs = std::tuple; + +template +void GetFstInfo(GetInfoArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + *(std::get<5>(*args)) = FstInfo(fst, std::get<1>(*args), std::get<2>(*args), + std::get<3>(*args), std::get<4>(*args)); +} + +void GetFstInfo(const FstClass &fst, bool test_properties, + const string &arc_filter, const string &info_type, bool verify, + FstInfo *info); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INFO_H_ diff --git a/projects/llm_framework/include/fst/script/intersect.h b/projects/llm_framework/include/fst/script/intersect.h new file mode 100644 index 00000000..229bd56f --- /dev/null +++ b/projects/llm_framework/include/fst/script/intersect.h @@ -0,0 +1,35 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_INTERSECT_H_ +#define FST_SCRIPT_INTERSECT_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using IntersectArgs = std::tuple; + +template +void Intersect(IntersectArgs *args) { + const Fst &ifst1 = *(std::get<0>(*args).GetFst()); + const Fst &ifst2 = *(std::get<1>(*args).GetFst()); + MutableFst *ofst = std::get<2>(*args)->GetMutableFst(); + const auto &opts = std::get<3>(*args); + Intersect(ifst1, ifst2, ofst, opts); +} + +void Intersect(const FstClass &ifst, const FstClass &ifst2, + MutableFstClass *ofst, + const ComposeOptions &opts = ComposeOptions()); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INTERSECT_H_ diff --git a/projects/llm_framework/include/fst/script/invert.h b/projects/llm_framework/include/fst/script/invert.h new file mode 100644 index 00000000..5bc31317 --- /dev/null +++ b/projects/llm_framework/include/fst/script/invert.h @@ -0,0 +1,23 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_INVERT_H_ +#define FST_SCRIPT_INVERT_H_ + +#include +#include + +namespace fst { +namespace script { + +template +void Invert(MutableFstClass *fst) { + Invert(fst->GetMutableFst()); +} + +void Invert(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_INVERT_H_ diff --git a/projects/llm_framework/include/fst/script/isomorphic.h b/projects/llm_framework/include/fst/script/isomorphic.h new file mode 100644 index 00000000..94ea77f9 --- /dev/null +++ b/projects/llm_framework/include/fst/script/isomorphic.h @@ -0,0 +1,34 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_ISOMORPHIC_H_ +#define FST_SCRIPT_ISOMORPHIC_H_ + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using IsomorphicInnerArgs = std::tuple; + +using IsomorphicArgs = WithReturnValue; + +template +void Isomorphic(IsomorphicArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + args->retval = Isomorphic(fst1, fst2, std::get<2>(args->args)); +} + +bool Isomorphic(const FstClass &fst1, const FstClass &fst2, + float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_ISOMORPHIC_H_ diff --git a/projects/llm_framework/include/fst/script/map.h b/projects/llm_framework/include/fst/script/map.h new file mode 100644 index 00000000..158d98aa --- /dev/null +++ b/projects/llm_framework/include/fst/script/map.h @@ -0,0 +1,158 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_MAP_H_ +#define FST_SCRIPT_MAP_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +template +Fst *ArcMap(const Fst &fst, + const M &mapper) { + using ToArc = typename M::ToArc; + auto *ofst = new VectorFst; + ArcMap(fst, ofst, mapper); + return ofst; +} + +template +Fst *StateMap(const Fst &fst, + const M &mapper) { + using ToArc = typename M::ToArc; + auto *ofst = new VectorFst; + StateMap(fst, ofst, mapper); + return ofst; +} + +enum MapType { + ARC_SUM_MAPPER, + ARC_UNIQUE_MAPPER, + IDENTITY_MAPPER, + INPUT_EPSILON_MAPPER, + INVERT_MAPPER, + OUTPUT_EPSILON_MAPPER, + PLUS_MAPPER, + POWER_MAPPER, + QUANTIZE_MAPPER, + RMWEIGHT_MAPPER, + SUPERFINAL_MAPPER, + TIMES_MAPPER, + TO_LOG_MAPPER, + TO_LOG64_MAPPER, + TO_STD_MAPPER +}; + +using MapInnerArgs = + std::tuple; + +using MapArgs = WithReturnValue; + +template +void Map(MapArgs *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(args->args).GetFst()); + const auto map_type = std::get<1>(args->args); + switch (map_type) { + case ARC_SUM_MAPPER: { + std::unique_ptr> ofst(StateMap(ifst, ArcSumMapper(ifst))); + args->retval = new FstClass(*ofst); + return; + } + case ARC_UNIQUE_MAPPER: { + std::unique_ptr> ofst( + StateMap(ifst, ArcUniqueMapper(ifst))); + args->retval = new FstClass(*ofst); + return; + } + case IDENTITY_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, IdentityArcMapper())); + args->retval = new FstClass(*ofst); + return; + } + case INPUT_EPSILON_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, InputEpsilonMapper())); + args->retval = new FstClass(*ofst); + return; + } + case INVERT_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, InvertWeightMapper())); + args->retval = new FstClass(*ofst); + return; + } + case OUTPUT_EPSILON_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, OutputEpsilonMapper())); + args->retval = new FstClass(*ofst); + return; + } + case PLUS_MAPPER: { + const auto weight = *(std::get<4>(args->args).GetWeight()); + std::unique_ptr> ofst(ArcMap(ifst, PlusMapper(weight))); + args->retval = new FstClass(*ofst); + return; + } + case POWER_MAPPER: { + const auto power = std::get<3>(args->args); + std::unique_ptr> ofst(ArcMap(ifst, PowerMapper(power))); + args->retval = new FstClass(*ofst); + return; + } + case QUANTIZE_MAPPER: { + const auto delta = std::get<2>(args->args); + std::unique_ptr> ofst(ArcMap(ifst, QuantizeMapper(delta))); + args->retval = new FstClass(*ofst); + return; + } + case RMWEIGHT_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, RmWeightMapper())); + args->retval = new FstClass(*ofst); + return; + } + case SUPERFINAL_MAPPER: { + std::unique_ptr> ofst(ArcMap(ifst, SuperFinalMapper())); + args->retval = new FstClass(*ofst); + return; + } + case TIMES_MAPPER: { + const auto weight = *(std::get<4>(args->args).GetWeight()); + std::unique_ptr> ofst(ArcMap(ifst, TimesMapper(weight))); + args->retval = new FstClass(*ofst); + return; + } + case TO_LOG_MAPPER: { + std::unique_ptr> ofst( + ArcMap(ifst, WeightConvertMapper())); + args->retval = new FstClass(*ofst); + return; + } + case TO_LOG64_MAPPER: { + std::unique_ptr> ofst( + ArcMap(ifst, WeightConvertMapper())); + args->retval = new FstClass(*ofst); + return; + } + case TO_STD_MAPPER: { + std::unique_ptr> ofst( + ArcMap(ifst, WeightConvertMapper())); + args->retval = new FstClass(*ofst); + return; + } + } +} + +FstClass *Map(const FstClass &ifst, MapType map_type, float delta, double power, + const WeightClass &weight); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_MAP_H_ diff --git a/projects/llm_framework/include/fst/script/minimize.h b/projects/llm_framework/include/fst/script/minimize.h new file mode 100644 index 00000000..773e8ef1 --- /dev/null +++ b/projects/llm_framework/include/fst/script/minimize.h @@ -0,0 +1,33 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_MINIMIZE_H_ +#define FST_SCRIPT_MINIMIZE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using MinimizeArgs = std::tuple; + +template +void Minimize(MinimizeArgs *args) { + MutableFst *ofst1 = std::get<0>(*args)->GetMutableFst(); + MutableFst *ofst2 = (std::get<1>(*args) ? + std::get<1>(*args)->GetMutableFst() : + nullptr); + Minimize(ofst1, ofst2, std::get<2>(*args), std::get<3>(*args)); +} + +void Minimize(MutableFstClass *ofst1, MutableFstClass *ofst2 = nullptr, + float delta = kShortestDelta, bool allow_nondet = false); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_MINIMIZE_H_ diff --git a/projects/llm_framework/include/fst/script/print-impl.h b/projects/llm_framework/include/fst/script/print-impl.h new file mode 100644 index 00000000..539c6d8f --- /dev/null +++ b/projects/llm_framework/include/fst/script/print-impl.h @@ -0,0 +1,132 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Stand-alone class to print out binary FSTs in the AT&T format, a helper +// class for fstprint.cc. + +#ifndef FST_SCRIPT_PRINT_IMPL_H_ +#define FST_SCRIPT_PRINT_IMPL_H_ + +#include +#include +#include + +#include +#include + +DECLARE_string(fst_field_separator); + +namespace fst { + +// Print a binary FST in textual format (helper class for fstprint.cc). +// WARNING: Stand-alone use of this class not recommended, most code should +// read/write using the binary format which is much more efficient. +template +class FstPrinter { + public: + using StateId = typename Arc::StateId; + using Label = typename Arc::Label; + using Weight = typename Arc::Weight; + + FstPrinter(const Fst &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, bool accep, + bool show_weight_one, const string &field_separator, + const string &missing_symbol = "") + : fst_(fst), + isyms_(isyms), + osyms_(osyms), + ssyms_(ssyms), + accep_(accep && fst.Properties(kAcceptor, true)), + ostrm_(nullptr), + show_weight_one_(show_weight_one), + sep_(field_separator), + missing_symbol_(missing_symbol) {} + + // Prints FST to an output stream. + void Print(std::ostream *ostrm, const string &dest) { + ostrm_ = ostrm; + dest_ = dest; + const auto start = fst_.Start(); + if (start == kNoStateId) return; + // Initial state first. + PrintState(start); + for (StateIterator> siter(fst_); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + if (s != start) PrintState(s); + } + } + + private: + void PrintId(StateId id, const SymbolTable *syms, const char *name) const { + if (syms) { + string symbol = syms->Find(id); + if (symbol.empty()) { + if (missing_symbol_.empty()) { + FSTERROR() << "FstPrinter: Integer " << id + << " is not mapped to any textual symbol" + << ", symbol table = " << syms->Name() + << ", destination = " << dest_; + symbol = "?"; + } else { + symbol = missing_symbol_; + } + } + *ostrm_ << symbol; + } else { + *ostrm_ << id; + } + } + + void PrintStateId(StateId s) const { PrintId(s, ssyms_, "state ID"); } + + void PrintILabel(Label l) const { PrintId(l, isyms_, "arc input label"); } + + void PrintOLabel(Label l) const { PrintId(l, osyms_, "arc output label"); } + + void PrintState(StateId s) const { + bool output = false; + for (ArcIterator> aiter(fst_, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + PrintStateId(s); + *ostrm_ << sep_; + PrintStateId(arc.nextstate); + *ostrm_ << sep_; + PrintILabel(arc.ilabel); + if (!accep_) { + *ostrm_ << sep_; + PrintOLabel(arc.olabel); + } + if (show_weight_one_ || arc.weight != Weight::One()) + *ostrm_ << sep_ << arc.weight; + *ostrm_ << "\n"; + output = true; + } + const auto weight = fst_.Final(s); + if (weight != Weight::Zero() || !output) { + PrintStateId(s); + if (show_weight_one_ || weight != Weight::One()) { + *ostrm_ << sep_ << weight; + } + *ostrm_ << "\n"; + } + } + + const Fst &fst_; + const SymbolTable *isyms_; // ilabel symbol table. + const SymbolTable *osyms_; // olabel symbol table. + const SymbolTable *ssyms_; // slabel symbol table. + bool accep_; // Print as acceptor when possible? + std::ostream *ostrm_; // Text FST destination. + string dest_; // Text FST destination name. + bool show_weight_one_; // Print weights equal to Weight::One()? + string sep_; // Separator character between fields. + string missing_symbol_; // Symbol to print when lookup fails (default + // "" means raise error). + // + FstPrinter(const FstPrinter &) = delete; + FstPrinter &operator=(const FstPrinter &) = delete; +}; + +} // namespace fst + +#endif // FST_SCRIPT_PRINT_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/print.h b/projects/llm_framework/include/fst/script/print.h new file mode 100644 index 00000000..687606b3 --- /dev/null +++ b/projects/llm_framework/include/fst/script/print.h @@ -0,0 +1,79 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PRINT_H_ +#define FST_SCRIPT_PRINT_H_ + +#include + +#include +#include +#include + +DECLARE_string(fst_field_separator); + +namespace fst { +namespace script { + +// Note: it is safe to pass these strings as references because +// this struct is only used to pass them deeper in the call graph. +// Be sure you understand why this is so before using this struct +// for anything else! +struct FstPrinterArgs { + const FstClass &fst; + const SymbolTable *isyms; + const SymbolTable *osyms; + const SymbolTable *ssyms; + const bool accept; + const bool show_weight_one; + std::ostream *ostrm; + const string &dest; + const string &sep; // NOLINT + const string &missing_symbol; + + FstPrinterArgs(const FstClass &fst, const SymbolTable *isyms, + const SymbolTable *osyms, const SymbolTable *ssyms, + bool accept, bool show_weight_one, std::ostream *ostrm, + const string &dest, const string &sep, + const string &missing_sym = "") + : fst(fst), + isyms(isyms), + osyms(osyms), + ssyms(ssyms), + accept(accept), + show_weight_one(show_weight_one), + ostrm(ostrm), + dest(dest), + sep(sep), + missing_symbol(missing_sym) {} +}; + +template +void PrintFst(FstPrinterArgs *args) { + const Fst &fst = *(args->fst.GetFst()); + FstPrinter fstprinter(fst, args->isyms, args->osyms, args->ssyms, + args->accept, args->show_weight_one, args->sep, + args->missing_symbol); + fstprinter.Print(args->ostrm, args->dest); +} + +void PrintFst(const FstClass &fst, std::ostream &ostrm, const string &dest, + const SymbolTable *isyms, const SymbolTable *osyms, + const SymbolTable *ssyms, bool accept, bool show_weight_one, + const string &missing_sym = ""); + +// The same, but with more sensible defaults. +template +void PrintFst(const Fst &fst, std::ostream &ostrm, const string &dest = "", + const SymbolTable *isyms = nullptr, + const SymbolTable *osyms = nullptr, + const SymbolTable *ssyms = nullptr) { + const string sep = FLAGS_fst_field_separator.substr(0, 1); + FstPrinter fstprinter(fst, isyms, osyms, ssyms, true, true, sep); + fstprinter.Print(&ostrm, dest); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PRINT_H_ diff --git a/projects/llm_framework/include/fst/script/project.h b/projects/llm_framework/include/fst/script/project.h new file mode 100644 index 00000000..13edeb1d --- /dev/null +++ b/projects/llm_framework/include/fst/script/project.h @@ -0,0 +1,28 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PROJECT_H_ +#define FST_SCRIPT_PROJECT_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ProjectArgs = std::pair; + +template +void Project(ProjectArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + Project(fst, std::get<1>(*args)); +} + +void Project(MutableFstClass *fst, ProjectType project_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PROJECT_H_ diff --git a/projects/llm_framework/include/fst/script/prune.h b/projects/llm_framework/include/fst/script/prune.h new file mode 100644 index 00000000..ed10b540 --- /dev/null +++ b/projects/llm_framework/include/fst/script/prune.h @@ -0,0 +1,51 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PRUNE_H_ +#define FST_SCRIPT_PRUNE_H_ + +#include +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using PruneArgs1 = std::tuple; + +template +void Prune(PruneArgs1 *args) { + using Weight = typename Arc::Weight; + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto weight_threshold = *(std::get<2>(*args).GetWeight()); + Prune(ifst, ofst, weight_threshold, std::get<3>(*args), std::get<4>(*args)); +} + +using PruneArgs2 = std::tuple; + +template +void Prune(PruneArgs2 *args) { + using Weight = typename Arc::Weight; + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const auto weight_threshold = *(std::get<1>(*args).GetWeight()); + Prune(fst, weight_threshold, std::get<2>(*args), std::get<3>(*args)); +} + +void Prune(const FstClass &ifst, MutableFstClass *ofst, + const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, + float delta = kDelta); + +void Prune(MutableFstClass *fst, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PRUNE_H_ diff --git a/projects/llm_framework/include/fst/script/push.h b/projects/llm_framework/include/fst/script/push.h new file mode 100644 index 00000000..018cd8f8 --- /dev/null +++ b/projects/llm_framework/include/fst/script/push.h @@ -0,0 +1,53 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_PUSH_H_ +#define FST_SCRIPT_PUSH_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using PushArgs1 = std::tuple; + +template +void Push(PushArgs1 *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + Push(fst, std::get<1>(*args), std::get<2>(*args), std::get<3>(*args)); +} + +using PushArgs2 = std::tuple; + +template +void Push(PushArgs2 *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + switch (std::get<3>(*args)) { + case REWEIGHT_TO_FINAL: { + Push(ifst, ofst, std::get<2>(*args), + std::get<4>(*args)); + return; + } + case REWEIGHT_TO_INITIAL: { + Push(ifst, ofst, std::get<2>(*args), + std::get<4>(*args)); + return; + } + } +} + +void Push(MutableFstClass *fst, ReweightType rew_type, float delta = kDelta, + bool remove_total_weight = false); + +void Push(const FstClass &ifst, MutableFstClass *ofst, uint32 flags, + ReweightType rew_type, float delta = kDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_PUSH_H_ diff --git a/projects/llm_framework/include/fst/script/randequivalent.h b/projects/llm_framework/include/fst/script/randequivalent.h new file mode 100644 index 00000000..945f8a06 --- /dev/null +++ b/projects/llm_framework/include/fst/script/randequivalent.h @@ -0,0 +1,67 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RANDEQUIVALENT_H_ +#define FST_SCRIPT_RANDEQUIVALENT_H_ + +#include +#include + +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +using RandEquivalentInnerArgs = std::tuple &>; + +using RandEquivalentArgs = WithReturnValue; + +template +void RandEquivalent(RandEquivalentArgs *args) { + const Fst &fst1 = *(std::get<0>(args->args).GetFst()); + const Fst &fst2 = *(std::get<1>(args->args).GetFst()); + const auto seed = std::get<4>(args->args); + const auto &opts = std::get<5>(args->args); + switch (opts.selector) { + case UNIFORM_ARC_SELECTOR: { + const UniformArcSelector selector(seed); + const RandGenOptions> ropts(selector, + opts.max_length); + args->retval = RandEquivalent(fst1, fst2, std::get<2>(args->args), + std::get<3>(args->args), ropts); + return; + } + case FAST_LOG_PROB_ARC_SELECTOR: { + const FastLogProbArcSelector selector(seed); + const RandGenOptions> ropts(selector, + opts.max_length); + args->retval = RandEquivalent(fst1, fst2, std::get<2>(args->args), + std::get<3>(args->args), ropts); + return; + } + case LOG_PROB_ARC_SELECTOR: { + const LogProbArcSelector selector(seed); + const RandGenOptions> ropts(selector, + opts.max_length); + args->retval = RandEquivalent(fst1, fst2, std::get<2>(args->args), + std::get<3>(args->args), ropts); + return; + } + } +} + +bool RandEquivalent(const FstClass &fst1, const FstClass &fst2, int32 npath = 1, + float delta = kDelta, time_t seed = time(nullptr), + const RandGenOptions &opts = + RandGenOptions(UNIFORM_ARC_SELECTOR)); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RANDEQUIVALENT_H_ diff --git a/projects/llm_framework/include/fst/script/randgen.h b/projects/llm_framework/include/fst/script/randgen.h new file mode 100644 index 00000000..5ce79d01 --- /dev/null +++ b/projects/llm_framework/include/fst/script/randgen.h @@ -0,0 +1,63 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RANDGEN_H_ +#define FST_SCRIPT_RANDGEN_H_ + +#include + +#include + +#include +#include +#include + +namespace fst { +namespace script { + +using RandGenArgs = std::tuple &>; + +template +void RandGen(RandGenArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const time_t seed = std::get<2>(*args); + const auto &opts = std::get<3>(*args); + switch (opts.selector) { + case UNIFORM_ARC_SELECTOR: { + const UniformArcSelector selector(seed); + const RandGenOptions> ropts( + selector, opts.max_length, opts.npath, opts.weighted, + opts.remove_total_weight); + RandGen(ifst, ofst, ropts); + return; + } + case FAST_LOG_PROB_ARC_SELECTOR: { + const FastLogProbArcSelector selector(seed); + const RandGenOptions> ropts( + selector, opts.max_length, opts.npath, opts.weighted, + opts.remove_total_weight); + RandGen(ifst, ofst, ropts); + return; + } + case LOG_PROB_ARC_SELECTOR: { + const LogProbArcSelector selector(seed); + const RandGenOptions> ropts( + selector, opts.max_length, opts.npath, opts.weighted, + opts.remove_total_weight); + RandGen(ifst, ofst, ropts); + return; + } + } +} + +void RandGen(const FstClass &ifst, MutableFstClass *ofst, + time_t seed = time(nullptr), + const RandGenOptions &opts = + RandGenOptions(UNIFORM_ARC_SELECTOR)); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RANDGEN_H_ diff --git a/projects/llm_framework/include/fst/script/register.h b/projects/llm_framework/include/fst/script/register.h new file mode 100644 index 00000000..d66e7ade --- /dev/null +++ b/projects/llm_framework/include/fst/script/register.h @@ -0,0 +1,99 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REGISTER_H_ +#define FST_SCRIPT_REGISTER_H_ + +#include +#include + +#include +#include +#include + +// Holds methods and classes responsible for maintaining +// the register for FstClass arc types. + +namespace fst { +namespace script { + +// Registers for reading and converting various kinds of FST classes. + +// This class definition is to avoid a nested class definition inside the +// IORegistration struct. + +template +struct FstClassRegEntry { + Reader reader; + Creator creator; + Converter converter; + + FstClassRegEntry(Reader r, Creator cr, Converter co) + : reader(r), creator(cr), converter(co) {} + + FstClassRegEntry() + : reader(nullptr), creator(nullptr), converter(nullptr) {} +}; + +template +class FstClassIORegister + : public GenericRegister, + FstClassIORegister> { + public: + Reader GetReader(const string &arc_type) const { + return this->GetEntry(arc_type).reader; + } + + Creator GetCreator(const string &arc_type) const { + return this->GetEntry(arc_type).creator; + } + + Converter GetConverter(const string &arc_type) const { + return this->GetEntry(arc_type).converter; + } + + protected: + string ConvertKeyToSoFilename(const string &key) const final { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + return legal_type + "-arc.so"; + } +}; + +// Struct containing everything needed to register a particular type +// of FST class (e.g., a plain FstClass, or a MutableFstClass, etc.). +template +struct IORegistration { + using Reader = FstClassType *(*)(std::istream &stream, + const FstReadOptions &opts); + + using Creator = FstClassImplBase *(*)(); + + using Converter = FstClassImplBase *(*)(const FstClass &other); + + using Entry = FstClassRegEntry; + + // FST class Register. + using Register = FstClassIORegister; + + // FST class Register-er. + using Registerer = + GenericRegisterer>; +}; + +#define REGISTER_FST_CLASS(Class, Arc) \ + static IORegistration::Registerer Class##_##Arc##_registerer( \ + Arc::Type(), \ + IORegistration::Entry(Class::Read, Class::Create, \ + Class::Convert)) + +#define REGISTER_FST_CLASSES(Arc) \ + REGISTER_FST_CLASS(FstClass, Arc); \ + REGISTER_FST_CLASS(MutableFstClass, Arc); \ + REGISTER_FST_CLASS(VectorFstClass, Arc); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REGISTER_H_ diff --git a/projects/llm_framework/include/fst/script/relabel.h b/projects/llm_framework/include/fst/script/relabel.h new file mode 100644 index 00000000..74443490 --- /dev/null +++ b/projects/llm_framework/include/fst/script/relabel.h @@ -0,0 +1,64 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RELABEL_H_ +#define FST_SCRIPT_RELABEL_H_ + +#include +#include +#include +#include + +#include +#include + +namespace fst { +namespace script { + +using RelabelArgs1 = std::tuple; + +template +void Relabel(RelabelArgs1 *args) { + MutableFst *ofst = std::get<0>(*args)->GetMutableFst(); + Relabel(ofst, std::get<1>(*args), std::get<2>(*args), std::get<3>(*args), + std::get<4>(*args), std::get<5>(*args), std::get<6>(*args), + std::get<7>(*args), std::get<8>(*args)); +} + +using LabelPair = std::pair; + +using RelabelArgs2 = std::tuple &, + const std::vector &>; + +template +void Relabel(RelabelArgs2 *args) { + MutableFst *ofst = std::get<0>(*args)->GetMutableFst(); + using LabelPair = std::pair; + // In case the MutableFstClass::Label is not the same as Arc::Label, + // make a copy. + std::vector typed_ipairs(std::get<1>(*args).size()); + std::copy(std::get<1>(*args).begin(), std::get<1>(*args).end(), + typed_ipairs.begin()); + std::vector typed_opairs(std::get<2>(*args).size()); + std::copy(std::get<2>(*args).begin(), std::get<2>(*args).end(), + typed_opairs.begin()); + Relabel(ofst, typed_ipairs, typed_opairs); +} + +void Relabel(MutableFstClass *ofst, + const SymbolTable *old_isymbols, const SymbolTable *new_isymbols, + const string &unknown_isymbol, bool attach_new_isymbols, + const SymbolTable *old_osymbols, const SymbolTable *new_osymbols, + const string &unknown_osymbol, bool attach_new_osymbols); + +void Relabel(MutableFstClass *ofst, const std::vector &ipairs, + const std::vector &opairs); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RELABEL_H_ diff --git a/projects/llm_framework/include/fst/script/replace.h b/projects/llm_framework/include/fst/script/replace.h new file mode 100644 index 00000000..926ece24 --- /dev/null +++ b/projects/llm_framework/include/fst/script/replace.h @@ -0,0 +1,72 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REPLACE_H_ +#define FST_SCRIPT_REPLACE_H_ + +#include +#include +#include + +#include +#include + +namespace fst { +namespace script { + +struct ReplaceOptions { + const int64 root; // Root rule for expansion. + const ReplaceLabelType call_label_type; // How to label call arc. + const ReplaceLabelType return_label_type; // How to label return arc. + const int64 return_label; // Specifies return arc label. + + explicit ReplaceOptions(int64 root, + ReplaceLabelType call_label_type = REPLACE_LABEL_INPUT, + ReplaceLabelType return_label_type = REPLACE_LABEL_NEITHER, + int64 return_label = 0) + : root(root), + call_label_type(call_label_type), + return_label_type(return_label_type), + return_label(return_label) {} +}; + +using LabelFstClassPair = std::pair; + +using ReplaceArgs = std::tuple &, + MutableFstClass *, const ReplaceOptions &>; + +template +void Replace(ReplaceArgs *args) { + using LabelFstPair = std::pair *>; + // Now that we know the arc type, we construct a vector of + // std::pair that the real Replace will use. + const auto &untyped_pairs = std::get<0>(*args); + std::vector typed_pairs; + typed_pairs.reserve(untyped_pairs.size()); + for (const auto &untyped_pair : untyped_pairs) { + typed_pairs.emplace_back(untyped_pair.first, // Converts label. + untyped_pair.second->GetFst()); + } + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const auto &opts = std::get<2>(*args); + ReplaceFstOptions typed_opts(opts.root, opts.call_label_type, + opts.return_label_type, opts.return_label); + ReplaceFst rfst(typed_pairs, typed_opts); + // Checks for cyclic dependencies before attempting expansion. + if (rfst.CyclicDependencies()) { + FSTERROR() << "Replace: Cyclic dependencies detected; cannot expand"; + ofst->SetProperties(kError, kError); + return; + } + typed_opts.gc = true; // Caching options to speed up batch copy. + typed_opts.gc_limit = 0; + *ofst = rfst; +} + +void Replace(const std::vector &pairs, + MutableFstClass *ofst, const ReplaceOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REPLACE_H_ diff --git a/projects/llm_framework/include/fst/script/reverse.h b/projects/llm_framework/include/fst/script/reverse.h new file mode 100644 index 00000000..badd96b5 --- /dev/null +++ b/projects/llm_framework/include/fst/script/reverse.h @@ -0,0 +1,30 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REVERSE_H_ +#define FST_SCRIPT_REVERSE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using ReverseArgs = std::tuple; + +template +void Reverse(ReverseArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + Reverse(ifst, ofst, std::get<2>(*args)); +} + +void Reverse(const FstClass &ifst, MutableFstClass *ofst, + bool require_superinitial = true); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REVERSE_H_ diff --git a/projects/llm_framework/include/fst/script/reweight.h b/projects/llm_framework/include/fst/script/reweight.h new file mode 100644 index 00000000..3893ad85 --- /dev/null +++ b/projects/llm_framework/include/fst/script/reweight.h @@ -0,0 +1,37 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_REWEIGHT_H_ +#define FST_SCRIPT_REWEIGHT_H_ + +#include +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +using ReweightArgs = std::tuple &, ReweightType>; + +template +void Reweight(ReweightArgs *args) { + using Weight = typename Arc::Weight; + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const std::vector &potentials = std::get<1>(*args); + std::vector typed_potentials; + internal::CopyWeights(potentials, &typed_potentials); + Reweight(fst, typed_potentials, std::get<2>(*args)); +} + +void Reweight(MutableFstClass *fst, const std::vector &potentials, + ReweightType reweight_type); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_REWEIGHT_H_ diff --git a/projects/llm_framework/include/fst/script/rmepsilon.h b/projects/llm_framework/include/fst/script/rmepsilon.h new file mode 100644 index 00000000..42986c85 --- /dev/null +++ b/projects/llm_framework/include/fst/script/rmepsilon.h @@ -0,0 +1,109 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_RMEPSILON_H_ +#define FST_SCRIPT_RMEPSILON_H_ + +#include +#include + +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +struct RmEpsilonOptions : public ShortestDistanceOptions { + const bool connect; + const WeightClass &weight_threshold; + const int64 state_threshold; + + RmEpsilonOptions(QueueType queue_type, bool connect, + const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId, float delta = kDelta) + : ShortestDistanceOptions(queue_type, EPSILON_ARC_FILTER, kNoStateId, + delta), + connect(connect), + weight_threshold(weight_threshold), + state_threshold(state_threshold) {} +}; + +namespace internal { + +// Code to implement switching on queue types. + +template +void RmEpsilon(MutableFst *fst, + std::vector *distance, + const RmEpsilonOptions &opts, Queue *queue) { + using Weight = typename Arc::Weight; + const fst::RmEpsilonOptions ropts( + queue, opts.delta, opts.connect, + *opts.weight_threshold.GetWeight(), opts.state_threshold); + RmEpsilon(fst, distance, ropts); +} + +template +void RmEpsilon(MutableFst *fst, const RmEpsilonOptions &opts) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + std::vector distance; + switch (opts.queue_type) { + case AUTO_QUEUE: { + AutoQueue queue(*fst, &distance, EpsilonArcFilter()); + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case FIFO_QUEUE: { + FifoQueue queue; + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case LIFO_QUEUE: { + LifoQueue queue; + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case SHORTEST_FIRST_QUEUE: { + NaturalShortestFirstQueue queue(distance); + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case STATE_ORDER_QUEUE: { + StateOrderQueue queue; + RmEpsilon(fst, &distance, opts, &queue); + return; + } + case TOP_ORDER_QUEUE: { + TopOrderQueue queue(*fst, EpsilonArcFilter()); + internal::RmEpsilon(fst, &distance, opts, &queue); + return; + } + default: { + FSTERROR() << "RmEpsilon: Unknown queue type: " << opts.queue_type; + fst->SetProperties(kError, kError); + return; + } + } +} + +} // namespace internal + +using RmEpsilonArgs = std::pair; + +template +void RmEpsilon(RmEpsilonArgs *args) { + MutableFst *fst = std::get<0>(*args)->GetMutableFst(); + const auto &opts = std::get<1>(*args); + internal::RmEpsilon(fst, opts); +} + +void RmEpsilon(MutableFstClass *fst, const RmEpsilonOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_RMEPSILON_H_ diff --git a/projects/llm_framework/include/fst/script/script-impl.h b/projects/llm_framework/include/fst/script/script-impl.h new file mode 100644 index 00000000..33c2853a --- /dev/null +++ b/projects/llm_framework/include/fst/script/script-impl.h @@ -0,0 +1,211 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// This file defines the registration mechanism for new operations. +// These operations are designed to enable scripts to work with FST classes +// at a high level. +// +// If you have a new arc type and want these operations to work with FSTs +// with that arc type, see below for the registration steps +// you must take. +// +// These methods are only recommended for use in high-level scripting +// applications. Most users should use the lower-level templated versions +// corresponding to these. +// +// If you have a new arc type you'd like these operations to work with, +// use the REGISTER_FST_OPERATIONS macro defined in fstscript.h. +// +// If you have a custom operation you'd like to define, you need four +// components. In the following, assume you want to create a new operation +// with the signature +// +// void Foo(const FstClass &ifst, MutableFstClass *ofst); +// +// You need: +// +// 1) A way to bundle the args that your new Foo operation will take, as +// a single struct. The template structs in arg-packs.h provide a handy +// way to do this. In Foo's case, that might look like this: +// +// using FooArgs = std::pair; +// +// Note: this package of args is going to be passed by non-const pointer. +// +// 2) A function template that is able to perform Foo, given the args and +// arc type. Yours might look like this: +// +// template +// void Foo(FooArgs *args) { +// // Pulls out the actual, arc-templated FSTs. +// const Fst &ifst = std::get<0>(*args).GetFst(); +// MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); +// // Actually perform Foo on ifst and ofst. +// } +// +// 3) a client-facing function for your operation. This would look like +// the following: +// +// void Foo(const FstClass &ifst, MutableFstClass *ofst) { +// // Check that the arc types of the FSTs match +// if (!ArcTypesMatch(ifst, *ofst, "Foo")) return; +// // package the args +// FooArgs args(ifst, ofst); +// // Finally, call the operation +// Apply>("Foo", ifst->ArcType(), &args); +// } +// +// The Apply<> function template takes care of the link between 2 and 3, +// provided you also have: +// +// 4) A registration for your new operation, on the arc types you care about. +// This can be provided easily by the REGISTER_FST_OPERATION macro in +// operations.h: +// +// REGISTER_FST_OPERATION(Foo, StdArc, FooArgs); +// REGISTER_FST_OPERATION(Foo, MyArc, FooArgs); +// // .. etc +// +// +// That's it! Now when you call Foo(const FstClass &, MutableFstClass *), +// it dispatches (in #3) via the Apply<> function to the correct +// instantiation of the template function in #2. +// + +#ifndef FST_SCRIPT_SCRIPT_IMPL_H_ +#define FST_SCRIPT_SCRIPT_IMPL_H_ + +// This file contains general-purpose templates which are used in the +// implementation of the operations. + +#include +#include + +#include +#include + +#include + +namespace fst { +namespace script { + +enum RandArcSelection { + UNIFORM_ARC_SELECTOR, + LOG_PROB_ARC_SELECTOR, + FAST_LOG_PROB_ARC_SELECTOR +}; + +// A generic register for operations with various kinds of signatures. +// Needed since every function signature requires a new registration class. +// The std::pair is understood to be the operation name and arc +// type; subclasses (or typedefs) need only provide the operation signature. + +template +class GenericOperationRegister + : public GenericRegister, OperationSignature, + GenericOperationRegister> { + public: + void RegisterOperation(const string &operation_name, const string &arc_type, + OperationSignature op) { + this->SetEntry(std::make_pair(operation_name, arc_type), op); + } + + OperationSignature GetOperation(const string &operation_name, + const string &arc_type) { + return this->GetEntry(std::make_pair(operation_name, arc_type)); + } + + protected: + string ConvertKeyToSoFilename( + const std::pair &key) const final { + // Uses the old-style FST for now. + string legal_type(key.second); // The arc type. + ConvertToLegalCSymbol(&legal_type); + return legal_type + "-arc.so"; + } +}; + +// Operation package: everything you need to register a new type of operation. +// The ArgPack should be the type that's passed into each wrapped function; +// for instance, it might be a struct containing all the args. It's always +// passed by pointer, so const members should be used to enforce constness where +// it's needed. Return values should be implemented as a member of ArgPack as +// well. + +template +struct Operation { + using ArgPack = Args; + + using OpType = void (*)(ArgPack *args); + + // The register (hash) type. + using Register = GenericOperationRegister; + + // The register-er type + using Registerer = GenericRegisterer; +}; + +// Macro for registering new types of operations. + +#define REGISTER_FST_OPERATION(Op, Arc, ArgPack) \ + static fst::script::Operation::Registerer \ + arc_dispatched_operation_##ArgPack##Op##Arc##_registerer \ + (std::make_pair(#Op, Arc::Type()), Op) + +// Template function to apply an operation by name. + +template +void Apply(const string &op_name, const string &arc_type, + typename OpReg::ArgPack *args) { + const auto op = OpReg::Register::GetRegister()->GetOperation(op_name, + arc_type); + if (!op) { + FSTERROR() << "No operation found for " << op_name << " on " + << "arc type " << arc_type; + return; + } + op(args); +} + +namespace internal { + +// Helper that logs to ERROR if the arc types of m and n don't match, +// assuming that both m and n implement .ArcType(). The op_name argument is +// used to construct the error message. +template +bool ArcTypesMatch(const M &m, const N &n, const string &op_name) { + if (m.ArcType() != n.ArcType()) { + FSTERROR() << "Arguments with non-matching arc types passed to " + << op_name << ":\t" << m.ArcType() << " and " << n.ArcType(); + return false; + } + return true; +} + +// From untyped to typed weights. +template +void CopyWeights(const std::vector &weights, + std::vector *typed_weights) { + typed_weights->clear(); + typed_weights->reserve(weights.size()); + for (const auto &weight : weights) { + typed_weights->push_back(*weight.GetWeight()); + } +} + +// From typed to untyped weights. +template +void CopyWeights(const std::vector &typed_weights, + std::vector *weights) { + weights->clear(); + weights->reserve(typed_weights.size()); + for (const auto &typed_weight : typed_weights) { + weights->emplace_back(typed_weight); + } +} + +} // namespace internal +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SCRIPT_IMPL_H_ diff --git a/projects/llm_framework/include/fst/script/shortest-distance.h b/projects/llm_framework/include/fst/script/shortest-distance.h new file mode 100644 index 00000000..a44a6c9b --- /dev/null +++ b/projects/llm_framework/include/fst/script/shortest-distance.h @@ -0,0 +1,214 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_SHORTEST_DISTANCE_H_ +#define FST_SCRIPT_SHORTEST_DISTANCE_H_ + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace fst { +namespace script { + +enum ArcFilterType { + ANY_ARC_FILTER, + EPSILON_ARC_FILTER, + INPUT_EPSILON_ARC_FILTER, + OUTPUT_EPSILON_ARC_FILTER +}; + +struct ShortestDistanceOptions { + const QueueType queue_type; + const ArcFilterType arc_filter_type; + const int64 source; + const float delta; + + ShortestDistanceOptions(QueueType queue_type, ArcFilterType arc_filter_type, + int64 source, float delta) + : queue_type(queue_type), + arc_filter_type(arc_filter_type), + source(source), + delta(delta) {} +}; + +namespace internal { + +// Code to implement switching on queue and arc filter types. + +template +struct QueueConstructor { + using Weight = typename Arc::Weight; + + static Queue *Construct(const Fst &, const std::vector *) { + return new Queue(); + } +}; + +// Specializations to support queues with different constructors. + +template +struct QueueConstructor, ArcFilter> { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + // template + static AutoQueue *Construct(const Fst &fst, + const std::vector *distance) { + return new AutoQueue(fst, distance, ArcFilter()); + } +}; + +template +struct QueueConstructor< + Arc, NaturalShortestFirstQueue, + ArcFilter> { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + static NaturalShortestFirstQueue *Construct( + const Fst &, const std::vector *distance) { + return new NaturalShortestFirstQueue(*distance); + } +}; + +template +struct QueueConstructor, ArcFilter> { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + static TopOrderQueue *Construct(const Fst &fst, + const std::vector *) { + return new TopOrderQueue(fst, ArcFilter()); + } +}; + +template +void ShortestDistance(const Fst &fst, + std::vector *distance, + const ShortestDistanceOptions &opts) { + std::unique_ptr queue( + QueueConstructor::Construct(fst, distance)); + const fst::ShortestDistanceOptions sopts( + queue.get(), ArcFilter(), opts.source, opts.delta); + ShortestDistance(fst, distance, sopts); +} + +template +void ShortestDistance(const Fst &fst, + std::vector *distance, + const ShortestDistanceOptions &opts) { + switch (opts.arc_filter_type) { + case ANY_ARC_FILTER: { + ShortestDistance>(fst, distance, opts); + return; + } + case EPSILON_ARC_FILTER: { + ShortestDistance>(fst, distance, opts); + return; + } + case INPUT_EPSILON_ARC_FILTER: { + ShortestDistance>(fst, distance, + opts); + return; + } + case OUTPUT_EPSILON_ARC_FILTER: { + ShortestDistance>(fst, distance, + opts); + return; + } + default: { + FSTERROR() << "ShortestDistance: Unknown arc filter type: " + << opts.arc_filter_type; + distance->clear(); + distance->resize(1, Arc::Weight::NoWeight()); + return; + } + } +} + +} // namespace internal + +using ShortestDistanceArgs1 = + std::tuple *, + const ShortestDistanceOptions &>; + +template +void ShortestDistance(ShortestDistanceArgs1 *args) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + const Fst &fst = *(std::get<0>(*args).GetFst()); + const auto &opts = std::get<2>(*args); + std::vector typed_distance; + switch (opts.queue_type) { + case AUTO_QUEUE: { + internal::ShortestDistance>(fst, &typed_distance, + opts); + break; + } + case FIFO_QUEUE: { + internal::ShortestDistance>(fst, &typed_distance, + opts); + break; + } + case LIFO_QUEUE: { + internal::ShortestDistance>(fst, &typed_distance, + opts); + break; + } + case SHORTEST_FIRST_QUEUE: { + internal::ShortestDistance>( + fst, &typed_distance, opts); + break; + } + case STATE_ORDER_QUEUE: { + internal::ShortestDistance>( + fst, &typed_distance, opts); + break; + } + case TOP_ORDER_QUEUE: { + internal::ShortestDistance>( + fst, &typed_distance, opts); + break; + } + default: { + FSTERROR() << "ShortestDistance: Unknown queue type: " << opts.queue_type; + typed_distance.clear(); + typed_distance.resize(1, Arc::Weight::NoWeight()); + break; + } + } + internal::CopyWeights(typed_distance, std::get<1>(*args)); +} + +using ShortestDistanceArgs2 = + std::tuple *, bool, double>; + +template +void ShortestDistance(ShortestDistanceArgs2 *args) { + using Weight = typename Arc::Weight; + const Fst &fst = *(std::get<0>(*args).GetFst()); + std::vector typed_distance; + ShortestDistance(fst, &typed_distance, std::get<2>(*args), + std::get<3>(*args)); + internal::CopyWeights(typed_distance, std::get<1>(*args)); +} + +void ShortestDistance(const FstClass &fst, std::vector *distance, + const ShortestDistanceOptions &opts); + +void ShortestDistance(const FstClass &ifst, std::vector *distance, + bool reverse = false, + double delta = fst::kShortestDelta); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SHORTEST_DISTANCE_H_ diff --git a/projects/llm_framework/include/fst/script/shortest-path.h b/projects/llm_framework/include/fst/script/shortest-path.h new file mode 100644 index 00000000..86bc88da --- /dev/null +++ b/projects/llm_framework/include/fst/script/shortest-path.h @@ -0,0 +1,116 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_SHORTEST_PATH_H_ +#define FST_SCRIPT_SHORTEST_PATH_H_ + +#include +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +// Slightly simplified interface: `has_distance` and `first_path` are disabled. + +struct ShortestPathOptions : public ShortestDistanceOptions { + const int32 nshortest; + const bool unique; + const WeightClass &weight_threshold; + const int64 state_threshold; + + ShortestPathOptions(QueueType queue_type, int32 nshortest, bool unique, + float delta, const WeightClass &weight_threshold, + int64 state_threshold = kNoStateId) + : ShortestDistanceOptions(queue_type, ANY_ARC_FILTER, kNoStateId, delta), + nshortest(nshortest), + unique(unique), + weight_threshold(weight_threshold), + state_threshold(state_threshold) {} +}; + +namespace internal { + +// Code to implement switching on queue types. + +template +void ShortestPath(const Fst &ifst, MutableFst *ofst, + std::vector *distance, + const ShortestPathOptions &opts) { + using ArcFilter = AnyArcFilter; + using Weight = typename Arc::Weight; + const std::unique_ptr queue( + QueueConstructor::Construct(ifst, distance)); + const fst::ShortestPathOptions sopts( + queue.get(), ArcFilter(), opts.nshortest, opts.unique, + /* has_distance=*/false, opts.delta, /* first_path=*/false, + *opts.weight_threshold.GetWeight(), opts.state_threshold); + ShortestPath(ifst, ofst, distance, sopts); +} + +template +void ShortestPath(const Fst &ifst, MutableFst *ofst, + const ShortestPathOptions &opts) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + std::vector distance; + switch (opts.queue_type) { + case AUTO_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case FIFO_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case LIFO_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case SHORTEST_FIRST_QUEUE: { + ShortestPath>(ifst, ofst, + &distance, + opts); + return; + } + case STATE_ORDER_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + case TOP_ORDER_QUEUE: { + ShortestPath>(ifst, ofst, &distance, opts); + return; + } + default: { + FSTERROR() << "ShortestPath: Unknown queue type: " + << opts.queue_type; + ofst->SetProperties(kError, kError); + return; + } + } +} + +} // namespace internal + +using ShortestPathArgs = std::tuple; + +template +void ShortestPath(ShortestPathArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + const ShortestPathOptions &opts = std::get<2>(*args); + internal::ShortestPath(ifst, ofst, opts); +} + +void ShortestPath(const FstClass &ifst, MutableFstClass *ofst, + const ShortestPathOptions &opts); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SHORTEST_PATH_H_ diff --git a/projects/llm_framework/include/fst/script/stateiterator-class.h b/projects/llm_framework/include/fst/script/stateiterator-class.h new file mode 100644 index 00000000..f6fddfe6 --- /dev/null +++ b/projects/llm_framework/include/fst/script/stateiterator-class.h @@ -0,0 +1,85 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_STATEITERATOR_CLASS_H_ +#define FST_SCRIPT_STATEITERATOR_CLASS_H_ + +#include + +#include +#include + +// Scripting API support for StateIterator. + +namespace fst { +namespace script { + +// Virtual interface implemented by each concrete StateIteratorImpl. +class StateIteratorImplBase { + public: + virtual bool Done() const = 0; + virtual int64 Value() const = 0; + virtual void Next() = 0; + virtual void Reset() = 0; + virtual ~StateIteratorImplBase() {} +}; + +// Templated implementation. +template +class StateIteratorClassImpl : public StateIteratorImplBase { + public: + explicit StateIteratorClassImpl(const Fst &fst) : siter_(fst) {} + + bool Done() const final { return siter_.Done(); } + + int64 Value() const final { return siter_.Value(); } + + void Next() final { siter_.Next(); } + + void Reset() final { siter_.Reset(); } + + ~StateIteratorClassImpl() override {} + + private: + StateIterator> siter_; +}; + +class StateIteratorClass; + +using InitStateIteratorClassArgs = + std::pair; + +// Untemplated user-facing class holding a templated pimpl. +class StateIteratorClass { + public: + explicit StateIteratorClass(const FstClass &fst); + + template + explicit StateIteratorClass(const Fst &fst) + : impl_(new StateIteratorClassImpl(fst)) {} + + bool Done() const { return impl_->Done(); } + + int64 Value() const { return impl_->Value(); } + + void Next() { impl_->Next(); } + + void Reset() { impl_->Reset(); } + + template + friend void InitStateIteratorClass(InitStateIteratorClassArgs *args); + + private: + std::unique_ptr impl_; +}; + +template +void InitStateIteratorClass(InitStateIteratorClassArgs *args) { + const Fst &fst = *(std::get<0>(*args).GetFst()); + std::get<1>(*args)->impl_.reset(new StateIteratorClassImpl(fst)); +} + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_STATEITERATOR_CLASS_H_ diff --git a/projects/llm_framework/include/fst/script/synchronize.h b/projects/llm_framework/include/fst/script/synchronize.h new file mode 100644 index 00000000..01df151a --- /dev/null +++ b/projects/llm_framework/include/fst/script/synchronize.h @@ -0,0 +1,29 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_SYNCHRONIZE_H_ +#define FST_SCRIPT_SYNCHRONIZE_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using SynchronizeArgs = std::pair; + +template +void Synchronize(SynchronizeArgs *args) { + const Fst &ifst = *(std::get<0>(*args).GetFst()); + MutableFst *ofst = std::get<1>(*args)->GetMutableFst(); + Synchronize(ifst, ofst); +} + +void Synchronize(const FstClass &ifst, MutableFstClass *ofst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_SYNCHRONIZE_H_ diff --git a/projects/llm_framework/include/fst/script/text-io.h b/projects/llm_framework/include/fst/script/text-io.h new file mode 100644 index 00000000..464bf885 --- /dev/null +++ b/projects/llm_framework/include/fst/script/text-io.h @@ -0,0 +1,28 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Utilities for reading and writing textual strings representing states, +// labels, and weights and files specifying label-label pairs and potentials +// (state-weight pairs). + +#ifndef FST_SCRIPT_TEXT_IO_H__ +#define FST_SCRIPT_TEXT_IO_H__ + +#include +#include + +#include + +namespace fst { +namespace script { + +bool ReadPotentials(const string &weight_type, const string &filename, + std::vector *potentials); + +bool WritePotentials(const string &filename, + const std::vector &potentials); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_TEXT_IO_H__ diff --git a/projects/llm_framework/include/fst/script/topsort.h b/projects/llm_framework/include/fst/script/topsort.h new file mode 100644 index 00000000..fb6738d7 --- /dev/null +++ b/projects/llm_framework/include/fst/script/topsort.h @@ -0,0 +1,26 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_TOPSORT_H_ +#define FST_SCRIPT_TOPSORT_H_ + +#include +#include +#include + +namespace fst { +namespace script { + +using TopSortArgs = WithReturnValue; + +template +void TopSort(TopSortArgs *args) { + args->retval = TopSort(args->args->GetMutableFst()); +} + +bool TopSort(MutableFstClass *fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_TOPSORT_H_ diff --git a/projects/llm_framework/include/fst/script/union.h b/projects/llm_framework/include/fst/script/union.h new file mode 100644 index 00000000..9493e2b1 --- /dev/null +++ b/projects/llm_framework/include/fst/script/union.h @@ -0,0 +1,29 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_UNION_H_ +#define FST_SCRIPT_UNION_H_ + +#include + +#include +#include + +namespace fst { +namespace script { + +using UnionArgs = std::pair; + +template +void Union(UnionArgs *args) { + MutableFst *fst1 = std::get<0>(*args)->GetMutableFst(); + const Fst &fst2 = *(std::get<1>(*args).GetFst()); + Union(fst1, fst2); +} + +void Union(MutableFstClass *fst1, const FstClass &fst2); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_UNION_H_ diff --git a/projects/llm_framework/include/fst/script/verify.h b/projects/llm_framework/include/fst/script/verify.h new file mode 100644 index 00000000..52f58641 --- /dev/null +++ b/projects/llm_framework/include/fst/script/verify.h @@ -0,0 +1,27 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. + +#ifndef FST_SCRIPT_VERIFY_H_ +#define FST_SCRIPT_VERIFY_H_ + +#include +#include +#include + +namespace fst { +namespace script { + +using VerifyArgs = WithReturnValue; + +template +void Verify(VerifyArgs *args) { + const Fst &fst = *(args->args.GetFst()); + args->retval = Verify(fst); +} + +bool Verify(const FstClass &fst); + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_VERIFY_H_ diff --git a/projects/llm_framework/include/fst/script/weight-class.h b/projects/llm_framework/include/fst/script/weight-class.h new file mode 100644 index 00000000..6dadf92c --- /dev/null +++ b/projects/llm_framework/include/fst/script/weight-class.h @@ -0,0 +1,235 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Represents a generic weight in an FST; that is, represents a specific type +// of weight underneath while hiding that type from a client. + +#ifndef FST_SCRIPT_WEIGHT_CLASS_H_ +#define FST_SCRIPT_WEIGHT_CLASS_H_ + +#include +#include +#include + +#include +#include +#include +#include + +namespace fst { +namespace script { + +class WeightImplBase { + public: + virtual WeightImplBase *Copy() const = 0; + virtual void Print(std::ostream *o) const = 0; + virtual const string &Type() const = 0; + virtual string ToString() const = 0; + virtual bool Member() const = 0; + virtual bool operator==(const WeightImplBase &other) const = 0; + virtual bool operator!=(const WeightImplBase &other) const = 0; + virtual WeightImplBase &PlusEq(const WeightImplBase &other) = 0; + virtual WeightImplBase &TimesEq(const WeightImplBase &other) = 0; + virtual WeightImplBase &DivideEq(const WeightImplBase &other) = 0; + virtual WeightImplBase &PowerEq(size_t n) = 0; + virtual ~WeightImplBase() {} +}; + +template +class WeightClassImpl : public WeightImplBase { + public: + explicit WeightClassImpl(const W &weight) : weight_(weight) {} + + WeightClassImpl *Copy() const final { + return new WeightClassImpl(weight_); + } + + const string &Type() const final { return W::Type(); } + + void Print(std::ostream *ostrm) const final { *ostrm << weight_; } + + string ToString() const final { + string str; + WeightToStr(weight_, &str); + return str; + } + + bool Member() const final { return weight_.Member(); } + + bool operator==(const WeightImplBase &other) const final { + const auto *typed_other = static_cast *>(&other); + return weight_ == typed_other->weight_; + } + + bool operator!=(const WeightImplBase &other) const final { + return !(*this == other); + } + + WeightClassImpl &PlusEq(const WeightImplBase &other) final { + const auto *typed_other = static_cast *>(&other); + weight_ = Plus(weight_, typed_other->weight_); + return *this; + } + + WeightClassImpl &TimesEq(const WeightImplBase &other) final { + const auto *typed_other = static_cast *>(&other); + weight_ = Times(weight_, typed_other->weight_); + return *this; + } + + WeightClassImpl &DivideEq(const WeightImplBase &other) final { + const auto *typed_other = static_cast *>(&other); + weight_ = Divide(weight_, typed_other->weight_); + return *this; + } + + WeightClassImpl &PowerEq(size_t n) final { + weight_ = Power(weight_, n); + return *this; + } + + W *GetImpl() { return &weight_; } + + private: + W weight_; +}; + + +class WeightClass { + public: + WeightClass() = default; + + template + explicit WeightClass(const W &weight) + : impl_(new WeightClassImpl(weight)) {} + + template + explicit WeightClass(const WeightClassImpl &impl) + : impl_(new WeightClassImpl(impl)) {} + + WeightClass(const string &weight_type, const string &weight_str); + + WeightClass(const WeightClass &other) + : impl_(other.impl_ ? other.impl_->Copy() : nullptr) {} + + WeightClass &operator=(const WeightClass &other) { + impl_.reset(other.impl_ ? other.impl_->Copy() : nullptr); + return *this; + } + + static constexpr const char *__ZERO__ = "__ZERO__"; // NOLINT + + static WeightClass Zero(const string &weight_type); + + static constexpr const char *__ONE__ = "__ONE__"; // NOLINT + + static WeightClass One(const string &weight_type); + + static constexpr const char *__NOWEIGHT__ = "__NOWEIGHT__"; // NOLINT + + static WeightClass NoWeight(const string &weight_type); + + template + const W *GetWeight() const { + if (W::Type() != impl_->Type()) { + return nullptr; + } else { + auto *typed_impl = static_cast *>(impl_.get()); + return typed_impl->GetImpl(); + } + } + + string ToString() const { return (impl_) ? impl_->ToString() : "none"; } + + const string &Type() const { + if (impl_) return impl_->Type(); + static const string *const no_type = new string("none"); + return *no_type; + } + + bool Member() const { return impl_ && impl_->Member(); } + + bool WeightTypesMatch(const WeightClass &other, const string &op_name) const; + + friend bool operator==(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Times(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs); + + friend WeightClass Power(const WeightClass &w, size_t n); + + private: + const WeightImplBase *GetImpl() const { return impl_.get(); } + + WeightImplBase *GetImpl() { return impl_.get(); } + + std::unique_ptr impl_; + + friend std::ostream &operator<<(std::ostream &o, const WeightClass &c); +}; + +bool operator==(const WeightClass &lhs, const WeightClass &rhs); + +bool operator!=(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Plus(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Times(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Divide(const WeightClass &lhs, const WeightClass &rhs); + +WeightClass Power(const WeightClass &w, size_t n); + +std::ostream &operator<<(std::ostream &o, const WeightClass &c); + +// Registration for generic weight types. + +using StrToWeightImplBaseT = WeightImplBase *(*)(const string &str, + const string &src, + size_t nline); + +template +WeightImplBase *StrToWeightImplBase(const string &str, const string &src, + size_t nline) { + if (str == WeightClass::__ZERO__) + return new WeightClassImpl(W::Zero()); + else if (str == WeightClass::__ONE__) + return new WeightClassImpl(W::One()); + else if (str == WeightClass::__NOWEIGHT__) + return new WeightClassImpl(W::NoWeight()); + return new WeightClassImpl(StrToWeight(str, src, nline)); +} + +class WeightClassRegister : public GenericRegister { + protected: + string ConvertKeyToSoFilename(const string &key) const final { + string legal_type(key); + ConvertToLegalCSymbol(&legal_type); + return legal_type + ".so"; + } +}; + +using WeightClassRegisterer = GenericRegisterer; + +// Internal version; needs to be called by wrapper in order for macro args to +// expand. +#define REGISTER_FST_WEIGHT__(Weight, line) \ + static WeightClassRegisterer weight_registerer##_##line( \ + Weight::Type(), StrToWeightImplBase) + +// This layer is where __FILE__ and __LINE__ are expanded. +#define REGISTER_FST_WEIGHT_EXPANDER(Weight, line) \ + REGISTER_FST_WEIGHT__(Weight, line) + +// Macro for registering new weight types; clients call this. +#define REGISTER_FST_WEIGHT(Weight) \ + REGISTER_FST_WEIGHT_EXPANDER(Weight, __LINE__) + +} // namespace script +} // namespace fst + +#endif // FST_SCRIPT_WEIGHT_CLASS_H_ diff --git a/projects/llm_framework/include/fst/set-weight.h b/projects/llm_framework/include/fst/set-weight.h new file mode 100644 index 00000000..dd665f0d --- /dev/null +++ b/projects/llm_framework/include/fst/set-weight.h @@ -0,0 +1,618 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Weights consisting of sets (of integral Labels) and +// associated semiring operation definitions using intersect +// and union. + +#ifndef FST_SET_WEIGHT_H_ +#define FST_SET_WEIGHT_H_ + +#include + +#include +#include +#include +#include + +#include +#include + + +namespace fst { + +constexpr int kSetEmpty = 0; // Label for the empty set. +constexpr int kSetUniv = -1; // Label for the universal set. +constexpr int kSetBad = -2; // Label for a non-set. +constexpr char kSetSeparator = '_'; // Label separator in sets. + +// Determines whether to use (intersect, union) or (union, intersect) +// as (+, *) for the semiring. SET_INTERSECT_UNION_RESTRICTED is a +// restricted version of (intersect, union) that requires summed +// arguments to be equal (or an error is signalled), useful for +// algorithms that require a unique labelled path weight. SET_BOOLEAN +// treats all non-Zero() elements as equivalent (with Zero() == +// UnivSet()), useful for algorithms that don't really depend on the +// detailed sets. +enum SetType { SET_INTERSECT_UNION = 0, + SET_UNION_INTERSECT = 1, + SET_INTERSECT_UNION_RESTRICT = 2, + SET_BOOLEAN = 3 }; + +template +class SetWeightIterator; + +// Set semiring of integral labels. +template +class SetWeight { + public: + using Label = Label_; + using ReverseWeight = SetWeight; + using Iterator = SetWeightIterator; + friend class SetWeightIterator; + // Allow type-converting copy and move constructors private access. + template + friend class SetWeight; + + SetWeight() {} + + // Input should be positive, sorted and unique. + template + SetWeight(const Iterator &begin, const Iterator &end) { + for (auto iter = begin; iter != end; ++iter) PushBack(*iter); + } + + // Input should be positive. (Non-positive value has + // special internal meaning w.r.t. integral constants above.) + explicit SetWeight(Label label) { PushBack(label); } + + template + explicit SetWeight(const SetWeight &w) + : first_(w.first_), rest_(w.rest_) {} + + template + explicit SetWeight(SetWeight &&w) + : first_(w.first_), rest_(std::move(w.rest_)) { w.Clear(); } + + template + SetWeight &operator=(const SetWeight &w) { + first_ = w.first_; + rest_ = w.rest_; + return *this; + } + + template + SetWeight &operator=(SetWeight &&w) { + first_ = w.first_; + rest_ = std::move(w.rest_); + w.Clear(); + return *this; + } + + static const SetWeight &Zero() { + return S == SET_UNION_INTERSECT ? EmptySet() : UnivSet(); + } + + static const SetWeight &One() { + return S == SET_UNION_INTERSECT ? UnivSet() : EmptySet(); + } + + static const SetWeight &NoWeight() { + static const auto *const no_weight = new SetWeight(Label(kSetBad)); + return *no_weight; + } + + static const string &Type() { + static const string *const type = new string( + S == SET_UNION_INTERSECT + ? "union_intersect_set" + : (S == SET_INTERSECT_UNION + ? "intersect_union_set" + : (S == SET_INTERSECT_UNION_RESTRICT + ? "restricted_set_intersect_union" + : "boolean_set"))); + return *type; + } + + bool Member() const; + + std::istream &Read(std::istream &strm); + + std::ostream &Write(std::ostream &strm) const; + + size_t Hash() const; + + SetWeight Quantize(float delta = kDelta) const { return *this; } + + ReverseWeight Reverse() const; + + static constexpr uint64 Properties() { + return kIdempotent | kLeftSemiring | kRightSemiring | kCommutative; + } + + // These operations combined with the SetWeightIterator + // provide the access and mutation of the set internal elements. + + // The empty set. + static const SetWeight &EmptySet() { + static const auto *const empty = new SetWeight(Label(kSetEmpty)); + return *empty; + } + + // The univeral set. + static const SetWeight &UnivSet() { + static const auto *const univ = new SetWeight(Label(kSetUniv)); + return *univ; + } + + // Clear existing SetWeight. + void Clear() { + first_ = kSetEmpty; + rest_.clear(); + } + + size_t Size() const { return first_ == kSetEmpty ? 0 : rest_.size() + 1; } + + Label Back() { + if (rest_.empty()) { + return first_; + } else { + return rest_.back(); + } + } + + // Caller must add in sort order and be unique (or error signalled). + // Input should also be positive. Non-positive value (for the first + // push) has special internal meaning w.r.t. integral constants above. + void PushBack(Label label) { + if (first_ == kSetEmpty) { + first_ = label; + } else { + if (label <= Back() || label <= 0) { + FSTERROR() << "SetWeight: labels must be positive, added" + << " in sort order and be unique."; + rest_.push_back(Label(kSetBad)); + } + rest_.push_back(label); + } + } + + private: + Label first_ = kSetEmpty; // First label in set (kSetEmpty if empty). + std::list &fst); +// +// // Required copy constructor that allows updating FST argument; +// // pass only if relevant and changed. +// StateMapper(const StateMapper &mapper, const Fst *fst = 0); +// +// // Specifies initial state of result. +// B::StateId Start() const; +// // Specifies state's final weight in result. +// B::Weight Final(B::StateId state) const; +// +// // These methods iterate through a state's arcs in result. +// +// // Specifies state to iterate over. +// void SetState(B::StateId state); +// +// // End of arcs? +// bool Done() const; +// +// // Current arc. +// const B &Value() const; +// +// // Advances to next arc (when !Done) +// void Next(); +// +// // Specifies input symbol table action the mapper requires (see above). +// MapSymbolsAction InputSymbolsAction() const; +// +// // Specifies output symbol table action the mapper requires (see above). +// MapSymbolsAction OutputSymbolsAction() const; +// +// // This specifies the known properties of an FST mapped by this +// // mapper. It takes as argument the input FST's known properties. +// uint64 Properties(uint64 props) const; +// }; +// +// We include a various state map versions below. One dimension of variation is +// whether the mapping mutates its input, writes to a new result FST, or is an +// on-the-fly Fst. Another dimension is how we pass the mapper. We allow passing +// the mapper by pointer for cases that we need to change the state of the +// user's mapper. We also include map versions that pass the mapper by value or +// const reference when this suffices. + +// Maps an arc type A using a mapper function object C, passed by pointer. This +// version modifies the input FST. +template +void StateMap(MutableFst *fst, C *mapper) { + if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + fst->SetOutputSymbols(nullptr); + } + if (fst->Start() == kNoStateId) return; + const auto props = fst->Properties(kFstProperties, false); + fst->SetStart(mapper->Start()); + for (StateIterator> siter(*fst); !siter.Done(); siter.Next()) { + const auto state = siter.Value(); + mapper->SetState(state); + fst->DeleteArcs(state); + for (; !mapper->Done(); mapper->Next()) { + fst->AddArc(state, mapper->Value()); + } + fst->SetFinal(state, mapper->Final(state)); + } + fst->SetProperties(mapper->Properties(props), kFstProperties); +} + +// Maps an arc type A using a mapper function object C, passed by value. +// This version modifies the input FST. +template +void StateMap(MutableFst *fst, C mapper) { + StateMap(fst, &mapper); +} + +// Maps an arc type A to an arc type B using mapper functor C, passed by +// pointer. This version writes to an output FST. +template +void StateMap(const Fst &ifst, MutableFst *ofst, C *mapper) { + ofst->DeleteStates(); + if (mapper->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetInputSymbols(ifst.InputSymbols()); + } else if (mapper->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetInputSymbols(nullptr); + } + if (mapper->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + ofst->SetOutputSymbols(ifst.OutputSymbols()); + } else if (mapper->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + ofst->SetOutputSymbols(nullptr); + } + const auto iprops = ifst.Properties(kCopyProperties, false); + if (ifst.Start() == kNoStateId) { + if (iprops & kError) ofst->SetProperties(kError, kError); + return; + } + // Adds all states. + if (ifst.Properties(kExpanded, false)) ofst->ReserveStates(CountStates(ifst)); + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + ofst->AddState(); + } + ofst->SetStart(mapper->Start()); + for (StateIterator> siter(ifst); !siter.Done(); siter.Next()) { + const auto state = siter.Value(); + mapper->SetState(state); + for (; !mapper->Done(); mapper->Next()) { + ofst->AddArc(state, mapper->Value()); + } + ofst->SetFinal(state, mapper->Final(state)); + } + const auto oprops = ofst->Properties(kFstProperties, false); + ofst->SetProperties(mapper->Properties(iprops) | oprops, kFstProperties); +} + +// Maps an arc type A to an arc type B using mapper functor object C, passed by +// value. This version writes to an output FST. +template +void StateMap(const Fst &ifst, MutableFst *ofst, C mapper) { + StateMap(ifst, ofst, &mapper); +} + +using StateMapFstOptions = CacheOptions; + +template +class StateMapFst; + +// Facade around StateIteratorBase inheriting from StateIteratorBase. +template +class StateMapStateIteratorBase : public StateIteratorBase { + public: + using Arc = B; + using StateId = typename Arc::StateId; + + explicit StateMapStateIteratorBase(StateIteratorBase *base) + : base_(base) {} + + bool Done() const final { return base_->Done(); } + + StateId Value() const final { return base_->Value(); } + + void Next() final { base_->Next(); } + + void Reset() final { base_->Reset(); } + + private: + std::unique_ptr> base_; + + StateMapStateIteratorBase() = delete; +}; + +namespace internal { + +// Implementation of delayed StateMapFst. +template +class StateMapFstImpl : public CacheImpl { + public: + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + + using CacheImpl::PushArc; + using CacheImpl::HasArcs; + using CacheImpl::HasFinal; + using CacheImpl::HasStart; + using CacheImpl::SetArcs; + using CacheImpl::SetFinal; + using CacheImpl::SetStart; + + friend class StateIterator>; + + StateMapFstImpl(const Fst &fst, const C &mapper, + const StateMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(new C(mapper, fst_.get())), + own_mapper_(true) { + Init(); + } + + StateMapFstImpl(const Fst &fst, C *mapper, const StateMapFstOptions &opts) + : CacheImpl(opts), + fst_(fst.Copy()), + mapper_(mapper), + own_mapper_(false) { + Init(); + } + + StateMapFstImpl(const StateMapFstImpl &impl) + : CacheImpl(impl), + fst_(impl.fst_->Copy(true)), + mapper_(new C(*impl.mapper_, fst_.get())), + own_mapper_(true) { + Init(); + } + + ~StateMapFstImpl() override { + if (own_mapper_) delete mapper_; + } + + StateId Start() { + if (!HasStart()) SetStart(mapper_->Start()); + return CacheImpl::Start(); + } + + Weight Final(StateId state) { + if (!HasFinal(state)) SetFinal(state, mapper_->Final(state)); + return CacheImpl::Final(state); + } + + size_t NumArcs(StateId state) { + if (!HasArcs(state)) Expand(state); + return CacheImpl::NumArcs(state); + } + + size_t NumInputEpsilons(StateId state) { + if (!HasArcs(state)) Expand(state); + return CacheImpl::NumInputEpsilons(state); + } + + size_t NumOutputEpsilons(StateId state) { + if (!HasArcs(state)) Expand(state); + return CacheImpl::NumOutputEpsilons(state); + } + + void InitStateIterator(StateIteratorData *datb) const { + StateIteratorData data; + fst_->InitStateIterator(&data); + datb->base = data.base ? new StateMapStateIteratorBase(data.base) + : nullptr; + datb->nstates = data.nstates; + } + + void InitArcIterator(StateId state, ArcIteratorData *data) { + if (!HasArcs(state)) Expand(state); + CacheImpl::InitArcIterator(state, data); + } + + uint64 Properties() const override { return Properties(kFstProperties); } + + uint64 Properties(uint64 mask) const override { + if ((mask & kError) && (fst_->Properties(kError, false) || + (mapper_->Properties(0) & kError))) { + SetProperties(kError, kError); + } + return FstImpl::Properties(mask); + } + + void Expand(StateId state) { + // Adds exiting arcs. + for (mapper_->SetState(state); !mapper_->Done(); mapper_->Next()) { + PushArc(state, mapper_->Value()); + } + SetArcs(state); + } + + const Fst *GetFst() const { return fst_.get(); } + + private: + void Init() { + SetType("statemap"); + if (mapper_->InputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetInputSymbols(fst_->InputSymbols()); + } else if (mapper_->InputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetInputSymbols(nullptr); + } + if (mapper_->OutputSymbolsAction() == MAP_COPY_SYMBOLS) { + SetOutputSymbols(fst_->OutputSymbols()); + } else if (mapper_->OutputSymbolsAction() == MAP_CLEAR_SYMBOLS) { + SetOutputSymbols(nullptr); + } + const auto props = fst_->Properties(kCopyProperties, false); + SetProperties(mapper_->Properties(props)); + } + + std::unique_ptr> fst_; + C *mapper_; + bool own_mapper_; +}; + +} // namespace internal + +// Maps an arc type A to an arc type B using Mapper function object +// C. This version is a delayed FST. +template +class StateMapFst : public ImplToFst> { + public: + friend class ArcIterator>; + + using Arc = B; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using Store = DefaultCacheStore; + using State = typename Store::State; + using Impl = internal::StateMapFstImpl; + + StateMapFst(const Fst &fst, const C &mapper, + const StateMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + StateMapFst(const Fst &fst, C *mapper, const StateMapFstOptions &opts) + : ImplToFst(std::make_shared(fst, mapper, opts)) {} + + StateMapFst(const Fst &fst, const C &mapper) + : ImplToFst( + std::make_shared(fst, mapper, StateMapFstOptions())) {} + + StateMapFst(const Fst &fst, C *mapper) + : ImplToFst( + std::make_shared(fst, mapper, StateMapFstOptions())) {} + + // See Fst<>::Copy() for doc. + StateMapFst(const StateMapFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Get a copy of this StateMapFst. See Fst<>::Copy() for further doc. + StateMapFst *Copy(bool safe = false) const override { + return new StateMapFst(*this, safe); + } + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId state, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(state, data); + } + + protected: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + private: + StateMapFst &operator=(const StateMapFst &) = delete; +}; + +// Specialization for StateMapFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename A::StateId; + + ArcIterator(const StateMapFst &fst, StateId state) + : CacheArcIterator>(fst.GetMutableImpl(), state) { + if (!fst.GetImpl()->HasArcs(state)) fst.GetMutableImpl()->Expand(state); + } +}; + +// Utility mappers. + +// Mapper that returns its input. +template +class IdentityStateMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit IdentityStateMapper(const Fst &fst) : fst_(fst) {} + + // Allows updating FST argument; pass only if changed. + IdentityStateMapper(const IdentityStateMapper &mapper, + const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_) {} + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId state) const { return fst_.Final(state); } + + void SetState(StateId state) { + aiter_.reset(new ArcIterator>(fst_, state)); + } + + bool Done() const { return aiter_->Done(); } + + const Arc &Value() const { return aiter_->Value(); } + + void Next() { aiter_->Next(); } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { return props; } + + private: + const Fst &fst_; + std::unique_ptr>> aiter_; +}; + +template +class ArcSumMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit ArcSumMapper(const Fst &fst) : fst_(fst), i_(0) {} + + // Allows updating FST argument; pass only if changed. + ArcSumMapper(const ArcSumMapper &mapper, const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_), i_(0) {} + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId state) const { return fst_.Final(state); } + + void SetState(StateId state) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(state)); + for (ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + arcs_.push_back(aiter.Value()); + } + // First sorts the exiting arcs by input label, output label and destination + // state and then sums weights of arcs with the same input label, output + // label, and destination state. + std::sort(arcs_.begin(), arcs_.end(), comp_); + size_t narcs = 0; + for (const auto &arc : arcs_) { + if (narcs > 0 && equal_(arc, arcs_[narcs - 1])) { + arcs_[narcs - 1].weight = Plus(arcs_[narcs - 1].weight, arc.weight); + } else { + arcs_[narcs] = arc; + ++narcs; + } + } + arcs_.resize(narcs); + } + + bool Done() const { return i_ >= arcs_.size(); } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return props & kArcSortProperties & kDeleteArcsProperties & + kWeightInvariantProperties; + } + + private: + struct Compare { + bool operator()(const Arc &x, const Arc &y) const { + if (x.ilabel < y.ilabel) return true; + if (x.ilabel > y.ilabel) return false; + if (x.olabel < y.olabel) return true; + if (x.olabel > y.olabel) return false; + if (x.nextstate < y.nextstate) return true; + if (x.nextstate > y.nextstate) return false; + return false; + } + }; + + struct Equal { + bool operator()(const Arc &x, const Arc &y) const { + return (x.ilabel == y.ilabel && x.olabel == y.olabel && + x.nextstate == y.nextstate); + } + }; + + const Fst &fst_; + Compare comp_; + Equal equal_; + std::vector arcs_; + ssize_t i_; // Current arc position. + + ArcSumMapper &operator=(const ArcSumMapper &) = delete; +}; + +template +class ArcUniqueMapper { + public: + using FromArc = Arc; + using ToArc = Arc; + + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + explicit ArcUniqueMapper(const Fst &fst) : fst_(fst), i_(0) {} + + // Allows updating FST argument; pass only if changed. + ArcUniqueMapper(const ArcUniqueMapper &mapper, + const Fst *fst = nullptr) + : fst_(fst ? *fst : mapper.fst_), i_(0) {} + + StateId Start() const { return fst_.Start(); } + + Weight Final(StateId state) const { return fst_.Final(state); } + + void SetState(StateId state) { + i_ = 0; + arcs_.clear(); + arcs_.reserve(fst_.NumArcs(state)); + for (ArcIterator> aiter(fst_, state); !aiter.Done(); + aiter.Next()) { + arcs_.push_back(aiter.Value()); + } + // First sorts the exiting arcs by input label, output label and destination + // state and then uniques identical arcs. + std::sort(arcs_.begin(), arcs_.end(), comp_); + arcs_.erase(std::unique(arcs_.begin(), arcs_.end(), equal_), arcs_.end()); + } + + bool Done() const { return i_ >= arcs_.size(); } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + constexpr MapSymbolsAction InputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + constexpr MapSymbolsAction OutputSymbolsAction() const { + return MAP_COPY_SYMBOLS; + } + + uint64 Properties(uint64 props) const { + return props & kArcSortProperties & kDeleteArcsProperties; + } + + private: + struct Compare { + bool operator()(const Arc &x, const Arc &y) const { + if (x.ilabel < y.ilabel) return true; + if (x.ilabel > y.ilabel) return false; + if (x.olabel < y.olabel) return true; + if (x.olabel > y.olabel) return false; + if (x.nextstate < y.nextstate) return true; + if (x.nextstate > y.nextstate) return false; + return false; + } + }; + + struct Equal { + bool operator()(const Arc &x, const Arc &y) const { + return (x.ilabel == y.ilabel && x.olabel == y.olabel && + x.nextstate == y.nextstate && x.weight == y.weight); + } + }; + + const Fst &fst_; + Compare comp_; + Equal equal_; + std::vector arcs_; + size_t i_; // Current arc position. + + ArcUniqueMapper &operator=(const ArcUniqueMapper &) = delete; +}; + +// Useful aliases when using StdArc. + +using StdArcSumMapper = ArcSumMapper; + +using StdArcUniqueMapper = ArcUniqueMapper; + +} // namespace fst + +#endif // FST_STATE_MAP_H_ diff --git a/projects/llm_framework/include/fst/state-reachable.h b/projects/llm_framework/include/fst/state-reachable.h new file mode 100644 index 00000000..36b5559a --- /dev/null +++ b/projects/llm_framework/include/fst/state-reachable.h @@ -0,0 +1,224 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Class to determine whether a given (final) state can be reached from some +// other given state. + +#ifndef FST_STATE_REACHABLE_H_ +#define FST_STATE_REACHABLE_H_ + +#include + +#include + +#include +#include +#include +#include +#include + + +namespace fst { + +// Computes the (final) states reachable from a given state in an FST. After +// this visitor has been called, a final state f can be reached from a state +// s iff (*isets)[s].Member(state2index[f]) is true, where (*isets[s]) is a +// set of half-open inteval of final state indices and state2index[f] maps from +// a final state to its index. If state2index is empty, it is filled-in with +// suitable indices. If it is non-empty, those indices are used; in this case, +// the final states must have out-degree 0. +template > +class IntervalReachVisitor { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Index = I; + using ISet = S; + using Interval = typename ISet::Interval; + + IntervalReachVisitor(const Fst &fst, std::vector *isets, + std::vector *state2index) + : fst_(fst), + isets_(isets), + state2index_(state2index), + index_(state2index->empty() ? 1 : -1), + error_(false) { + isets_->clear(); + } + + void InitVisit(const Fst &) { error_ = false; } + + bool InitState(StateId s, StateId r) { + while (isets_->size() <= s) isets_->push_back(S()); + while (state2index_->size() <= s) state2index_->push_back(-1); + if (fst_.Final(s) != Weight::Zero()) { + // Create tree interval. + auto *intervals = (*isets_)[s].MutableIntervals(); + if (index_ < 0) { // Uses state2index_ map to set index. + if (fst_.NumArcs(s) > 0) { + FSTERROR() << "IntervalReachVisitor: state2index map must be empty " + << "for this FST"; + error_ = true; + return false; + } + const auto index = (*state2index_)[s]; + if (index < 0) { + FSTERROR() << "IntervalReachVisitor: state2index map incomplete"; + error_ = true; + return false; + } + intervals->push_back(Interval(index, index + 1)); + } else { // Use pre-order index. + intervals->push_back(Interval(index_, index_ + 1)); + (*state2index_)[s] = index_++; + } + } + return true; + } + + constexpr bool TreeArc(StateId, const Arc &) const { return true; } + + bool BackArc(StateId s, const Arc &arc) { + FSTERROR() << "IntervalReachVisitor: Cyclic input"; + error_ = true; + return false; + } + + bool ForwardOrCrossArc(StateId s, const Arc &arc) { + // Non-tree interval. + (*isets_)[s].Union((*isets_)[arc.nextstate]); + return true; + } + + void FinishState(StateId s, StateId p, const Arc *) { + if (index_ >= 0 && fst_.Final(s) != Weight::Zero()) { + auto *intervals = (*isets_)[s].MutableIntervals(); + (*intervals)[0].end = index_; // Updates tree interval end. + } + (*isets_)[s].Normalize(); + if (p != kNoStateId) { + (*isets_)[p].Union((*isets_)[s]); // Propagates intervals to parent. + } + } + + void FinishVisit() {} + + bool Error() const { return error_; } + + private: + const Fst &fst_; + std::vector *isets_; + std::vector *state2index_; + Index index_; + bool error_; +}; + +// Tests reachability of final states from a given state. To test for +// reachability from a state s, first do SetState(s). Then a final state f can +// be reached from state s of FST iff Reach(f) is true. The input can be cyclic, +// but no cycle may contain a final state. +template > +class StateReachable { + public: + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using Index = I; + using ISet = S; + using Interval = typename ISet::Interval; + + explicit StateReachable(const Fst &fst) : error_(false) { + if (fst.Properties(kAcyclic, true)) { + AcyclicStateReachable(fst); + } else { + CyclicStateReachable(fst); + } + } + + explicit StateReachable(const StateReachable &reachable) { + FSTERROR() << "Copy constructor for state reachable class " + << "not implemented."; + error_ = true; + } + + // Sets current state. + void SetState(StateId s) { s_ = s; } + + // Can reach this final state from current state? + bool Reach(StateId s) { + if (s >= state2index_.size()) return false; + const auto i = state2index_[s]; + if (i < 0) { + FSTERROR() << "StateReachable: State non-final: " << s; + error_ = true; + return false; + } + return isets_[s_].Member(i); + } + + // Access to the state-to-index mapping. Unassigned states have index -1. + std::vector &State2Index() { return state2index_; } + + // Access to the interval sets. These specify the reachability to the final + // states as intervals of the final state indices. + const std::vector &IntervalSets() { return isets_; } + + bool Error() const { return error_; } + + private: + void AcyclicStateReachable(const Fst &fst) { + IntervalReachVisitor reach_visitor(fst, &isets_, + &state2index_); + DfsVisit(fst, &reach_visitor); + if (reach_visitor.Error()) error_ = true; + } + + void CyclicStateReachable(const Fst &fst) { + // Finds state reachability on the acyclic condensation FST. + VectorFst cfst; + std::vector scc; + Condense(fst, &cfst, &scc); + StateReachable reachable(cfst); + if (reachable.Error()) { + error_ = true; + return; + } + // Gets the number of states per SCC. + std::vector nscc; + for (StateId s = 0; s < scc.size(); ++s) { + const auto c = scc[s]; + while (c >= nscc.size()) nscc.push_back(0); + ++nscc[c]; + } + // Constructs the interval sets and state index mapping for the original + // FST from the condensation FST. + state2index_.resize(scc.size(), -1); + isets_.resize(scc.size()); + for (StateId s = 0; s < scc.size(); ++s) { + const auto c = scc[s]; + isets_[s] = reachable.IntervalSets()[c]; + state2index_[s] = reachable.State2Index()[c]; + // Checks that each final state in an input FST is not contained in a + // cycle (i.e., not in a non-trivial SCC). + if (cfst.Final(c) != Weight::Zero() && nscc[c] > 1) { + FSTERROR() << "StateReachable: Final state contained in a cycle"; + error_ = true; + return; + } + } + } + + StateId s_; // Current state. + std::vector isets_; // Interval sets per state. + std::vector state2index_; // Finds index for a final state. + bool error_; + + StateReachable &operator=(const StateReachable &) = delete; +}; + +} // namespace fst + +#endif // FST_STATE_REACHABLE_H_ diff --git a/projects/llm_framework/include/fst/state-table.h b/projects/llm_framework/include/fst/state-table.h new file mode 100644 index 00000000..a5067592 --- /dev/null +++ b/projects/llm_framework/include/fst/state-table.h @@ -0,0 +1,494 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Classes for representing the mapping between state tuples and state IDs. + +#ifndef FST_STATE_TABLE_H_ +#define FST_STATE_TABLE_H_ + +#include +#include +#include + +#include + +#include +#include +#include + + +namespace fst { + +// State tables determine the bijective mapping between state tuples (e.g., in +// composition, triples of two FST states and a composition filter state) and +// their corresponding state IDs. They are classes, templated on state tuples, +// with the following interface: +// +// template +// class StateTable { +// public: +// using StateTuple = T; +// +// // Required constructors. +// StateTable(); +// +// StateTable(const StateTable &); +// +// // Looks up state ID by tuple. If it doesn't exist, then add it. +// StateId FindState(const StateTuple &tuple); +// +// // Looks up state tuple by state ID. +// const StateTuple &Tuple(StateId s) const; +// +// // # of stored tuples. +// StateId Size() const; +// }; +// +// A state tuple has the form: +// +// template +// struct StateTuple { +// using StateId = S; +// +// // Required constructors. +// +// StateTuple(); +// +// StateTuple(const StateTuple &tuple); +// }; + +// An implementation using a hash map for the tuple to state ID mapping. The +// state tuple T must support operator==. +template +class HashStateTable : public HashBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using HashBiTable::FindId; + using HashBiTable::FindEntry; + using HashBiTable::Size; + + HashStateTable() : HashBiTable() {} + + explicit HashStateTable(size_t table_size) + : HashBiTable(table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a hash map for the tuple to state ID mapping. The +// state tuple T must support operator==. +template +class CompactHashStateTable + : public CompactHashBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using CompactHashBiTable::FindId; + using CompactHashBiTable::FindEntry; + using CompactHashBiTable::Size; + + CompactHashStateTable() : CompactHashBiTable() {} + + explicit CompactHashStateTable(size_t table_size) + : CompactHashBiTable(table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a vector for the tuple to state mapping. It is +// passed a fingerprint functor that should fingerprint tuples uniquely to an +// integer that can used as a vector index. Normally, VectorStateTable +// constructs the fingerprint functor. Alternately, the user can pass this +// object, in which case the table takes ownership. +template +class VectorStateTable : public VectorBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using VectorBiTable::FindId; + using VectorBiTable::FindEntry; + using VectorBiTable::Size; + using VectorBiTable::Fingerprint; + + explicit VectorStateTable(FP *fingerprint = nullptr, size_t table_size = 0) + : VectorBiTable(fingerprint, table_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a vector and a compact hash table. The selection +// functor returns true for tuples to be hashed in the vector. The fingerprint +// functor should fingerprint tuples uniquely to an integer that can be used as +// a vector index. A hash functor is used when hashing tuples into the compact +// hash table. +template +class VectorHashStateTable + : public VectorHashBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using VectorHashBiTable::FindId; + using VectorHashBiTable::FindEntry; + using VectorHashBiTable::Size; + using VectorHashBiTable::Selector; + using VectorHashBiTable::Fingerprint; + using VectorHashBiTable::Hash; + + VectorHashStateTable(Select *select, FP *fingerprint, H *hash, + size_t vector_size = 0, size_t tuple_size = 0) + : VectorHashBiTable( + select, fingerprint, hash, vector_size, tuple_size) {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// An implementation using a hash map to map from tuples to state IDs. This +// version permits erasing of states. The state tuple's default constructor +// must produce a tuple that will never be seen and the table must suppor +// operator==. +template +class ErasableStateTable : public ErasableBiTable { + public: + using StateTuple = T; + using StateId = typename StateTuple::StateId; + + using ErasableBiTable::FindId; + using ErasableBiTable::FindEntry; + using ErasableBiTable::Size; + using ErasableBiTable::Erase; + + ErasableStateTable() : ErasableBiTable() {} + + StateId FindState(const StateTuple &tuple) { return FindId(tuple); } + + const StateTuple &Tuple(StateId s) const { return FindEntry(s); } +}; + +// The composition state table has the form: +// +// template +// class ComposeStateTable { +// public: +// using StateId = typename Arc::StateId; +// +// // Required constructors. +// +// ComposeStateTable(const Fst &fst1, const Fst &fst2); +// ComposeStateTable(const ComposeStateTable &table); +// +// // Looks up a state ID by tuple, adding it if doesn't exist. +// StateId FindState(const StateTuple &tuple); +// +// // Looks up a tuple by state ID. +// const ComposeStateTuple &Tuple(StateId s) const; +// +// // The number of of stored tuples. +// StateId Size() const; +// +// // Return true if error was encountered. +// bool Error() const; +// }; +// +// The following interface is used to represent the composition state. +// +// template +// class CompositionStateTuple { +// public: +// using StateId = typename StateId; +// using FS = FilterState; +// +// // Required constructors. +// StateTuple(); +// StateTuple(StateId s1, StateId s2, const FilterState &fs); +// +// StateId StateId1() const; +// StateId StateId2() const; +// +// FilterState GetFilterState() const; +// +// std::pair StatePair() const; +// +// size_t Hash() const; +// +// friend bool operator==(const StateTuple& x, const StateTuple &y); +// } +// +template +class DefaultComposeStateTuple { + public: + using StateId = S; + using FilterState = FS; + + DefaultComposeStateTuple() + : state_pair_(kNoStateId, kNoStateId), fs_(FilterState::NoState()) {} + + DefaultComposeStateTuple(StateId s1, StateId s2, const FilterState &fs) + : state_pair_(s1, s2), fs_(fs) {} + + StateId StateId1() const { return state_pair_.first; } + + StateId StateId2() const { return state_pair_.second; } + + FilterState GetFilterState() const { return fs_; } + + const std::pair &StatePair() const { return state_pair_; } + + friend bool operator==(const DefaultComposeStateTuple &x, + const DefaultComposeStateTuple &y) { + return (&x == &y) || (x.state_pair_ == y.state_pair_ && x.fs_ == y.fs_); + } + + size_t Hash() const { + return static_cast(StateId1()) + + static_cast(StateId2()) * 7853u + + GetFilterState().Hash() * 7867u; + } + + private: + std::pair state_pair_; + FilterState fs_; // State of composition filter. +}; + +// Specialization for TrivialFilterState that does not explicitely store the +// filter state since it is always the unique non-blocking state. +template +class DefaultComposeStateTuple { + public: + using StateId = S; + using FilterState = TrivialFilterState; + + DefaultComposeStateTuple() + : state_pair_(kNoStateId, kNoStateId) {} + + DefaultComposeStateTuple(StateId s1, StateId s2, const FilterState &) + : state_pair_(s1, s2) {} + + StateId StateId1() const { return state_pair_.first; } + + StateId StateId2() const { return state_pair_.second; } + + FilterState GetFilterState() const { return FilterState(true); } + + const std::pair &StatePair() const { return state_pair_; } + + friend bool operator==(const DefaultComposeStateTuple &x, + const DefaultComposeStateTuple &y) { + return (&x == &y) || (x.state_pair_ == y.state_pair_); + } + + size_t Hash() const { return StateId1() + StateId2() * 7853; } + + private: + std::pair state_pair_; +}; + +// Hashing of composition state tuples. +template +class ComposeHash { + public: + size_t operator()(const T &t) const { return t.Hash(); } +}; + +// A HashStateTable over composition tuples. +template , + typename StateTable = + CompactHashStateTable>> +class GenericComposeStateTable : public StateTable { + public: + using StateId = typename Arc::StateId; + + GenericComposeStateTable(const Fst &fst1, const Fst &fst2) {} + + GenericComposeStateTable(const Fst &fst1, const Fst &fst2, + size_t table_size) + : StateTable(table_size) {} + + constexpr bool Error() const { return false; } + + private: + GenericComposeStateTable &operator=(const GenericComposeStateTable &table) = + delete; +}; + +// Fingerprint for general composition tuples. +template +class ComposeFingerprint { + public: + using StateId = typename StateTuple::StateId; + + // Required but suboptimal constructor. + ComposeFingerprint() : mult1_(8192), mult2_(8192) { + LOG(WARNING) << "TupleFingerprint: # of FST states should be provided."; + } + + // Constructor is provided the sizes of the input FSTs. + ComposeFingerprint(StateId nstates1, StateId nstates2) + : mult1_(nstates1), mult2_(nstates1 * nstates2) {} + + size_t operator()(const StateTuple &tuple) { + return tuple.StateId1() + tuple.StateId2() * mult1_ + + tuple.GetFilterState().Hash() * mult2_; + } + + private: + const ssize_t mult1_; + const ssize_t mult2_; +}; + +// Useful when the first composition state determines the tuple. +template +class ComposeState1Fingerprint { + public: + size_t operator()(const StateTuple &tuple) { return tuple.StateId1(); } +}; + +// Useful when the second composition state determines the tuple. +template +class ComposeState2Fingerprint { + public: + size_t operator()(const StateTuple &tuple) { return tuple.StateId2(); } +}; + +// A VectorStateTable over composition tuples. This can be used when the +// product of number of states in FST1 and FST2 (and the composition filter +// state hash) is manageable. If the FSTs are not expanded FSTs, they will +// first have their states counted. +template +class ProductComposeStateTable + : public VectorStateTable> { + public: + using StateId = typename Arc::StateId; + using StateTable = + VectorStateTable>; + + ProductComposeStateTable(const Fst &fst1, const Fst &fst2, + size_t table_size = 0) + : StateTable(new ComposeFingerprint(CountStates(fst1), + CountStates(fst2)), + table_size) {} + + ProductComposeStateTable( + const ProductComposeStateTable &table) + : StateTable(new ComposeFingerprint(table.Fingerprint())) {} + + constexpr bool Error() const { return false; } + + private: + ProductComposeStateTable &operator=(const ProductComposeStateTable &table) = + delete; +}; + +// A vector-backed table over composition tuples which can be used when the +// first FST is a string (i.e., satisfies kString property) and the second is +// deterministic and epsilon-free. It should be used with a composition filter +// that creates at most one filter state per tuple under these conditions (e.g., +// SequenceComposeFilter or MatchComposeFilter). +template +class StringDetComposeStateTable + : public VectorStateTable> { + public: + using StateId = typename Arc::StateId; + using StateTable = + VectorStateTable>; + + StringDetComposeStateTable(const Fst &fst1, const Fst &fst2) + : error_(false) { + static constexpr auto props2 = kIDeterministic | kNoIEpsilons; + if (fst1.Properties(kString, true) != kString) { + FSTERROR() << "StringDetComposeStateTable: 1st FST is not a string"; + error_ = true; + } else if (fst2.Properties(props2, true) != props2) { + FSTERROR() << "StringDetComposeStateTable: 2nd FST is not deterministic " + "and epsilon-free"; + error_ = true; + } + } + + StringDetComposeStateTable( + const StringDetComposeStateTable &table) + : StateTable(table), error_(table.error_) {} + + bool Error() const { return error_; } + + private: + bool error_; + + StringDetComposeStateTable &operator=(const StringDetComposeStateTable &) = + delete; +}; + +// A vector-backed table over composition tuples which can be used when the +// first FST is deterministic and epsilon-free and the second is a string (i.e., +// satisfies kString). It should be used with a composition filter that creates +// at most one filter state per tuple under these conditions (e.g., +// SequenceComposeFilter or MatchComposeFilter). +template +class DetStringComposeStateTable + : public VectorStateTable> { + public: + using StateId = typename Arc::StateId; + using StateTable = + VectorStateTable>; + + DetStringComposeStateTable(const Fst &fst1, const Fst &fst2) + : error_(false) { + static constexpr auto props = kODeterministic | kNoOEpsilons; + if (fst1.Properties(props, true) != props) { + FSTERROR() << "StringDetComposeStateTable: 1st FST is not " + << "input-deterministic and epsilon-free"; + error_ = true; + } else if (fst2.Properties(kString, true) != kString) { + FSTERROR() << "DetStringComposeStateTable: 2nd FST is not a string"; + error_ = true; + } + } + + DetStringComposeStateTable( + const DetStringComposeStateTable &table) + : StateTable(table), error_(table.error_) {} + + bool Error() const { return error_; } + + private: + bool error_; + + DetStringComposeStateTable &operator=(const DetStringComposeStateTable &) = + delete; +}; + +// An erasable table over composition tuples. The Erase(StateId) method can be +// called if the user either is sure that composition will never return to that +// tuple or doesn't care that if it does, it is assigned a new state ID. +template +class ErasableComposeStateTable + : public ErasableStateTable> { + public: + ErasableComposeStateTable(const Fst &fst1, const Fst &fst2) {} + + constexpr bool Error() const { return false; } + + private: + ErasableComposeStateTable &operator=(const ErasableComposeStateTable &table) = + delete; +}; + +} // namespace fst + +#endif // FST_STATE_TABLE_H_ diff --git a/projects/llm_framework/include/fst/statesort.h b/projects/llm_framework/include/fst/statesort.h new file mode 100644 index 00000000..346c7d32 --- /dev/null +++ b/projects/llm_framework/include/fst/statesort.h @@ -0,0 +1,74 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to sort states of an FST. + +#ifndef FST_STATESORT_H_ +#define FST_STATESORT_H_ + +#include +#include + +#include + +#include + + +namespace fst { + +// Sorts the input states of an FST. order[i] gives the the state ID after +// sorting that corresponds to the state ID i before sorting; it must +// therefore be a permutation of the input FST's states ID sequence. +template +void StateSort(MutableFst *fst, + const std::vector &order) { + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + if (order.size() != fst->NumStates()) { + FSTERROR() << "StateSort: Bad order vector size: " << order.size(); + fst->SetProperties(kError, kError); + return; + } + if (fst->Start() == kNoStateId) return; + const auto props = fst->Properties(kStateSortProperties, false); + std::vector done(order.size(), false); + std::vector arcsa; + std::vector arcsb; + fst->SetStart(order[fst->Start()]); + for (StateIterator> siter(*fst); !siter.Done(); + siter.Next()) { + auto s1 = siter.Value(); + StateId s2; + if (done[s1]) continue; + auto final1 = fst->Final(s1); + auto final2 = Weight::Zero(); + arcsa.clear(); + for (ArcIterator> aiter(*fst, s1); !aiter.Done(); + aiter.Next()) { + arcsa.push_back(aiter.Value()); + } + for (; !done[s1]; s1 = s2, final1 = final2, std::swap(arcsa, arcsb)) { + s2 = order[s1]; + if (!done[s2]) { + final2 = fst->Final(s2); + arcsb.clear(); + for (ArcIterator> aiter(*fst, s2); !aiter.Done(); + aiter.Next()) { + arcsb.push_back(aiter.Value()); + } + } + fst->SetFinal(s2, final1); + fst->DeleteArcs(s2); + for (auto arc : arcsa) { // Copy intended. + arc.nextstate = order[arc.nextstate]; + fst->AddArc(s2, arc); + } + done[s1] = true; + } + } + fst->SetProperties(props, kFstProperties); +} + +} // namespace fst + +#endif // FST_STATESORT_H_ diff --git a/projects/llm_framework/include/fst/string-weight.h b/projects/llm_framework/include/fst/string-weight.h new file mode 100644 index 00000000..fb3e70f1 --- /dev/null +++ b/projects/llm_framework/include/fst/string-weight.h @@ -0,0 +1,807 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// String weight set and associated semiring operation definitions. + +#ifndef FST_STRING_WEIGHT_H_ +#define FST_STRING_WEIGHT_H_ + +#include + +#include +#include +#include + +#include +#include +#include + + +namespace fst { + +constexpr int kStringInfinity = -1; // Label for the infinite string. +constexpr int kStringBad = -2; // Label for a non-string. +constexpr char kStringSeparator = '_'; // Label separator in strings. + +// Determines whether to use left or right string semiring. Includes a +// 'restricted' version that signals an error if proper prefixes/suffixes +// would otherwise be returned by Plus, useful with various +// algorithms that require functional transducer input with the +// string semirings. +enum StringType { STRING_LEFT = 0, STRING_RIGHT = 1, STRING_RESTRICT = 2 }; + +constexpr StringType ReverseStringType(StringType s) { + return s == STRING_LEFT ? STRING_RIGHT + : (s == STRING_RIGHT ? STRING_LEFT : STRING_RESTRICT); +} + +template +class StringWeightIterator; +template +class StringWeightReverseIterator; + +// String semiring: (longest_common_prefix/suffix, ., Infinity, Epsilon) +template +class StringWeight { + public: + using Label = Label_; + using ReverseWeight = StringWeight; + using Iterator = StringWeightIterator; + using ReverseIterator = StringWeightReverseIterator; + + friend class StringWeightIterator; + friend class StringWeightReverseIterator; + + StringWeight() {} + + template + StringWeight(const Iterator &begin, const Iterator &end) { + for (auto iter = begin; iter != end; ++iter) PushBack(*iter); + } + + explicit StringWeight(Label label) { PushBack(label); } + + static const StringWeight &Zero() { + static const auto *const zero = new StringWeight(Label(kStringInfinity)); + return *zero; + } + + static const StringWeight &One() { + static const auto *const one = new StringWeight(); + return *one; + } + + static const StringWeight &NoWeight() { + static const auto *const no_weight = new StringWeight(Label(kStringBad)); + return *no_weight; + } + + static const string &Type() { + static const string *const type = new string( + S == STRING_LEFT + ? "left_string" + : (S == STRING_RIGHT ? "right_string" : "restricted_string")); + return *type; + } + + bool Member() const; + + std::istream &Read(std::istream &strm); + + std::ostream &Write(std::ostream &strm) const; + + size_t Hash() const; + + StringWeight Quantize(float delta = kDelta) const { return *this; } + + ReverseWeight Reverse() const; + + static constexpr uint64 Properties() { + return kIdempotent | + (S == STRING_LEFT ? kLeftSemiring + : (S == STRING_RIGHT + ? kRightSemiring + : /* S == STRING_RESTRICT */ kLeftSemiring | + kRightSemiring)); + } + + // These operations combined with the StringWeightIterator and + // StringWeightReverseIterator provide the access and mutation of the string + // internal elements. + + // Clear existing StringWeight. + void Clear() { + first_ = 0; + rest_.clear(); + } + + size_t Size() const { return first_ ? rest_.size() + 1 : 0; } + + void PushFront(Label label) { + if (first_) rest_.push_front(first_); + first_ = label; + } + + void PushBack(Label label) { + if (!first_) { + first_ = label; + } else { + rest_.push_back(label); + } + } + + private: + Label first_ = 0; // First label in string (0 if empty). + std::list; + + friend class ArcIterator>; + friend class StateIterator>; + + explicit SynchronizeFst( + const Fst &fst, + const SynchronizeFstOptions &opts = SynchronizeFstOptions()) + : ImplToFst(std::make_shared(fst, opts)) {} + + // See Fst<>::Copy() for doc. + SynchronizeFst(const SynchronizeFst &fst, bool safe = false) + : ImplToFst(fst, safe) {} + + // Gets a copy of this SynchronizeFst. See Fst<>::Copy() for further doc. + SynchronizeFst *Copy(bool safe = false) const override { + return new SynchronizeFst(*this, safe); + } + + inline void InitStateIterator(StateIteratorData *data) const override; + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetMutableImpl()->InitArcIterator(s, data); + } + + private: + using ImplToFst::GetImpl; + using ImplToFst::GetMutableImpl; + + SynchronizeFst &operator=(const SynchronizeFst &) = delete; +}; + +// Specialization for SynchronizeFst. +template +class StateIterator> + : public CacheStateIterator> { + public: + explicit StateIterator(const SynchronizeFst &fst) + : CacheStateIterator>(fst, fst.GetMutableImpl()) {} +}; + +// Specialization for SynchronizeFst. +template +class ArcIterator> + : public CacheArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const SynchronizeFst &fst, StateId s) + : CacheArcIterator>(fst.GetMutableImpl(), s) { + if (!fst.GetImpl()->HasArcs(s)) fst.GetMutableImpl()->Expand(s); + } +}; + +template +inline void SynchronizeFst::InitStateIterator( + StateIteratorData *data) const { + data->base = new StateIterator>(*this); +} + +// Synchronizes a transducer. This version writes the synchronized result to a +// MutableFst. The result will be an equivalent FST that has the property that +// during the traversal of a path, the delay is either zero or strictly +// increasing, where the delay is the difference between the number of +// non-epsilon output labels and input labels along the path. +// +// For the algorithm to terminate, the input transducer must have bounded +// delay, i.e., the delay of every cycle must be zero. +// +// Complexity: +// +// - A has bounded delay: exponential. +// - A does not have bounded delay: does not terminate. +// +// For more information, see: +// +// Mohri, M. 2003. Edit-distance of weighted automata: General definitions and +// algorithms. International Journal of Computer Science 14(6): 957-982. +template +void Synchronize(const Fst &ifst, MutableFst *ofst) { + // Caches only the last state for fastest copy. + const SynchronizeFstOptions opts(FLAGS_fst_default_cache_gc, 0); + *ofst = SynchronizeFst(ifst, opts); +} + +} // namespace fst + +#endif // FST_SYNCHRONIZE_H_ diff --git a/projects/llm_framework/include/fst/test-properties.h b/projects/llm_framework/include/fst/test-properties.h new file mode 100644 index 00000000..677ed01b --- /dev/null +++ b/projects/llm_framework/include/fst/test-properties.h @@ -0,0 +1,246 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions to manipulate and test property bits. + +#ifndef FST_TEST_PROPERTIES_H_ +#define FST_TEST_PROPERTIES_H_ + +#include + +#include +#include + +#include +#include + + +DECLARE_bool(fst_verify_properties); + +namespace fst { +// namespace internal { + +// For a binary property, the bit is always returned set. For a trinary (i.e., +// two-bit) property, both bits are returned set iff either corresponding input +// bit is set. +inline uint64 KnownProperties(uint64 props) { + return kBinaryProperties | (props & kTrinaryProperties) | + ((props & kPosTrinaryProperties) << 1) | + ((props & kNegTrinaryProperties) >> 1); +} + +// Tests compatibility between two sets of properties. +inline bool CompatProperties(uint64 props1, uint64 props2) { + const auto known_props1 = KnownProperties(props1); + const auto known_props2 = KnownProperties(props2); + const auto known_props = known_props1 & known_props2; + const auto incompat_props = (props1 & known_props) ^ (props2 & known_props); + if (incompat_props) { + uint64 prop = 1; + for (int i = 0; i < 64; ++i, prop <<= 1) { + if (prop & incompat_props) { + LOG(ERROR) << "CompatProperties: Mismatch: " << PropertyNames[i] + << ": props1 = " << (props1 & prop ? "true" : "false") + << ", props2 = " << (props2 & prop ? "true" : "false"); + } + } + return false; + } else { + return true; + } +} + +// Computes FST property values defined in properties.h. The value of each +// property indicated in the mask will be determined and returned (these will +// never be unknown here). In the course of determining the properties +// specifically requested in the mask, certain other properties may be +// determined (those with little additional expense) and their values will be +// returned as well. The complete set of known properties (whether true or +// false) determined by this operation will be assigned to the the value pointed +// to by KNOWN. If 'use_stored' is true, pre-computed FST properties may be used +// when possible. 'mask & required_mask' is used to determine whether the stored +// propertoes can be used. This routine is seldom called directly; instead it is +// used to implement fst.Properties(mask, true). +template +uint64 ComputeProperties(const Fst &fst, uint64 mask, uint64 *known, + bool use_stored) { + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + const auto fst_props = fst.Properties(kFstProperties, false); // FST-stored. + // Check stored FST properties first if allowed. + if (use_stored) { + const auto known_props = KnownProperties(fst_props); + // If FST contains required info, return it. + if ((known_props & mask) == mask) { + if (known) *known = known_props; + return fst_props; + } + } + // Computes (trinary) properties explicitly. + // Initialize with binary properties (already known). + uint64 comp_props = fst_props & kBinaryProperties; + // Computes these trinary properties with a DFS. We compute only those that + // need a DFS here, since we otherwise would like to avoid a DFS since its + // stack could grow large. + uint64 dfs_props = kCyclic | kAcyclic | kInitialCyclic | kInitialAcyclic | + kAccessible | kNotAccessible | kCoAccessible | + kNotCoAccessible; + std::vector scc; + if (mask & (dfs_props | kWeightedCycles | kUnweightedCycles)) { + SccVisitor scc_visitor(&scc, nullptr, nullptr, &comp_props); + DfsVisit(fst, &scc_visitor); + } + // Computes any remaining trinary properties via a state and arcs iterations + if (mask & ~(kBinaryProperties | dfs_props)) { + comp_props |= kAcceptor | kNoEpsilons | kNoIEpsilons | kNoOEpsilons | + kILabelSorted | kOLabelSorted | kUnweighted | kTopSorted | + kString; + if (mask & (kIDeterministic | kNonIDeterministic)) { + comp_props |= kIDeterministic; + } + if (mask & (kODeterministic | kNonODeterministic)) { + comp_props |= kODeterministic; + } + if (mask & (dfs_props | kWeightedCycles | kUnweightedCycles)) { + comp_props |= kUnweightedCycles; + } + std::unique_ptr> ilabels; + std::unique_ptr> olabels; + StateId nfinal = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + Arc prev_arc; + // Creates these only if we need to. + if (mask & (kIDeterministic | kNonIDeterministic)) { + ilabels.reset(new std::unordered_set &fst1, const Fst &fst2) { + VLOG(1) << "Check FSTs for sanity (including property bits)."; + CHECK(Verify(fst1)); + CHECK(Verify(fst2)); + + // Ensures seed used once per instantiation. + static UniformArcSelector uniform_selector(seed_); + RandGenOptions> opts(uniform_selector, + kRandomPathLength); + return RandEquivalent(fst1, fst2, kNumRandomPaths, kTestDelta, opts); + } + + // Tests FSA is unambiguous + bool Unambiguous(const Fst &fst) { + VectorFst sfst, dfst; + VectorFst lfst1, lfst2; + Map(fst, &sfst, RmWeightMapper()); + Determinize(sfst, &dfst); + Map(fst, &lfst1, RmWeightMapper()); + Map(dfst, &lfst2, RmWeightMapper()); + return Equiv(lfst1, lfst2); + } + + // Ensures input-epsilon free transducers fst1 and fst2 have the + // same domain and that for each string pair '(is, os)' in fst1, + // '(is, os)' is the minimum weight match to 'is' in fst2. + template + bool MinRelated(const Fst &fst1, const Fst &fst2) { + // Same domain + VectorFst P1(fst1), P2(fst2); + Project(&P1, PROJECT_INPUT); + Project(&P2, PROJECT_INPUT); + if (!Equiv(P1, P2)) { + LOG(ERROR) << "Inputs not equivalent"; + return false; + } + + // Ensures seed used once per instantiation. + static UniformArcSelector uniform_selector(seed_); + RandGenOptions> opts(uniform_selector, + kRandomPathLength); + + VectorFst path, paths1, paths2; + for (ssize_t n = 0; n < kNumRandomPaths; ++n) { + RandGen(fst1, &path, opts); + Invert(&path); + Map(&path, RmWeightMapper()); + Compose(path, fst2, &paths1); + Weight sum1 = ShortestDistance(paths1); + Compose(paths1, path, &paths2); + Weight sum2 = ShortestDistance(paths2); + if (!ApproxEqual(Plus(sum1, sum2), sum2, kTestDelta)) { + LOG(ERROR) << "Sums not equivalent: " << sum1 << " " << sum2; + return false; + } + } + return true; + } + + // Tests ShortestDistance(A - P) >= + // ShortestDistance(A) times Threshold. + template + bool PruneEquiv(const Fst &fst, const Fst &pfst, Weight threshold) { + VLOG(1) << "Check FSTs for sanity (including property bits)."; + CHECK(Verify(fst)); + CHECK(Verify(pfst)); + + DifferenceFst D(fst, DeterminizeFst(RmEpsilonFst( + ArcMapFst>( + pfst, RmWeightMapper())))); + Weight sum1 = Times(ShortestDistance(fst), threshold); + Weight sum2 = ShortestDistance(D); + return ApproxEqual(Plus(sum1, sum2), sum1, kTestDelta); + } + + // Random seed. + int seed_; + // FST with no states + VectorFst zero_fst_; + // FST with one state that accepts epsilon. + VectorFst one_fst_; + // FST with one state that accepts all strings. + VectorFst univ_fst_; + // Generates weights used in testing. + WeightGenerator *weight_generator_; + // Maximum random path length. + static const int kRandomPathLength; + // Number of random paths to explore. + static const int kNumRandomPaths; + // Maximum number of nshortest paths. + static const int kNumRandomShortestPaths; + // Maximum number of nshortest states. + static const int kNumShortestStates; + // Delta for equivalence tests. + static const float kTestDelta; + + WeightedTester(const WeightedTester &) = delete; + WeightedTester &operator=(const WeightedTester &) = delete; +}; + +template +const int WeightedTester::kRandomPathLength = 25; + +template +const int WeightedTester::kNumRandomPaths = 100; + +template +const int WeightedTester::kNumRandomShortestPaths = 100; + +template +const int WeightedTester::kNumShortestStates = 10000; + +template +const float WeightedTester::kTestDelta = .05; + +// This class tests a variety of identities and properties that must +// hold for various algorithms on unweighted FSAs and that are not tested +// by WeightedTester. Only the specialization does anything interesting. +template +class UnweightedTester { + public: + UnweightedTester(const Fst &zero_fsa, const Fst &one_fsa, + const Fst &univ_fsa) {} + + void Test(const Fst &A1, const Fst &A2, const Fst &A3) {} +}; + +// Specialization for StdArc. This should work for any commutative, +// idempotent semiring when restricted to the unweighted case +// (being isomorphic to the boolean semiring). +template <> +class UnweightedTester { + public: + typedef StdArc Arc; + typedef Arc::Label Label; + typedef Arc::StateId StateId; + typedef Arc::Weight Weight; + + UnweightedTester(const Fst &zero_fsa, const Fst &one_fsa, + const Fst &univ_fsa) + : zero_fsa_(zero_fsa), one_fsa_(one_fsa), univ_fsa_(univ_fsa) {} + + void Test(const Fst &A1, const Fst &A2, const Fst &A3) { + TestRational(A1, A2, A3); + TestIntersect(A1, A2, A3); + TestOptimize(A1); + } + + private: + // Tests rational operations with identities + void TestRational(const Fst &A1, const Fst &A2, + const Fst &A3) { + { + VLOG(1) << "Check the union contains its arguments (destructive)."; + VectorFst U(A1); + Union(&U, A2); + + CHECK(Subset(A1, U)); + CHECK(Subset(A2, U)); + } + + { + VLOG(1) << "Check the union contains its arguments (delayed)."; + UnionFst U(A1, A2); + + CHECK(Subset(A1, U)); + CHECK(Subset(A2, U)); + } + + { + VLOG(1) << "Check if A^n c A* (destructive)."; + VectorFst C(one_fsa_); + int n = rand() % 5; + for (int i = 0; i < n; ++i) Concat(&C, A1); + + VectorFst S(A1); + Closure(&S, CLOSURE_STAR); + CHECK(Subset(C, S)); + } + + { + VLOG(1) << "Check if A^n c A* (delayed)."; + int n = rand() % 5; + Fst *C = new VectorFst(one_fsa_); + for (int i = 0; i < n; ++i) { + ConcatFst *F = new ConcatFst(*C, A1); + delete C; + C = F; + } + ClosureFst S(A1, CLOSURE_STAR); + CHECK(Subset(*C, S)); + delete C; + } + } + + // Tests intersect-based operations. + void TestIntersect(const Fst &A1, const Fst &A2, + const Fst &A3) { + VectorFst S1(A1); + VectorFst S2(A2); + VectorFst S3(A3); + + ILabelCompare comp; + + ArcSort(&S1, comp); + ArcSort(&S2, comp); + ArcSort(&S3, comp); + + { + VLOG(1) << "Check the intersection is contained in its arguments."; + IntersectFst I1(S1, S2); + CHECK(Subset(I1, S1)); + CHECK(Subset(I1, S2)); + } + + { + VLOG(1) << "Check union distributes over intersection."; + IntersectFst I1(S1, S2); + UnionFst U1(I1, S3); + + UnionFst U2(S1, S3); + UnionFst U3(S2, S3); + ArcSortFst> S4(U3, comp); + IntersectFst I2(U2, S4); + + CHECK(Equiv(U1, I2)); + } + + VectorFst C1; + VectorFst C2; + Complement(S1, &C1); + Complement(S2, &C2); + ArcSort(&C1, comp); + ArcSort(&C2, comp); + + { + VLOG(1) << "Check S U S' = Sigma*"; + UnionFst U(S1, C1); + CHECK(Equiv(U, univ_fsa_)); + } + + { + VLOG(1) << "Check S n S' = {}"; + IntersectFst I(S1, C1); + CHECK(Equiv(I, zero_fsa_)); + } + + { + VLOG(1) << "Check (S1' U S2') == (S1 n S2)'"; + UnionFst U(C1, C2); + + IntersectFst I(S1, S2); + VectorFst C3; + Complement(I, &C3); + CHECK(Equiv(U, C3)); + } + + { + VLOG(1) << "Check (S1' n S2') == (S1 U S2)'"; + IntersectFst I(C1, C2); + + UnionFst U(S1, S2); + VectorFst C3; + Complement(U, &C3); + CHECK(Equiv(I, C3)); + } + } + + // Tests optimization operations + void TestOptimize(const Fst &A) { + { + VLOG(1) << "Check determinized FSA is equivalent to its input."; + DeterminizeFst D(A); + CHECK(Equiv(A, D)); + } + + { + VLOG(1) << "Check disambiguated FSA is equivalent to its input."; + VectorFst R(A), D; + RmEpsilon(&R); + + Disambiguate(R, &D); + CHECK(Equiv(R, D)); + } + + { + VLOG(1) << "Check minimized FSA is equivalent to its input."; + int n; + { + RmEpsilonFst R(A); + DeterminizeFst D(R); + VectorFst M(D); + Minimize(&M, static_cast *>(nullptr), kDelta); + CHECK(Equiv(A, M)); + n = M.NumStates(); + } + + if (n) { // Skip test if A is the empty machine + VLOG(1) << "Check that Hopcroft's and Revuz's algorithms lead to the" + << " same number of states as Brozozowski's algorithm"; + VectorFst R; + Reverse(A, &R); + RmEpsilon(&R); + DeterminizeFst DR(R); + VectorFst RD; + Reverse(DR, &RD); + DeterminizeFst DRD(RD); + VectorFst M(DRD); + CHECK_EQ(n + 1, M.NumStates()); // Accounts for the epsilon transition + // to the initial state + } + } + } + + // Tests if two FSAS are equivalent. + bool Equiv(const Fst &fsa1, const Fst &fsa2) { + VLOG(1) << "Check FSAs for sanity (including property bits)."; + CHECK(Verify(fsa1)); + CHECK(Verify(fsa2)); + + VectorFst vfsa1(fsa1); + VectorFst vfsa2(fsa2); + RmEpsilon(&vfsa1); + RmEpsilon(&vfsa2); + DeterminizeFst dfa1(vfsa1); + DeterminizeFst dfa2(vfsa2); + + // Test equivalence using union-find algorithm + bool equiv1 = Equivalent(dfa1, dfa2); + + // Test equivalence by checking if (S1 - S2) U (S2 - S1) is empty + ILabelCompare comp; + VectorFst sdfa1(dfa1); + ArcSort(&sdfa1, comp); + VectorFst sdfa2(dfa2); + ArcSort(&sdfa2, comp); + + DifferenceFst dfsa1(sdfa1, sdfa2); + DifferenceFst dfsa2(sdfa2, sdfa1); + + VectorFst ufsa(dfsa1); + Union(&ufsa, dfsa2); + Connect(&ufsa); + bool equiv2 = ufsa.NumStates() == 0; + + // Check two equivalence tests match + CHECK((equiv1 && equiv2) || (!equiv1 && !equiv2)); + + return equiv1; + } + + // Tests if FSA1 is a subset of FSA2 (disregarding weights). + bool Subset(const Fst &fsa1, const Fst &fsa2) { + VLOG(1) << "Check FSAs (incl. property bits) for sanity"; + CHECK(Verify(fsa1)); + CHECK(Verify(fsa2)); + + VectorFst vfsa1; + VectorFst vfsa2; + RmEpsilon(&vfsa1); + RmEpsilon(&vfsa2); + ILabelCompare comp; + ArcSort(&vfsa1, comp); + ArcSort(&vfsa2, comp); + IntersectFst ifsa(vfsa1, vfsa2); + DeterminizeFst dfa1(vfsa1); + DeterminizeFst dfa2(ifsa); + return Equivalent(dfa1, dfa2); + } + + // Returns complement Fsa + void Complement(const Fst &ifsa, MutableFst *ofsa) { + RmEpsilonFst rfsa(ifsa); + DeterminizeFst dfa(rfsa); + DifferenceFst cfsa(univ_fsa_, dfa); + *ofsa = cfsa; + } + + // FSA with no states + VectorFst zero_fsa_; + + // FSA with one state that accepts epsilon. + VectorFst one_fsa_; + + // FSA with one state that accepts all strings. + VectorFst univ_fsa_; +}; + +// This class tests a variety of identities and properties that must +// hold for various FST algorithms. It randomly generates FSTs, using +// function object 'weight_generator' to select weights. 'WeightTester' +// and 'UnweightedTester' are then called. +template +class AlgoTester { + public: + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + AlgoTester(WeightGenerator generator, int seed) + : weight_generator_(generator) { + one_fst_.AddState(); + one_fst_.SetStart(0); + one_fst_.SetFinal(0, Weight::One()); + + univ_fst_.AddState(); + univ_fst_.SetStart(0); + univ_fst_.SetFinal(0, Weight::One()); + for (int i = 0; i < kNumRandomLabels; ++i) + univ_fst_.AddArc(0, Arc(i, i, Weight::One(), 0)); + + weighted_tester_ = new WeightedTester( + seed, zero_fst_, one_fst_, univ_fst_, &weight_generator_); + + unweighted_tester_ = + new UnweightedTester(zero_fst_, one_fst_, univ_fst_); + } + + ~AlgoTester() { + delete weighted_tester_; + delete unweighted_tester_; + } + + void MakeRandFst(MutableFst *fst) { + RandFst(kNumRandomStates, kNumRandomArcs, + kNumRandomLabels, kAcyclicProb, + &weight_generator_, fst); + } + + void Test() { + VLOG(1) << "weight type = " << Weight::Type(); + + for (int i = 0; i < FLAGS_repeat; ++i) { + // Random transducers + VectorFst T1; + VectorFst T2; + VectorFst T3; + MakeRandFst(&T1); + MakeRandFst(&T2); + MakeRandFst(&T3); + weighted_tester_->Test(T1, T2, T3); + + VectorFst A1(T1); + VectorFst A2(T2); + VectorFst A3(T3); + Project(&A1, PROJECT_OUTPUT); + Project(&A2, PROJECT_INPUT); + Project(&A3, PROJECT_INPUT); + ArcMap(&A1, rm_weight_mapper_); + ArcMap(&A2, rm_weight_mapper_); + ArcMap(&A3, rm_weight_mapper_); + unweighted_tester_->Test(A1, A2, A3); + } + } + + private: + // Generates weights used in testing. + WeightGenerator weight_generator_; + + // FST with no states + VectorFst zero_fst_; + + // FST with one state that accepts epsilon. + VectorFst one_fst_; + + // FST with one state that accepts all strings. + VectorFst univ_fst_; + + // Tests weighted FSTs + WeightedTester *weighted_tester_; + + // Tests unweighted FSTs + UnweightedTester *unweighted_tester_; + + // Mapper to remove weights from an Fst + RmWeightMapper rm_weight_mapper_; + + // Maximum number of states in random test Fst. + static const int kNumRandomStates; + + // Maximum number of arcs in random test Fst. + static const int kNumRandomArcs; + + // Number of alternative random labels. + static const int kNumRandomLabels; + + // Probability to force an acyclic Fst + static const float kAcyclicProb; + + // Maximum random path length. + static const int kRandomPathLength; + + // Number of random paths to explore. + static const int kNumRandomPaths; + + AlgoTester(const AlgoTester &) = delete; + AlgoTester &operator=(const AlgoTester &) = delete; +}; + +template +const int AlgoTester::kNumRandomStates = 10; + +template +const int AlgoTester::kNumRandomArcs = 25; + +template +const int AlgoTester::kNumRandomLabels = 5; + +template +const float AlgoTester::kAcyclicProb = .25; + +template +const int AlgoTester::kRandomPathLength = 25; + +template +const int AlgoTester::kNumRandomPaths = 100; + +} // namespace fst + +#endif // FST_TEST_ALGO_TEST_H_ diff --git a/projects/llm_framework/include/fst/test/fst_test.h b/projects/llm_framework/include/fst/test/fst_test.h new file mode 100644 index 00000000..7d536d90 --- /dev/null +++ b/projects/llm_framework/include/fst/test/fst_test.h @@ -0,0 +1,318 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Regression test for FST classes. + +#ifndef FST_TEST_FST_TEST_H_ +#define FST_TEST_FST_TEST_H_ + +#include +#include +#include +#include +#include + +DECLARE_string(tmpdir); + +namespace fst { + +// This tests an Fst F that is assumed to have a copy method from an +// arbitrary Fst. Some test functions make further assumptions mostly +// obvious from their name. These tests are written as member temple +// functions that take a test fst as its argument so that different +// Fsts in the interface hierarchy can be tested separately and so +// that we can instantiate only those tests that make sense for a +// particular Fst. +template +class FstTester { + public: + typedef typename F::Arc Arc; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + typedef typename Arc::Label Label; + + FstTester() { + VectorFst vfst; + InitFst(&vfst, 128); + testfst_ = new F(vfst); + } + + explicit FstTester(F *testfst) : testfst_(testfst) {} + + ~FstTester() { delete testfst_; } + + // This verifies the contents described in InitFst() using + // methods defined in a generic Fst. + template + void TestBase(const G &fst) const { + CHECK(Verify(fst)); + CHECK_EQ(fst.Start(), 0); + StateId ns = 0; + StateIterator siter(fst); + Matcher matcher(fst, MATCH_INPUT); + MatchType match_type = matcher.Type(true); + for (; !siter.Done(); siter.Next()) { + } + for (siter.Reset(); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + matcher.SetState(s); + CHECK_EQ(fst.Final(s), NthWeight(s)); + size_t na = 0; + ArcIterator aiter(fst, s); + for (; !aiter.Done(); aiter.Next()) { + } + for (aiter.Reset(); !aiter.Done(); aiter.Next()) { + ++na; + const Arc &arc = aiter.Value(); + CHECK_EQ(arc.ilabel, na); + CHECK_EQ(arc.olabel, 0); + CHECK_EQ(arc.weight, NthWeight(na)); + CHECK_EQ(arc.nextstate, s); + if (match_type == MATCH_INPUT) { + CHECK(matcher.Find(arc.ilabel)); + CHECK_EQ(matcher.Value().ilabel, arc.ilabel); + } + } + CHECK_EQ(na, s); + CHECK_EQ(na, aiter.Position()); + CHECK_EQ(fst.NumArcs(s), s); + CHECK_EQ(fst.NumInputEpsilons(s), 0); + CHECK_EQ(fst.NumOutputEpsilons(s), s); + CHECK(!matcher.Find(s + 1)); // out-of-range + CHECK(!matcher.Find(kNoLabel)); // no explicit epsilons + CHECK(matcher.Find(0)); + CHECK_EQ(matcher.Value().ilabel, kNoLabel); // implicit epsilon loop + ++ns; + } + CHECK(fst.Properties(kNotAcceptor, true)); + CHECK(fst.Properties(kOEpsilons, true)); + } + + void TestBase() const { TestBase(*testfst_); } + + // This verifies methods specfic to an ExpandedFst. + template + void TestExpanded(const G &fst) const { + StateId ns = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + ++ns; + } + CHECK_EQ(fst.NumStates(), ns); + CHECK(fst.Properties(kExpanded, false)); + } + + void TestExpanded() const { TestExpanded(*testfst_); } + + // This verifies methods specific to a MutableFst. + template + void TestMutable(G *fst) const { + for (StateIterator siter(*fst); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + size_t na = 0; + size_t ni = fst->NumInputEpsilons(s); + MutableArcIterator aiter(fst, s); + for (; !aiter.Done(); aiter.Next()) { + } + for (aiter.Reset(); !aiter.Done(); aiter.Next()) { + ++na; + Arc arc = aiter.Value(); + arc.ilabel = 0; + aiter.SetValue(arc); + arc = aiter.Value(); + CHECK_EQ(arc.ilabel, 0); + CHECK_EQ(fst->NumInputEpsilons(s), ni + 1); + arc.ilabel = na; + aiter.SetValue(arc); + CHECK_EQ(fst->NumInputEpsilons(s), ni); + } + } + + G *cfst1 = fst->Copy(); + cfst1->DeleteStates(); + CHECK_EQ(cfst1->NumStates(), 0); + delete cfst1; + + G *cfst2 = fst->Copy(); + for (StateIterator siter(*cfst2); !siter.Done(); siter.Next()) { + StateId s = siter.Value(); + cfst2->DeleteArcs(s); + CHECK_EQ(cfst2->NumArcs(s), 0); + CHECK_EQ(cfst2->NumInputEpsilons(s), 0); + CHECK_EQ(cfst2->NumOutputEpsilons(s), 0); + } + delete cfst2; + } + + void TestMutable() { TestMutable(testfst_); } + + // This verifies the copy methods. + template + void TestAssign(G *fst) const { + // Assignment from G + G afst1; + afst1 = *fst; + CHECK(Equal(*fst, afst1)); + + // Assignment from Fst + G afst2; + afst2 = *static_cast *>(fst); + CHECK(Equal(*fst, afst2)); + + // Assignment from self + afst2.operator=(afst2); + CHECK(Equal(*fst, afst2)); + } + + void TestAssign() { TestAssign(testfst_); } + + // This verifies the copy methods. + template + void TestCopy(const G &fst) const { + // Copy from G + G c1fst(fst); + TestBase(c1fst); + + // Copy from Fst + const G c2fst(static_cast &>(fst)); + TestBase(c2fst); + + // Copy from self + const G *c3fst = fst.Copy(); + TestBase(*c3fst); + delete c3fst; + } + + void TestCopy() const { TestCopy(*testfst_); } + + // This verifies the read/write methods. + template + void TestIO(const G &fst) const { + const string filename = FLAGS_tmpdir + "/test.fst"; + const string aligned = FLAGS_tmpdir + "/aligned.fst"; + { + // write/read + CHECK(fst.Write(filename)); + G *ffst = G::Read(filename); + CHECK(ffst); + TestBase(*ffst); + delete ffst; + } + + { + // generic read/cast/test + Fst *gfst = Fst::Read(filename); + CHECK(gfst); + G *dfst = static_cast(gfst); + TestBase(*dfst); + + // generic write/read/test + CHECK(gfst->Write(filename)); + Fst *hfst = Fst::Read(filename); + CHECK(hfst); + TestBase(*hfst); + delete gfst; + delete hfst; + } + + { + // check mmaping by first writing the file with the aligned attribute set + { + std::ofstream ostr(aligned); + FstWriteOptions opts; + opts.source = aligned; + opts.align = true; + CHECK(fst.Write(ostr, opts)); + } + std::ifstream istr(aligned); + FstReadOptions opts; + opts.mode = FstReadOptions::ReadMode("map"); + opts.source = aligned; + G *gfst = G::Read(istr, opts); + CHECK(gfst); + TestBase(*gfst); + delete gfst; + } + + // check mmaping of unaligned files to make sure it does not fail. + { + { + std::ofstream ostr(aligned); + FstWriteOptions opts; + opts.source = aligned; + opts.align = false; + CHECK(fst.Write(ostr, opts)); + } + std::ifstream istr(aligned); + FstReadOptions opts; + opts.mode = FstReadOptions::ReadMode("map"); + opts.source = aligned; + G *gfst = G::Read(istr, opts); + CHECK(gfst); + TestBase(*gfst); + delete gfst; + } + + // expanded write/read/test + if (fst.Properties(kExpanded, false)) { + ExpandedFst *efst = ExpandedFst::Read(filename); + CHECK(efst); + TestBase(*efst); + TestExpanded(*efst); + delete efst; + } + + // mutable write/read/test + if (fst.Properties(kMutable, false)) { + MutableFst *mfst = MutableFst::Read(filename); + CHECK(mfst); + TestBase(*mfst); + TestExpanded(*mfst); + TestMutable(mfst); + delete mfst; + } + } + + void TestIO() const { TestIO(*testfst_); } + + private: + // This constructs test FSTs. Given a mutable FST, will leave + // the FST as follows: + // (I) NumStates() = nstates + // (II) Start() = 0 + // (III) Final(s) = NthWeight(s) + // (IV) For state s: + // (a) NumArcs(s) == s + // (b) For ith arc of s: + // (1) ilabel = i + // (2) olabel = 0 + // (3) weight = NthWeight(i) + // (4) nextstate = s + void InitFst(MutableFst *fst, size_t nstates) const { + fst->DeleteStates(); + CHECK_GT(nstates, 0); + + for (StateId s = 0; s < nstates; ++s) { + fst->AddState(); + fst->SetFinal(s, NthWeight(s)); + for (size_t i = 1; i <= s; ++i) { + Arc arc(i, 0, NthWeight(i), s); + fst->AddArc(s, arc); + } + } + + fst->SetStart(0); + } + + // Generates One() + ... + One() (n times) + Weight NthWeight(int n) const { + Weight w = Weight::Zero(); + for (int i = 0; i < n; ++i) w = Plus(w, Weight::One()); + return w; + } + + F *testfst_; // what we're testing +}; + +} // namespace fst + +#endif // FST_TEST_FST_TEST_H_ diff --git a/projects/llm_framework/include/fst/test/rand-fst.h b/projects/llm_framework/include/fst/test/rand-fst.h new file mode 100644 index 00000000..f2f34c67 --- /dev/null +++ b/projects/llm_framework/include/fst/test/rand-fst.h @@ -0,0 +1,90 @@ +#ifndef FST_TEST_RAND_FST_H_ +#define FST_TEST_RAND_FST_H_ + +#include +#include +#include + +namespace fst { + +// Generates a random FST. +template +void RandFst(const int num_random_states, const int num_random_arcs, + const int num_random_labels, const float acyclic_prob, + WeightGenerator *weight_generator, MutableFst *fst) { + typedef typename Arc::Label Label; + typedef typename Arc::StateId StateId; + typedef typename Arc::Weight Weight; + + // Determines direction of the arcs wrt state numbering. This way we + // can force acyclicity when desired. + enum ArcDirection { + ANY_DIRECTION = 0, + FORWARD_DIRECTION = 1, + REVERSE_DIRECTION = 2, + NUM_DIRECTIONS = 3 + }; + + ArcDirection arc_direction = ANY_DIRECTION; + if (rand() / (RAND_MAX + 1.0) < acyclic_prob) + arc_direction = rand() % 2 ? FORWARD_DIRECTION : REVERSE_DIRECTION; + + fst->DeleteStates(); + StateId ns = rand() % num_random_states; + + if (ns == 0) return; + for (StateId s = 0; s < ns; ++s) fst->AddState(); + + StateId start = rand() % ns; + fst->SetStart(start); + + size_t na = rand() % num_random_arcs; + for (size_t n = 0; n < na; ++n) { + StateId s = rand() % ns; + Arc arc; + arc.ilabel = rand() % num_random_labels; + arc.olabel = rand() % num_random_labels; + arc.weight = (*weight_generator)(); + arc.nextstate = rand() % ns; + + if ((arc_direction == FORWARD_DIRECTION || + arc_direction == REVERSE_DIRECTION) && + s == arc.nextstate) { + continue; // skips self-loops + } + + if ((arc_direction == FORWARD_DIRECTION && s > arc.nextstate) || + (arc_direction == REVERSE_DIRECTION && s < arc.nextstate)) { + StateId t = s; // reverses arcs + s = arc.nextstate; + arc.nextstate = t; + } + + fst->AddArc(s, arc); + } + + StateId nf = rand() % (ns + 1); + for (StateId n = 0; n < nf; ++n) { + StateId s = rand() % ns; + Weight final = (*weight_generator)(); + fst->SetFinal(s, final); + } + VLOG(1) << "Check FST for sanity (including property bits)."; + CHECK(Verify(*fst)); + + // Get/compute all properties. + uint64 props = fst->Properties(kFstProperties, true); + + // Select random set of properties to be unknown. + uint64 mask = 0; + for (int n = 0; n < 8; ++n) { + mask |= rand() & 0xff; + mask <<= 8; + } + mask &= ~kTrinaryProperties; + fst->SetProperties(props & ~mask, mask); +} + +} // namespace fst + +#endif // FST_TEST_RAND_FST_H_ diff --git a/projects/llm_framework/include/fst/test/weight-tester.h b/projects/llm_framework/include/fst/test/weight-tester.h new file mode 100644 index 00000000..b7de665e --- /dev/null +++ b/projects/llm_framework/include/fst/test/weight-tester.h @@ -0,0 +1,207 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Utility class for regression testing of FST weights. + +#ifndef FST_TEST_WEIGHT_TESTER_H_ +#define FST_TEST_WEIGHT_TESTER_H_ + +#include +#include + +#include + +#include +#include + +namespace fst { + +// This class tests a variety of identities and properties that must +// hold for the Weight class to be well-defined. It calls function object +// WEIGHT_GENERATOR to select weights that are used in the tests. +template +class WeightTester { + public: + WeightTester(WeightGenerator generator) + : weight_generator_(std::move(generator)) {} + + void Test(int iterations, bool test_division = true) { + for (int i = 0; i < iterations; ++i) { + // Selects the test weights. + const Weight w1(weight_generator_()); + const Weight w2(weight_generator_()); + const Weight w3(weight_generator_()); + + VLOG(1) << "weight type = " << Weight::Type(); + VLOG(1) << "w1 = " << w1; + VLOG(1) << "w2 = " << w2; + VLOG(1) << "w3 = " << w3; + + TestSemiring(w1, w2, w3); + if (test_division) TestDivision(w1, w2); + TestReverse(w1, w2); + TestEquality(w1, w2, w3); + TestIO(w1); + TestCopy(w1); + } + } + + private: + // Note in the tests below we use ApproxEqual rather than == and add + // kDelta to inequalities where the weights might be inexact. + + // Tests (Plus, Times, Zero, One) defines a commutative semiring. + void TestSemiring(Weight w1, Weight w2, Weight w3) { + // Checks that the operations are closed. + CHECK(Plus(w1, w2).Member()); + CHECK(Times(w1, w2).Member()); + + // Checks that the operations are associative. + CHECK(ApproxEqual(Plus(w1, Plus(w2, w3)), Plus(Plus(w1, w2), w3))); + CHECK(ApproxEqual(Times(w1, Times(w2, w3)), Times(Times(w1, w2), w3))); + + // Checks the identity elements. + CHECK(Plus(w1, Weight::Zero()) == w1); + CHECK(Plus(Weight::Zero(), w1) == w1); + CHECK(Times(w1, Weight::One()) == w1); + CHECK(Times(Weight::One(), w1) == w1); + + // Check the no weight element. + CHECK(!Weight::NoWeight().Member()); + CHECK(!Plus(w1, Weight::NoWeight()).Member()); + CHECK(!Plus(Weight::NoWeight(), w1).Member()); + CHECK(!Times(w1, Weight::NoWeight()).Member()); + CHECK(!Times(Weight::NoWeight(), w1).Member()); + + // Checks that the operations commute. + CHECK(ApproxEqual(Plus(w1, w2), Plus(w2, w1))); + + if (Weight::Properties() & kCommutative) + CHECK(ApproxEqual(Times(w1, w2), Times(w2, w1))); + + // Checks Zero() is the annihilator. + CHECK(Times(w1, Weight::Zero()) == Weight::Zero()); + CHECK(Times(Weight::Zero(), w1) == Weight::Zero()); + + // Check Power(w, 0) is Weight::One() + CHECK(Power(w1, 0) == Weight::One()); + + // Check Power(w, 1) is w + CHECK(Power(w1, 1) == w1); + + // Check Power(w, 3) is Times(w, Times(w, w)) + CHECK(Power(w1, 3) == Times(w1, Times(w1, w1))); + + // Checks distributivity. + if (Weight::Properties() & kLeftSemiring) { + CHECK(ApproxEqual(Times(w1, Plus(w2, w3)), + Plus(Times(w1, w2), Times(w1, w3)))); + } + if (Weight::Properties() & kRightSemiring) + CHECK(ApproxEqual(Times(Plus(w1, w2), w3), + Plus(Times(w1, w3), Times(w2, w3)))); + + if (Weight::Properties() & kIdempotent) CHECK(Plus(w1, w1) == w1); + + if (Weight::Properties() & kPath) + CHECK(Plus(w1, w2) == w1 || Plus(w1, w2) == w2); + + // Ensure weights form a left or right semiring. + CHECK(Weight::Properties() & (kLeftSemiring | kRightSemiring)); + + // Check when Times() is commutative that it is marked as a semiring. + if (Weight::Properties() & kCommutative) + CHECK(Weight::Properties() & kSemiring); + } + + // Tests division operation. + void TestDivision(Weight w1, Weight w2) { + Weight p = Times(w1, w2); + + if (Weight::Properties() & kLeftSemiring) { + Weight d = Divide(p, w1, DIVIDE_LEFT); + if (d.Member()) CHECK(ApproxEqual(p, Times(w1, d))); + CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_LEFT).Member()); + CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_LEFT).Member()); + } + + if (Weight::Properties() & kRightSemiring) { + Weight d = Divide(p, w2, DIVIDE_RIGHT); + if (d.Member()) CHECK(ApproxEqual(p, Times(d, w2))); + CHECK(!Divide(w1, Weight::NoWeight(), DIVIDE_RIGHT).Member()); + CHECK(!Divide(Weight::NoWeight(), w1, DIVIDE_RIGHT).Member()); + } + + if (Weight::Properties() & kCommutative) { + Weight d = Divide(p, w1, DIVIDE_RIGHT); + if (d.Member()) CHECK(ApproxEqual(p, Times(d, w1))); + } + } + + // Tests reverse operation. + void TestReverse(Weight w1, Weight w2) { + typedef typename Weight::ReverseWeight ReverseWeight; + + ReverseWeight rw1 = w1.Reverse(); + ReverseWeight rw2 = w2.Reverse(); + + CHECK(rw1.Reverse() == w1); + CHECK(Plus(w1, w2).Reverse() == Plus(rw1, rw2)); + CHECK(Times(w1, w2).Reverse() == Times(rw2, rw1)); + } + + // Tests == is an equivalence relation. + void TestEquality(Weight w1, Weight w2, Weight w3) { + // Checks reflexivity. + CHECK(w1 == w1); + + // Checks symmetry. + CHECK((w1 == w2) == (w2 == w1)); + + // Checks transitivity. + if (w1 == w2 && w2 == w3) CHECK(w1 == w3); + } + + // Tests binary serialization and textual I/O. + void TestIO(Weight w) { + // Tests binary I/O + { + std::ostringstream os; + w.Write(os); + os.flush(); + std::istringstream is(os.str()); + Weight v; + v.Read(is); + CHECK_EQ(w, v); + } + + // Tests textual I/O. + { + std::ostringstream os; + os << w; + std::istringstream is(os.str()); + Weight v(Weight::One()); + is >> v; + CHECK(ApproxEqual(w, v)); + } + } + + // Tests copy constructor and assignment operator + void TestCopy(Weight w) { + Weight x = w; + CHECK(w == x); + + x = Weight(w); + CHECK(w == x); + + x.operator=(x); + CHECK(w == x); + } + + // Generates weights used in testing. + WeightGenerator weight_generator_; +}; + +} // namespace fst + +#endif // FST_TEST_WEIGHT_TESTER_H_ diff --git a/projects/llm_framework/include/fst/topsort.h b/projects/llm_framework/include/fst/topsort.h new file mode 100644 index 00000000..cae3e154 --- /dev/null +++ b/projects/llm_framework/include/fst/topsort.h @@ -0,0 +1,95 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Topological sort of FSTs. + +#ifndef FST_TOPSORT_H_ +#define FST_TOPSORT_H_ + +#include +#include + + +#include +#include +#include + + +namespace fst { + +// DFS visitor class to return topological ordering. +template +class TopOrderVisitor { + public: + using StateId = typename Arc::StateId; + + // If acyclic, order[i] gives the topological position of StateId i; + // otherwise it is unchanged. acyclic_ will be true iff the FST has no + // cycles. The caller retains ownership of the state order vector. + TopOrderVisitor(std::vector *order, bool *acyclic) + : order_(order), acyclic_(acyclic) {} + + void InitVisit(const Fst &fst) { + finish_.reset(new std::vector()); + *acyclic_ = true; + } + + constexpr bool InitState(StateId, StateId) const { return true; } + + constexpr bool TreeArc(StateId, const Arc &) const { return true; } + + bool BackArc(StateId, const Arc &) { return (*acyclic_ = false); } + + constexpr bool ForwardOrCrossArc(StateId, const Arc &) const { return true; } + + void FinishState(StateId s, StateId, const Arc *) { finish_->push_back(s); } + + void FinishVisit() { + if (*acyclic_) { + order_->clear(); + for (StateId s = 0; s < finish_->size(); ++s) { + order_->push_back(kNoStateId); + } + for (StateId s = 0; s < finish_->size(); ++s) { + (*order_)[(*finish_)[finish_->size() - s - 1]] = s; + } + } + finish_.reset(); + } + + private: + std::vector *order_; + bool *acyclic_; + // States in finish-time order. + std::unique_ptr> finish_; +}; + +// Topologically sorts its input if acyclic, modifying it. Otherwise, the input +// is unchanged. When sorted, all transitions are from lower to higher state +// IDs. +// +// Complexity: +// +// Time: O(V + E) +// Space: O(V + E) +// +// where V is the number of states and E is the number of arcs. +template +bool TopSort(MutableFst *fst) { + std::vector order; + bool acyclic; + TopOrderVisitor top_order_visitor(&order, &acyclic); + DfsVisit(*fst, &top_order_visitor); + if (acyclic) { + StateSort(fst, order); + fst->SetProperties(kAcyclic | kInitialAcyclic | kTopSorted, + kAcyclic | kInitialAcyclic | kTopSorted); + } else { + fst->SetProperties(kCyclic | kNotTopSorted, kCyclic | kNotTopSorted); + } + return acyclic; +} + +} // namespace fst + +#endif // FST_TOPSORT_H_ diff --git a/projects/llm_framework/include/fst/tuple-weight.h b/projects/llm_framework/include/fst/tuple-weight.h new file mode 100644 index 00000000..c8ecab76 --- /dev/null +++ b/projects/llm_framework/include/fst/tuple-weight.h @@ -0,0 +1,163 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Tuple weight set operation definitions. + +#ifndef FST_TUPLE_WEIGHT_H_ +#define FST_TUPLE_WEIGHT_H_ + +#include +#include +#include +#include +#include + +#include +#include + +#include + + +namespace fst { + +// n-tuple weight, element of the n-th Cartesian power of W. +template +class TupleWeight { + public: + using ReverseWeight = TupleWeight; + + using Weight = W; + using Index = size_t; + + template + TupleWeight(Iterator begin, Iterator end) { + std::copy(begin, end, values_.begin()); + } + + explicit TupleWeight(const W &weight = W::Zero()) { values_.fill(weight); } + + // Initialize component `index` to `weight`; initialize all other components + // to `default_weight` + TupleWeight(Index index, const W &weight, const W &default_weight) + : TupleWeight(default_weight) { + values_[index] = weight; + } + + static const TupleWeight &Zero() { + static const TupleWeight zero(W::Zero()); + return zero; + } + + static const TupleWeight &One() { + static const TupleWeight one(W::One()); + return one; + } + + static const TupleWeight &NoWeight() { + static const TupleWeight no_weight(W::NoWeight()); + return no_weight; + } + + constexpr static size_t Length() { return n; } + + std::istream &Read(std::istream &istrm) { + for (size_t i = 0; i < n; ++i) values_[i].Read(istrm); + return istrm; + } + + std::ostream &Write(std::ostream &ostrm) const { + for (size_t i = 0; i < n; ++i) values_[i].Write(ostrm); + return ostrm; + } + + bool Member() const { + return std::all_of(values_.begin(), values_.end(), + std::mem_fn(&W::Member)); + } + + size_t Hash() const { + uint64 hash = 0; + for (size_t i = 0; i < n; ++i) hash = 5 * hash + values_[i].Hash(); + return size_t(hash); + } + + TupleWeight Quantize(float delta = kDelta) const { + TupleWeight weight; + for (size_t i = 0; i < n; ++i) { + weight.values_[i] = values_[i].Quantize(delta); + } + return weight; + } + + ReverseWeight Reverse() const { + TupleWeight w; + for (size_t i = 0; i < n; ++i) w.values_[i] = values_[i].Reverse(); + return w; + } + + const W &Value(size_t i) const { return values_[i]; } + + void SetValue(size_t i, const W &w) { values_[i] = w; } + + private: + std::array values_; +}; + +template +inline bool operator==(const TupleWeight &w1, + const TupleWeight &w2) { + for (size_t i = 0; i < n; ++i) { + if (w1.Value(i) != w2.Value(i)) return false; + } + return true; +} + +template +inline bool operator!=(const TupleWeight &w1, + const TupleWeight &w2) { + for (size_t i = 0; i < n; ++i) { + if (w1.Value(i) != w2.Value(i)) return true; + } + return false; +} + +template +inline bool ApproxEqual(const TupleWeight &w1, + const TupleWeight &w2, float delta = kDelta) { + for (size_t i = 0; i < n; ++i) { + if (!ApproxEqual(w1.Value(i), w2.Value(i), delta)) return false; + } + return true; +} + +template +inline std::ostream &operator<<(std::ostream &strm, + const TupleWeight &w) { + CompositeWeightWriter writer(strm); + writer.WriteBegin(); + for (size_t i = 0; i < n; ++i) writer.WriteElement(w.Value(i)); + writer.WriteEnd(); + return strm; +} + +template +inline std::istream &operator>>(std::istream &strm, TupleWeight &w) { + CompositeWeightReader reader(strm); + reader.ReadBegin(); + W v; + // Reads first n-1 elements. + static_assert(n > 0, "Size must be positive."); + for (size_t i = 0; i < n - 1; ++i) { + reader.ReadElement(&v); + w.SetValue(i, v); + } + // Reads n-th element. + reader.ReadElement(&v, true); + w.SetValue(n - 1, v); + reader.ReadEnd(); + return strm; +} + +} // namespace fst + +#endif // FST_TUPLE_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/types.h b/projects/llm_framework/include/fst/types.h new file mode 100644 index 00000000..9c0b7998 --- /dev/null +++ b/projects/llm_framework/include/fst/types.h @@ -0,0 +1,41 @@ +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Various type definitions (mostly for Google compatibility). + +#include // For std::ptrdiff_t. +#include // for ssize_t. +#include // for ?int*_t. + +#ifndef FST_LIB_TYPES_H_ +#define FST_LIB_TYPES_H_ + +using int8 = int8_t; +using int16 = int16_t; +using int32 = int32_t; +using int64 = int64_t; + +using uint8 = uint8_t; +using uint16 = uint16_t; +using uint32 = uint32_t; +using uint64 = uint64_t; + +#ifdef _MSC_VER +// Not really Windows-specific: they should have used ptrdiff_t in the first +// place. But on Windows there has never been ssize_t. +using ssize_t = std::ptrdiff_t; +#endif // _MSC_VER + +#endif // FST_LIB_TYPES_H_ diff --git a/projects/llm_framework/include/fst/union-find.h b/projects/llm_framework/include/fst/union-find.h new file mode 100644 index 00000000..b5c7f10b --- /dev/null +++ b/projects/llm_framework/include/fst/union-find.h @@ -0,0 +1,84 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Union-find algorithm for dense sets of non-negative integers, implemented +// using disjoint tree forests with rank heuristics and path compression. + +#ifndef FST_UNION_FIND_H_ +#define FST_UNION_FIND_H_ + +#include +#include + +namespace fst { + +// Union-Find algorithm for dense sets of non-negative integers. +template +class UnionFind { + public: + // Creates a disjoint set forest for the range [0; max); 'fail' is a value + // indicating that an element hasn't been initialized using MakeSet(...). + // The upper bound of the range can be reset (increased) using MakeSet(...). + UnionFind(T max, T fail) : parent_(max, fail), rank_(max), fail_(fail) {} + + // Finds the representative of the set 'item' belongs to, performing path + // compression if necessary. + T FindSet(T item) { + if (item >= parent_.size() || item == fail_ || parent_[item] == fail_) { + return fail_; + } + auto *p = &parent_[item]; + for (; *p != item; item = *p, p = &parent_[item]) exec_stack_.push(p); + for (; !exec_stack_.empty(); exec_stack_.pop()) *exec_stack_.top() = *p; + return *p; + } + + // Creates the (destructive) union of the sets x and y belong to. + void Union(T x, T y) { Link(FindSet(x), FindSet(y)); } + + // Initialization of an element: creates a singleton set containing 'item'. + // The range [0; max) is reset if item >= max. + T MakeSet(T item) { + if (item >= parent_.size()) { + // New value in parent_ should be initialized to fail_. + const auto nitem = item > 0 ? 2 * item : 2; + parent_.resize(nitem, fail_); + rank_.resize(nitem); + } + parent_[item] = item; + return item; + } + + // Initialization of all elements starting from 0 to max - 1 to distinct sets. + void MakeAllSet(T max) { + parent_.resize(max); + for (T item = 0; item < max; ++item) parent_[item] = item; + } + + private: + // Links trees rooted in 'x' and 'y'. + void Link(T x, T y) { + if (x == y) return; + if (rank_[x] > rank_[y]) { + parent_[y] = x; + } else { + parent_[x] = y; + if (rank_[x] == rank_[y]) { + ++rank_[y]; + } + } + } + + UnionFind(const UnionFind &) = delete; + + UnionFind &operator=(const UnionFind &) = delete; + + std::vector parent_; // Parent nodes. + std::vector rank_; // Rank of an element = min. depth in tree. + T fail_; // Value indicating lookup failure. + std::stack exec_stack_; // Used for path compression. +}; + +} // namespace fst + +#endif // FST_UNION_FIND_H_ diff --git a/projects/llm_framework/include/fst/union-weight.h b/projects/llm_framework/include/fst/union-weight.h new file mode 100644 index 00000000..bb2dea96 --- /dev/null +++ b/projects/llm_framework/include/fst/union-weight.h @@ -0,0 +1,505 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Union weight set and associated semiring operation definitions. +// +// TODO(riley): add in normalizer functor + +#ifndef FST_UNION_WEIGHT_H_ +#define FST_UNION_WEIGHT_H_ + +#include + +#include +#include +#include +#include +#include + +#include + + +namespace fst { + +// Example UnionWeightOptions for UnionWeight template below. The Merge +// operation is used to collapse elements of the set and the Compare function +// to efficiently implement the merge. In the simplest case, merge would just +// apply with equality of set elements so the result is a set (and not a +// multiset). More generally, this can be used to maintain the multiplicity or +// other such weight associated with the set elements (cf. Gallic weights). + +// template +// struct UnionWeightOptions { +// // Comparison function C is a total order on W that is monotonic w.r.t. to +// // Times: for all a, b,c != Zero(): C(a, b) => C(ca, cb) and is +// // anti-monotonic w.r.rt to Divide: C(a, b) => C(c/b, c/a). +// // +// // For all a, b: only one of C(a, b), C(b, a) or a ~ b must true where +// // ~ is an equivalence relation on W. Also we require a ~ b iff +// // a.Reverse() ~ b.Reverse(). +// using Compare = NaturalLess; +// +// // How to combine two weights if a ~ b as above. For all a, b: a ~ b => +// // merge(a, b) ~ a, Merge must define a semiring endomorphism from the +// // unmerged weight sets to the merged weight sets. +// struct Merge { +// W operator()(const W &w1, const W &w2) const { return w1; } +// }; +// +// // For ReverseWeight. +// using ReverseOptions = UnionWeightOptions; +// }; + +template +class UnionWeight; + +template +class UnionWeightIterator; + +template +class UnionWeightReverseIterator; + +template +bool operator==(const UnionWeight &, const UnionWeight &); + +// Semiring that uses Times() and One() from W and union and the empty set +// for Plus() and Zero(), respectively. Template argument O specifies the union +// weight options as above. +template +class UnionWeight { + public: + using Weight = W; + using Compare = typename O::Compare; + using Merge = typename O::Merge; + + using ReverseWeight = + UnionWeight; + + friend class UnionWeightIterator; + friend class UnionWeightReverseIterator; + friend bool operator== + <>(const UnionWeight &, const UnionWeight &); + + // Sets represented as first_ weight + rest_ weights. Uses first_ as + // NoWeight() to indicate the union weight Zero() ask the empty set. Uses + // rest_ containing NoWeight() to indicate the union weight NoWeight(). + UnionWeight() : first_(W::NoWeight()) {} + + explicit UnionWeight(W weight) : first_(weight) { + if (weight == W::NoWeight()) rest_.push_back(weight); + } + + static const UnionWeight &Zero() { + static const UnionWeight zero(W::NoWeight()); + return zero; + } + + static const UnionWeight &One() { + static const UnionWeight one(W::One()); + return one; + } + + static const UnionWeight &NoWeight() { + static const UnionWeight no_weight(W::Zero(), W::NoWeight()); + return no_weight; + } + + static const string &Type() { + static const string *const type = new string(W::Type() + "_union"); + return *type; + } + + static constexpr uint64 Properties() { + return W::Properties() & + (kLeftSemiring | kRightSemiring | kCommutative | kIdempotent); + } + + bool Member() const; + + std::istream &Read(std::istream &strm); + + std::ostream &Write(std::ostream &strm) const; + + size_t Hash() const; + + UnionWeight Quantize(float delta = kDelta) const; + + ReverseWeight Reverse() const; + + // These operations combined with the UnionWeightIterator and + // UnionWeightReverseIterator provide the access and mutation of the union + // weight internal elements. + + // Common initializer among constructors; clears existing UnionWeight. + void Clear() { + first_ = W::NoWeight(); + rest_.clear(); + } + + size_t Size() const { return first_.Member() ? rest_.size() + 1 : 0; } + + const W &Back() const { return rest_.empty() ? first_ : rest_.back(); } + + // When srt is true, assumes elements added sorted w.r.t Compare and merging + // of weights performed as needed. Otherwise, just ensures first_ is the + // least element wrt Compare. + void PushBack(W weight, bool srt); + + // Sorts the elements of the set. Assumes that first_, if present, is the + // least element. + void Sort() { rest_.sort(comp_); } + + private: + W &Back() { + if (rest_.empty()) { + return first_; + } else { + return rest_.back(); + } + } + + UnionWeight(W w1, W w2) : first_(std::move(w1)), rest_(1, std::move(w2)) {} + + W first_; // First weight in set. + std::list rest_; // Remaining weights in set. + Compare comp_; + Merge merge_; +}; + +template +void UnionWeight::PushBack(W weight, bool srt) { + if (!weight.Member()) { + rest_.push_back(std::move(weight)); + } else if (!first_.Member()) { + first_ = std::move(weight); + } else if (srt) { + auto &back = Back(); + if (comp_(back, weight)) { + rest_.push_back(std::move(weight)); + } else { + back = merge_(back, std::move(weight)); + } + } else { + if (comp_(first_, weight)) { + rest_.push_back(std::move(weight)); + } else { + rest_.push_back(first_); + first_ = std::move(weight); + } + } +} + +// Traverses union weight in the forward direction. +template +class UnionWeightIterator { + public: + explicit UnionWeightIterator(const UnionWeight &weight) + : first_(weight.first_), + rest_(weight.rest_), + init_(true), + it_(rest_.begin()) {} + + bool Done() const { return init_ ? !first_.Member() : it_ == rest_.end(); } + + const W &Value() const { return init_ ? first_ : *it_; } + + void Next() { + if (init_) { + init_ = false; + } else { + ++it_; + } + } + + void Reset() { + init_ = true; + it_ = rest_.begin(); + } + + private: + const W &first_; + const std::list &rest_; + bool init_; // in the initialized state? + typename std::list::const_iterator it_; +}; + +// Traverses union weight in backward direction. +template +class UnionWeightReverseIterator { + public: + explicit UnionWeightReverseIterator(const UnionWeight &weight) + : first_(weight.first_), + rest_(weight.rest_), + fin_(!first_.Member()), + it_(rest_.rbegin()) {} + + bool Done() const { return fin_; } + + const L &Value() const { return it_ == rest_.rend() ? first_ : *it_; } + + void Next() { + if (it_ == rest_.rend()) { + fin_ = true; + } else { + ++it_; + } + } + + void Reset() { + fin_ = !first_.Member(); + it_ = rest_.rbegin(); + } + + private: + const L &first_; + const std::list &rest_; + bool fin_; // in the final state? + typename std::list::const_reverse_iterator it_; +}; + +// UnionWeight member functions follow that require UnionWeightIterator. +template +inline std::istream &UnionWeight::Read(std::istream &istrm) { + Clear(); + int32 size; + ReadType(istrm, &size); + for (int i = 0; i < size; ++i) { + W weight; + ReadType(istrm, &weight); + PushBack(weight, true); + } + return istrm; +} + +template +inline std::ostream &UnionWeight::Write(std::ostream &ostrm) const { + const int32 size = Size(); + WriteType(ostrm, size); + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + WriteType(ostrm, it.Value()); + } + return ostrm; +} + +template +inline bool UnionWeight::Member() const { + if (Size() <= 1) return true; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + if (!it.Value().Member()) return false; + } + return true; +} + +template +inline UnionWeight UnionWeight::Quantize(float delta) const { + UnionWeight weight; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + weight.PushBack(it.Value().Quantize(delta), true); + } + return weight; +} + +template +inline typename UnionWeight::ReverseWeight UnionWeight::Reverse() + const { + ReverseWeight weight; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + weight.PushBack(it.Value().Reverse(), false); + } + weight.Sort(); + return weight; +} + +template +inline size_t UnionWeight::Hash() const { + size_t h = 0; + static constexpr int lshift = 5; + static constexpr int rshift = CHAR_BIT * sizeof(size_t) - lshift; + for (UnionWeightIterator it(*this); !it.Done(); it.Next()) { + h = h << lshift ^ h >> rshift ^ it.Value().Hash(); + } + return h; +} + +// Requires union weight has been canonicalized. +template +inline bool operator==(const UnionWeight &w1, + const UnionWeight &w2) { + if (w1.Size() != w2.Size()) return false; + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + for (; !it1.Done(); it1.Next(), it2.Next()) { + if (it1.Value() != it2.Value()) return false; + } + return true; +} + +// Requires union weight has been canonicalized. +template +inline bool operator!=(const UnionWeight &w1, + const UnionWeight &w2) { + return !(w1 == w2); +} + +// Requires union weight has been canonicalized. +template +inline bool ApproxEqual(const UnionWeight &w1, + const UnionWeight &w2, float delta = kDelta) { + if (w1.Size() != w2.Size()) return false; + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + for (; !it1.Done(); it1.Next(), it2.Next()) { + if (!ApproxEqual(it1.Value(), it2.Value(), delta)) return false; + } + return true; +} + +template +inline std::ostream &operator<<(std::ostream &ostrm, + const UnionWeight &weight) { + UnionWeightIterator it(weight); + if (it.Done()) { + return ostrm << "EmptySet"; + } else if (!weight.Member()) { + return ostrm << "BadSet"; + } else { + CompositeWeightWriter writer(ostrm); + writer.WriteBegin(); + for (; !it.Done(); it.Next()) writer.WriteElement(it.Value()); + writer.WriteEnd(); + } + return ostrm; +} + +template +inline std::istream &operator>>(std::istream &istrm, + UnionWeight &weight) { + string s; + istrm >> s; + if (s == "EmptySet") { + weight = UnionWeight::Zero(); + } else if (s == "BadSet") { + weight = UnionWeight::NoWeight(); + } else { + weight = UnionWeight::Zero(); + std::istringstream sstrm(s); + CompositeWeightReader reader(sstrm); + reader.ReadBegin(); + bool more = true; + while (more) { + W v; + more = reader.ReadElement(&v); + weight.PushBack(v, true); + } + reader.ReadEnd(); + } + return istrm; +} + +template +inline UnionWeight Plus(const UnionWeight &w1, + const UnionWeight &w2) { + if (!w1.Member() || !w2.Member()) return UnionWeight::NoWeight(); + if (w1 == UnionWeight::Zero()) return w2; + if (w2 == UnionWeight::Zero()) return w1; + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + UnionWeight sum; + typename O::Compare comp; + while (!it1.Done() && !it2.Done()) { + const auto v1 = it1.Value(); + const auto v2 = it2.Value(); + if (comp(v1, v2)) { + sum.PushBack(v1, true); + it1.Next(); + } else { + sum.PushBack(v2, true); + it2.Next(); + } + } + for (; !it1.Done(); it1.Next()) sum.PushBack(it1.Value(), true); + for (; !it2.Done(); it2.Next()) sum.PushBack(it2.Value(), true); + return sum; +} + +template +inline UnionWeight Times(const UnionWeight &w1, + const UnionWeight &w2) { + if (!w1.Member() || !w2.Member()) return UnionWeight::NoWeight(); + if (w1 == UnionWeight::Zero() || w2 == UnionWeight::Zero()) { + return UnionWeight::Zero(); + } + UnionWeightIterator it1(w1); + UnionWeightIterator it2(w2); + UnionWeight prod1; + for (; !it1.Done(); it1.Next()) { + UnionWeight prod2; + for (; !it2.Done(); it2.Next()) { + prod2.PushBack(Times(it1.Value(), it2.Value()), true); + } + prod1 = Plus(prod1, prod2); + it2.Reset(); + } + return prod1; +} + +template +inline UnionWeight Divide(const UnionWeight &w1, + const UnionWeight &w2, DivideType typ) { + if (!w1.Member() || !w2.Member()) return UnionWeight::NoWeight(); + if (w1 == UnionWeight::Zero() || w2 == UnionWeight::Zero()) { + return UnionWeight::Zero(); + } + UnionWeightIterator it1(w1); + UnionWeightReverseIterator it2(w2); + UnionWeight quot; + if (w1.Size() == 1) { + for (; !it2.Done(); it2.Next()) { + quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true); + } + } else if (w2.Size() == 1) { + for (; !it1.Done(); it1.Next()) { + quot.PushBack(Divide(it1.Value(), it2.Value(), typ), true); + } + } else { + quot = UnionWeight::NoWeight(); + } + return quot; +} + +// This function object generates weights over the union of weights for the +// underlying generators for the template weight types. This is intended +// primarily for testing. +template +class WeightGenerate> { + public: + using Weight = UnionWeight; + using Generate = WeightGenerate; + + explicit WeightGenerate(bool allow_zero = true, + size_t num_random_weights = kNumRandomWeights) + : generate_(false), allow_zero_(allow_zero), + num_random_weights_(num_random_weights) {} + + Weight operator()() const { + const int n = rand() % (num_random_weights_ + 1); // NOLINT + if (allow_zero_ && n == num_random_weights_) { + return Weight::Zero(); + } else if (n % 2 == 0) { + return Weight(generate_()); + } else { + return Plus(Weight(generate_()), Weight(generate_())); + } + } + + private: + Generate generate_; + // Permits Zero() and zero divisors. + bool allow_zero_; + // The number of alternative random weights. + const size_t num_random_weights_; +}; + +} // namespace fst + +#endif // FST_UNION_WEIGHT_H_ diff --git a/projects/llm_framework/include/fst/union.h b/projects/llm_framework/include/fst/union.h new file mode 100644 index 00000000..257099b8 --- /dev/null +++ b/projects/llm_framework/include/fst/union.h @@ -0,0 +1,157 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Functions and classes to compute the union of two FSTs. + +#ifndef FST_UNION_H_ +#define FST_UNION_H_ + +#include +#include +#include + +#include +#include + + +namespace fst { + +// Computes the union (sum) of two FSTs. This version writes the union to an +// output MutableFst. If A transduces string x to y with weight a and B +// transduces string w to v with weight b, then their union transduces x to y +// with weight a and w to v with weight b. +// +// Complexity: +// +// Time: (V_2 + E_2) +// Space: O(V_2 + E_2) +// +// where Vi is the number of states, and Ei is the number of arcs, in the ith +// FST. +template +void Union(MutableFst *fst1, const Fst &fst2) { + using Weight = typename Arc::Weight; + // Checks for symbol table compatibility. + if (!CompatSymbols(fst1->InputSymbols(), fst2.InputSymbols()) || + !CompatSymbols(fst1->OutputSymbols(), fst2.OutputSymbols())) { + FSTERROR() << "Union: Input/output symbol tables of 1st argument " + << "do not match input/output symbol tables of 2nd argument"; + fst1->SetProperties(kError, kError); + return; + } + const auto numstates1 = fst1->NumStates(); + const bool initial_acyclic1 = fst1->Properties(kInitialAcyclic, true); + const auto props1 = fst1->Properties(kFstProperties, false); + const auto props2 = fst2.Properties(kFstProperties, false); + const auto start2 = fst2.Start(); + if (start2 == kNoStateId) { + if (props2 & kError) fst1->SetProperties(kError, kError); + return; + } + if (fst2.Properties(kExpanded, false)) { + fst1->ReserveStates(numstates1 + CountStates(fst2) + + (initial_acyclic1 ? 0 : 1)); + } + for (StateIterator> siter(fst2); !siter.Done(); siter.Next()) { + const auto s1 = fst1->AddState(); + const auto s2 = siter.Value(); + fst1->SetFinal(s1, fst2.Final(s2)); + fst1->ReserveArcs(s1, fst2.NumArcs(s2)); + for (ArcIterator> aiter(fst2, s2); !aiter.Done(); aiter.Next()) { + auto arc = aiter.Value(); // Copy intended. + arc.nextstate += numstates1; + fst1->AddArc(s1, std::move(arc)); + } + } + const auto start1 = fst1->Start(); + if (start1 == kNoStateId) { + fst1->SetStart(start2); + fst1->SetProperties(props2, kCopyProperties); + return; + } + if (initial_acyclic1) { + fst1->AddArc(start1, Arc(0, 0, Weight::One(), start2 + numstates1)); + } else { + const auto nstart1 = fst1->AddState(); + fst1->SetStart(nstart1); + fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start1)); + fst1->AddArc(nstart1, Arc(0, 0, Weight::One(), start2 + numstates1)); + } + fst1->SetProperties(UnionProperties(props1, props2), kFstProperties); +} + +// Computes the union of two FSTs, modifying the RationalFst argument. +template +void Union(RationalFst *fst1, const Fst &fst2) { + fst1->GetMutableImpl()->AddUnion(fst2); +} + +using UnionFstOptions = RationalFstOptions; + +// Computes the union (sum) of two FSTs. This version is a delayed FST. If A +// transduces string x to y with weight a and B transduces string w to v with +// weight b, then their union transduces x to y with weight a and w to v with +// weight b. +// +// Complexity: +// +// Time: O(v_1 + e_1 + v_2 + e_2) +// Space: O(v_1 + v_2) +// +// where vi is the number of states visited, and ei is the number of arcs +// visited, in the ith FST. Constant time and space to visit an input state or +// arc is assumed and exclusive of caching. +template +class UnionFst : public RationalFst { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + UnionFst(const Fst &fst1, const Fst &fst2) { + GetMutableImpl()->InitUnion(fst1, fst2); + } + + UnionFst(const Fst &fst1, const Fst &fst2, + const UnionFstOptions &opts) + : RationalFst(opts) { + GetMutableImpl()->InitUnion(fst1, fst2); + } + + // See Fst<>::Copy() for doc. + UnionFst(const UnionFst &fst, bool safe = false) + : RationalFst(fst, safe) {} + + // Gets a copy of this UnionFst. See Fst<>::Copy() for further doc. + UnionFst *Copy(bool safe = false) const override { + return new UnionFst(*this, safe); + } + + private: + using ImplToFst>::GetImpl; + using ImplToFst>::GetMutableImpl; +}; + +// Specialization for UnionFst. +template +class StateIterator> : public StateIterator> { + public: + explicit StateIterator(const UnionFst &fst) + : StateIterator>(fst) {} +}; + +// Specialization for UnionFst. +template +class ArcIterator> : public ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const UnionFst &fst, StateId s) + : ArcIterator>(fst, s) {} +}; + +using StdUnionFst = UnionFst; + +} // namespace fst + +#endif // FST_UNION_H_ diff --git a/projects/llm_framework/include/fst/util.h b/projects/llm_framework/include/fst/util.h new file mode 100644 index 00000000..c7520213 --- /dev/null +++ b/projects/llm_framework/include/fst/util.h @@ -0,0 +1,400 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// FST utility inline definitions. + +#ifndef FST_UTIL_H_ +#define FST_UTIL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include + + +// Utility for error handling. + +DECLARE_bool(fst_error_fatal); + +#define FSTERROR() \ + (FLAGS_fst_error_fatal ? LOG(FATAL) : LOG(ERROR)) + +namespace fst { + +// Utility for type I/O. + +// Reads types from an input stream. + +// Generic case. +template ::value, T>::type* = nullptr> +inline std::istream &ReadType(std::istream &strm, T *t) { + return t->Read(strm); +} + +// Numeric (boolean, integral, floating-point) case. +template ::value, T>::type* = nullptr> +inline std::istream &ReadType(std::istream &strm, T *t) { + return strm.read(reinterpret_cast(t), sizeof(T)); \ +} + +// String case. +inline std::istream &ReadType(std::istream &strm, string *s) { // NOLINT + s->clear(); + int32 ns = 0; + strm.read(reinterpret_cast(&ns), sizeof(ns)); + for (int32 i = 0; i < ns; ++i) { + char c; + strm.read(&c, 1); + *s += c; + } + return strm; +} + +// Pair case. +template +inline std::istream &ReadType(std::istream &strm, std::pair *p) { + ReadType(strm, &p->first); + ReadType(strm, &p->second); + return strm; +} + +template +inline std::istream &ReadType(std::istream &strm, std::pair *p) { + ReadType(strm, const_cast(&p->first)); + ReadType(strm, &p->second); + return strm; +} + +namespace internal { +template +std::istream &ReadContainerType(std::istream &strm, C *c, ReserveFn reserve) { + c->clear(); + int64 n = 0; + ReadType(strm, &n); + reserve(c, n); + auto insert = std::inserter(*c, c->begin()); + for (int64 i = 0; i < n; ++i) { + typename C::value_type value; + ReadType(strm, &value); + *insert = value; + } + return strm; +} +} // namespace internal + +template +std::istream &ReadType(std::istream &strm, std::vector *c) { + return internal::ReadContainerType( + strm, c, [](decltype(c) v, int n) { v->reserve(n); }); +} + +template +std::istream &ReadType(std::istream &strm, std::list *c) { + return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {}); +} + +template +std::istream &ReadType(std::istream &strm, std::set *c) { + return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {}); +} + +template +std::istream &ReadType(std::istream &strm, std::map *c) { + return internal::ReadContainerType(strm, c, [](decltype(c) v, int n) {}); +} + +template +std::istream &ReadType(std::istream &strm, std::unordered_set *c) { + return internal::ReadContainerType( + strm, c, [](decltype(c) v, int n) { v->reserve(n); }); +} + +template +std::istream &ReadType(std::istream &strm, std::unordered_map *c) { + return internal::ReadContainerType( + strm, c, [](decltype(c) v, int n) { v->reserve(n); }); +} + +// Writes types to an output stream. + +// Generic case. +template ::value, T>::type* = nullptr> +inline std::ostream &WriteType(std::ostream &strm, const T t) { + t.Write(strm); + return strm; +} + +// Numeric (boolean, integral, floating-point) case. +template ::value, T>::type* = nullptr> +inline std::ostream &WriteType(std::ostream &strm, const T t) { + return strm.write(reinterpret_cast(&t), sizeof(T)); +} + +// String case. +inline std::ostream &WriteType(std::ostream &strm, const string &s) { // NOLINT + int32 ns = s.size(); + strm.write(reinterpret_cast(&ns), sizeof(ns)); + return strm.write(s.data(), ns); +} + +// Pair case. +template +inline std::ostream &WriteType(std::ostream &strm, + const std::pair &p) { // NOLINT + WriteType(strm, p.first); + WriteType(strm, p.second); + return strm; +} + +namespace internal { +template +std::ostream &WriteContainer(std::ostream &strm, const C &c) { + const int64 n = c.size(); + WriteType(strm, n); + for (const auto &e : c) { + WriteType(strm, e); + } + return strm; +} +} // namespace internal + +template +std::ostream &WriteType(std::ostream &strm, const std::vector &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::list &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::set &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::map &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::unordered_set &c) { + return internal::WriteContainer(strm, c); +} + +template +std::ostream &WriteType(std::ostream &strm, const std::unordered_map &c) { + return internal::WriteContainer(strm, c); +} + +// Utilities for converting between int64 or Weight and string. + +int64 StrToInt64(const string &s, const string &src, size_t nline, + bool allow_negative, bool *error = nullptr); + +template +Weight StrToWeight(const string &s, const string &src, size_t nline) { + Weight w; + std::istringstream strm(s); + strm >> w; + if (!strm) { + FSTERROR() << "StrToWeight: Bad weight = \"" << s << "\", source = " << src + << ", line = " << nline; + return Weight::NoWeight(); + } + return w; +} + +template +void WeightToStr(Weight w, string *s) { + std::ostringstream strm; + strm.precision(9); + strm << w; + s->append(strm.str().data(), strm.str().size()); +} + +// Utilities for reading/writing integer pairs (typically labels) + +// Modifies line using a vector of pointers to a buffer beginning with line. +void SplitString(char *line, const char *delim, std::vector *vec, + bool omit_empty_strings); + +template +bool ReadIntPairs(const string &filename, std::vector> *pairs, + bool allow_negative = false) { + std::ifstream strm(filename, std::ios_base::in); + if (!strm) { + LOG(ERROR) << "ReadIntPairs: Can't open file: " << filename; + return false; + } + const int kLineLen = 8096; + char line[kLineLen]; + size_t nline = 0; + pairs->clear(); + while (strm.getline(line, kLineLen)) { + ++nline; + std::vector col; + SplitString(line, "\n\t ", &col, true); + // empty line or comment? + if (col.empty() || col[0][0] == '\0' || col[0][0] == '#') continue; + if (col.size() != 2) { + LOG(ERROR) << "ReadIntPairs: Bad number of columns, " + << "file = " << filename << ", line = " << nline; + return false; + } + bool err; + I i1 = StrToInt64(col[0], filename, nline, allow_negative, &err); + if (err) return false; + I i2 = StrToInt64(col[1], filename, nline, allow_negative, &err); + if (err) return false; + pairs->push_back(std::make_pair(i1, i2)); + } + return true; +} + +template +bool WriteIntPairs(const string &filename, + const std::vector> &pairs) { + std::ostream *strm = &std::cout; + if (!filename.empty()) { + strm = new std::ofstream(filename); + if (!*strm) { + LOG(ERROR) << "WriteIntPairs: Can't open file: " << filename; + return false; + } + } + for (ssize_t n = 0; n < pairs.size(); ++n) { + *strm << pairs[n].first << "\t" << pairs[n].second << "\n"; + } + if (!*strm) { + LOG(ERROR) << "WriteIntPairs: Write failed: " + << (filename.empty() ? "standard output" : filename); + return false; + } + if (strm != &std::cout) delete strm; + return true; +} + +// Utilities for reading/writing label pairs. + +template +bool ReadLabelPairs(const string &filename, + std::vector> *pairs, + bool allow_negative = false) { + return ReadIntPairs(filename, pairs, allow_negative); +} + +template +bool WriteLabelPairs(const string &filename, + const std::vector> &pairs) { + return WriteIntPairs(filename, pairs); +} + +// Utilities for converting a type name to a legal C symbol. + +void ConvertToLegalCSymbol(string *s); + +// Utilities for stream I/O. + +bool AlignInput(std::istream &strm); +bool AlignOutput(std::ostream &strm); + +// An associative container for which testing membership is faster than an STL +// set if members are restricted to an interval that excludes most non-members. +// A Key must have ==, !=, and < operators defined. Element NoKey should be a +// key that marks an uninitialized key and is otherwise unused. Find() returns +// an STL const_iterator to the match found, otherwise it equals End(). +template +class CompactSet { + public: + using const_iterator = typename std::set::const_iterator; + + CompactSet() : min_key_(NoKey), max_key_(NoKey) {} + + CompactSet(const CompactSet &compact_set) + : set_(compact_set.set_), + min_key_(compact_set.min_key_), + max_key_(compact_set.max_key_) {} + + void Insert(Key key) { + set_.insert(key); + if (min_key_ == NoKey || key < min_key_) min_key_ = key; + if (max_key_ == NoKey || max_key_ < key) max_key_ = key; + } + + void Erase(Key key) { + set_.erase(key); + if (set_.empty()) { + min_key_ = max_key_ = NoKey; + } else if (key == min_key_) { + ++min_key_; + } else if (key == max_key_) { + --max_key_; + } + } + + void Clear() { + set_.clear(); + min_key_ = max_key_ = NoKey; + } + + const_iterator Find(Key key) const { + if (min_key_ == NoKey || key < min_key_ || max_key_ < key) { + return set_.end(); + } else { + return set_.find(key); + } + } + + bool Member(Key key) const { + if (min_key_ == NoKey || key < min_key_ || max_key_ < key) { + return false; // out of range + } else if (min_key_ != NoKey && max_key_ + 1 == min_key_ + set_.size()) { + return true; // dense range + } else { + return set_.count(key); + } + } + + const_iterator Begin() const { return set_.begin(); } + + const_iterator End() const { return set_.end(); } + + // All stored keys are greater than or equal to this value. + Key LowerBound() const { return min_key_; } + + // All stored keys are less than or equal to this value. + Key UpperBound() const { return max_key_; } + + private: + std::set set_; + Key min_key_; + Key max_key_; + + void operator=(const CompactSet &) = delete; +}; + +} // namespace fst + +#endif // FST_UTIL_H_ diff --git a/projects/llm_framework/include/fst/vector-fst.h b/projects/llm_framework/include/fst/vector-fst.h new file mode 100644 index 00000000..7514bc55 --- /dev/null +++ b/projects/llm_framework/include/fst/vector-fst.h @@ -0,0 +1,796 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Simple concrete, mutable FST whose states and arcs are stored in STL vectors. + +#ifndef FST_VECTOR_FST_H_ +#define FST_VECTOR_FST_H_ + +#include +#include +#include + +#include + +#include // For optional argument declarations +#include +#include + + +namespace fst { + +template +class VectorFst; + +template +void Cast(const F &, G *); + +// Arcs (of type A) implemented by an STL vector per state. M specifies Arc +// allocator (default declared in fst-decl.h). +template */> +class VectorState { + public: + using Arc = A; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + using ArcAllocator = M; + using StateAllocator = + typename ArcAllocator::template rebind>::other; + + // Provide STL allocator for arcs. + explicit VectorState(const ArcAllocator &alloc) + : final_(Weight::Zero()), niepsilons_(0), noepsilons_(0), arcs_(alloc) {} + + VectorState(const VectorState &state, const ArcAllocator &alloc) + : final_(state.Final()), + niepsilons_(state.NumInputEpsilons()), + noepsilons_(state.NumOutputEpsilons()), + arcs_(state.arcs_.begin(), state.arcs_.end(), alloc) {} + + void Reset() { + final_ = Weight::Zero(); + niepsilons_ = 0; + noepsilons_ = 0; + arcs_.clear(); + } + + Weight Final() const { return final_; } + + size_t NumInputEpsilons() const { return niepsilons_; } + + size_t NumOutputEpsilons() const { return noepsilons_; } + + size_t NumArcs() const { return arcs_.size(); } + + const Arc &GetArc(size_t n) const { return arcs_[n]; } + + const Arc *Arcs() const { return !arcs_.empty() ? &arcs_[0] : nullptr; } + + Arc *MutableArcs() { return !arcs_.empty() ? &arcs_[0] : nullptr; } + + void ReserveArcs(size_t n) { arcs_.reserve(n); } + + void SetFinal(Weight weight) { final_ = std::move(weight); } + + void SetNumInputEpsilons(size_t n) { niepsilons_ = n; } + + void SetNumOutputEpsilons(size_t n) { noepsilons_ = n; } + + void AddArc(const Arc &arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(arc); + } + + void AddArc(Arc &&arc) { + IncrementNumEpsilons(arc); + arcs_.push_back(std::move(arc)); + } + + template + void EmplaceArc(T&&... ctor_args) { + arcs_.emplace_back(std::forward(ctor_args)...); + IncrementNumEpsilons(arcs_.back()); + } + + void SetArc(const Arc &arc, size_t n) { + if (arcs_[n].ilabel == 0) --niepsilons_; + if (arcs_[n].olabel == 0) --noepsilons_; + IncrementNumEpsilons(arc); + arcs_[n] = arc; + } + + void DeleteArcs() { + niepsilons_ = 0; + noepsilons_ = 0; + arcs_.clear(); + } + + void DeleteArcs(size_t n) { + for (size_t i = 0; i < n; ++i) { + if (arcs_.back().ilabel == 0) --niepsilons_; + if (arcs_.back().olabel == 0) --noepsilons_; + arcs_.pop_back(); + } + } + + // For state class allocation. + void *operator new(size_t size, StateAllocator *alloc) { + return alloc->allocate(1); + } + + // For state destruction and memory freeing. + static void Destroy(VectorState *state, StateAllocator *alloc) { + if (state) { + state->~VectorState(); + alloc->deallocate(state, 1); + } + } + + private: + // Update the number of epsilons as a result of having added an arc. + void IncrementNumEpsilons(const Arc &arc) { + if (arc.ilabel == 0) ++niepsilons_; + if (arc.olabel == 0) ++noepsilons_; + } + + Weight final_; // Final weight. + size_t niepsilons_; // # of input epsilons + size_t noepsilons_; // # of output epsilons + std::vector arcs_; // Arc container. +}; + +namespace internal { + +// States are implemented by STL vectors, templated on the +// State definition. This does not manage the Fst properties. +template +class VectorFstBaseImpl : public FstImpl { + public: + using State = S; + using Arc = typename State::Arc; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + VectorFstBaseImpl() : start_(kNoStateId) {} + + ~VectorFstBaseImpl() override { + for (size_t s = 0; s < states_.size(); ++s) { + State::Destroy(states_[s], &state_alloc_); + } + } + + // Copying is not permitted. + VectorFstBaseImpl(const VectorFstBaseImpl &) = delete; + VectorFstBaseImpl &operator=(const VectorFstBaseImpl &) = delete; + + // Moving is permitted. + VectorFstBaseImpl(VectorFstBaseImpl &&impl) noexcept + : FstImpl(), + states_(std::move(impl.states_)), + start_(impl.start_) { + impl.states_.clear(); + impl.start_ = kNoStateId; + } + + VectorFstBaseImpl &operator=(VectorFstBaseImpl &&impl) noexcept { + states_ = std::move(impl.states_); + start_ = impl.start_; + impl.states_.clear(); + impl.start_ = kNoStateId; + return *this; + } + + StateId Start() const { return start_; } + + Weight Final(StateId state) const { return states_[state]->Final(); } + + StateId NumStates() const { return states_.size(); } + + size_t NumArcs(StateId state) const { return states_[state]->NumArcs(); } + + size_t NumInputEpsilons(StateId state) const { + return GetState(state)->NumInputEpsilons(); + } + + size_t NumOutputEpsilons(StateId state) const { + return GetState(state)->NumOutputEpsilons(); + } + + void SetStart(StateId state) { start_ = state; } + + void SetFinal(StateId state, Weight weight) { + states_[state]->SetFinal(std::move(weight)); + } + + StateId AddState() { + states_.push_back(new (&state_alloc_) State(arc_alloc_)); + return states_.size() - 1; + } + + StateId AddState(State *state) { + states_.push_back(state); + return states_.size() - 1; + } + + void AddArc(StateId state, const Arc &arc) { states_[state]->AddArc(arc); } + + void AddArc(StateId state, Arc &&arc) { + states_[state]->AddArc(std::move(arc)); + } + + template + void EmplaceArc(StateId state, T&&... ctor_args) { + states_[state]->EmplaceArc(std::forward(ctor_args)...); + } + + void DeleteStates(const std::vector &dstates) { + std::vector newid(states_.size(), 0); + for (size_t i = 0; i < dstates.size(); ++i) newid[dstates[i]] = kNoStateId; + StateId nstates = 0; + for (StateId state = 0; state < states_.size(); ++state) { + if (newid[state] != kNoStateId) { + newid[state] = nstates; + if (state != nstates) states_[nstates] = states_[state]; + ++nstates; + } else { + State::Destroy(states_[state], &state_alloc_); + } + } + states_.resize(nstates); + for (StateId state = 0; state < states_.size(); ++state) { + auto *arcs = states_[state]->MutableArcs(); + size_t narcs = 0; + auto nieps = states_[state]->NumInputEpsilons(); + auto noeps = states_[state]->NumOutputEpsilons(); + for (size_t i = 0; i < states_[state]->NumArcs(); ++i) { + const auto t = newid[arcs[i].nextstate]; + if (t != kNoStateId) { + arcs[i].nextstate = t; + if (i != narcs) arcs[narcs] = arcs[i]; + ++narcs; + } else { + if (arcs[i].ilabel == 0) --nieps; + if (arcs[i].olabel == 0) --noeps; + } + } + states_[state]->DeleteArcs(states_[state]->NumArcs() - narcs); + states_[state]->SetNumInputEpsilons(nieps); + states_[state]->SetNumOutputEpsilons(noeps); + } + if (Start() != kNoStateId) SetStart(newid[Start()]); + } + + void DeleteStates() { + for (size_t state = 0; state < states_.size(); ++state) { + State::Destroy(states_[state], &state_alloc_); + } + states_.clear(); + SetStart(kNoStateId); + } + + void DeleteArcs(StateId state, size_t n) { states_[state]->DeleteArcs(n); } + + void DeleteArcs(StateId state) { states_[state]->DeleteArcs(); } + + State *GetState(StateId state) { return states_[state]; } + + const State *GetState(StateId state) const { return states_[state]; } + + void SetState(StateId state, State *vstate) { states_[state] = vstate; } + + void ReserveStates(StateId n) { states_.reserve(n); } + + void ReserveArcs(StateId state, size_t n) { states_[state]->ReserveArcs(n); } + + // Provide information needed for generic state iterator. + void InitStateIterator(StateIteratorData *data) const { + data->base = nullptr; + data->nstates = states_.size(); + } + + // Provide information needed for generic arc iterator. + void InitArcIterator(StateId state, ArcIteratorData *data) const { + data->base = nullptr; + data->narcs = states_[state]->NumArcs(); + data->arcs = states_[state]->Arcs(); + data->ref_count = nullptr; + } + + private: + std::vector states_; // States represenation. + StateId start_; // Initial state. + typename State::StateAllocator state_alloc_; // For state allocation. + typename State::ArcAllocator arc_alloc_; // For arc allocation. +}; + +// This is a VectorFstBaseImpl container that holds VectorStates and manages FST +// properties. +template +class VectorFstImpl : public VectorFstBaseImpl { + public: + using State = S; + using Arc = typename State::Arc; + using Label = typename Arc::Label; + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + using FstImpl::SetInputSymbols; + using FstImpl::SetOutputSymbols; + using FstImpl::SetType; + using FstImpl::SetProperties; + using FstImpl::Properties; + + using VectorFstBaseImpl::Start; + using VectorFstBaseImpl::NumStates; + using VectorFstBaseImpl::GetState; + using VectorFstBaseImpl::ReserveArcs; + + friend class MutableArcIterator>; + + using BaseImpl = VectorFstBaseImpl; + + VectorFstImpl() { + SetType("vector"); + SetProperties(kNullProperties | kStaticProperties); + } + + explicit VectorFstImpl(const Fst &fst); + + static VectorFstImpl *Read(std::istream &strm, const FstReadOptions &opts); + + void SetStart(StateId state) { + BaseImpl::SetStart(state); + SetProperties(SetStartProperties(Properties())); + } + + void SetFinal(StateId state, Weight weight) { + const auto old_weight = BaseImpl::Final(state); + const auto properties = + SetFinalProperties(Properties(), old_weight, weight); + BaseImpl::SetFinal(state, std::move(weight)); + SetProperties(properties); + } + + StateId AddState() { + const auto state = BaseImpl::AddState(); + SetProperties(AddStateProperties(Properties())); + return state; + } + + void AddArc(StateId state, const Arc &arc) { + BaseImpl::AddArc(state, arc); + UpdatePropertiesAfterAddArc(state); + } + + void AddArc(StateId state, Arc &&arc) { + BaseImpl::AddArc(state, std::move(arc)); + UpdatePropertiesAfterAddArc(state); + } + + template + void EmplaceArc(StateId state, T&&... ctor_args) { + BaseImpl::EmplaceArc(state, std::forward(ctor_args)...); + UpdatePropertiesAfterAddArc(state); + } + + void DeleteStates(const std::vector &dstates) { + BaseImpl::DeleteStates(dstates); + SetProperties(DeleteStatesProperties(Properties())); + } + + void DeleteStates() { + BaseImpl::DeleteStates(); + SetProperties(DeleteAllStatesProperties(Properties(), kStaticProperties)); + } + + void DeleteArcs(StateId state, size_t n) { + BaseImpl::DeleteArcs(state, n); + SetProperties(DeleteArcsProperties(Properties())); + } + + void DeleteArcs(StateId state) { + BaseImpl::DeleteArcs(state); + SetProperties(DeleteArcsProperties(Properties())); + } + + // Properties always true of this FST class + static constexpr uint64 kStaticProperties = kExpanded | kMutable; + + private: + void UpdatePropertiesAfterAddArc(StateId state) { + auto *vstate = GetState(state); + const size_t num_arcs{vstate->NumArcs()}; + if (num_arcs) { + const auto &arc = vstate->GetArc(num_arcs - 1); + const auto *parc = (num_arcs < 2) + ? nullptr + : &(vstate->GetArc(num_arcs - 2)); + SetProperties(AddArcProperties(Properties(), state, arc, parc)); + } + } + + // Minimum file format version supported. + static constexpr int kMinFileVersion = 2; +}; + +template +constexpr uint64 VectorFstImpl::kStaticProperties; + +template +constexpr int VectorFstImpl::kMinFileVersion; + +template +VectorFstImpl::VectorFstImpl(const Fst &fst) { + SetType("vector"); + SetInputSymbols(fst.InputSymbols()); + SetOutputSymbols(fst.OutputSymbols()); + BaseImpl::SetStart(fst.Start()); + if (fst.Properties(kExpanded, false)) { + BaseImpl::ReserveStates(CountStates(fst)); + } + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + const auto state = siter.Value(); + BaseImpl::AddState(); + BaseImpl::SetFinal(state, fst.Final(state)); + ReserveArcs(state, fst.NumArcs(state)); + for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + BaseImpl::AddArc(state, arc); + } + } + SetProperties(fst.Properties(kCopyProperties, false) | kStaticProperties); +} + +template +VectorFstImpl *VectorFstImpl::Read(std::istream &strm, + const FstReadOptions &opts) { + std::unique_ptr> impl(new VectorFstImpl()); + FstHeader hdr; + if (!impl->ReadHeader(strm, opts, kMinFileVersion, &hdr)) return nullptr; + impl->BaseImpl::SetStart(hdr.Start()); + if (hdr.NumStates() != kNoStateId) impl->ReserveStates(hdr.NumStates()); + StateId state = 0; + for (; hdr.NumStates() == kNoStateId || state < hdr.NumStates(); ++state) { + Weight weight; + if (!weight.Read(strm)) break; + impl->BaseImpl::AddState(); + auto *vstate = impl->GetState(state); + vstate->SetFinal(weight); + int64 narcs; + ReadType(strm, &narcs); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->ReserveArcs(state, narcs); + for (int64 i = 0; i < narcs; ++i) { + Arc arc; + ReadType(strm, &arc.ilabel); + ReadType(strm, &arc.olabel); + arc.weight.Read(strm); + ReadType(strm, &arc.nextstate); + if (!strm) { + LOG(ERROR) << "VectorFst::Read: Read failed: " << opts.source; + return nullptr; + } + impl->BaseImpl::AddArc(state, std::move(arc)); + } + } + if (hdr.NumStates() != kNoStateId && state != hdr.NumStates()) { + LOG(ERROR) << "VectorFst::Read: Unexpected end of file: " << opts.source; + return nullptr; + } + return impl.release(); +} + +} // namespace internal + +// Simple concrete, mutable FST. This class attaches interface to implementation +// and handles reference counting, delegating most methods to ImplToMutableFst. +// Also supports ReserveStates and ReserveArcs methods (cf. STL vector methods). +// The second optional template argument gives the State definition. +template */> +class VectorFst : public ImplToMutableFst> { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using State = S; + using Impl = internal::VectorFstImpl; + + friend class StateIterator>; + friend class ArcIterator>; + friend class MutableArcIterator>; + + template + friend void Cast(const F &, G *); + + VectorFst() : ImplToMutableFst(std::make_shared()) {} + + explicit VectorFst(const Fst &fst) + : ImplToMutableFst(std::make_shared(fst)) {} + + VectorFst(const VectorFst &fst, bool safe = false) + : ImplToMutableFst(fst) {} + + VectorFst(VectorFst &&) noexcept; + + // Get a copy of this VectorFst. See Fst<>::Copy() for further doc. + VectorFst *Copy(bool safe = false) const override { + return new VectorFst(*this, safe); + } + + VectorFst &operator=(const VectorFst &) = default; + + VectorFst &operator=(VectorFst &&) noexcept; + + VectorFst &operator=(const Fst &fst) override { + if (this != &fst) SetImpl(std::make_shared(fst)); + return *this; + } + + template + void EmplaceArc(StateId state, T&&... ctor_args) { + MutateCheck(); + GetMutableImpl()->EmplaceArc(state, std::forward(ctor_args)...); + } + + // Reads a VectorFst from an input stream, returning nullptr on error. + static VectorFst *Read(std::istream &strm, + const FstReadOptions &opts) { + auto *impl = Impl::Read(strm, opts); + return impl ? new VectorFst(std::shared_ptr(impl)) + : nullptr; + } + + // Read a VectorFst from a file, returning nullptr on error; empty filename + // reads from standard input. + static VectorFst *Read(const string &filename) { + auto *impl = ImplToExpandedFst>::Read(filename); + return impl ? new VectorFst(std::shared_ptr(impl)) + : nullptr; + } + + bool Write(std::ostream &strm, const FstWriteOptions &opts) const override { + return WriteFst(*this, strm, opts); + } + + bool Write(const string &filename) const override { + return Fst::WriteFile(filename); + } + + template + static bool WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts); + + void InitStateIterator(StateIteratorData *data) const override { + GetImpl()->InitStateIterator(data); + } + + void InitArcIterator(StateId s, ArcIteratorData *data) const override { + GetImpl()->InitArcIterator(s, data); + } + + inline void InitMutableArcIterator(StateId s, + MutableArcIteratorData *) override; + + using ImplToMutableFst>::ReserveArcs; + using ImplToMutableFst>::ReserveStates; + + private: + using ImplToMutableFst>::GetImpl; + using ImplToMutableFst>::GetMutableImpl; + using ImplToMutableFst>::MutateCheck; + using ImplToMutableFst>::SetImpl; + + explicit VectorFst(std::shared_ptr impl) + : ImplToMutableFst(impl) {} +}; + +template +inline VectorFst::VectorFst( + VectorFst &&fst) noexcept = default; + +template +inline VectorFst &VectorFst::operator=( + VectorFst &&fst) noexcept = default; + +// Writes FST to file in Vector format, potentially with a pass over the machine +// before writing to compute number of states. +template +template +bool VectorFst::WriteFst(const FST &fst, std::ostream &strm, + const FstWriteOptions &opts) { + static constexpr int file_version = 2; + bool update_header = true; + FstHeader hdr; + hdr.SetStart(fst.Start()); + hdr.SetNumStates(kNoStateId); + std::streampos start_offset = 0; + if (fst.Properties(kExpanded, false) || opts.stream_write || + (start_offset = strm.tellp()) != -1) { + hdr.SetNumStates(CountStates(fst)); + update_header = false; + } + const auto properties = + fst.Properties(kCopyProperties, false) | Impl::kStaticProperties; + internal::FstImpl::WriteFstHeader(fst, strm, opts, file_version, + "vector", properties, &hdr); + StateId num_states = 0; + for (StateIterator siter(fst); !siter.Done(); siter.Next()) { + const auto s = siter.Value(); + fst.Final(s).Write(strm); + const int64 narcs = fst.NumArcs(s); + WriteType(strm, narcs); + for (ArcIterator aiter(fst, s); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + WriteType(strm, arc.ilabel); + WriteType(strm, arc.olabel); + arc.weight.Write(strm); + WriteType(strm, arc.nextstate); + } + ++num_states; + } + strm.flush(); + if (!strm) { + LOG(ERROR) << "VectorFst::Write: Write failed: " << opts.source; + return false; + } + if (update_header) { + hdr.SetNumStates(num_states); + return internal::FstImpl::UpdateFstHeader( + fst, strm, opts, file_version, "vector", properties, &hdr, + start_offset); + } else { + if (num_states != hdr.NumStates()) { + LOG(ERROR) << "Inconsistent number of states observed during write"; + return false; + } + } + return true; +} + +// Specialization for VectorFst; see generic version in fst.h for sample usage +// (but use the VectorFst type instead). This version should inline. +template +class StateIterator> { + public: + using StateId = typename Arc::StateId; + + explicit StateIterator(const VectorFst &fst) + : nstates_(fst.GetImpl()->NumStates()), s_(0) {} + + bool Done() const { return s_ >= nstates_; } + + StateId Value() const { return s_; } + + void Next() { ++s_; } + + void Reset() { s_ = 0; } + + private: + const StateId nstates_; + StateId s_; +}; + +// Specialization for VectorFst; see generic version in fst.h for sample usage +// (but use the VectorFst type instead). This version should inline. +template +class ArcIterator> { + public: + using StateId = typename Arc::StateId; + + ArcIterator(const VectorFst &fst, StateId s) + : arcs_(fst.GetImpl()->GetState(s)->Arcs()), + narcs_(fst.GetImpl()->GetState(s)->NumArcs()), + i_(0) {} + + bool Done() const { return i_ >= narcs_; } + + const Arc &Value() const { return arcs_[i_]; } + + void Next() { ++i_; } + + void Reset() { i_ = 0; } + + void Seek(size_t a) { i_ = a; } + + size_t Position() const { return i_; } + + constexpr uint32 Flags() const { return kArcValueFlags; } + + void SetFlags(uint32, uint32) {} + + private: + const Arc *arcs_; + size_t narcs_; + size_t i_; +}; + +// Specialization for VectorFst; see generic version in mutable-fst.h for sample +// usage (but use the VectorFst type instead). This version should inline. +template +class MutableArcIterator> + : public MutableArcIteratorBase { + public: + using StateId = typename Arc::StateId; + using Weight = typename Arc::Weight; + + MutableArcIterator(VectorFst *fst, StateId s) : i_(0) { + fst->MutateCheck(); + state_ = fst->GetMutableImpl()->GetState(s); + properties_ = &fst->GetImpl()->properties_; + } + + bool Done() const final { return i_ >= state_->NumArcs(); } + + const Arc &Value() const final { return state_->GetArc(i_); } + + void Next() final { ++i_; } + + size_t Position() const final { return i_; } + + void Reset() final { i_ = 0; } + + void Seek(size_t a) final { i_ = a; } + + void SetValue(const Arc &arc) final { + const auto &oarc = state_->GetArc(i_); + if (oarc.ilabel != oarc.olabel) *properties_ &= ~kNotAcceptor; + if (oarc.ilabel == 0) { + *properties_ &= ~kIEpsilons; + if (oarc.olabel == 0) *properties_ &= ~kEpsilons; + } + if (oarc.olabel == 0) *properties_ &= ~kOEpsilons; + if (oarc.weight != Weight::Zero() && oarc.weight != Weight::One()) { + *properties_ &= ~kWeighted; + } + state_->SetArc(arc, i_); + if (arc.ilabel != arc.olabel) { + *properties_ |= kNotAcceptor; + *properties_ &= ~kAcceptor; + } + if (arc.ilabel == 0) { + *properties_ |= kIEpsilons; + *properties_ &= ~kNoIEpsilons; + if (arc.olabel == 0) { + *properties_ |= kEpsilons; + *properties_ &= ~kNoEpsilons; + } + } + if (arc.olabel == 0) { + *properties_ |= kOEpsilons; + *properties_ &= ~kNoOEpsilons; + } + if (arc.weight != Weight::Zero() && arc.weight != Weight::One()) { + *properties_ |= kWeighted; + *properties_ &= ~kUnweighted; + } + *properties_ &= kSetArcProperties | kAcceptor | kNotAcceptor | kEpsilons | + kNoEpsilons | kIEpsilons | kNoIEpsilons | kOEpsilons | + kNoOEpsilons | kWeighted | kUnweighted; + } + + uint32 Flags() const final { return kArcValueFlags; } + + void SetFlags(uint32, uint32) final {} + + private: + State *state_; + uint64 *properties_; + size_t i_; +}; + +// Provides information needed for the generic mutable arc iterator. +template +inline void VectorFst::InitMutableArcIterator( + StateId s, MutableArcIteratorData *data) { + data->base = new MutableArcIterator>(this, s); +} + +// A useful alias when using StdArc. +using StdVectorFst = VectorFst; + +} // namespace fst + +#endif // FST_VECTOR_FST_H_ diff --git a/projects/llm_framework/include/fst/verify.h b/projects/llm_framework/include/fst/verify.h new file mode 100644 index 00000000..2ea8a64a --- /dev/null +++ b/projects/llm_framework/include/fst/verify.h @@ -0,0 +1,100 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Function to verify an FST's contents. + +#ifndef FST_VERIFY_H_ +#define FST_VERIFY_H_ + +#include + +#include +#include + + +namespace fst { + +// Verifies that an Fst's contents are sane. +template +bool Verify(const Fst &fst, bool allow_negative_labels = false) { + using StateId = typename Arc::StateId; + const auto start = fst.Start(); + const auto *isyms = fst.InputSymbols(); + const auto *osyms = fst.OutputSymbols(); + // Count states + StateId ns = 0; + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) ++ns; + if (start == kNoStateId && ns > 0) { + LOG(ERROR) << "Verify: FST start state ID not set"; + return false; + } else if (start >= ns) { + LOG(ERROR) << "Verify: FST start state ID exceeds number of states"; + return false; + } + for (StateIterator> siter(fst); !siter.Done(); siter.Next()) { + auto state = siter.Value(); + size_t na = 0; + for (ArcIterator> aiter(fst, state); !aiter.Done(); aiter.Next()) { + const auto &arc = aiter.Value(); + if (!allow_negative_labels && arc.ilabel < 0) { + LOG(ERROR) << "Verify: FST input label ID of arc at position " << na + << " of state " << state << " is negative"; + return false; + } else if (isyms && isyms->Find(arc.ilabel).empty()) { + LOG(ERROR) << "Verify: FST input label ID " << arc.ilabel + << " of arc at position " << na << " of state " << state + << " is missing from input symbol table \"" << isyms->Name() + << "\""; + return false; + } else if (!allow_negative_labels && arc.olabel < 0) { + LOG(ERROR) << "Verify: FST output label ID of arc at position " << na + << " of state " << state << " is negative"; + return false; + } else if (osyms && osyms->Find(arc.olabel).empty()) { + LOG(ERROR) << "Verify: FST output label ID " << arc.olabel + << " of arc at position " << na << " of state " << state + << " is missing from output symbol table \"" << osyms->Name() + << "\""; + return false; + } else if (!arc.weight.Member()) { + LOG(ERROR) << "Verify: FST weight of arc at position " << na + << " of state " << state << " is invalid"; + return false; + } else if (arc.nextstate < 0) { + LOG(ERROR) << "Verify: FST destination state ID of arc at position " + << na << " of state " << state << " is negative"; + return false; + } else if (arc.nextstate >= ns) { + LOG(ERROR) << "Verify: FST destination state ID of arc at position " + << na << " of state " << state + << " exceeds number of states"; + return false; + } + ++na; + } + if (!fst.Final(state).Member()) { + LOG(ERROR) << "Verify: FST final weight of state " << state + << " is invalid"; + return false; + } + } + const auto fst_props = fst.Properties(kFstProperties, false); + if (fst_props & kError) { + LOG(ERROR) << "Verify: FST error property is set"; + return false; + } + uint64 known_props; + uint64 test_props = + ComputeProperties(fst, kFstProperties, &known_props, false); + if (!CompatProperties(fst_props, test_props)) { + LOG(ERROR) << "Verify: Stored FST properties incorrect " + << "(props1 = stored props, props2 = tested)"; + return false; + } else { + return true; + } +} + +} // namespace fst + +#endif // FST_VERIFY_H_ diff --git a/projects/llm_framework/include/fst/visit.h b/projects/llm_framework/include/fst/visit.h new file mode 100644 index 00000000..c6658047 --- /dev/null +++ b/projects/llm_framework/include/fst/visit.h @@ -0,0 +1,321 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// Queue-dependent visitation of finite-state transducers. See also dfs-visit.h. + +#ifndef FST_VISIT_H_ +#define FST_VISIT_H_ + + +#include +#include + + +namespace fst { + +// Visitor Interface: class determining actions taken during a visit. If any of +// the boolean member functions return false, the visit is aborted by first +// calling FinishState() on all unfinished (grey) states and then calling +// FinishVisit(). +// +// Note this is more general than the visitor interface in dfs-visit.h but lacks +// some DFS-specific behavior. +// +// template +// class Visitor { +// public: +// using StateId = typename Arc::StateId; +// +// Visitor(T *return_data); +// +// // Invoked before visit. +// void InitVisit(const Fst &fst); +// +// // Invoked when state discovered (2nd arg is visitation root). +// bool InitState(StateId s, StateId root); +// +// // Invoked when arc to white/undiscovered state examined. +// bool WhiteArc(StateId s, const Arc &arc); +// +// // Invoked when arc to grey/unfinished state examined. +// bool GreyArc(StateId s, const Arc &arc); +// +// // Invoked when arc to black/finished state examined. +// bool BlackArc(StateId s, const Arc &arc); +// +// // Invoked when state finished. +// void FinishState(StateId s); +// +// // Invoked after visit. +// void FinishVisit(); +// }; + +// Performs queue-dependent visitation. Visitor class argument determines +// actions and contains any return data. ArcFilter determines arcs that are +// considered. If 'access_only' is true, performs visitation only to states +// accessible from the initial state. +template +void Visit(const FST &fst, Visitor *visitor, Queue *queue, ArcFilter filter, + bool access_only = false) { + using Arc = typename FST::Arc; + using StateId = typename Arc::StateId; + visitor->InitVisit(fst); + const auto start = fst.Start(); + if (start == kNoStateId) { + visitor->FinishVisit(); + return; + } + // An FST's state's visit color. + static constexpr uint8 kWhiteState = 0x01; // Undiscovered. + static constexpr uint8 kGreyState = 0x02; // Discovered & unfinished. + static constexpr uint8 kBlackState = 0x04; // Finished. + // We destroy an iterator as soon as possible and mark it so. + static constexpr uint8 kArcIterDone = 0x08; + std::vector state_status; + std::vector *> arc_iterator; + MemoryPool> aiter_pool; + StateId nstates = start + 1; // Number of known states in general case. + bool expanded = false; + if (fst.Properties(kExpanded, false)) { // Tests if expanded, then uses + nstates = CountStates(fst); // ExpandedFst::NumStates(). + expanded = true; + } + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + StateIterator> siter(fst); + // Continues visit while true. + bool visit = true; + // Iterates over trees in visit forest. + for (auto root = start; visit && root < nstates;) { + visit = visitor->InitState(root, root); + state_status[root] = kGreyState; + queue->Enqueue(root); + while (!queue->Empty()) { + auto state = queue->Head(); + if (state >= state_status.size()) { + nstates = state + 1; + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + } + // Creates arc iterator if needed. + if (!arc_iterator[state] && !(state_status[state] & kArcIterDone) && + visit) { + arc_iterator[state] = new (&aiter_pool) ArcIterator(fst, state); + } + // Deletes arc iterator if done. + auto *aiter = arc_iterator[state]; + if ((aiter && aiter->Done()) || !visit) { + Destroy(aiter, &aiter_pool); + arc_iterator[state] = nullptr; + state_status[state] |= kArcIterDone; + } + // Dequeues state and marks black if done. + if (state_status[state] & kArcIterDone) { + queue->Dequeue(); + visitor->FinishState(state); + state_status[state] = kBlackState; + continue; + } + const auto &arc = aiter->Value(); + if (arc.nextstate >= state_status.size()) { + nstates = arc.nextstate + 1; + state_status.resize(nstates, kWhiteState); + arc_iterator.resize(nstates); + } + // Visits respective arc types. + if (filter(arc)) { + // Enqueues destination state and marks grey if white. + if (state_status[arc.nextstate] == kWhiteState) { + visit = visitor->WhiteArc(state, arc); + if (!visit) continue; + visit = visitor->InitState(arc.nextstate, root); + state_status[arc.nextstate] = kGreyState; + queue->Enqueue(arc.nextstate); + } else if (state_status[arc.nextstate] == kBlackState) { + visit = visitor->BlackArc(state, arc); + } else { + visit = visitor->GreyArc(state, arc); + } + } + aiter->Next(); + // Destroys an iterator ASAP for efficiency. + if (aiter->Done()) { + Destroy(aiter, &aiter_pool); + arc_iterator[state] = nullptr; + state_status[state] |= kArcIterDone; + } + } + if (access_only) break; + // Finds next tree root. + for (root = (root == start) ? 0 : root + 1; + root < nstates && state_status[root] != kWhiteState; ++root) { + } + // Check for a state beyond the largest known state. + if (!expanded && root == nstates) { + for (; !siter.Done(); siter.Next()) { + if (siter.Value() == nstates) { + ++nstates; + state_status.push_back(kWhiteState); + arc_iterator.push_back(nullptr); + break; + } + } + } + } + visitor->FinishVisit(); +} + +template +inline void Visit(const Fst &fst, Visitor *visitor, Queue *queue) { + Visit(fst, visitor, queue, AnyArcFilter()); +} + +// Copies input FST to mutable FST following queue order. +template +class CopyVisitor { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + explicit CopyVisitor(MutableFst *ofst) : ifst_(nullptr), ofst_(ofst) {} + + void InitVisit(const Fst &ifst) { + ifst_ = &ifst; + ofst_->DeleteStates(); + ofst_->SetStart(ifst_->Start()); + } + + bool InitState(StateId state, StateId) { + while (ofst_->NumStates() <= state) ofst_->AddState(); + return true; + } + + bool WhiteArc(StateId state, const Arc &arc) { + ofst_->AddArc(state, arc); + return true; + } + + bool GreyArc(StateId state, const Arc &arc) { + ofst_->AddArc(state, arc); + return true; + } + + bool BlackArc(StateId state, const Arc &arc) { + ofst_->AddArc(state, arc); + return true; + } + + void FinishState(StateId state) { + ofst_->SetFinal(state, ifst_->Final(state)); + } + + void FinishVisit() {} + + private: + const Fst *ifst_; + MutableFst *ofst_; +}; + +// Visits input FST up to a state limit following queue order. +template +class PartialVisitor { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + explicit PartialVisitor(StateId maxvisit) + : fst_(nullptr), maxvisit_(maxvisit) {} + + void InitVisit(const Fst &ifst) { + fst_ = &ifst; + ninit_ = 0; + nfinish_ = 0; + } + + bool InitState(StateId state, StateId root) { + ++ninit_; + return ninit_ <= maxvisit_; + } + + bool WhiteArc(StateId state, const Arc &arc) { return true; } + + bool GreyArc(StateId state, const Arc &arc) { return true; } + + bool BlackArc(StateId state, const Arc &arc) { return true; } + + void FinishState(StateId state) { + fst_->Final(state); // Visits super-final arc. + ++nfinish_; + } + + void FinishVisit() {} + + StateId NumInitialized() { return ninit_; } + + StateId NumFinished() { return nfinish_; } + + private: + const Fst *fst_; + StateId maxvisit_; + StateId ninit_; + StateId nfinish_; +}; + +// Copies input FST to mutable FST up to a state limit following queue order. +template +class PartialCopyVisitor : public CopyVisitor { + public: + using Arc = A; + using StateId = typename Arc::StateId; + + using CopyVisitor::WhiteArc; + + PartialCopyVisitor(MutableFst *ofst, StateId maxvisit, + bool copy_grey = true, bool copy_black = true) + : CopyVisitor(ofst), maxvisit_(maxvisit), + copy_grey_(copy_grey), copy_black_(copy_black) {} + + void InitVisit(const Fst &ifst) { + CopyVisitor::InitVisit(ifst); + ninit_ = 0; + nfinish_ = 0; + } + + bool InitState(StateId state, StateId root) { + CopyVisitor::InitState(state, root); + ++ninit_; + return ninit_ <= maxvisit_; + } + + bool GreyArc(StateId state, const Arc &arc) { + if (copy_grey_) return CopyVisitor::GreyArc(state, arc); + return true; + } + + bool BlackArc(StateId state, const Arc &arc) { + if (copy_black_) return CopyVisitor::BlackArc(state, arc); + return true; + } + + void FinishState(StateId state) { + CopyVisitor::FinishState(state); + ++nfinish_; + } + + void FinishVisit() {} + + StateId NumInitialized() { return ninit_; } + + StateId NumFinished() { return nfinish_; } + + private: + StateId maxvisit_; + StateId ninit_; + StateId nfinish_; + const bool copy_grey_; + const bool copy_black_; +}; + +} // namespace fst + +#endif // FST_VISIT_H_ diff --git a/projects/llm_framework/include/fst/weight.h b/projects/llm_framework/include/fst/weight.h new file mode 100644 index 00000000..ea012ec6 --- /dev/null +++ b/projects/llm_framework/include/fst/weight.h @@ -0,0 +1,389 @@ +// See www.openfst.org for extensive documentation on this weighted +// finite-state transducer library. +// +// General weight set and associated semiring operation definitions. + +#ifndef FST_WEIGHT_H_ +#define FST_WEIGHT_H_ + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + + +DECLARE_string(fst_weight_parentheses); +DECLARE_string(fst_weight_separator); + +namespace fst { + +// A semiring is specified by two binary operations Plus and Times and two +// designated elements Zero and One with the following properties: +// +// Plus: associative, commutative, and has Zero as its identity. +// +// Times: associative and has identity One, distributes w.r.t. Plus, and +// has Zero as an annihilator: +// Times(Zero(), a) == Times(a, Zero()) = Zero(). +// +// A left semiring distributes on the left; a right semiring is similarly +// defined. +// +// A Weight class must have binary functions Plus and Times and static member +// functions Zero() and One() and these must form (at least) a left or right +// semiring. +// +// In addition, the following should be defined for a Weight: +// +// Member: predicate on set membership. +// +// NoWeight: static member function that returns an element that is +// not a set member; used to signal an error. +// +// >>: reads textual representation of a weight. +// +// <<: prints textual representation of a weight. +// +// Read(istream &istrm): reads binary representation of a weight. +// +// Write(ostream &ostrm): writes binary representation of a weight. +// +// Hash: maps weight to size_t. +// +// ApproxEqual: approximate equality (for inexact weights) +// +// Quantize: quantizes w.r.t delta (for inexact weights) +// +// Divide: +// - In a left semiring, for all a, b, b', c: +// if Times(a, b) = c, Divide(c, a, DIVIDE_LEFT) = b' and b'.Member(), +// then Times(a, b') = c. +// - In a right semiring, for all a, a', b, c: +// if Times(a, b) = c, Divide(c, b, DIVIDE_RIGHT) = a' and a'.Member(), +// then Times(a', b) = c. +// - In a commutative semiring, +// * for all a, c: +// Divide(c, a, DIVIDE_ANY) = Divide(c, a, DIVIDE_LEFT) +// = Divide(c, a, DIVIDE_RIGHT) +// * for all a, b, b', c: +// if Times(a, b), Divide(c, a, DIVIDE_ANY) = b' and b'.Member(), +// then Times(a, b') = c +// - In the case where there exist no b such that c = Times(a, b), the +// return value of Divide(c, a, DIVIDE_LEFT) is unspecified. Returning +// Weight::NoWeight() is recommemded but not required in order to +// allow the most efficient implementation. +// - All algorithms in this library only call Divide(c, a) when it is +// guaranteed that there exists a b such that c = Times(a, b). +// +// ReverseWeight: the type of the corresponding reverse weight. +// +// Typically the same type as Weight for a (both left and right) semiring. +// For the left string semiring, it is the right string semiring. +// +// Reverse: a mapping from Weight to ReverseWeight s.t. +// +// --> Reverse(Reverse(a)) = a +// --> Reverse(Plus(a, b)) = Plus(Reverse(a), Reverse(b)) +// --> Reverse(Times(a, b)) = Times(Reverse(b), Reverse(a)) +// Typically the identity mapping in a (both left and right) semiring. +// In the left string semiring, it maps to the reverse string in the right +// string semiring. +// +// Properties: specifies additional properties that hold: +// LeftSemiring: indicates weights form a left semiring. +// RightSemiring: indicates weights form a right semiring. +// Commutative: for all a,b: Times(a,b) == Times(b,a) +// Idempotent: for all a: Plus(a, a) == a. +// Path: for all a, b: Plus(a, b) == a or Plus(a, b) == b. + +// CONSTANT DEFINITIONS + +// A representable float near .001. +constexpr float kDelta = 1.0F / 1024.0F; + +// For all a, b, c: Times(c, Plus(a, b)) = Plus(Times(c, a), Times(c, b)). +constexpr uint64 kLeftSemiring = 0x0000000000000001ULL; + +// For all a, b, c: Times(Plus(a, b), c) = Plus(Times(a, c), Times(b, c)). +constexpr uint64 kRightSemiring = 0x0000000000000002ULL; + +constexpr uint64 kSemiring = kLeftSemiring | kRightSemiring; + +// For all a, b: Times(a, b) = Times(b, a). +constexpr uint64 kCommutative = 0x0000000000000004ULL; + +// For all a: Plus(a, a) = a. +constexpr uint64 kIdempotent = 0x0000000000000008ULL; + +// For all a, b: Plus(a, b) = a or Plus(a, b) = b. +constexpr uint64 kPath = 0x0000000000000010ULL; + +// For random weight generation: default number of distinct weights. +// This is also used for a few other weight generation defaults. +constexpr size_t kNumRandomWeights = 5; + +// Weight property boolean constants needed for SFINAE. + +// MSVC compiler bug workaround: an expression containing W::Properties() cannot +// be directly used as a value argument to std::enable_if or integral_constant. +// WeightPropertiesThunk::Properties works instead, however. +namespace bug { +template +struct WeightPropertiesThunk { + WeightPropertiesThunk() = delete; + constexpr static const uint64 Properties = W::Properties(); +}; + +template +using TestWeightProperties = std::integral_constant::Properties & props) == props>; +} // namespace bug + +template +using IsIdempotent = bug::TestWeightProperties; + +template +using IsPath = bug::TestWeightProperties; + + +// Determines direction of division. +enum DivideType { + DIVIDE_LEFT, // left division + DIVIDE_RIGHT, // right division + DIVIDE_ANY +}; // division in a commutative semiring + +// NATURAL ORDER +// +// By definition: +// +// a <= b iff a + b = a +// +// The natural order is a negative partial order iff the semiring is +// idempotent. It is trivially monotonic for plus. It is left +// (resp. right) monotonic for times iff the semiring is left +// (resp. right) distributive. It is a total order iff the semiring +// has the path property. +// +// For more information, see: +// +// Mohri, M. 2002. Semiring framework and algorithms for shortest-distance +// problems, Journal of Automata, Languages and +// Combinatorics 7(3): 321-350, 2002. +// +// We define the strict version of this order below. + +template +class NaturalLess { +public: + using Weight = W; + + NaturalLess() { + if (!(W::Properties() & kIdempotent)) { + FSTERROR() << "NaturalLess: Weight type is not idempotent: " << W::Type(); + } + } + + bool operator()(const W &w1, const W &w2) const { + return (Plus(w1, w2) == w1) && w1 != w2; + } +}; + +// Power is the iterated product for arbitrary semirings such that Power(w, 0) +// is One() for the semiring, and Power(w, n) = Times(Power(w, n - 1), w). +template +Weight Power(const Weight &weight, size_t n) { + auto result = Weight::One(); + for (size_t i = 0; i < n; ++i) result = Times(result, weight); + return result; +} + +// Simple default adder class. Specializations might be more complex. +template +class Adder { + public: + explicit Adder(Weight w = Weight::Zero()) : sum_(w) { } + + Weight Add(const Weight &w) { + sum_ = Plus(sum_, w); + return sum_; + } + + Weight Sum() { return sum_; } + + void Reset(Weight w = Weight::Zero()) { sum_ = w; } + + private: + Weight sum_; +}; + +// General weight converter: raises error. +template +struct WeightConvert { + W2 operator()(W1 w1) const { + FSTERROR() << "WeightConvert: Can't convert weight from \"" << W1::Type() + << "\" to \"" << W2::Type(); + return W2::NoWeight(); + } +}; + +// Specialized weight converter to self. +template +struct WeightConvert { + W operator()(W weight) const { return weight; } +}; + +// General random weight generator: raises error. +template +struct WeightGenerate { + W operator()() const { + FSTERROR() << "WeightGenerate: No random generator for " << W::Type(); + return W::NoWeight(); + } +}; + +namespace internal { + +class CompositeWeightIO { + public: + CompositeWeightIO(); + CompositeWeightIO(char separator, std::pair parentheses); + + std::pair parentheses() const { + return {open_paren_, close_paren_}; + } + char separator() const { return separator_; } + + bool error() const { return error_; } + + protected: + const char separator_; + const char open_paren_; + const char close_paren_; + + private: + bool error_; +}; + +} // namespace internal + +// Helper class for writing textual composite weights. +class CompositeWeightWriter : public internal::CompositeWeightIO { + public: + // Uses configuration from flags (FLAGS_fst_weight_separator, + // FLAGS_fst_weight_parentheses). + explicit CompositeWeightWriter(std::ostream &ostrm); + + // parentheses defines the opening and closing parenthesis characters. + // Set parentheses = {0, 0} to disable writing parenthesis. + CompositeWeightWriter(std::ostream &ostrm, char separator, + std::pair parentheses); + + CompositeWeightWriter(const CompositeWeightWriter &) = delete; + CompositeWeightWriter &operator=(const CompositeWeightWriter &) = delete; + + // Writes open parenthesis to a stream if option selected. + void WriteBegin(); + + // Writes element to a stream. + template + void WriteElement(const T &comp) { + if (i_++ > 0) ostrm_ << separator_; + ostrm_ << comp; + } + + // Writes close parenthesis to a stream if option selected. + void WriteEnd(); + + private: + std::ostream &ostrm_; + int i_ = 0; // Element position. +}; + +// Helper class for reading textual composite weights. Elements are separated by +// a separator character. There must be at least one element per textual +// representation. Parentheses characters should be set if the composite +// weights themselves contain composite weights to ensure proper parsing. +class CompositeWeightReader : public internal::CompositeWeightIO { + public: + // Uses configuration from flags (FLAGS_fst_weight_separator, + // FLAGS_fst_weight_parentheses). + explicit CompositeWeightReader(std::istream &istrm); + + // parentheses defines the opening and closing parenthesis characters. + // Set parentheses = {0, 0} to disable reading parenthesis. + CompositeWeightReader(std::istream &istrm, char separator, + std::pair parentheses); + + CompositeWeightReader(const CompositeWeightReader &) = delete; + CompositeWeightReader &operator=(const CompositeWeightReader &) = delete; + + // Reads open parenthesis from a stream if option selected. + void ReadBegin(); + + // Reads element from a stream. The second argument, when true, indicates that + // this will be the last element (allowing more forgiving formatting of the + // last element). Returns false when last element is read. + template + bool ReadElement(T *comp, bool last = false); + + // Finalizes reading. + void ReadEnd(); + + private: + std::istream &istrm_; // Input stream. + int c_ = 0; // Last character read, or EOF. + int depth_ = 0; // Weight parentheses depth. +}; + +template +inline bool CompositeWeightReader::ReadElement(T *comp, bool last) { + string s; + const bool has_parens = open_paren_ != 0; + while ((c_ != std::istream::traits_type::eof()) && !std::isspace(c_) && + (c_ != separator_ || depth_ > 1 || last) && + (c_ != close_paren_ || depth_ != 1)) { + s += c_; + // If parentheses encountered before separator, they must be matched. + if (has_parens && c_ == open_paren_) { + ++depth_; + } else if (has_parens && c_ == close_paren_) { + // Failure on unmatched parentheses. + if (depth_ == 0) { + FSTERROR() << "CompositeWeightReader: Unmatched close paren: " + << "Is the fst_weight_parentheses flag set correctly?"; + istrm_.clear(std::ios::badbit); + return false; + } + --depth_; + } + c_ = istrm_.get(); + } + if (s.empty()) { + FSTERROR() << "CompositeWeightReader: Empty element: " + << "Is the fst_weight_parentheses flag set correctly?"; + istrm_.clear(std::ios::badbit); + return false; + } + std::istringstream istrm(s); + istrm >> *comp; + // Skips separator/close parenthesis. + if (c_ != std::istream::traits_type::eof() && !std::isspace(c_)) { + c_ = istrm_.get(); + } + const bool is_eof = c_ == std::istream::traits_type::eof(); + // Clears fail bit if just EOF. + if (is_eof && !istrm_.bad()) istrm_.clear(std::ios::eofbit); + return !is_eof && !std::isspace(c_); +} + +} // namespace fst + +#endif // FST_WEIGHT_H_ diff --git a/projects/llm_framework/include/gflags/defines.h b/projects/llm_framework/include/gflags/defines.h new file mode 100644 index 00000000..944ed7db --- /dev/null +++ b/projects/llm_framework/include/gflags/defines.h @@ -0,0 +1,48 @@ +/* Generated from defines.h.in during build configuration using CMake. */ + +// Note: This header file is only used internally. It is not part of public interface! +// Any cmakedefine is defined using the -D flag instead when Bazel is used. +// For Bazel, this file is thus not used to avoid a private file in $(GENDIR). + +#ifndef GFLAGS_DEFINES_H_ +#define GFLAGS_DEFINES_H_ + + +// Define if you build this library for a MS Windows OS. +/* #undef OS_WINDOWS */ + +// Define if you have the header file. +#define HAVE_STDINT_H + +// Define if you have the header file. +#define HAVE_SYS_TYPES_H + +// Define if you have the header file. +#define HAVE_INTTYPES_H + +// Define if you have the header file. +#define HAVE_SYS_STAT_H + +// Define if you have the header file. +#define HAVE_UNISTD_H + +// Define if you have the header file. +#define HAVE_FNMATCH_H + +// Define if you have the header file (Windows 2000/XP). +/* #undef HAVE_SHLWAPI_H */ + +// Define if you have the strtoll function. +#define HAVE_STRTOLL + +// Define if you have the strtoq function. +/* #undef HAVE_STRTOQ */ + +// Define if you have the header file. +/* #undef HAVE_PTHREAD */ + +// Define if your pthread library defines the type pthread_rwlock_t +/* #undef HAVE_RWLOCK */ + + +#endif // GFLAGS_DEFINES_H_ diff --git a/projects/llm_framework/include/gflags/gflags.h b/projects/llm_framework/include/gflags/gflags.h new file mode 100644 index 00000000..9273da8d --- /dev/null +++ b/projects/llm_framework/include/gflags/gflags.h @@ -0,0 +1,626 @@ +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// Revamped and reorganized by Craig Silverstein +// +// This is the file that should be included by any file which declares +// or defines a command line flag or wants to parse command line flags +// or print a program usage message (which will include information about +// flags). Executive summary, in the form of an example foo.cc file: +// +// #include "foo.h" // foo.h has a line "DECLARE_int32(start);" +// #include "validators.h" // hypothetical file defining ValidateIsFile() +// +// DEFINE_int32(end, 1000, "The last record to read"); +// +// DEFINE_string(filename, "my_file.txt", "The file to read"); +// // Crash if the specified file does not exist. +// static bool dummy = RegisterFlagValidator(&FLAGS_filename, +// &ValidateIsFile); +// +// DECLARE_bool(verbose); // some other file has a DEFINE_bool(verbose, ...) +// +// void MyFunc() { +// if (FLAGS_verbose) printf("Records %d-%d\n", FLAGS_start, FLAGS_end); +// } +// +// Then, at the command-line: +// ./foo --noverbose --start=5 --end=100 +// +// For more details, see +// doc/gflags.html +// +// --- A note about thread-safety: +// +// We describe many functions in this routine as being thread-hostile, +// thread-compatible, or thread-safe. Here are the meanings we use: +// +// thread-safe: it is safe for multiple threads to call this routine +// (or, when referring to a class, methods of this class) +// concurrently. +// thread-hostile: it is not safe for multiple threads to call this +// routine (or methods of this class) concurrently. In gflags, +// most thread-hostile routines are intended to be called early in, +// or even before, main() -- that is, before threads are spawned. +// thread-compatible: it is safe for multiple threads to read from +// this variable (when applied to variables), or to call const +// methods of this class (when applied to classes), as long as no +// other thread is writing to the variable or calling non-const +// methods of this class. + +#ifndef GFLAGS_GFLAGS_H_ +#define GFLAGS_GFLAGS_H_ + +#include +#include + +#include "gflags/gflags_declare.h" // IWYU pragma: export + + +// We always want to export variables defined in user code +#ifndef GFLAGS_DLL_DEFINE_FLAG +# if GFLAGS_IS_A_DLL && defined(_MSC_VER) +# define GFLAGS_DLL_DEFINE_FLAG __declspec(dllexport) +# else +# define GFLAGS_DLL_DEFINE_FLAG +# endif +#endif + + +namespace GFLAGS_NAMESPACE { + + +// -------------------------------------------------------------------- +// To actually define a flag in a file, use DEFINE_bool, +// DEFINE_string, etc. at the bottom of this file. You may also find +// it useful to register a validator with the flag. This ensures that +// when the flag is parsed from the commandline, or is later set via +// SetCommandLineOption, we call the validation function. It is _not_ +// called when you assign the value to the flag directly using the = operator. +// +// The validation function should return true if the flag value is valid, and +// false otherwise. If the function returns false for the new setting of the +// flag, the flag will retain its current value. If it returns false for the +// default value, ParseCommandLineFlags() will die. +// +// This function is safe to call at global construct time (as in the +// example below). +// +// Example use: +// static bool ValidatePort(const char* flagname, int32 value) { +// if (value > 0 && value < 32768) // value is ok +// return true; +// printf("Invalid value for --%s: %d\n", flagname, (int)value); +// return false; +// } +// DEFINE_int32(port, 0, "What port to listen on"); +// static bool dummy = RegisterFlagValidator(&FLAGS_port, &ValidatePort); + +// Returns true if successfully registered, false if not (because the +// first argument doesn't point to a command-line flag, or because a +// validator is already registered for this flag). +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const bool* flag, bool (*validate_fn)(const char*, bool)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const int32* flag, bool (*validate_fn)(const char*, int32)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const uint32* flag, bool (*validate_fn)(const char*, uint32)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const int64* flag, bool (*validate_fn)(const char*, int64)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const uint64* flag, bool (*validate_fn)(const char*, uint64)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const double* flag, bool (*validate_fn)(const char*, double)); +extern GFLAGS_DLL_DECL bool RegisterFlagValidator(const std::string* flag, bool (*validate_fn)(const char*, const std::string&)); + +// Convenience macro for the registration of a flag validator +#define DEFINE_validator(name, validator) \ + static const bool name##_validator_registered = \ + GFLAGS_NAMESPACE::RegisterFlagValidator(&FLAGS_##name, validator) + + +// -------------------------------------------------------------------- +// These methods are the best way to get access to info about the +// list of commandline flags. Note that these routines are pretty slow. +// GetAllFlags: mostly-complete info about the list, sorted by file. +// ShowUsageWithFlags: pretty-prints the list to stdout (what --help does) +// ShowUsageWithFlagsRestrict: limit to filenames with restrict as a substr +// +// In addition to accessing flags, you can also access argv[0] (the program +// name) and argv (the entire commandline), which we sock away a copy of. +// These variables are static, so you should only set them once. +// +// No need to export this data only structure from DLL, avoiding VS warning 4251. +struct CommandLineFlagInfo { + std::string name; // the name of the flag + std::string type; // the type of the flag: int32, etc + std::string description; // the "help text" associated with the flag + std::string current_value; // the current value, as a string + std::string default_value; // the default value, as a string + std::string filename; // 'cleaned' version of filename holding the flag + bool has_validator_fn; // true if RegisterFlagValidator called on this flag + bool is_default; // true if the flag has the default value and + // has not been set explicitly from the cmdline + // or via SetCommandLineOption + const void* flag_ptr; // pointer to the flag's current value (i.e. FLAGS_foo) +}; + +// Using this inside of a validator is a recipe for a deadlock. +// TODO(user) Fix locking when validators are running, to make it safe to +// call validators during ParseAllFlags. +// Also make sure then to uncomment the corresponding unit test in +// gflags_unittest.sh +extern GFLAGS_DLL_DECL void GetAllFlags(std::vector* OUTPUT); +// These two are actually defined in gflags_reporting.cc. +extern GFLAGS_DLL_DECL void ShowUsageWithFlags(const char *argv0); // what --help does +extern GFLAGS_DLL_DECL void ShowUsageWithFlagsRestrict(const char *argv0, const char *restrict); + +// Create a descriptive string for a flag. +// Goes to some trouble to make pretty line breaks. +extern GFLAGS_DLL_DECL std::string DescribeOneFlag(const CommandLineFlagInfo& flag); + +// Thread-hostile; meant to be called before any threads are spawned. +extern GFLAGS_DLL_DECL void SetArgv(int argc, const char** argv); + +// The following functions are thread-safe as long as SetArgv() is +// only called before any threads start. +extern GFLAGS_DLL_DECL const std::vector& GetArgvs(); +extern GFLAGS_DLL_DECL const char* GetArgv(); // all of argv as a string +extern GFLAGS_DLL_DECL const char* GetArgv0(); // only argv0 +extern GFLAGS_DLL_DECL uint32 GetArgvSum(); // simple checksum of argv +extern GFLAGS_DLL_DECL const char* ProgramInvocationName(); // argv0, or "UNKNOWN" if not set +extern GFLAGS_DLL_DECL const char* ProgramInvocationShortName(); // basename(argv0) + +// ProgramUsage() is thread-safe as long as SetUsageMessage() is only +// called before any threads start. +extern GFLAGS_DLL_DECL const char* ProgramUsage(); // string set by SetUsageMessage() + +// VersionString() is thread-safe as long as SetVersionString() is only +// called before any threads start. +extern GFLAGS_DLL_DECL const char* VersionString(); // string set by SetVersionString() + + + +// -------------------------------------------------------------------- +// Normally you access commandline flags by just saying "if (FLAGS_foo)" +// or whatever, and set them by calling "FLAGS_foo = bar" (or, more +// commonly, via the DEFINE_foo macro). But if you need a bit more +// control, we have programmatic ways to get/set the flags as well. +// These programmatic ways to access flags are thread-safe, but direct +// access is only thread-compatible. + +// Return true iff the flagname was found. +// OUTPUT is set to the flag's value, or unchanged if we return false. +extern GFLAGS_DLL_DECL bool GetCommandLineOption(const char* name, std::string* OUTPUT); + +// Return true iff the flagname was found. OUTPUT is set to the flag's +// CommandLineFlagInfo or unchanged if we return false. +extern GFLAGS_DLL_DECL bool GetCommandLineFlagInfo(const char* name, CommandLineFlagInfo* OUTPUT); + +// Return the CommandLineFlagInfo of the flagname. exit() if name not found. +// Example usage, to check if a flag's value is currently the default value: +// if (GetCommandLineFlagInfoOrDie("foo").is_default) ... +extern GFLAGS_DLL_DECL CommandLineFlagInfo GetCommandLineFlagInfoOrDie(const char* name); + +enum GFLAGS_DLL_DECL FlagSettingMode { + // update the flag's value (can call this multiple times). + SET_FLAGS_VALUE, + // update the flag's value, but *only if* it has not yet been updated + // with SET_FLAGS_VALUE, SET_FLAG_IF_DEFAULT, or "FLAGS_xxx = nondef". + SET_FLAG_IF_DEFAULT, + // set the flag's default value to this. If the flag has not yet updated + // yet (via SET_FLAGS_VALUE, SET_FLAG_IF_DEFAULT, or "FLAGS_xxx = nondef") + // change the flag's current value to the new default value as well. + SET_FLAGS_DEFAULT +}; + +// Set a particular flag ("command line option"). Returns a string +// describing the new value that the option has been set to. The +// return value API is not well-specified, so basically just depend on +// it to be empty if the setting failed for some reason -- the name is +// not a valid flag name, or the value is not a valid value -- and +// non-empty else. + +// SetCommandLineOption uses set_mode == SET_FLAGS_VALUE (the common case) +extern GFLAGS_DLL_DECL std::string SetCommandLineOption (const char* name, const char* value); +extern GFLAGS_DLL_DECL std::string SetCommandLineOptionWithMode(const char* name, const char* value, FlagSettingMode set_mode); + + +// -------------------------------------------------------------------- +// Saves the states (value, default value, whether the user has set +// the flag, registered validators, etc) of all flags, and restores +// them when the FlagSaver is destroyed. This is very useful in +// tests, say, when you want to let your tests change the flags, but +// make sure that they get reverted to the original states when your +// test is complete. +// +// Example usage: +// void TestFoo() { +// FlagSaver s1; +// FLAG_foo = false; +// FLAG_bar = "some value"; +// +// // test happens here. You can return at any time +// // without worrying about restoring the FLAG values. +// } +// +// Note: This class is marked with GFLAGS_ATTRIBUTE_UNUSED because all +// the work is done in the constructor and destructor, so in the standard +// usage example above, the compiler would complain that it's an +// unused variable. +// +// This class is thread-safe. However, its destructor writes to +// exactly the set of flags that have changed value during its +// lifetime, so concurrent _direct_ access to those flags +// (i.e. FLAGS_foo instead of {Get,Set}CommandLineOption()) is unsafe. + +class GFLAGS_DLL_DECL FlagSaver { + public: + FlagSaver(); + ~FlagSaver(); + + private: + class FlagSaverImpl* impl_; // we use pimpl here to keep API steady + + FlagSaver(const FlagSaver&); // no copying! + void operator=(const FlagSaver&); +}__attribute((unused)); + +// -------------------------------------------------------------------- +// Some deprecated or hopefully-soon-to-be-deprecated functions. + +// This is often used for logging. TODO(csilvers): figure out a better way +extern GFLAGS_DLL_DECL std::string CommandlineFlagsIntoString(); +// Usually where this is used, a FlagSaver should be used instead. +extern GFLAGS_DLL_DECL +bool ReadFlagsFromString(const std::string& flagfilecontents, + const char* prog_name, + bool errors_are_fatal); // uses SET_FLAGS_VALUE + +// These let you manually implement --flagfile functionality. +// DEPRECATED. +extern GFLAGS_DLL_DECL bool AppendFlagsIntoFile(const std::string& filename, const char* prog_name); +extern GFLAGS_DLL_DECL bool ReadFromFlagsFile(const std::string& filename, const char* prog_name, bool errors_are_fatal); // uses SET_FLAGS_VALUE + + +// -------------------------------------------------------------------- +// Useful routines for initializing flags from the environment. +// In each case, if 'varname' does not exist in the environment +// return defval. If 'varname' does exist but is not valid +// (e.g., not a number for an int32 flag), abort with an error. +// Otherwise, return the value. NOTE: for booleans, for true use +// 't' or 'T' or 'true' or '1', for false 'f' or 'F' or 'false' or '0'. + +extern GFLAGS_DLL_DECL bool BoolFromEnv(const char *varname, bool defval); +extern GFLAGS_DLL_DECL int32 Int32FromEnv(const char *varname, int32 defval); +extern GFLAGS_DLL_DECL uint32 Uint32FromEnv(const char *varname, uint32 defval); +extern GFLAGS_DLL_DECL int64 Int64FromEnv(const char *varname, int64 defval); +extern GFLAGS_DLL_DECL uint64 Uint64FromEnv(const char *varname, uint64 defval); +extern GFLAGS_DLL_DECL double DoubleFromEnv(const char *varname, double defval); +extern GFLAGS_DLL_DECL const char *StringFromEnv(const char *varname, const char *defval); + + +// -------------------------------------------------------------------- +// The next two functions parse gflags from main(): + +// Set the "usage" message for this program. For example: +// string usage("This program does nothing. Sample usage:\n"); +// usage += argv[0] + " "; +// SetUsageMessage(usage); +// Do not include commandline flags in the usage: we do that for you! +// Thread-hostile; meant to be called before any threads are spawned. +extern GFLAGS_DLL_DECL void SetUsageMessage(const std::string& usage); + +// Sets the version string, which is emitted with --version. +// For instance: SetVersionString("1.3"); +// Thread-hostile; meant to be called before any threads are spawned. +extern GFLAGS_DLL_DECL void SetVersionString(const std::string& version); + + +// Looks for flags in argv and parses them. Rearranges argv to put +// flags first, or removes them entirely if remove_flags is true. +// If a flag is defined more than once in the command line or flag +// file, the last definition is used. Returns the index (into argv) +// of the first non-flag argument. +// See top-of-file for more details on this function. +#ifndef SWIG // In swig, use ParseCommandLineFlagsScript() instead. +extern GFLAGS_DLL_DECL uint32 ParseCommandLineFlags(int *argc, char*** argv, bool remove_flags); +#endif + + +// Calls to ParseCommandLineNonHelpFlags and then to +// HandleCommandLineHelpFlags can be used instead of a call to +// ParseCommandLineFlags during initialization, in order to allow for +// changing default values for some FLAGS (via +// e.g. SetCommandLineOptionWithMode calls) between the time of +// command line parsing and the time of dumping help information for +// the flags as a result of command line parsing. If a flag is +// defined more than once in the command line or flag file, the last +// definition is used. Returns the index (into argv) of the first +// non-flag argument. (If remove_flags is true, will always return 1.) +extern GFLAGS_DLL_DECL uint32 ParseCommandLineNonHelpFlags(int *argc, char*** argv, bool remove_flags); + +// This is actually defined in gflags_reporting.cc. +// This function is misnamed (it also handles --version, etc.), but +// it's too late to change that now. :-( +extern GFLAGS_DLL_DECL void HandleCommandLineHelpFlags(); // in gflags_reporting.cc + +// Allow command line reparsing. Disables the error normally +// generated when an unknown flag is found, since it may be found in a +// later parse. Thread-hostile; meant to be called before any threads +// are spawned. +extern GFLAGS_DLL_DECL void AllowCommandLineReparsing(); + +// Reparse the flags that have not yet been recognized. Only flags +// registered since the last parse will be recognized. Any flag value +// must be provided as part of the argument using "=", not as a +// separate command line argument that follows the flag argument. +// Intended for handling flags from dynamically loaded libraries, +// since their flags are not registered until they are loaded. +extern GFLAGS_DLL_DECL void ReparseCommandLineNonHelpFlags(); + +// Clean up memory allocated by flags. This is only needed to reduce +// the quantity of "potentially leaked" reports emitted by memory +// debugging tools such as valgrind. It is not required for normal +// operation, or for the google perftools heap-checker. It must only +// be called when the process is about to exit, and all threads that +// might access flags are quiescent. Referencing flags after this is +// called will have unexpected consequences. This is not safe to run +// when multiple threads might be running: the function is +// thread-hostile. +extern GFLAGS_DLL_DECL void ShutDownCommandLineFlags(); + + +// -------------------------------------------------------------------- +// Now come the command line flag declaration/definition macros that +// will actually be used. They're kind of hairy. A major reason +// for this is initialization: we want people to be able to access +// variables in global constructors and have that not crash, even if +// their global constructor runs before the global constructor here. +// (Obviously, we can't guarantee the flags will have the correct +// default value in that case, but at least accessing them is safe.) +// The only way to do that is have flags point to a static buffer. +// So we make one, using a union to ensure proper alignment, and +// then use placement-new to actually set up the flag with the +// correct default value. In the same vein, we have to worry about +// flag access in global destructors, so FlagRegisterer has to be +// careful never to destroy the flag-values it constructs. +// +// Note that when we define a flag variable FLAGS_, we also +// preemptively define a junk variable, FLAGS_no. This is to +// cause a link-time error if someone tries to define 2 flags with +// names like "logging" and "nologging". We do this because a bool +// flag FLAG can be set from the command line to true with a "-FLAG" +// argument, and to false with a "-noFLAG" argument, and so this can +// potentially avert confusion. +// +// We also put flags into their own namespace. It is purposefully +// named in an opaque way that people should have trouble typing +// directly. The idea is that DEFINE puts the flag in the weird +// namespace, and DECLARE imports the flag from there into the current +// namespace. The net result is to force people to use DECLARE to get +// access to a flag, rather than saying "extern GFLAGS_DLL_DECL bool FLAGS_whatever;" +// or some such instead. We want this so we can put extra +// functionality (like sanity-checking) in DECLARE if we want, and +// make sure it is picked up everywhere. +// +// We also put the type of the variable in the namespace, so that +// people can't DECLARE_int32 something that they DEFINE_bool'd +// elsewhere. + +class GFLAGS_DLL_DECL FlagRegisterer { + public: + // We instantiate this template ctor for all supported types, + // so it is possible to place implementation of the FlagRegisterer ctor in + // .cc file. + // Calling this constructor with unsupported type will produce linker error. + template + FlagRegisterer(const char* name, + const char* help, const char* filename, + FlagType* current_storage, FlagType* defvalue_storage); +}; + +// Force compiler to not generate code for the given template specialization. +#if defined(_MSC_VER) && _MSC_VER < 1800 // Visual Studio 2013 version 12.0 + #define GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(type) +#else + #define GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(type) \ + extern template GFLAGS_DLL_DECL FlagRegisterer::FlagRegisterer( \ + const char* name, const char* help, const char* filename, \ + type* current_storage, type* defvalue_storage) +#endif + +// Do this for all supported flag types. +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(bool); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(int32); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(uint32); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(int64); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(uint64); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(double); +GFLAGS_DECLARE_FLAG_REGISTERER_CTOR(std::string); + +#undef GFLAGS_DECLARE_FLAG_REGISTERER_CTOR + +// If your application #defines STRIP_FLAG_HELP to a non-zero value +// before #including this file, we remove the help message from the +// binary file. This can reduce the size of the resulting binary +// somewhat, and may also be useful for security reasons. + +extern GFLAGS_DLL_DECL const char kStrippedFlagHelp[]; + + +} // namespace GFLAGS_NAMESPACE + + +#ifndef SWIG // In swig, ignore the main flag declarations + +#if defined(STRIP_FLAG_HELP) && STRIP_FLAG_HELP > 0 +// Need this construct to avoid the 'defined but not used' warning. +#define MAYBE_STRIPPED_HELP(txt) \ + (false ? (txt) : GFLAGS_NAMESPACE::kStrippedFlagHelp) +#else +#define MAYBE_STRIPPED_HELP(txt) txt +#endif + +// Each command-line flag has two variables associated with it: one +// with the current value, and one with the default value. However, +// we have a third variable, which is where value is assigned; it's a +// constant. This guarantees that FLAG_##value is initialized at +// static initialization time (e.g. before program-start) rather than +// than global construction time (which is after program-start but +// before main), at least when 'value' is a compile-time constant. We +// use a small trick for the "default value" variable, and call it +// FLAGS_no. This serves the second purpose of assuring a +// compile error if someone tries to define a flag named no +// which is illegal (--foo and --nofoo both affect the "foo" flag). +#define DEFINE_VARIABLE(type, shorttype, name, value, help) \ + namespace fL##shorttype { \ + static const type FLAGS_nono##name = value; \ + /* We always want to export defined variables, dll or no */ \ + GFLAGS_DLL_DEFINE_FLAG type FLAGS_##name = FLAGS_nono##name; \ + static type FLAGS_no##name = FLAGS_nono##name; \ + static GFLAGS_NAMESPACE::FlagRegisterer o_##name( \ + #name, MAYBE_STRIPPED_HELP(help), __FILE__, \ + &FLAGS_##name, &FLAGS_no##name); \ + } \ + using fL##shorttype::FLAGS_##name + +// For DEFINE_bool, we want to do the extra check that the passed-in +// value is actually a bool, and not a string or something that can be +// coerced to a bool. These declarations (no definition needed!) will +// help us do that, and never evaluate From, which is important. +// We'll use 'sizeof(IsBool(val))' to distinguish. This code requires +// that the compiler have different sizes for bool & double. Since +// this is not guaranteed by the standard, we check it with a +// COMPILE_ASSERT. +namespace fLB { +struct CompileAssert {}; +typedef CompileAssert expected_sizeof_double_neq_sizeof_bool[ + (sizeof(double) != sizeof(bool)) ? 1 : -1]; +template double GFLAGS_DLL_DECL IsBoolFlag(const From& from); +GFLAGS_DLL_DECL bool IsBoolFlag(bool from); +} // namespace fLB + +// Here are the actual DEFINE_*-macros. The respective DECLARE_*-macros +// are in a separate include, gflags_declare.h, for reducing +// the physical transitive size for DECLARE use. +#define DEFINE_bool(name, val, txt) \ + namespace fLB { \ + typedef ::fLB::CompileAssert FLAG_##name##_value_is_not_a_bool[ \ + (sizeof(::fLB::IsBoolFlag(val)) != sizeof(double))? 1: -1]; \ + } \ + DEFINE_VARIABLE(bool, B, name, val, txt) + +#define DEFINE_int32(name, val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::int32, I, \ + name, val, txt) + +#define DEFINE_uint32(name,val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::uint32, U, \ + name, val, txt) + +#define DEFINE_int64(name, val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::int64, I64, \ + name, val, txt) + +#define DEFINE_uint64(name,val, txt) \ + DEFINE_VARIABLE(GFLAGS_NAMESPACE::uint64, U64, \ + name, val, txt) + +#define DEFINE_double(name, val, txt) \ + DEFINE_VARIABLE(double, D, name, val, txt) + +// Strings are trickier, because they're not a POD, so we can't +// construct them at static-initialization time (instead they get +// constructed at global-constructor time, which is much later). To +// try to avoid crashes in that case, we use a char buffer to store +// the string, which we can static-initialize, and then placement-new +// into it later. It's not perfect, but the best we can do. + +namespace fLS { + +inline clstring* dont_pass0toDEFINE_string(char *stringspot, + const char *value) { + return new(stringspot) clstring(value); +} +inline clstring* dont_pass0toDEFINE_string(char *stringspot, + const clstring &value) { + return new(stringspot) clstring(value); +} +inline clstring* dont_pass0toDEFINE_string(char *stringspot, + int value); + +// Auxiliary class used to explicitly call destructor of string objects +// allocated using placement new during static program deinitialization. +// The destructor MUST be an inline function such that the explicit +// destruction occurs in the same compilation unit as the placement new. +class StringFlagDestructor { + void *current_storage_; + void *defvalue_storage_; + +public: + + StringFlagDestructor(void *current, void *defvalue) + : current_storage_(current), defvalue_storage_(defvalue) {} + + ~StringFlagDestructor() { + reinterpret_cast(current_storage_ )->~clstring(); + reinterpret_cast(defvalue_storage_)->~clstring(); + } +}; + +} // namespace fLS + +// We need to define a var named FLAGS_no##name so people don't define +// --string and --nostring. And we need a temporary place to put val +// so we don't have to evaluate it twice. Two great needs that go +// great together! +// The weird 'using' + 'extern' inside the fLS namespace is to work around +// an unknown compiler bug/issue with the gcc 4.2.1 on SUSE 10. See +// http://code.google.com/p/google-gflags/issues/detail?id=20 +#define DEFINE_string(name, val, txt) \ + namespace fLS { \ + using ::fLS::clstring; \ + using ::fLS::StringFlagDestructor; \ + static union { void* align; char s[sizeof(clstring)]; } s_##name[2]; \ + clstring* const FLAGS_no##name = ::fLS:: \ + dont_pass0toDEFINE_string(s_##name[0].s, \ + val); \ + static GFLAGS_NAMESPACE::FlagRegisterer o_##name( \ + #name, MAYBE_STRIPPED_HELP(txt), __FILE__, \ + FLAGS_no##name, new (s_##name[1].s) clstring(*FLAGS_no##name)); \ + static StringFlagDestructor d_##name(s_##name[0].s, s_##name[1].s); \ + extern GFLAGS_DLL_DEFINE_FLAG clstring& FLAGS_##name; \ + using fLS::FLAGS_##name; \ + clstring& FLAGS_##name = *FLAGS_no##name; \ + } \ + using fLS::FLAGS_##name + +#endif // SWIG + + + + + +#endif // GFLAGS_GFLAGS_H_ diff --git a/projects/llm_framework/include/gflags/gflags_completions.h b/projects/llm_framework/include/gflags/gflags_completions.h new file mode 100644 index 00000000..2fa0db6d --- /dev/null +++ b/projects/llm_framework/include/gflags/gflags_completions.h @@ -0,0 +1,121 @@ +// Copyright (c) 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// --- + +// +// Implement helpful bash-style command line flag completions +// +// ** Functional API: +// HandleCommandLineCompletions() should be called early during +// program startup, but after command line flag code has been +// initialized, such as the beginning of HandleCommandLineHelpFlags(). +// It checks the value of the flag --tab_completion_word. If this +// flag is empty, nothing happens here. If it contains a string, +// however, then HandleCommandLineCompletions() will hijack the +// process, attempting to identify the intention behind this +// completion. Regardless of the outcome of this deduction, the +// process will be terminated, similar to --helpshort flag +// handling. +// +// ** Overview of Bash completions: +// Bash can be told to programatically determine completions for the +// current 'cursor word'. It does this by (in this case) invoking a +// command with some additional arguments identifying the command +// being executed, the word being completed, and the previous word +// (if any). Bash then expects a sequence of output lines to be +// printed to stdout. If these lines all contain a common prefix +// longer than the cursor word, bash will replace the cursor word +// with that common prefix, and display nothing. If there isn't such +// a common prefix, bash will display the lines in pages using 'more'. +// +// ** Strategy taken for command line completions: +// If we can deduce either the exact flag intended, or a common flag +// prefix, we'll output exactly that. Otherwise, if information +// must be displayed to the user, we'll take the opportunity to add +// some helpful information beyond just the flag name (specifically, +// we'll include the default flag value and as much of the flag's +// description as can fit on a single terminal line width, as specified +// by the flag --tab_completion_columns). Furthermore, we'll try to +// make bash order the output such that the most useful or relevent +// flags are the most likely to be shown at the top. +// +// ** Additional features: +// To assist in finding that one really useful flag, substring matching +// was implemented. Before pressing a to get completion for the +// current word, you can append one or more '?' to the flag to do +// substring matching. Here's the semantics: +// --foo Show me all flags with names prefixed by 'foo' +// --foo? Show me all flags with 'foo' somewhere in the name +// --foo?? Same as prior case, but also search in module +// definition path for 'foo' +// --foo??? Same as prior case, but also search in flag +// descriptions for 'foo' +// Finally, we'll trim the output to a relatively small number of +// flags to keep bash quiet about the verbosity of output. If one +// really wanted to see all possible matches, appending a '+' to the +// search word will force the exhaustive list of matches to be printed. +// +// ** How to have bash accept completions from a binary: +// Bash requires that it be informed about each command that programmatic +// completion should be enabled for. Example addition to a .bashrc +// file would be (your path to gflags_completions.sh file may differ): + +/* +$ complete -o bashdefault -o default -o nospace -C \ + '/home/build/eng/bash/bash_completions.sh --tab_completion_columns $COLUMNS' \ + time env binary_name another_binary [...] +*/ + +// This would allow the following to work: +// $ /path/to/binary_name --vmodule +// Or: +// $ ./bin/path/another_binary --gfs_u +// (etc) +// +// Sadly, it appears that bash gives no easy way to force this behavior for +// all commands. That's where the "time" in the above example comes in. +// If you haven't specifically added a command to the list of completion +// supported commands, you can still get completions by prefixing the +// entire command with "env". +// $ env /some/brand/new/binary --vmod +// Assuming that "binary" is a newly compiled binary, this should still +// produce the expected completion output. + + +#ifndef GFLAGS_COMPLETIONS_H_ +#define GFLAGS_COMPLETIONS_H_ + +namespace gflags { + +extern void HandleCommandLineCompletions(void); + +} + +#endif // GFLAGS_COMPLETIONS_H_ diff --git a/projects/llm_framework/include/gflags/gflags_declare.h b/projects/llm_framework/include/gflags/gflags_declare.h new file mode 100644 index 00000000..69cf1129 --- /dev/null +++ b/projects/llm_framework/include/gflags/gflags_declare.h @@ -0,0 +1,156 @@ +// Copyright (c) 1999, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// --- +// +// Revamped and reorganized by Craig Silverstein +// +// This is the file that should be included by any file which declares +// command line flag. + +#ifndef GFLAGS_DECLARE_H_ +#define GFLAGS_DECLARE_H_ + + +// --------------------------------------------------------------------------- +// Namespace of gflags library symbols. +#define GFLAGS_NAMESPACE gflags + +// --------------------------------------------------------------------------- +// Windows DLL import/export. + +// Whether gflags library is a DLL. +// +// Set to 1 by default when the shared gflags library was built on Windows. +// Must be overwritten when this header file is used with the optionally also +// built static library instead; set by CMake's INTERFACE_COMPILE_DEFINITIONS. +#ifndef GFLAGS_IS_A_DLL +# define GFLAGS_IS_A_DLL 1 +#endif + +// We always want to import the symbols of the gflags library. +#ifndef GFLAGS_DLL_DECL +# if GFLAGS_IS_A_DLL && defined(_MSC_VER) +# define GFLAGS_DLL_DECL __declspec(dllimport) +# elif defined(__GNUC__) && __GNUC__ >= 4 +# define GFLAGS_DLL_DECL __attribute__((visibility("default"))) +# else +# define GFLAGS_DLL_DECL +# endif +#endif + +// We always want to import variables declared in user code. +#ifndef GFLAGS_DLL_DECLARE_FLAG +# if GFLAGS_IS_A_DLL && defined(_MSC_VER) +# define GFLAGS_DLL_DECLARE_FLAG __declspec(dllimport) +# elif defined(__GNUC__) && __GNUC__ >= 4 +# define GFLAGS_DLL_DECLARE_FLAG __attribute__((visibility("default"))) +# else +# define GFLAGS_DLL_DECLARE_FLAG +# endif +#endif + +// --------------------------------------------------------------------------- +// Flag types +#include +#if 1 +# include // the normal place uint32_t is defined +#elif 1 +# include // the normal place u_int32_t is defined +#elif 1 +# include // a third place for uint32_t or u_int32_t +#endif + +namespace GFLAGS_NAMESPACE { + +#if 1 // C99 +typedef int32_t int32; +typedef uint32_t uint32; +typedef int64_t int64; +typedef uint64_t uint64; +#elif 0 // BSD +typedef int32_t int32; +typedef u_int32_t uint32; +typedef int64_t int64; +typedef u_int64_t uint64; +#elif 0 // Windows +typedef __int32 int32; +typedef unsigned __int32 uint32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +#else +# error Do not know how to define a 32-bit integer quantity on your system +#endif + +} // namespace GFLAGS_NAMESPACE + + +namespace fLS { + +// The meaning of "string" might be different between now and when the +// macros below get invoked (e.g., if someone is experimenting with +// other string implementations that get defined after this file is +// included). Save the current meaning now and use it in the macros. +typedef std::string clstring; + +} // namespace fLS + + +#define DECLARE_VARIABLE(type, shorttype, name) \ + /* We always want to import declared variables, dll or no */ \ + namespace fL##shorttype { extern GFLAGS_DLL_DECLARE_FLAG type FLAGS_##name; } \ + using fL##shorttype::FLAGS_##name + +#define DECLARE_bool(name) \ + DECLARE_VARIABLE(bool, B, name) + +#define DECLARE_int32(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::int32, I, name) + +#define DECLARE_uint32(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::uint32, U, name) + +#define DECLARE_int64(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::int64, I64, name) + +#define DECLARE_uint64(name) \ + DECLARE_VARIABLE(::GFLAGS_NAMESPACE::uint64, U64, name) + +#define DECLARE_double(name) \ + DECLARE_VARIABLE(double, D, name) + +#define DECLARE_string(name) \ + /* We always want to import declared variables, dll or no */ \ + namespace fLS { \ + extern GFLAGS_DLL_DECLARE_FLAG ::fLS::clstring& FLAGS_##name; \ + } \ + using fLS::FLAGS_##name + + +#endif // GFLAGS_DECLARE_H_ diff --git a/projects/llm_framework/include/glog/log_severity.h b/projects/llm_framework/include/glog/log_severity.h new file mode 100644 index 00000000..99945a42 --- /dev/null +++ b/projects/llm_framework/include/glog/log_severity.h @@ -0,0 +1,92 @@ +// Copyright (c) 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef BASE_LOG_SEVERITY_H__ +#define BASE_LOG_SEVERITY_H__ + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +// Variables of type LogSeverity are widely taken to lie in the range +// [0, NUM_SEVERITIES-1]. Be careful to preserve this assumption if +// you ever need to change their values or add a new severity. +typedef int LogSeverity; + +const int GLOG_INFO = 0, GLOG_WARNING = 1, GLOG_ERROR = 2, GLOG_FATAL = 3, + NUM_SEVERITIES = 4; +#ifndef GLOG_NO_ABBREVIATED_SEVERITIES +# ifdef ERROR +# error ERROR macro is defined. Define GLOG_NO_ABBREVIATED_SEVERITIES before including logging.h. See the document for detail. +# endif +const int INFO = GLOG_INFO, WARNING = GLOG_WARNING, + ERROR = GLOG_ERROR, FATAL = GLOG_FATAL; +#endif + +// DFATAL is FATAL in debug mode, ERROR in normal mode +#ifdef NDEBUG +#define DFATAL_LEVEL ERROR +#else +#define DFATAL_LEVEL FATAL +#endif + +extern GOOGLE_GLOG_DLL_DECL const char* const LogSeverityNames[NUM_SEVERITIES]; + +// NDEBUG usage helpers related to (RAW_)DCHECK: +// +// DEBUG_MODE is for small !NDEBUG uses like +// if (DEBUG_MODE) foo.CheckThatFoo(); +// instead of substantially more verbose +// #ifndef NDEBUG +// foo.CheckThatFoo(); +// #endif +// +// IF_DEBUG_MODE is for small !NDEBUG uses like +// IF_DEBUG_MODE( string error; ) +// DCHECK(Foo(&error)) << error; +// instead of substantially more verbose +// #ifndef NDEBUG +// string error; +// DCHECK(Foo(&error)) << error; +// #endif +// +#ifdef NDEBUG +enum { DEBUG_MODE = 0 }; +#define IF_DEBUG_MODE(x) +#else +enum { DEBUG_MODE = 1 }; +#define IF_DEBUG_MODE(x) x +#endif + +#endif // BASE_LOG_SEVERITY_H__ diff --git a/projects/llm_framework/include/glog/logging.h b/projects/llm_framework/include/glog/logging.h new file mode 100644 index 00000000..4cd247e1 --- /dev/null +++ b/projects/llm_framework/include/glog/logging.h @@ -0,0 +1,1662 @@ +// Copyright (c) 1999, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: Ray Sidney +// +// This file contains #include information about logging-related stuff. +// Pretty much everybody needs to #include this file so that they can +// log various happenings. +// +#ifndef _LOGGING_H_ +#define _LOGGING_H_ + +#include +#include +#include +#include +#include +#include +#include +#if 1 +# include +#endif +#include + +#if defined(_MSC_VER) +#define GLOG_MSVC_PUSH_DISABLE_WARNING(n) __pragma(warning(push)) \ + __pragma(warning(disable:n)) +#define GLOG_MSVC_POP_WARNING() __pragma(warning(pop)) +#else +#define GLOG_MSVC_PUSH_DISABLE_WARNING(n) +#define GLOG_MSVC_POP_WARNING() +#endif + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +// We care a lot about number of bits things take up. Unfortunately, +// systems define their bit-specific ints in a lot of different ways. +// We use our own way, and have a typedef to get there. +// Note: these commands below may look like "#if 1" or "#if 0", but +// that's because they were constructed that way at ./configure time. +// Look at logging.h.in to see how they're calculated (based on your config). +#if 1 +#include // the normal place uint16_t is defined +#endif +#if 1 +#include // the normal place u_int16_t is defined +#endif +#if 1 +#include // a third place for uint16_t or u_int16_t +#endif + +#if 0 +#include +#endif + +namespace google { + +#if 1 // the C99 format +typedef int32_t int32; +typedef uint32_t uint32; +typedef int64_t int64; +typedef uint64_t uint64; +#elif 1 // the BSD format +typedef int32_t int32; +typedef u_int32_t uint32; +typedef int64_t int64; +typedef u_int64_t uint64; +#elif 0 // the windows (vc7) format +typedef __int32 int32; +typedef unsigned __int32 uint32; +typedef __int64 int64; +typedef unsigned __int64 uint64; +#else +#error Do not know how to define a 32-bit integer quantity on your system +#endif + +} + +// The global value of GOOGLE_STRIP_LOG. All the messages logged to +// LOG(XXX) with severity less than GOOGLE_STRIP_LOG will not be displayed. +// If it can be determined at compile time that the message will not be +// printed, the statement will be compiled out. +// +// Example: to strip out all INFO and WARNING messages, use the value +// of 2 below. To make an exception for WARNING messages from a single +// file, add "#define GOOGLE_STRIP_LOG 1" to that file _before_ including +// base/logging.h +#ifndef GOOGLE_STRIP_LOG +#define GOOGLE_STRIP_LOG 0 +#endif + +// GCC can be told that a certain branch is not likely to be taken (for +// instance, a CHECK failure), and use that information in static analysis. +// Giving it this information can help it optimize for the common case in +// the absence of better information (ie. -fprofile-arcs). +// +#ifndef GOOGLE_PREDICT_BRANCH_NOT_TAKEN +#if 1 +#define GOOGLE_PREDICT_BRANCH_NOT_TAKEN(x) (__builtin_expect(x, 0)) +#else +#define GOOGLE_PREDICT_BRANCH_NOT_TAKEN(x) x +#endif +#endif + +#ifndef GOOGLE_PREDICT_FALSE +#if 1 +#define GOOGLE_PREDICT_FALSE(x) (__builtin_expect(x, 0)) +#else +#define GOOGLE_PREDICT_FALSE(x) x +#endif +#endif + +#ifndef GOOGLE_PREDICT_TRUE +#if 1 +#define GOOGLE_PREDICT_TRUE(x) (__builtin_expect(!!(x), 1)) +#else +#define GOOGLE_PREDICT_TRUE(x) x +#endif +#endif + + +// Make a bunch of macros for logging. The way to log things is to stream +// things to LOG(). E.g., +// +// LOG(INFO) << "Found " << num_cookies << " cookies"; +// +// You can capture log messages in a string, rather than reporting them +// immediately: +// +// vector errors; +// LOG_STRING(ERROR, &errors) << "Couldn't parse cookie #" << cookie_num; +// +// This pushes back the new error onto 'errors'; if given a NULL pointer, +// it reports the error via LOG(ERROR). +// +// You can also do conditional logging: +// +// LOG_IF(INFO, num_cookies > 10) << "Got lots of cookies"; +// +// You can also do occasional logging (log every n'th occurrence of an +// event): +// +// LOG_EVERY_N(INFO, 10) << "Got the " << google::COUNTER << "th cookie"; +// +// The above will cause log messages to be output on the 1st, 11th, 21st, ... +// times it is executed. Note that the special google::COUNTER value is used +// to identify which repetition is happening. +// +// You can also do occasional conditional logging (log every n'th +// occurrence of an event, when condition is satisfied): +// +// LOG_IF_EVERY_N(INFO, (size > 1024), 10) << "Got the " << google::COUNTER +// << "th big cookie"; +// +// You can log messages the first N times your code executes a line. E.g. +// +// LOG_FIRST_N(INFO, 20) << "Got the " << google::COUNTER << "th cookie"; +// +// Outputs log messages for the first 20 times it is executed. +// +// Analogous SYSLOG, SYSLOG_IF, and SYSLOG_EVERY_N macros are available. +// These log to syslog as well as to the normal logs. If you use these at +// all, you need to be aware that syslog can drastically reduce performance, +// especially if it is configured for remote logging! Don't use these +// unless you fully understand this and have a concrete need to use them. +// Even then, try to minimize your use of them. +// +// There are also "debug mode" logging macros like the ones above: +// +// DLOG(INFO) << "Found cookies"; +// +// DLOG_IF(INFO, num_cookies > 10) << "Got lots of cookies"; +// +// DLOG_EVERY_N(INFO, 10) << "Got the " << google::COUNTER << "th cookie"; +// +// All "debug mode" logging is compiled away to nothing for non-debug mode +// compiles. +// +// We also have +// +// LOG_ASSERT(assertion); +// DLOG_ASSERT(assertion); +// +// which is syntactic sugar for {,D}LOG_IF(FATAL, assert fails) << assertion; +// +// There are "verbose level" logging macros. They look like +// +// VLOG(1) << "I'm printed when you run the program with --v=1 or more"; +// VLOG(2) << "I'm printed when you run the program with --v=2 or more"; +// +// These always log at the INFO log level (when they log at all). +// The verbose logging can also be turned on module-by-module. For instance, +// --vmodule=mapreduce=2,file=1,gfs*=3 --v=0 +// will cause: +// a. VLOG(2) and lower messages to be printed from mapreduce.{h,cc} +// b. VLOG(1) and lower messages to be printed from file.{h,cc} +// c. VLOG(3) and lower messages to be printed from files prefixed with "gfs" +// d. VLOG(0) and lower messages to be printed from elsewhere +// +// The wildcarding functionality shown by (c) supports both '*' (match +// 0 or more characters) and '?' (match any single character) wildcards. +// +// There's also VLOG_IS_ON(n) "verbose level" condition macro. To be used as +// +// if (VLOG_IS_ON(2)) { +// // do some logging preparation and logging +// // that can't be accomplished with just VLOG(2) << ...; +// } +// +// There are also VLOG_IF, VLOG_EVERY_N and VLOG_IF_EVERY_N "verbose level" +// condition macros for sample cases, when some extra computation and +// preparation for logs is not needed. +// VLOG_IF(1, (size > 1024)) +// << "I'm printed when size is more than 1024 and when you run the " +// "program with --v=1 or more"; +// VLOG_EVERY_N(1, 10) +// << "I'm printed every 10th occurrence, and when you run the program " +// "with --v=1 or more. Present occurence is " << google::COUNTER; +// VLOG_IF_EVERY_N(1, (size > 1024), 10) +// << "I'm printed on every 10th occurence of case when size is more " +// " than 1024, when you run the program with --v=1 or more. "; +// "Present occurence is " << google::COUNTER; +// +// The supported severity levels for macros that allow you to specify one +// are (in increasing order of severity) INFO, WARNING, ERROR, and FATAL. +// Note that messages of a given severity are logged not only in the +// logfile for that severity, but also in all logfiles of lower severity. +// E.g., a message of severity FATAL will be logged to the logfiles of +// severity FATAL, ERROR, WARNING, and INFO. +// +// There is also the special severity of DFATAL, which logs FATAL in +// debug mode, ERROR in normal mode. +// +// Very important: logging a message at the FATAL severity level causes +// the program to terminate (after the message is logged). +// +// Unless otherwise specified, logs will be written to the filename +// "...log..", followed +// by the date, time, and pid (you can't prevent the date, time, and pid +// from being in the filename). +// +// The logging code takes two flags: +// --v=# set the verbose level +// --logtostderr log all the messages to stderr instead of to logfiles + +// LOG LINE PREFIX FORMAT +// +// Log lines have this form: +// +// Lmmdd hh:mm:ss.uuuuuu threadid file:line] msg... +// +// where the fields are defined as follows: +// +// L A single character, representing the log level +// (eg 'I' for INFO) +// mm The month (zero padded; ie May is '05') +// dd The day (zero padded) +// hh:mm:ss.uuuuuu Time in hours, minutes and fractional seconds +// threadid The space-padded thread ID as returned by GetTID() +// (this matches the PID on Linux) +// file The file name +// line The line number +// msg The user-supplied message +// +// Example: +// +// I1103 11:57:31.739339 24395 google.cc:2341] Command line: ./some_prog +// I1103 11:57:31.739403 24395 google.cc:2342] Process id 24395 +// +// NOTE: although the microseconds are useful for comparing events on +// a single machine, clocks on different machines may not be well +// synchronized. Hence, use caution when comparing the low bits of +// timestamps from different machines. + +#ifndef DECLARE_VARIABLE +#define MUST_UNDEF_GFLAGS_DECLARE_MACROS +#define DECLARE_VARIABLE(type, shorttype, name, tn) \ + namespace fL##shorttype { \ + extern GOOGLE_GLOG_DLL_DECL type FLAGS_##name; \ + } \ + using fL##shorttype::FLAGS_##name + +// bool specialization +#define DECLARE_bool(name) \ + DECLARE_VARIABLE(bool, B, name, bool) + +// int32 specialization +#define DECLARE_int32(name) \ + DECLARE_VARIABLE(google::int32, I, name, int32) + +// Special case for string, because we have to specify the namespace +// std::string, which doesn't play nicely with our FLAG__namespace hackery. +#define DECLARE_string(name) \ + namespace fLS { \ + extern GOOGLE_GLOG_DLL_DECL std::string& FLAGS_##name; \ + } \ + using fLS::FLAGS_##name +#endif + +// Set whether log messages go to stderr instead of logfiles +DECLARE_bool(logtostderr); + +// Set whether log messages go to stderr in addition to logfiles. +DECLARE_bool(alsologtostderr); + +// Set color messages logged to stderr (if supported by terminal). +DECLARE_bool(colorlogtostderr); + +// Log messages at a level >= this flag are automatically sent to +// stderr in addition to log files. +DECLARE_int32(stderrthreshold); + +// Set whether the log prefix should be prepended to each line of output. +DECLARE_bool(log_prefix); + +// Log messages at a level <= this flag are buffered. +// Log messages at a higher level are flushed immediately. +DECLARE_int32(logbuflevel); + +// Sets the maximum number of seconds which logs may be buffered for. +DECLARE_int32(logbufsecs); + +// Log suppression level: messages logged at a lower level than this +// are suppressed. +DECLARE_int32(minloglevel); + +// If specified, logfiles are written into this directory instead of the +// default logging directory. +DECLARE_string(log_dir); + +// Set the log file mode. +DECLARE_int32(logfile_mode); + +// Sets the path of the directory into which to put additional links +// to the log files. +DECLARE_string(log_link); + +DECLARE_int32(v); // in vlog_is_on.cc + +// Sets the maximum log file size (in MB). +DECLARE_int32(max_log_size); + +// Sets whether to avoid logging to the disk if the disk is full. +DECLARE_bool(stop_logging_if_full_disk); + +#ifdef MUST_UNDEF_GFLAGS_DECLARE_MACROS +#undef MUST_UNDEF_GFLAGS_DECLARE_MACROS +#undef DECLARE_VARIABLE +#undef DECLARE_bool +#undef DECLARE_int32 +#undef DECLARE_string +#endif + +// Log messages below the GOOGLE_STRIP_LOG level will be compiled away for +// security reasons. See LOG(severtiy) below. + +// A few definitions of macros that don't generate much code. Since +// LOG(INFO) and its ilk are used all over our code, it's +// better to have compact code for these operations. + +#if GOOGLE_STRIP_LOG == 0 +#define COMPACT_GOOGLE_LOG_INFO google::LogMessage( \ + __FILE__, __LINE__) +#define LOG_TO_STRING_INFO(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_INFO, message) +#else +#define COMPACT_GOOGLE_LOG_INFO google::NullStream() +#define LOG_TO_STRING_INFO(message) google::NullStream() +#endif + +#if GOOGLE_STRIP_LOG <= 1 +#define COMPACT_GOOGLE_LOG_WARNING google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_WARNING) +#define LOG_TO_STRING_WARNING(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_WARNING, message) +#else +#define COMPACT_GOOGLE_LOG_WARNING google::NullStream() +#define LOG_TO_STRING_WARNING(message) google::NullStream() +#endif + +#if GOOGLE_STRIP_LOG <= 2 +#define COMPACT_GOOGLE_LOG_ERROR google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ERROR) +#define LOG_TO_STRING_ERROR(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ERROR, message) +#else +#define COMPACT_GOOGLE_LOG_ERROR google::NullStream() +#define LOG_TO_STRING_ERROR(message) google::NullStream() +#endif + +#if GOOGLE_STRIP_LOG <= 3 +#define COMPACT_GOOGLE_LOG_FATAL google::LogMessageFatal( \ + __FILE__, __LINE__) +#define LOG_TO_STRING_FATAL(message) google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_FATAL, message) +#else +#define COMPACT_GOOGLE_LOG_FATAL google::NullStreamFatal() +#define LOG_TO_STRING_FATAL(message) google::NullStreamFatal() +#endif + +#if defined(NDEBUG) && !defined(DCHECK_ALWAYS_ON) +#define DCHECK_IS_ON() 0 +#else +#define DCHECK_IS_ON() 1 +#endif + +// For DFATAL, we want to use LogMessage (as opposed to +// LogMessageFatal), to be consistent with the original behavior. +#if !DCHECK_IS_ON() +#define COMPACT_GOOGLE_LOG_DFATAL COMPACT_GOOGLE_LOG_ERROR +#elif GOOGLE_STRIP_LOG <= 3 +#define COMPACT_GOOGLE_LOG_DFATAL google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_FATAL) +#else +#define COMPACT_GOOGLE_LOG_DFATAL google::NullStreamFatal() +#endif + +#define GOOGLE_LOG_INFO(counter) google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO, counter, &google::LogMessage::SendToLog) +#define SYSLOG_INFO(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_INFO, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_WARNING(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_WARNING(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_WARNING, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_ERROR(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_ERROR(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_FATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_FATAL, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_FATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_FATAL, counter, \ + &google::LogMessage::SendToSyslogAndLog) +#define GOOGLE_LOG_DFATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::DFATAL_LEVEL, counter, \ + &google::LogMessage::SendToLog) +#define SYSLOG_DFATAL(counter) \ + google::LogMessage(__FILE__, __LINE__, google::DFATAL_LEVEL, counter, \ + &google::LogMessage::SendToSyslogAndLog) + +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32__) || defined(__CYGWIN__) || defined(__CYGWIN32__) +// A very useful logging macro to log windows errors: +#define LOG_SYSRESULT(result) \ + if (FAILED(HRESULT_FROM_WIN32(result))) { \ + LPSTR message = NULL; \ + LPSTR msg = reinterpret_cast(&message); \ + DWORD message_length = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | \ + FORMAT_MESSAGE_FROM_SYSTEM, \ + 0, result, 0, msg, 100, NULL); \ + if (message_length > 0) { \ + google::LogMessage(__FILE__, __LINE__, google::GLOG_ERROR, 0, \ + &google::LogMessage::SendToLog).stream() \ + << reinterpret_cast(message); \ + LocalFree(message); \ + } \ + } +#endif + +// We use the preprocessor's merging operator, "##", so that, e.g., +// LOG(INFO) becomes the token GOOGLE_LOG_INFO. There's some funny +// subtle difference between ostream member streaming functions (e.g., +// ostream::operator<<(int) and ostream non-member streaming functions +// (e.g., ::operator<<(ostream&, string&): it turns out that it's +// impossible to stream something like a string directly to an unnamed +// ostream. We employ a neat hack by calling the stream() member +// function of LogMessage which seems to avoid the problem. +#define LOG(severity) COMPACT_GOOGLE_LOG_ ## severity.stream() +#define SYSLOG(severity) SYSLOG_ ## severity(0).stream() + +namespace google { + +// They need the definitions of integer types. +#include "glog/log_severity.h" +#include "glog/vlog_is_on.h" + +// Initialize google's logging library. You will see the program name +// specified by argv0 in log outputs. +GOOGLE_GLOG_DLL_DECL void InitGoogleLogging(const char* argv0); + +// Shutdown google's logging library. +GOOGLE_GLOG_DLL_DECL void ShutdownGoogleLogging(); + +// Install a function which will be called after LOG(FATAL). +GOOGLE_GLOG_DLL_DECL void InstallFailureFunction(void (*fail_func)()); + +class LogSink; // defined below + +// If a non-NULL sink pointer is given, we push this message to that sink. +// For LOG_TO_SINK we then do normal LOG(severity) logging as well. +// This is useful for capturing messages and passing/storing them +// somewhere more specific than the global log of the process. +// Argument types: +// LogSink* sink; +// LogSeverity severity; +// The cast is to disambiguate NULL arguments. +#define LOG_TO_SINK(sink, severity) \ + google::LogMessage( \ + __FILE__, __LINE__, \ + google::GLOG_ ## severity, \ + static_cast(sink), true).stream() +#define LOG_TO_SINK_BUT_NOT_TO_LOGFILE(sink, severity) \ + google::LogMessage( \ + __FILE__, __LINE__, \ + google::GLOG_ ## severity, \ + static_cast(sink), false).stream() + +// If a non-NULL string pointer is given, we write this message to that string. +// We then do normal LOG(severity) logging as well. +// This is useful for capturing messages and storing them somewhere more +// specific than the global log of the process. +// Argument types: +// string* message; +// LogSeverity severity; +// The cast is to disambiguate NULL arguments. +// NOTE: LOG(severity) expands to LogMessage().stream() for the specified +// severity. +#define LOG_TO_STRING(severity, message) \ + LOG_TO_STRING_##severity(static_cast(message)).stream() + +// If a non-NULL pointer is given, we push the message onto the end +// of a vector of strings; otherwise, we report it with LOG(severity). +// This is handy for capturing messages and perhaps passing them back +// to the caller, rather than reporting them immediately. +// Argument types: +// LogSeverity severity; +// vector *outvec; +// The cast is to disambiguate NULL arguments. +#define LOG_STRING(severity, outvec) \ + LOG_TO_STRING_##severity(static_cast*>(outvec)).stream() + +#define LOG_IF(severity, condition) \ + static_cast(0), \ + !(condition) ? (void) 0 : google::LogMessageVoidify() & LOG(severity) +#define SYSLOG_IF(severity, condition) \ + static_cast(0), \ + !(condition) ? (void) 0 : google::LogMessageVoidify() & SYSLOG(severity) + +#define LOG_ASSERT(condition) \ + LOG_IF(FATAL, !(condition)) << "Assert failed: " #condition +#define SYSLOG_ASSERT(condition) \ + SYSLOG_IF(FATAL, !(condition)) << "Assert failed: " #condition + +// CHECK dies with a fatal error if condition is not true. It is *not* +// controlled by DCHECK_IS_ON(), so the check will be executed regardless of +// compilation mode. Therefore, it is safe to do things like: +// CHECK(fp->Write(x) == 4) +#define CHECK(condition) \ + LOG_IF(FATAL, GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!(condition))) \ + << "Check failed: " #condition " " + +// A container for a string pointer which can be evaluated to a bool - +// true iff the pointer is NULL. +struct CheckOpString { + CheckOpString(std::string* str) : str_(str) { } + // No destructor: if str_ is non-NULL, we're about to LOG(FATAL), + // so there's no point in cleaning up str_. + operator bool() const { + return GOOGLE_PREDICT_BRANCH_NOT_TAKEN(str_ != NULL); + } + std::string* str_; +}; + +// Function is overloaded for integral types to allow static const +// integrals declared in classes and not defined to be used as arguments to +// CHECK* macros. It's not encouraged though. +template +inline const T& GetReferenceableValue(const T& t) { return t; } +inline char GetReferenceableValue(char t) { return t; } +inline unsigned char GetReferenceableValue(unsigned char t) { return t; } +inline signed char GetReferenceableValue(signed char t) { return t; } +inline short GetReferenceableValue(short t) { return t; } +inline unsigned short GetReferenceableValue(unsigned short t) { return t; } +inline int GetReferenceableValue(int t) { return t; } +inline unsigned int GetReferenceableValue(unsigned int t) { return t; } +inline long GetReferenceableValue(long t) { return t; } +inline unsigned long GetReferenceableValue(unsigned long t) { return t; } +inline long long GetReferenceableValue(long long t) { return t; } +inline unsigned long long GetReferenceableValue(unsigned long long t) { + return t; +} + +// This is a dummy class to define the following operator. +struct DummyClassToDefineOperator {}; + +} + +// Define global operator<< to declare using ::operator<<. +// This declaration will allow use to use CHECK macros for user +// defined classes which have operator<< (e.g., stl_logging.h). +inline std::ostream& operator<<( + std::ostream& out, const google::DummyClassToDefineOperator&) { + return out; +} + +namespace google { + +// This formats a value for a failing CHECK_XX statement. Ordinarily, +// it uses the definition for operator<<, with a few special cases below. +template +inline void MakeCheckOpValueString(std::ostream* os, const T& v) { + (*os) << v; +} + +// Overrides for char types provide readable values for unprintable +// characters. +template <> GOOGLE_GLOG_DLL_DECL +void MakeCheckOpValueString(std::ostream* os, const char& v); +template <> GOOGLE_GLOG_DLL_DECL +void MakeCheckOpValueString(std::ostream* os, const signed char& v); +template <> GOOGLE_GLOG_DLL_DECL +void MakeCheckOpValueString(std::ostream* os, const unsigned char& v); + +// Build the error message string. Specify no inlining for code size. +template +std::string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) + __attribute__((noinline)); + +namespace base { +namespace internal { + +// If "s" is less than base_logging::INFO, returns base_logging::INFO. +// If "s" is greater than base_logging::FATAL, returns +// base_logging::ERROR. Otherwise, returns "s". +LogSeverity NormalizeSeverity(LogSeverity s); + +} // namespace internal + +// A helper class for formatting "expr (V1 vs. V2)" in a CHECK_XX +// statement. See MakeCheckOpString for sample usage. Other +// approaches were considered: use of a template method (e.g., +// base::BuildCheckOpString(exprtext, base::Print, &v1, +// base::Print, &v2), however this approach has complications +// related to volatile arguments and function-pointer arguments). +class GOOGLE_GLOG_DLL_DECL CheckOpMessageBuilder { + public: + // Inserts "exprtext" and " (" to the stream. + explicit CheckOpMessageBuilder(const char *exprtext); + // Deletes "stream_". + ~CheckOpMessageBuilder(); + // For inserting the first variable. + std::ostream* ForVar1() { return stream_; } + // For inserting the second variable (adds an intermediate " vs. "). + std::ostream* ForVar2(); + // Get the result (inserts the closing ")"). + std::string* NewString(); + + private: + std::ostringstream *stream_; +}; + +} // namespace base + +template +std::string* MakeCheckOpString(const T1& v1, const T2& v2, const char* exprtext) { + base::CheckOpMessageBuilder comb(exprtext); + MakeCheckOpValueString(comb.ForVar1(), v1); + MakeCheckOpValueString(comb.ForVar2(), v2); + return comb.NewString(); +} + +// Helper functions for CHECK_OP macro. +// The (int, int) specialization works around the issue that the compiler +// will not instantiate the template version of the function on values of +// unnamed enum type - see comment below. +#define DEFINE_CHECK_OP_IMPL(name, op) \ + template \ + inline std::string* name##Impl(const T1& v1, const T2& v2, \ + const char* exprtext) { \ + if (GOOGLE_PREDICT_TRUE(v1 op v2)) return NULL; \ + else return MakeCheckOpString(v1, v2, exprtext); \ + } \ + inline std::string* name##Impl(int v1, int v2, const char* exprtext) { \ + return name##Impl(v1, v2, exprtext); \ + } + +// We use the full name Check_EQ, Check_NE, etc. in case the file including +// base/logging.h provides its own #defines for the simpler names EQ, NE, etc. +// This happens if, for example, those are used as token names in a +// yacc grammar. +DEFINE_CHECK_OP_IMPL(Check_EQ, ==) // Compilation error with CHECK_EQ(NULL, x)? +DEFINE_CHECK_OP_IMPL(Check_NE, !=) // Use CHECK(x == NULL) instead. +DEFINE_CHECK_OP_IMPL(Check_LE, <=) +DEFINE_CHECK_OP_IMPL(Check_LT, < ) +DEFINE_CHECK_OP_IMPL(Check_GE, >=) +DEFINE_CHECK_OP_IMPL(Check_GT, > ) +#undef DEFINE_CHECK_OP_IMPL + +// Helper macro for binary operators. +// Don't use this macro directly in your code, use CHECK_EQ et al below. + +#if defined(STATIC_ANALYSIS) +// Only for static analysis tool to know that it is equivalent to assert +#define CHECK_OP_LOG(name, op, val1, val2, log) CHECK((val1) op (val2)) +#elif DCHECK_IS_ON() +// In debug mode, avoid constructing CheckOpStrings if possible, +// to reduce the overhead of CHECK statments by 2x. +// Real DCHECK-heavy tests have seen 1.5x speedups. + +// The meaning of "string" might be different between now and +// when this macro gets invoked (e.g., if someone is experimenting +// with other string implementations that get defined after this +// file is included). Save the current meaning now and use it +// in the macro. +typedef std::string _Check_string; +#define CHECK_OP_LOG(name, op, val1, val2, log) \ + while (google::_Check_string* _result = \ + google::Check##name##Impl( \ + google::GetReferenceableValue(val1), \ + google::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)) \ + log(__FILE__, __LINE__, \ + google::CheckOpString(_result)).stream() +#else +// In optimized mode, use CheckOpString to hint to compiler that +// the while condition is unlikely. +#define CHECK_OP_LOG(name, op, val1, val2, log) \ + while (google::CheckOpString _result = \ + google::Check##name##Impl( \ + google::GetReferenceableValue(val1), \ + google::GetReferenceableValue(val2), \ + #val1 " " #op " " #val2)) \ + log(__FILE__, __LINE__, _result).stream() +#endif // STATIC_ANALYSIS, DCHECK_IS_ON() + +#if GOOGLE_STRIP_LOG <= 3 +#define CHECK_OP(name, op, val1, val2) \ + CHECK_OP_LOG(name, op, val1, val2, google::LogMessageFatal) +#else +#define CHECK_OP(name, op, val1, val2) \ + CHECK_OP_LOG(name, op, val1, val2, google::NullStreamFatal) +#endif // STRIP_LOG <= 3 + +// Equality/Inequality checks - compare two values, and log a FATAL message +// including the two values when the result is not as expected. The values +// must have operator<<(ostream, ...) defined. +// +// You may append to the error message like so: +// CHECK_NE(1, 2) << ": The world must be ending!"; +// +// We are very careful to ensure that each argument is evaluated exactly +// once, and that anything which is legal to pass as a function argument is +// legal here. In particular, the arguments may be temporary expressions +// which will end up being destroyed at the end of the apparent statement, +// for example: +// CHECK_EQ(string("abc")[1], 'b'); +// +// WARNING: These don't compile correctly if one of the arguments is a pointer +// and the other is NULL. To work around this, simply static_cast NULL to the +// type of the desired pointer. + +#define CHECK_EQ(val1, val2) CHECK_OP(_EQ, ==, val1, val2) +#define CHECK_NE(val1, val2) CHECK_OP(_NE, !=, val1, val2) +#define CHECK_LE(val1, val2) CHECK_OP(_LE, <=, val1, val2) +#define CHECK_LT(val1, val2) CHECK_OP(_LT, < , val1, val2) +#define CHECK_GE(val1, val2) CHECK_OP(_GE, >=, val1, val2) +#define CHECK_GT(val1, val2) CHECK_OP(_GT, > , val1, val2) + +// Check that the input is non NULL. This very useful in constructor +// initializer lists. + +#define CHECK_NOTNULL(val) \ + google::CheckNotNull(__FILE__, __LINE__, "'" #val "' Must be non NULL", (val)) + +// Helper functions for string comparisons. +// To avoid bloat, the definitions are in logging.cc. +#define DECLARE_CHECK_STROP_IMPL(func, expected) \ + GOOGLE_GLOG_DLL_DECL std::string* Check##func##expected##Impl( \ + const char* s1, const char* s2, const char* names); +DECLARE_CHECK_STROP_IMPL(strcmp, true) +DECLARE_CHECK_STROP_IMPL(strcmp, false) +DECLARE_CHECK_STROP_IMPL(strcasecmp, true) +DECLARE_CHECK_STROP_IMPL(strcasecmp, false) +#undef DECLARE_CHECK_STROP_IMPL + +// Helper macro for string comparisons. +// Don't use this macro directly in your code, use CHECK_STREQ et al below. +#define CHECK_STROP(func, op, expected, s1, s2) \ + while (google::CheckOpString _result = \ + google::Check##func##expected##Impl((s1), (s2), \ + #s1 " " #op " " #s2)) \ + LOG(FATAL) << *_result.str_ + + +// String (char*) equality/inequality checks. +// CASE versions are case-insensitive. +// +// Note that "s1" and "s2" may be temporary strings which are destroyed +// by the compiler at the end of the current "full expression" +// (e.g. CHECK_STREQ(Foo().c_str(), Bar().c_str())). + +#define CHECK_STREQ(s1, s2) CHECK_STROP(strcmp, ==, true, s1, s2) +#define CHECK_STRNE(s1, s2) CHECK_STROP(strcmp, !=, false, s1, s2) +#define CHECK_STRCASEEQ(s1, s2) CHECK_STROP(strcasecmp, ==, true, s1, s2) +#define CHECK_STRCASENE(s1, s2) CHECK_STROP(strcasecmp, !=, false, s1, s2) + +#define CHECK_INDEX(I,A) CHECK(I < (sizeof(A)/sizeof(A[0]))) +#define CHECK_BOUND(B,A) CHECK(B <= (sizeof(A)/sizeof(A[0]))) + +#define CHECK_DOUBLE_EQ(val1, val2) \ + do { \ + CHECK_LE((val1), (val2)+0.000000000000001L); \ + CHECK_GE((val1), (val2)-0.000000000000001L); \ + } while (0) + +#define CHECK_NEAR(val1, val2, margin) \ + do { \ + CHECK_LE((val1), (val2)+(margin)); \ + CHECK_GE((val1), (val2)-(margin)); \ + } while (0) + +// perror()..googly style! +// +// PLOG() and PLOG_IF() and PCHECK() behave exactly like their LOG* and +// CHECK equivalents with the addition that they postpend a description +// of the current state of errno to their output lines. + +#define PLOG(severity) GOOGLE_PLOG(severity, 0).stream() + +#define GOOGLE_PLOG(severity, counter) \ + google::ErrnoLogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, counter, \ + &google::LogMessage::SendToLog) + +#define PLOG_IF(severity, condition) \ + static_cast(0), \ + !(condition) ? (void) 0 : google::LogMessageVoidify() & PLOG(severity) + +// A CHECK() macro that postpends errno if the condition is false. E.g. +// +// if (poll(fds, nfds, timeout) == -1) { PCHECK(errno == EINTR); ... } +#define PCHECK(condition) \ + PLOG_IF(FATAL, GOOGLE_PREDICT_BRANCH_NOT_TAKEN(!(condition))) \ + << "Check failed: " #condition " " + +// A CHECK() macro that lets you assert the success of a function that +// returns -1 and sets errno in case of an error. E.g. +// +// CHECK_ERR(mkdir(path, 0700)); +// +// or +// +// int fd = open(filename, flags); CHECK_ERR(fd) << ": open " << filename; +#define CHECK_ERR(invocation) \ +PLOG_IF(FATAL, GOOGLE_PREDICT_BRANCH_NOT_TAKEN((invocation) == -1)) \ + << #invocation + +// Use macro expansion to create, for each use of LOG_EVERY_N(), static +// variables with the __LINE__ expansion as part of the variable name. +#define LOG_EVERY_N_VARNAME(base, line) LOG_EVERY_N_VARNAME_CONCAT(base, line) +#define LOG_EVERY_N_VARNAME_CONCAT(base, line) base ## line + +#define LOG_OCCURRENCES LOG_EVERY_N_VARNAME(occurrences_, __LINE__) +#define LOG_OCCURRENCES_MOD_N LOG_EVERY_N_VARNAME(occurrences_mod_n_, __LINE__) + +#define SOME_KIND_OF_LOG_EVERY_N(severity, n, what_to_do) \ + static int LOG_OCCURRENCES = 0, LOG_OCCURRENCES_MOD_N = 0; \ + ++LOG_OCCURRENCES; \ + if (++LOG_OCCURRENCES_MOD_N > n) LOG_OCCURRENCES_MOD_N -= n; \ + if (LOG_OCCURRENCES_MOD_N == 1) \ + google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +#define SOME_KIND_OF_LOG_IF_EVERY_N(severity, condition, n, what_to_do) \ + static int LOG_OCCURRENCES = 0, LOG_OCCURRENCES_MOD_N = 0; \ + ++LOG_OCCURRENCES; \ + if (condition && \ + ((LOG_OCCURRENCES_MOD_N=(LOG_OCCURRENCES_MOD_N + 1) % n) == (1 % n))) \ + google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +#define SOME_KIND_OF_PLOG_EVERY_N(severity, n, what_to_do) \ + static int LOG_OCCURRENCES = 0, LOG_OCCURRENCES_MOD_N = 0; \ + ++LOG_OCCURRENCES; \ + if (++LOG_OCCURRENCES_MOD_N > n) LOG_OCCURRENCES_MOD_N -= n; \ + if (LOG_OCCURRENCES_MOD_N == 1) \ + google::ErrnoLogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +#define SOME_KIND_OF_LOG_FIRST_N(severity, n, what_to_do) \ + static int LOG_OCCURRENCES = 0; \ + if (LOG_OCCURRENCES <= n) \ + ++LOG_OCCURRENCES; \ + if (LOG_OCCURRENCES <= n) \ + google::LogMessage( \ + __FILE__, __LINE__, google::GLOG_ ## severity, LOG_OCCURRENCES, \ + &what_to_do).stream() + +namespace glog_internal_namespace_ { +template +struct CompileAssert { +}; +struct CrashReason; + +// Returns true if FailureSignalHandler is installed. +// Needs to be exported since it's used by the signalhandler_unittest. +GOOGLE_GLOG_DLL_DECL bool IsFailureSignalHandlerInstalled(); +} // namespace glog_internal_namespace_ + +#define LOG_EVERY_N(severity, n) \ + SOME_KIND_OF_LOG_EVERY_N(severity, (n), google::LogMessage::SendToLog) + +#define SYSLOG_EVERY_N(severity, n) \ + SOME_KIND_OF_LOG_EVERY_N(severity, (n), google::LogMessage::SendToSyslogAndLog) + +#define PLOG_EVERY_N(severity, n) \ + SOME_KIND_OF_PLOG_EVERY_N(severity, (n), google::LogMessage::SendToLog) + +#define LOG_FIRST_N(severity, n) \ + SOME_KIND_OF_LOG_FIRST_N(severity, (n), google::LogMessage::SendToLog) + +#define LOG_IF_EVERY_N(severity, condition, n) \ + SOME_KIND_OF_LOG_IF_EVERY_N(severity, (condition), (n), google::LogMessage::SendToLog) + +// We want the special COUNTER value available for LOG_EVERY_X()'ed messages +enum PRIVATE_Counter {COUNTER}; + +#ifdef GLOG_NO_ABBREVIATED_SEVERITIES +// wingdi.h defines ERROR to be 0. When we call LOG(ERROR), it gets +// substituted with 0, and it expands to COMPACT_GOOGLE_LOG_0. To allow us +// to keep using this syntax, we define this macro to do the same thing +// as COMPACT_GOOGLE_LOG_ERROR. +#define COMPACT_GOOGLE_LOG_0 COMPACT_GOOGLE_LOG_ERROR +#define SYSLOG_0 SYSLOG_ERROR +#define LOG_TO_STRING_0 LOG_TO_STRING_ERROR +// Needed for LOG_IS_ON(ERROR). +const LogSeverity GLOG_0 = GLOG_ERROR; +#else +// Users may include windows.h after logging.h without +// GLOG_NO_ABBREVIATED_SEVERITIES nor WIN32_LEAN_AND_MEAN. +// For this case, we cannot detect if ERROR is defined before users +// actually use ERROR. Let's make an undefined symbol to warn users. +# define GLOG_ERROR_MSG ERROR_macro_is_defined_Define_GLOG_NO_ABBREVIATED_SEVERITIES_before_including_logging_h_See_the_document_for_detail +# define COMPACT_GOOGLE_LOG_0 GLOG_ERROR_MSG +# define SYSLOG_0 GLOG_ERROR_MSG +# define LOG_TO_STRING_0 GLOG_ERROR_MSG +# define GLOG_0 GLOG_ERROR_MSG +#endif + +// Plus some debug-logging macros that get compiled to nothing for production + +#if DCHECK_IS_ON() + +#define DLOG(severity) LOG(severity) +#define DVLOG(verboselevel) VLOG(verboselevel) +#define DLOG_IF(severity, condition) LOG_IF(severity, condition) +#define DLOG_EVERY_N(severity, n) LOG_EVERY_N(severity, n) +#define DLOG_IF_EVERY_N(severity, condition, n) \ + LOG_IF_EVERY_N(severity, condition, n) +#define DLOG_ASSERT(condition) LOG_ASSERT(condition) + +// debug-only checking. executed if DCHECK_IS_ON(). +#define DCHECK(condition) CHECK(condition) +#define DCHECK_EQ(val1, val2) CHECK_EQ(val1, val2) +#define DCHECK_NE(val1, val2) CHECK_NE(val1, val2) +#define DCHECK_LE(val1, val2) CHECK_LE(val1, val2) +#define DCHECK_LT(val1, val2) CHECK_LT(val1, val2) +#define DCHECK_GE(val1, val2) CHECK_GE(val1, val2) +#define DCHECK_GT(val1, val2) CHECK_GT(val1, val2) +#define DCHECK_NOTNULL(val) CHECK_NOTNULL(val) +#define DCHECK_STREQ(str1, str2) CHECK_STREQ(str1, str2) +#define DCHECK_STRCASEEQ(str1, str2) CHECK_STRCASEEQ(str1, str2) +#define DCHECK_STRNE(str1, str2) CHECK_STRNE(str1, str2) +#define DCHECK_STRCASENE(str1, str2) CHECK_STRCASENE(str1, str2) + +#else // !DCHECK_IS_ON() + +#define DLOG(severity) \ + static_cast(0), \ + true ? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DVLOG(verboselevel) \ + static_cast(0), \ + (true || !VLOG_IS_ON(verboselevel)) ? \ + (void) 0 : google::LogMessageVoidify() & LOG(INFO) + +#define DLOG_IF(severity, condition) \ + static_cast(0), \ + (true || !(condition)) ? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DLOG_EVERY_N(severity, n) \ + static_cast(0), \ + true ? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DLOG_IF_EVERY_N(severity, condition, n) \ + static_cast(0), \ + (true || !(condition))? (void) 0 : google::LogMessageVoidify() & LOG(severity) + +#define DLOG_ASSERT(condition) \ + static_cast(0), \ + true ? (void) 0 : LOG_ASSERT(condition) + +// MSVC warning C4127: conditional expression is constant +#define DCHECK(condition) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK(condition) + +#define DCHECK_EQ(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_EQ(val1, val2) + +#define DCHECK_NE(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_NE(val1, val2) + +#define DCHECK_LE(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_LE(val1, val2) + +#define DCHECK_LT(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_LT(val1, val2) + +#define DCHECK_GE(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_GE(val1, val2) + +#define DCHECK_GT(val1, val2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_GT(val1, val2) + +// You may see warnings in release mode if you don't use the return +// value of DCHECK_NOTNULL. Please just use DCHECK for such cases. +#define DCHECK_NOTNULL(val) (val) + +#define DCHECK_STREQ(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STREQ(str1, str2) + +#define DCHECK_STRCASEEQ(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STRCASEEQ(str1, str2) + +#define DCHECK_STRNE(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STRNE(str1, str2) + +#define DCHECK_STRCASENE(str1, str2) \ + GLOG_MSVC_PUSH_DISABLE_WARNING(4127) \ + while (false) \ + GLOG_MSVC_POP_WARNING() CHECK_STRCASENE(str1, str2) + +#endif // DCHECK_IS_ON() + +// Log only in verbose mode. + +#define VLOG(verboselevel) LOG_IF(INFO, VLOG_IS_ON(verboselevel)) + +#define VLOG_IF(verboselevel, condition) \ + LOG_IF(INFO, (condition) && VLOG_IS_ON(verboselevel)) + +#define VLOG_EVERY_N(verboselevel, n) \ + LOG_IF_EVERY_N(INFO, VLOG_IS_ON(verboselevel), n) + +#define VLOG_IF_EVERY_N(verboselevel, condition, n) \ + LOG_IF_EVERY_N(INFO, (condition) && VLOG_IS_ON(verboselevel), n) + +namespace base_logging { + +// LogMessage::LogStream is a std::ostream backed by this streambuf. +// This class ignores overflow and leaves two bytes at the end of the +// buffer to allow for a '\n' and '\0'. +class GOOGLE_GLOG_DLL_DECL LogStreamBuf : public std::streambuf { + public: + // REQUIREMENTS: "len" must be >= 2 to account for the '\n' and '\0'. + LogStreamBuf(char *buf, int len) { + setp(buf, buf + len - 2); + } + + // This effectively ignores overflow. + virtual int_type overflow(int_type ch) { + return ch; + } + + // Legacy public ostrstream method. + size_t pcount() const { return pptr() - pbase(); } + char* pbase() const { return std::streambuf::pbase(); } +}; + +} // namespace base_logging + +// +// This class more or less represents a particular log message. You +// create an instance of LogMessage and then stream stuff to it. +// When you finish streaming to it, ~LogMessage is called and the +// full message gets streamed to the appropriate destination. +// +// You shouldn't actually use LogMessage's constructor to log things, +// though. You should use the LOG() macro (and variants thereof) +// above. +class GOOGLE_GLOG_DLL_DECL LogMessage { +public: + enum { + // Passing kNoLogPrefix for the line number disables the + // log-message prefix. Useful for using the LogMessage + // infrastructure as a printing utility. See also the --log_prefix + // flag for controlling the log-message prefix on an + // application-wide basis. + kNoLogPrefix = -1 + }; + + // LogStream inherit from non-DLL-exported class (std::ostrstream) + // and VC++ produces a warning for this situation. + // However, MSDN says "C4275 can be ignored in Microsoft Visual C++ + // 2005 if you are deriving from a type in the Standard C++ Library" + // http://msdn.microsoft.com/en-us/library/3tdb471s(VS.80).aspx + // Let's just ignore the warning. +GLOG_MSVC_PUSH_DISABLE_WARNING(4275) + class GOOGLE_GLOG_DLL_DECL LogStream : public std::ostream { +GLOG_MSVC_POP_WARNING() + public: + LogStream(char *buf, int len, int ctr) + : std::ostream(NULL), + streambuf_(buf, len), + ctr_(ctr), + self_(this) { + rdbuf(&streambuf_); + } + + int ctr() const { return ctr_; } + void set_ctr(int ctr) { ctr_ = ctr; } + LogStream* self() const { return self_; } + + // Legacy std::streambuf methods. + size_t pcount() const { return streambuf_.pcount(); } + char* pbase() const { return streambuf_.pbase(); } + char* str() const { return pbase(); } + + private: + LogStream(const LogStream&); + LogStream& operator=(const LogStream&); + base_logging::LogStreamBuf streambuf_; + int ctr_; // Counter hack (for the LOG_EVERY_X() macro) + LogStream *self_; // Consistency check hack + }; + +public: + // icc 8 requires this typedef to avoid an internal compiler error. + typedef void (LogMessage::*SendMethod)(); + + LogMessage(const char* file, int line, LogSeverity severity, int ctr, + SendMethod send_method); + + // Two special constructors that generate reduced amounts of code at + // LOG call sites for common cases. + + // Used for LOG(INFO): Implied are: + // severity = INFO, ctr = 0, send_method = &LogMessage::SendToLog. + // + // Using this constructor instead of the more complex constructor above + // saves 19 bytes per call site. + LogMessage(const char* file, int line); + + // Used for LOG(severity) where severity != INFO. Implied + // are: ctr = 0, send_method = &LogMessage::SendToLog + // + // Using this constructor instead of the more complex constructor above + // saves 17 bytes per call site. + LogMessage(const char* file, int line, LogSeverity severity); + + // Constructor to log this message to a specified sink (if not NULL). + // Implied are: ctr = 0, send_method = &LogMessage::SendToSinkAndLog if + // also_send_to_log is true, send_method = &LogMessage::SendToSink otherwise. + LogMessage(const char* file, int line, LogSeverity severity, LogSink* sink, + bool also_send_to_log); + + // Constructor where we also give a vector pointer + // for storing the messages (if the pointer is not NULL). + // Implied are: ctr = 0, send_method = &LogMessage::SaveOrSendToLog. + LogMessage(const char* file, int line, LogSeverity severity, + std::vector* outvec); + + // Constructor where we also give a string pointer for storing the + // message (if the pointer is not NULL). Implied are: ctr = 0, + // send_method = &LogMessage::WriteToStringAndLog. + LogMessage(const char* file, int line, LogSeverity severity, + std::string* message); + + // A special constructor used for check failures + LogMessage(const char* file, int line, const CheckOpString& result); + + ~LogMessage(); + + // Flush a buffered message to the sink set in the constructor. Always + // called by the destructor, it may also be called from elsewhere if + // needed. Only the first call is actioned; any later ones are ignored. + void Flush(); + + // An arbitrary limit on the length of a single log message. This + // is so that streaming can be done more efficiently. + static const size_t kMaxLogMessageLen; + + // Theses should not be called directly outside of logging.*, + // only passed as SendMethod arguments to other LogMessage methods: + void SendToLog(); // Actually dispatch to the logs + void SendToSyslogAndLog(); // Actually dispatch to syslog and the logs + + // Call abort() or similar to perform LOG(FATAL) crash. + static void __attribute__((noreturn)) Fail(); + + std::ostream& stream(); + + int preserved_errno() const; + + // Must be called without the log_mutex held. (L < log_mutex) + static int64 num_messages(int severity); + + struct LogMessageData; + +private: + // Fully internal SendMethod cases: + void SendToSinkAndLog(); // Send to sink if provided and dispatch to the logs + void SendToSink(); // Send to sink if provided, do nothing otherwise. + + // Write to string if provided and dispatch to the logs. + void WriteToStringAndLog(); + + void SaveOrSendToLog(); // Save to stringvec if provided, else to logs + + void Init(const char* file, int line, LogSeverity severity, + void (LogMessage::*send_method)()); + + // Used to fill in crash information during LOG(FATAL) failures. + void RecordCrashReason(glog_internal_namespace_::CrashReason* reason); + + // Counts of messages sent at each priority: + static int64 num_messages_[NUM_SEVERITIES]; // under log_mutex + + // We keep the data in a separate struct so that each instance of + // LogMessage uses less stack space. + LogMessageData* allocated_; + LogMessageData* data_; + + friend class LogDestination; + + LogMessage(const LogMessage&); + void operator=(const LogMessage&); +}; + +// This class happens to be thread-hostile because all instances share +// a single data buffer, but since it can only be created just before +// the process dies, we don't worry so much. +class GOOGLE_GLOG_DLL_DECL LogMessageFatal : public LogMessage { + public: + LogMessageFatal(const char* file, int line); + LogMessageFatal(const char* file, int line, const CheckOpString& result); + __attribute__((noreturn)) ~LogMessageFatal(); +}; + +// A non-macro interface to the log facility; (useful +// when the logging level is not a compile-time constant). +inline void LogAtLevel(int const severity, std::string const &msg) { + LogMessage(__FILE__, __LINE__, severity).stream() << msg; +} + +// A macro alternative of LogAtLevel. New code may want to use this +// version since there are two advantages: 1. this version outputs the +// file name and the line number where this macro is put like other +// LOG macros, 2. this macro can be used as C++ stream. +#define LOG_AT_LEVEL(severity) google::LogMessage(__FILE__, __LINE__, severity).stream() + +// Check if it's compiled in C++11 mode. +// +// GXX_EXPERIMENTAL_CXX0X is defined by gcc and clang up to at least +// gcc-4.7 and clang-3.1 (2011-12-13). __cplusplus was defined to 1 +// in gcc before 4.7 (Crosstool 16) and clang before 3.1, but is +// defined according to the language version in effect thereafter. +// Microsoft Visual Studio 14 (2015) sets __cplusplus==199711 despite +// reasonably good C++11 support, so we set LANG_CXX for it and +// newer versions (_MSC_VER >= 1900). +#if (defined(__GXX_EXPERIMENTAL_CXX0X__) || __cplusplus >= 201103L || \ + (defined(_MSC_VER) && _MSC_VER >= 1900)) +// Helper for CHECK_NOTNULL(). +// +// In C++11, all cases can be handled by a single function. Since the value +// category of the argument is preserved (also for rvalue references), +// member initializer lists like the one below will compile correctly: +// +// Foo() +// : x_(CHECK_NOTNULL(MethodReturningUniquePtr())) {} +template +T CheckNotNull(const char* file, int line, const char* names, T&& t) { + if (t == nullptr) { + LogMessageFatal(file, line, new std::string(names)); + } + return std::forward(t); +} + +#else + +// A small helper for CHECK_NOTNULL(). +template +T* CheckNotNull(const char *file, int line, const char *names, T* t) { + if (t == NULL) { + LogMessageFatal(file, line, new std::string(names)); + } + return t; +} +#endif + +// Allow folks to put a counter in the LOG_EVERY_X()'ed messages. This +// only works if ostream is a LogStream. If the ostream is not a +// LogStream you'll get an assert saying as much at runtime. +GOOGLE_GLOG_DLL_DECL std::ostream& operator<<(std::ostream &os, + const PRIVATE_Counter&); + + +// Derived class for PLOG*() above. +class GOOGLE_GLOG_DLL_DECL ErrnoLogMessage : public LogMessage { + public: + + ErrnoLogMessage(const char* file, int line, LogSeverity severity, int ctr, + void (LogMessage::*send_method)()); + + // Postpends ": strerror(errno) [errno]". + ~ErrnoLogMessage(); + + private: + ErrnoLogMessage(const ErrnoLogMessage&); + void operator=(const ErrnoLogMessage&); +}; + + +// This class is used to explicitly ignore values in the conditional +// logging macros. This avoids compiler warnings like "value computed +// is not used" and "statement has no effect". + +class GOOGLE_GLOG_DLL_DECL LogMessageVoidify { + public: + LogMessageVoidify() { } + // This has to be an operator with a precedence lower than << but + // higher than ?: + void operator&(std::ostream&) { } +}; + + +// Flushes all log files that contains messages that are at least of +// the specified severity level. Thread-safe. +GOOGLE_GLOG_DLL_DECL void FlushLogFiles(LogSeverity min_severity); + +// Flushes all log files that contains messages that are at least of +// the specified severity level. Thread-hostile because it ignores +// locking -- used for catastrophic failures. +GOOGLE_GLOG_DLL_DECL void FlushLogFilesUnsafe(LogSeverity min_severity); + +// +// Set the destination to which a particular severity level of log +// messages is sent. If base_filename is "", it means "don't log this +// severity". Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetLogDestination(LogSeverity severity, + const char* base_filename); + +// +// Set the basename of the symlink to the latest log file at a given +// severity. If symlink_basename is empty, do not make a symlink. If +// you don't call this function, the symlink basename is the +// invocation name of the program. Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetLogSymlink(LogSeverity severity, + const char* symlink_basename); + +// +// Used to send logs to some other kind of destination +// Users should subclass LogSink and override send to do whatever they want. +// Implementations must be thread-safe because a shared instance will +// be called from whichever thread ran the LOG(XXX) line. +class GOOGLE_GLOG_DLL_DECL LogSink { + public: + virtual ~LogSink(); + + // Sink's logging logic (message_len is such as to exclude '\n' at the end). + // This method can't use LOG() or CHECK() as logging system mutex(s) are held + // during this call. + virtual void send(LogSeverity severity, const char* full_filename, + const char* base_filename, int line, + const struct ::tm* tm_time, + const char* message, size_t message_len) = 0; + + // Redefine this to implement waiting for + // the sink's logging logic to complete. + // It will be called after each send() returns, + // but before that LogMessage exits or crashes. + // By default this function does nothing. + // Using this function one can implement complex logic for send() + // that itself involves logging; and do all this w/o causing deadlocks and + // inconsistent rearrangement of log messages. + // E.g. if a LogSink has thread-specific actions, the send() method + // can simply add the message to a queue and wake up another thread that + // handles real logging while itself making some LOG() calls; + // WaitTillSent() can be implemented to wait for that logic to complete. + // See our unittest for an example. + virtual void WaitTillSent(); + + // Returns the normal text output of the log message. + // Can be useful to implement send(). + static std::string ToString(LogSeverity severity, const char* file, int line, + const struct ::tm* tm_time, + const char* message, size_t message_len); +}; + +// Add or remove a LogSink as a consumer of logging data. Thread-safe. +GOOGLE_GLOG_DLL_DECL void AddLogSink(LogSink *destination); +GOOGLE_GLOG_DLL_DECL void RemoveLogSink(LogSink *destination); + +// +// Specify an "extension" added to the filename specified via +// SetLogDestination. This applies to all severity levels. It's +// often used to append the port we're listening on to the logfile +// name. Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetLogFilenameExtension( + const char* filename_extension); + +// +// Make it so that all log messages of at least a particular severity +// are logged to stderr (in addition to logging to the usual log +// file(s)). Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetStderrLogging(LogSeverity min_severity); + +// +// Make it so that all log messages go only to stderr. Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void LogToStderr(); + +// +// Make it so that all log messages of at least a particular severity are +// logged via email to a list of addresses (in addition to logging to the +// usual log file(s)). The list of addresses is just a string containing +// the email addresses to send to (separated by spaces, say). Thread-safe. +// +GOOGLE_GLOG_DLL_DECL void SetEmailLogging(LogSeverity min_severity, + const char* addresses); + +// A simple function that sends email. dest is a commma-separated +// list of addressess. Thread-safe. +GOOGLE_GLOG_DLL_DECL bool SendEmail(const char *dest, + const char *subject, const char *body); + +GOOGLE_GLOG_DLL_DECL const std::vector& GetLoggingDirectories(); + +// For tests only: Clear the internal [cached] list of logging directories to +// force a refresh the next time GetLoggingDirectories is called. +// Thread-hostile. +void TestOnly_ClearLoggingDirectoriesList(); + +// Returns a set of existing temporary directories, which will be a +// subset of the directories returned by GetLogginDirectories(). +// Thread-safe. +GOOGLE_GLOG_DLL_DECL void GetExistingTempDirectories( + std::vector* list); + +// Print any fatal message again -- useful to call from signal handler +// so that the last thing in the output is the fatal message. +// Thread-hostile, but a race is unlikely. +GOOGLE_GLOG_DLL_DECL void ReprintFatalMessage(); + +// Truncate a log file that may be the append-only output of multiple +// processes and hence can't simply be renamed/reopened (typically a +// stdout/stderr). If the file "path" is > "limit" bytes, copy the +// last "keep" bytes to offset 0 and truncate the rest. Since we could +// be racing with other writers, this approach has the potential to +// lose very small amounts of data. For security, only follow symlinks +// if the path is /proc/self/fd/* +GOOGLE_GLOG_DLL_DECL void TruncateLogFile(const char *path, + int64 limit, int64 keep); + +// Truncate stdout and stderr if they are over the value specified by +// --max_log_size; keep the final 1MB. This function has the same +// race condition as TruncateLogFile. +GOOGLE_GLOG_DLL_DECL void TruncateStdoutStderr(); + +// Return the string representation of the provided LogSeverity level. +// Thread-safe. +GOOGLE_GLOG_DLL_DECL const char* GetLogSeverityName(LogSeverity severity); + +// --------------------------------------------------------------------- +// Implementation details that are not useful to most clients +// --------------------------------------------------------------------- + +// A Logger is the interface used by logging modules to emit entries +// to a log. A typical implementation will dump formatted data to a +// sequence of files. We also provide interfaces that will forward +// the data to another thread so that the invoker never blocks. +// Implementations should be thread-safe since the logging system +// will write to them from multiple threads. + +namespace base { + +class GOOGLE_GLOG_DLL_DECL Logger { + public: + virtual ~Logger(); + + // Writes "message[0,message_len-1]" corresponding to an event that + // occurred at "timestamp". If "force_flush" is true, the log file + // is flushed immediately. + // + // The input message has already been formatted as deemed + // appropriate by the higher level logging facility. For example, + // textual log messages already contain timestamps, and the + // file:linenumber header. + virtual void Write(bool force_flush, + time_t timestamp, + const char* message, + int message_len) = 0; + + // Flush any buffered messages + virtual void Flush() = 0; + + // Get the current LOG file size. + // The returned value is approximate since some + // logged data may not have been flushed to disk yet. + virtual uint32 LogSize() = 0; +}; + +// Get the logger for the specified severity level. The logger +// remains the property of the logging module and should not be +// deleted by the caller. Thread-safe. +extern GOOGLE_GLOG_DLL_DECL Logger* GetLogger(LogSeverity level); + +// Set the logger for the specified severity level. The logger +// becomes the property of the logging module and should not +// be deleted by the caller. Thread-safe. +extern GOOGLE_GLOG_DLL_DECL void SetLogger(LogSeverity level, Logger* logger); + +} + +// glibc has traditionally implemented two incompatible versions of +// strerror_r(). There is a poorly defined convention for picking the +// version that we want, but it is not clear whether it even works with +// all versions of glibc. +// So, instead, we provide this wrapper that automatically detects the +// version that is in use, and then implements POSIX semantics. +// N.B. In addition to what POSIX says, we also guarantee that "buf" will +// be set to an empty string, if this function failed. This means, in most +// cases, you do not need to check the error code and you can directly +// use the value of "buf". It will never have an undefined value. +// DEPRECATED: Use StrError(int) instead. +GOOGLE_GLOG_DLL_DECL int posix_strerror_r(int err, char *buf, size_t len); + +// A thread-safe replacement for strerror(). Returns a string describing the +// given POSIX error code. +GOOGLE_GLOG_DLL_DECL std::string StrError(int err); + +// A class for which we define operator<<, which does nothing. +class GOOGLE_GLOG_DLL_DECL NullStream : public LogMessage::LogStream { + public: + // Initialize the LogStream so the messages can be written somewhere + // (they'll never be actually displayed). This will be needed if a + // NullStream& is implicitly converted to LogStream&, in which case + // the overloaded NullStream::operator<< will not be invoked. + NullStream() : LogMessage::LogStream(message_buffer_, 1, 0) { } + NullStream(const char* /*file*/, int /*line*/, + const CheckOpString& /*result*/) : + LogMessage::LogStream(message_buffer_, 1, 0) { } + NullStream &stream() { return *this; } + private: + // A very short buffer for messages (which we discard anyway). This + // will be needed if NullStream& converted to LogStream& (e.g. as a + // result of a conditional expression). + char message_buffer_[2]; +}; + +// Do nothing. This operator is inline, allowing the message to be +// compiled away. The message will not be compiled away if we do +// something like (flag ? LOG(INFO) : LOG(ERROR)) << message; when +// SKIP_LOG=WARNING. In those cases, NullStream will be implicitly +// converted to LogStream and the message will be computed and then +// quietly discarded. +template +inline NullStream& operator<<(NullStream &str, const T &) { return str; } + +// Similar to NullStream, but aborts the program (without stack +// trace), like LogMessageFatal. +class GOOGLE_GLOG_DLL_DECL NullStreamFatal : public NullStream { + public: + NullStreamFatal() { } + NullStreamFatal(const char* file, int line, const CheckOpString& result) : + NullStream(file, line, result) { } + __attribute__((noreturn)) ~NullStreamFatal() throw () { _exit(1); } +}; + +// Install a signal handler that will dump signal information and a stack +// trace when the program crashes on certain signals. We'll install the +// signal handler for the following signals. +// +// SIGSEGV, SIGILL, SIGFPE, SIGABRT, SIGBUS, and SIGTERM. +// +// By default, the signal handler will write the failure dump to the +// standard error. You can customize the destination by installing your +// own writer function by InstallFailureWriter() below. +// +// Note on threading: +// +// The function should be called before threads are created, if you want +// to use the failure signal handler for all threads. The stack trace +// will be shown only for the thread that receives the signal. In other +// words, stack traces of other threads won't be shown. +GOOGLE_GLOG_DLL_DECL void InstallFailureSignalHandler(); + +// Installs a function that is used for writing the failure dump. "data" +// is the pointer to the beginning of a message to be written, and "size" +// is the size of the message. You should not expect the data is +// terminated with '\0'. +GOOGLE_GLOG_DLL_DECL void InstallFailureWriter( + void (*writer)(const char* data, int size)); + +} + +#endif // _LOGGING_H_ diff --git a/projects/llm_framework/include/glog/raw_logging.h b/projects/llm_framework/include/glog/raw_logging.h new file mode 100644 index 00000000..cf3f27d9 --- /dev/null +++ b/projects/llm_framework/include/glog/raw_logging.h @@ -0,0 +1,180 @@ +// Copyright (c) 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: Maxim Lifantsev +// +// Thread-safe logging routines that do not allocate any memory or +// acquire any locks, and can therefore be used by low-level memory +// allocation and synchronization code. + +#ifndef BASE_RAW_LOGGING_H_ +#define BASE_RAW_LOGGING_H_ + +#include + +namespace google { + +#include "glog/log_severity.h" +#include "glog/vlog_is_on.h" + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +// This is similar to LOG(severity) << format... and VLOG(level) << format.., +// but +// * it is to be used ONLY by low-level modules that can't use normal LOG() +// * it is desiged to be a low-level logger that does not allocate any +// memory and does not need any locks, hence: +// * it logs straight and ONLY to STDERR w/o buffering +// * it uses an explicit format and arguments list +// * it will silently chop off really long message strings +// Usage example: +// RAW_LOG(ERROR, "Failed foo with %i: %s", status, error); +// RAW_VLOG(3, "status is %i", status); +// These will print an almost standard log lines like this to stderr only: +// E0821 211317 file.cc:123] RAW: Failed foo with 22: bad_file +// I0821 211317 file.cc:142] RAW: status is 20 +#define RAW_LOG(severity, ...) \ + do { \ + switch (google::GLOG_ ## severity) { \ + case 0: \ + RAW_LOG_INFO(__VA_ARGS__); \ + break; \ + case 1: \ + RAW_LOG_WARNING(__VA_ARGS__); \ + break; \ + case 2: \ + RAW_LOG_ERROR(__VA_ARGS__); \ + break; \ + case 3: \ + RAW_LOG_FATAL(__VA_ARGS__); \ + break; \ + default: \ + break; \ + } \ + } while (0) + +// The following STRIP_LOG testing is performed in the header file so that it's +// possible to completely compile out the logging code and the log messages. +#if STRIP_LOG == 0 +#define RAW_VLOG(verboselevel, ...) \ + do { \ + if (VLOG_IS_ON(verboselevel)) { \ + RAW_LOG_INFO(__VA_ARGS__); \ + } \ + } while (0) +#else +#define RAW_VLOG(verboselevel, ...) RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG == 0 + +#if STRIP_LOG == 0 +#define RAW_LOG_INFO(...) google::RawLog__(google::GLOG_INFO, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_INFO(...) google::RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG == 0 + +#if STRIP_LOG <= 1 +#define RAW_LOG_WARNING(...) google::RawLog__(google::GLOG_WARNING, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_WARNING(...) google::RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG <= 1 + +#if STRIP_LOG <= 2 +#define RAW_LOG_ERROR(...) google::RawLog__(google::GLOG_ERROR, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_ERROR(...) google::RawLogStub__(0, __VA_ARGS__) +#endif // STRIP_LOG <= 2 + +#if STRIP_LOG <= 3 +#define RAW_LOG_FATAL(...) google::RawLog__(google::GLOG_FATAL, \ + __FILE__, __LINE__, __VA_ARGS__) +#else +#define RAW_LOG_FATAL(...) \ + do { \ + google::RawLogStub__(0, __VA_ARGS__); \ + exit(1); \ + } while (0) +#endif // STRIP_LOG <= 3 + +// Similar to CHECK(condition) << message, +// but for low-level modules: we use only RAW_LOG that does not allocate memory. +// We do not want to provide args list here to encourage this usage: +// if (!cond) RAW_LOG(FATAL, "foo ...", hard_to_compute_args); +// so that the args are not computed when not needed. +#define RAW_CHECK(condition, message) \ + do { \ + if (!(condition)) { \ + RAW_LOG(FATAL, "Check %s failed: %s", #condition, message); \ + } \ + } while (0) + +// Debug versions of RAW_LOG and RAW_CHECK +#ifndef NDEBUG + +#define RAW_DLOG(severity, ...) RAW_LOG(severity, __VA_ARGS__) +#define RAW_DCHECK(condition, message) RAW_CHECK(condition, message) + +#else // NDEBUG + +#define RAW_DLOG(severity, ...) \ + while (false) \ + RAW_LOG(severity, __VA_ARGS__) +#define RAW_DCHECK(condition, message) \ + while (false) \ + RAW_CHECK(condition, message) + +#endif // NDEBUG + +// Stub log function used to work around for unused variable warnings when +// building with STRIP_LOG > 0. +static inline void RawLogStub__(int /* ignored */, ...) { +} + +// Helper function to implement RAW_LOG and RAW_VLOG +// Logs format... at "severity" level, reporting it +// as called from file:line. +// This does not allocate memory or acquire locks. +GOOGLE_GLOG_DLL_DECL void RawLog__(LogSeverity severity, + const char* file, + int line, + const char* format, ...) + ; + +} + +#endif // BASE_RAW_LOGGING_H_ diff --git a/projects/llm_framework/include/glog/stl_logging.h b/projects/llm_framework/include/glog/stl_logging.h new file mode 100644 index 00000000..40a15aa4 --- /dev/null +++ b/projects/llm_framework/include/glog/stl_logging.h @@ -0,0 +1,220 @@ +// Copyright (c) 2003, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Stream output operators for STL containers; to be used for logging *only*. +// Inclusion of this file lets you do: +// +// list x; +// LOG(INFO) << "data: " << x; +// vector v1, v2; +// CHECK_EQ(v1, v2); +// +// If you want to use this header file with hash maps or slist, you +// need to define macros before including this file: +// +// - GLOG_STL_LOGGING_FOR_UNORDERED - and +// - GLOG_STL_LOGGING_FOR_TR1_UNORDERED - +// - GLOG_STL_LOGGING_FOR_EXT_HASH - +// - GLOG_STL_LOGGING_FOR_EXT_SLIST - +// + +#ifndef UTIL_GTL_STL_LOGGING_INL_H_ +#define UTIL_GTL_STL_LOGGING_INL_H_ + +#if !1 +# error We do not support stl_logging for this compiler +#endif + +#include +#include +#include +#include +#include +#include +#include + +#ifdef GLOG_STL_LOGGING_FOR_UNORDERED +# include +# include +#endif + +#ifdef GLOG_STL_LOGGING_FOR_TR1_UNORDERED +# include +# include +#endif + +#ifdef GLOG_STL_LOGGING_FOR_EXT_HASH +# include +# include +#endif +#ifdef GLOG_STL_LOGGING_FOR_EXT_SLIST +# include +#endif + +// Forward declare these two, and define them after all the container streams +// operators so that we can recurse from pair -> container -> container -> pair +// properly. +template +std::ostream& operator<<(std::ostream& out, const std::pair& p); + +namespace google { + +template +void PrintSequence(std::ostream& out, Iter begin, Iter end); + +} + +#define OUTPUT_TWO_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +OUTPUT_TWO_ARG_CONTAINER(std::vector) +OUTPUT_TWO_ARG_CONTAINER(std::deque) +OUTPUT_TWO_ARG_CONTAINER(std::list) +#ifdef GLOG_STL_LOGGING_FOR_EXT_SLIST +OUTPUT_TWO_ARG_CONTAINER(__gnu_cxx::slist) +#endif + +#undef OUTPUT_TWO_ARG_CONTAINER + +#define OUTPUT_THREE_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +OUTPUT_THREE_ARG_CONTAINER(std::set) +OUTPUT_THREE_ARG_CONTAINER(std::multiset) + +#undef OUTPUT_THREE_ARG_CONTAINER + +#define OUTPUT_FOUR_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +OUTPUT_FOUR_ARG_CONTAINER(std::map) +OUTPUT_FOUR_ARG_CONTAINER(std::multimap) +#ifdef GLOG_STL_LOGGING_FOR_UNORDERED +OUTPUT_FOUR_ARG_CONTAINER(std::unordered_set) +OUTPUT_FOUR_ARG_CONTAINER(std::unordered_multiset) +#endif +#ifdef GLOG_STL_LOGGING_FOR_TR1_UNORDERED +OUTPUT_FOUR_ARG_CONTAINER(std::tr1::unordered_set) +OUTPUT_FOUR_ARG_CONTAINER(std::tr1::unordered_multiset) +#endif +#ifdef GLOG_STL_LOGGING_FOR_EXT_HASH +OUTPUT_FOUR_ARG_CONTAINER(__gnu_cxx::hash_set) +OUTPUT_FOUR_ARG_CONTAINER(__gnu_cxx::hash_multiset) +#endif + +#undef OUTPUT_FOUR_ARG_CONTAINER + +#define OUTPUT_FIVE_ARG_CONTAINER(Sequence) \ +template \ +inline std::ostream& operator<<(std::ostream& out, \ + const Sequence& seq) { \ + google::PrintSequence(out, seq.begin(), seq.end()); \ + return out; \ +} + +#ifdef GLOG_STL_LOGGING_FOR_UNORDERED +OUTPUT_FIVE_ARG_CONTAINER(std::unordered_map) +OUTPUT_FIVE_ARG_CONTAINER(std::unordered_multimap) +#endif +#ifdef GLOG_STL_LOGGING_FOR_TR1_UNORDERED +OUTPUT_FIVE_ARG_CONTAINER(std::tr1::unordered_map) +OUTPUT_FIVE_ARG_CONTAINER(std::tr1::unordered_multimap) +#endif +#ifdef GLOG_STL_LOGGING_FOR_EXT_HASH +OUTPUT_FIVE_ARG_CONTAINER(__gnu_cxx::hash_map) +OUTPUT_FIVE_ARG_CONTAINER(__gnu_cxx::hash_multimap) +#endif + +#undef OUTPUT_FIVE_ARG_CONTAINER + +template +inline std::ostream& operator<<(std::ostream& out, + const std::pair& p) { + out << '(' << p.first << ", " << p.second << ')'; + return out; +} + +namespace google { + +template +inline void PrintSequence(std::ostream& out, Iter begin, Iter end) { + // Output at most 100 elements -- appropriate if used for logging. + for (int i = 0; begin != end && i < 100; ++i, ++begin) { + if (i > 0) out << ' '; + out << *begin; + } + if (begin != end) { + out << " ..."; + } +} + +} + +// Note that this is technically undefined behavior! We are adding things into +// the std namespace for a reason though -- we are providing new operations on +// types which are themselves defined with this namespace. Without this, these +// operator overloads cannot be found via ADL. If these definitions are not +// found via ADL, they must be #included before they're used, which requires +// this header to be included before apparently independent other headers. +// +// For example, base/logging.h defines various template functions to implement +// CHECK_EQ(x, y) and stream x and y into the log in the event the check fails. +// It does so via the function template MakeCheckOpValueString: +// template +// void MakeCheckOpValueString(strstream* ss, const T& v) { +// (*ss) << v; +// } +// Because 'glog/logging.h' is included before 'glog/stl_logging.h', +// subsequent CHECK_EQ(v1, v2) for vector<...> typed variable v1 and v2 can only +// find these operator definitions via ADL. +// +// Even this solution has problems -- it may pull unintended operators into the +// namespace as well, allowing them to also be found via ADL, and creating code +// that only works with a particular order of includes. Long term, we need to +// move all of the *definitions* into namespace std, bet we need to ensure no +// one references them first. This lets us take that step. We cannot define them +// in both because that would create ambiguous overloads when both are found. +namespace std { using ::operator<<; } + +#endif // UTIL_GTL_STL_LOGGING_INL_H_ diff --git a/projects/llm_framework/include/glog/vlog_is_on.h b/projects/llm_framework/include/glog/vlog_is_on.h new file mode 100644 index 00000000..02b0b867 --- /dev/null +++ b/projects/llm_framework/include/glog/vlog_is_on.h @@ -0,0 +1,129 @@ +// Copyright (c) 1999, 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: Ray Sidney and many others +// +// Defines the VLOG_IS_ON macro that controls the variable-verbosity +// conditional logging. +// +// It's used by VLOG and VLOG_IF in logging.h +// and by RAW_VLOG in raw_logging.h to trigger the logging. +// +// It can also be used directly e.g. like this: +// if (VLOG_IS_ON(2)) { +// // do some logging preparation and logging +// // that can't be accomplished e.g. via just VLOG(2) << ...; +// } +// +// The truth value that VLOG_IS_ON(level) returns is determined by +// the three verbosity level flags: +// --v= Gives the default maximal active V-logging level; +// 0 is the default. +// Normally positive values are used for V-logging levels. +// --vmodule= Gives the per-module maximal V-logging levels to override +// the value given by --v. +// E.g. "my_module=2,foo*=3" would change the logging level +// for all code in source files "my_module.*" and "foo*.*" +// ("-inl" suffixes are also disregarded for this matching). +// +// SetVLOGLevel helper function is provided to do limited dynamic control over +// V-logging by overriding the per-module settings given via --vmodule flag. +// +// CAVEAT: --vmodule functionality is not available in non gcc compilers. +// + +#ifndef BASE_VLOG_IS_ON_H_ +#define BASE_VLOG_IS_ON_H_ + +#include "glog/log_severity.h" + +// Annoying stuff for windows -- makes sure clients can import these functions +#ifndef GOOGLE_GLOG_DLL_DECL +# if defined(_WIN32) && !defined(__CYGWIN__) +# define GOOGLE_GLOG_DLL_DECL __declspec(dllimport) +# else +# define GOOGLE_GLOG_DLL_DECL +# endif +#endif + +#if defined(__GNUC__) +// We emit an anonymous static int* variable at every VLOG_IS_ON(n) site. +// (Normally) the first time every VLOG_IS_ON(n) site is hit, +// we determine what variable will dynamically control logging at this site: +// it's either FLAGS_v or an appropriate internal variable +// matching the current source file that represents results of +// parsing of --vmodule flag and/or SetVLOGLevel calls. +#define VLOG_IS_ON(verboselevel) \ + __extension__ \ + ({ static google::int32* vlocal__ = &google::kLogSiteUninitialized; \ + google::int32 verbose_level__ = (verboselevel); \ + (*vlocal__ >= verbose_level__) && \ + ((vlocal__ != &google::kLogSiteUninitialized) || \ + (google::InitVLOG3__(&vlocal__, &FLAGS_v, \ + __FILE__, verbose_level__))); }) +#else +// GNU extensions not available, so we do not support --vmodule. +// Dynamic value of FLAGS_v always controls the logging level. +#define VLOG_IS_ON(verboselevel) (FLAGS_v >= (verboselevel)) +#endif + +// Set VLOG(_IS_ON) level for module_pattern to log_level. +// This lets us dynamically control what is normally set by the --vmodule flag. +// Returns the level that previously applied to module_pattern. +// NOTE: To change the log level for VLOG(_IS_ON) sites +// that have already executed after/during InitGoogleLogging, +// one needs to supply the exact --vmodule pattern that applied to them. +// (If no --vmodule pattern applied to them +// the value of FLAGS_v will continue to control them.) +extern GOOGLE_GLOG_DLL_DECL int SetVLOGLevel(const char* module_pattern, + int log_level); + +// Various declarations needed for VLOG_IS_ON above: ========================= + +// Special value used to indicate that a VLOG_IS_ON site has not been +// initialized. We make this a large value, so the common-case check +// of "*vlocal__ >= verbose_level__" in VLOG_IS_ON definition +// passes in such cases and InitVLOG3__ is then triggered. +extern google::int32 kLogSiteUninitialized; + +// Helper routine which determines the logging info for a particalur VLOG site. +// site_flag is the address of the site-local pointer to the controlling +// verbosity level +// site_default is the default to use for *site_flag +// fname is the current source file name +// verbose_level is the argument to VLOG_IS_ON +// We will return the return value for VLOG_IS_ON +// and if possible set *site_flag appropriately. +extern GOOGLE_GLOG_DLL_DECL bool InitVLOG3__( + google::int32** site_flag, + google::int32* site_default, + const char* fname, + google::int32 verbose_level); + +#endif // BASE_VLOG_IS_ON_H_ diff --git a/projects/llm_framework/main_melotts/SConstruct b/projects/llm_framework/main_melotts/SConstruct index 6663ca30..87886e09 100644 --- a/projects/llm_framework/main_melotts/SConstruct +++ b/projects/llm_framework/main_melotts/SConstruct @@ -25,9 +25,12 @@ REQUIREMENTS += ['samplerate'] INCLUDE += [ADir('../include')] INCLUDE += [ADir('src/runner'), ADir('../include/onnxruntime/core/session')] - +LINK_SEARCH_PATH += [ADir('../static_lib/wetext')] LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')] -LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a'] +LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a','-l:libglog.so','-l:libfst.so'] + + +LDFLAGS += [] STATIC_FILES += Glob('mode_*.json') diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-default.json b/projects/llm_framework/main_melotts/mode_melotts-en-default.json index 8b161169..54c64b1c 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-en-default.json +++ b/projects/llm_framework/main_melotts/mode_melotts-en-default.json @@ -21,6 +21,8 @@ "gbin": "g-en-def.bin", "tokens": "tokens-en.txt", "lexicon": "lexicon-en.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", "spacker_speed": 1.2, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-us.json b/projects/llm_framework/main_melotts/mode_melotts-en-us.json index 6a375c93..d6320873 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-en-us.json +++ b/projects/llm_framework/main_melotts/mode_melotts-en-us.json @@ -1,9 +1,9 @@ { "mode": "melotts-en-us", "type": "tts", - "homepage":"https://huggingface.co/myshell-ai/MeloTTS-English", - "compile_flage":"pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder-en --output_name decoder-en.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", - "pulsar_version":"3.4-3dfd5692", + "homepage": "https://huggingface.co/myshell-ai/MeloTTS-English", + "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder-en --output_name decoder-en.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", + "pulsar_version": "3.4-3dfd5692", "capabilities": [ "tts", "English" @@ -21,6 +21,8 @@ "gbin": "g-en.bin", "tokens": "tokens.txt", "lexicon": "lexicon.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", "spacker_speed": 1.0, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json b/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json index d2df3e12..0b93f91e 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json +++ b/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json @@ -21,6 +21,8 @@ "gbin": "g-jp.bin", "tokens": "tokens-jp.txt", "lexicon": "lexicon-jp.txt", + "tagger": "ja_tn_tagger.fst", + "verbalizer": "ja_tn_verbalizer.fst", "spacker_speed": 1.1, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json b/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json index b5edfe02..17867b92 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json +++ b/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json @@ -21,6 +21,8 @@ "gbin": "g-zh_mix_en.bin", "tokens": "tokens.txt", "lexicon": "lexicon.txt", + "tagger": "zh_tn_tagger.fst", + "verbalizer": "zh_tn_verbalizer.fst", "spacker_speed": 1.1, "mode_rate": 44100, "audio_rate": 16000, diff --git a/projects/llm_framework/main_melotts/src/main.cpp b/projects/llm_framework/main_melotts/src/main.cpp index 5c371d97..6400ead6 100644 --- a/projects/llm_framework/main_melotts/src/main.cpp +++ b/projects/llm_framework/main_melotts/src/main.cpp @@ -9,7 +9,6 @@ #include "Lexicon.hpp" #include #include "AudioFile.h" -#include "Lexicon.hpp" #include #include @@ -44,6 +43,8 @@ typedef struct { std::string tokens; std::string gbin; std::string sentence; + std::string tagger; + std::string verbalizer; float spacker_speed = 1.0; int mode_rate = 44100; int audio_rate = 16000; @@ -169,17 +170,22 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], length_scale); CONFIG_AUTO_SET(file_body["mode_param"], noise_scale_w); CONFIG_AUTO_SET(file_body["mode_param"], sdp_ratio); - mode_config_.tokens = base_model + mode_config_.tokens; - mode_config_.gbin = base_model + mode_config_.gbin; - mode_config_.encoder = base_model + mode_config_.encoder; - mode_config_.decoder = base_model + mode_config_.decoder; - mode_config_.lexicon = base_model + mode_config_.lexicon; + CONFIG_AUTO_SET(file_body["mode_param"], tagger); + CONFIG_AUTO_SET(file_body["mode_param"], verbalizer); + mode_config_.tokens = base_model + mode_config_.tokens; + mode_config_.gbin = base_model + mode_config_.gbin; + mode_config_.encoder = base_model + mode_config_.encoder; + mode_config_.decoder = base_model + mode_config_.decoder; + mode_config_.lexicon = base_model + mode_config_.lexicon; + mode_config_.tagger = base_model + mode_config_.tagger; + mode_config_.verbalizer = base_model + mode_config_.verbalizer; if (config_body.contains("awake_delay")) awake_delay_ = config_body["awake_delay"].get(); else if (file_body["mode_param"].contains("awake_delay")) awake_delay_ = file_body["mode_param"]["awake_delay"]; // Load lexicon - lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens); + lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens, mode_config_.tagger, + mode_config_.verbalizer); // Read g.bin g_matrix.resize(256, 0); FILE *fp = fopen(mode_config_.gbin.c_str(), "rb"); diff --git a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp index 29c31817..134f64c4 100644 --- a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp +++ b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp @@ -8,8 +8,10 @@ #include #include #include "../../../../../SDK/components/utilities/include/sample_log.h" +#include "processor/wetext_processor.h" + // Debug logging switch - set to true to enable debug logs -static bool DEBUG_LOGGING = true; +static bool DEBUG_LOGGING = false; // Macro for debug logging #define DEBUG_LOG(fmt, ...) \ do { \ @@ -36,16 +38,23 @@ class Lexicon { std::pair, std::vector> unknown_token; std::unordered_map reverse_tokens; + wetext::Processor* m_processor; + public: // Setter for debug logging static void setDebugLogging(bool enable) { DEBUG_LOGGING = enable; } - Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) : max_phrase_length(0) + Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename, const std::string& tagger_filename, + const std::string& verbalizer_filename) + : max_phrase_length(0) { - DEBUG_LOG("Dictionary loading: %s Pronunciation table loading: %s", tokens_filename.c_str(), - lexicon_filename.c_str()); + DEBUG_LOG("Dictionary loading: %s Pronunciation table loading: %s tagger_filename: %s verbalizer_filename: %s", + tokens_filename.c_str(), lexicon_filename.c_str(), tagger_filename.c_str(), + verbalizer_filename.c_str()); + + m_processor = new wetext::Processor(tagger_filename, verbalizer_filename); std::unordered_map tokens; std::ifstream ifs(tokens_filename); @@ -198,6 +207,12 @@ class Lexicon { void convert(const std::string& text, std::vector& phones, std::vector& tones) { DEBUG_LOG("\nStarting text processing: \"%s\"", text.c_str()); + + std::string taggedText = m_processor->Tag(text); + DEBUG_LOG("\taggedText processing: \"%s\"", taggedText.c_str()); + std::string normalizedText = m_processor->Verbalize(taggedText); + DEBUG_LOG("\normalizedText processing: \"%s\"", normalizedText.c_str()); + DEBUG_LOG("=======Matching Results======="); DEBUG_LOG("Unit\t|\tPhonemes\t|\tTones"); DEBUG_LOG("-----------------------------"); @@ -205,7 +220,7 @@ class Lexicon { tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); DEBUG_LOG("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), tonesToString(unknown_token.second).c_str()); - auto chars = splitEachChar(text); + auto chars = splitEachChar(normalizedText); int i = 0; while (i < chars.size()) { if (is_english(chars[i])) { diff --git a/projects/llm_framework/main_melotts/src/runner/base64.cpp b/projects/llm_framework/main_melotts/src/runner/base64.cpp index 5e0fd6ad..e8e1add3 100644 --- a/projects/llm_framework/main_melotts/src/runner/base64.cpp +++ b/projects/llm_framework/main_melotts/src/runner/base64.cpp @@ -1,17 +1,13 @@ #include "base64.h" static uint8 alphabet_map[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -static uint8 reverse_map[] = -{ -255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, - 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, 255, 255, 255, 63, - 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255, - 255, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, - 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 255, 255, 255, 255, 255, - 255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, - 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255 -}; +static uint8 reverse_map[] = { + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 62, + 255, 255, 255, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 255, 255, 255, 255, 255, 255, 255, 0, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 23, 24, 25, 255, 255, 255, 255, 255, 255, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 255, 255, 255, 255, 255}; // //GB2312到UTF-8的转换 // char* G2U(const char* gb2312) @@ -42,26 +38,27 @@ static uint8 reverse_map[] = // return str; // } -// uint32 base64_encode(char* input, uint8* encode) +// base64_uint32 base64_encode(char* input, uint8* encode) // { // //1、包含中文的字符串 字符编码(windows默认是gbk)转换成unicode - + // //2、字符编码方式是utf-8的二进制 // // uint8* text = (uint8*)G2U(input); -// uint32 text_len = (uint32)strlen((char*)input); +// base64_uint32 text_len = (base64_uint32)strlen((char*)input); -// uint32 i, j; +// base64_uint32 i, j; // for (i = 0, j = 0; i + 3 <= text_len; i += 3) // { -// encode[j++] = alphabet_map[text[i] >> 2]; //取出第一个字符的前6位并找出对应的结果字符 -// encode[j++] = alphabet_map[((text[i] << 4) & 0x30) | (text[i + 1] >> 4)]; //将第一个字符的后2位与第二个字符的前4位进行组合并找到对应的结果字符 -// encode[j++] = alphabet_map[((text[i + 1] << 2) & 0x3c) | (text[i + 2] >> 6)]; //将第二个字符的后4位与第三个字符的前2位组合并找出对应的结果字符 -// encode[j++] = alphabet_map[text[i + 2] & 0x3f]; //取出第三个字符的后6位并找出结果字符 +// encode[j++] = alphabet_map[text[i] >> 2]; //取出第一个字符的前6位并找出对应的结果字符 encode[j++] = +// alphabet_map[((text[i] << 4) & 0x30) | (text[i + 1] >> 4)]; +// //将第一个字符的后2位与第二个字符的前4位进行组合并找到对应的结果字符 encode[j++] = alphabet_map[((text[i + 1] << 2) & +// 0x3c) | (text[i + 2] >> 6)]; //将第二个字符的后4位与第三个字符的前2位组合并找出对应的结果字符 encode[j++] = +// alphabet_map[text[i + 2] & 0x3f]; //取出第三个字符的后6位并找出结果字符 // } // if (i < text_len) // { -// uint32 tail = text_len - i; +// base64_uint32 tail = text_len - i; // if (tail == 1) // { // encode[j++] = alphabet_map[text[i] >> 2]; @@ -81,40 +78,41 @@ static uint8 reverse_map[] = // return j; // } -int base64_decode(const uint8* code, uint32 code_len, char* str) +int base64_decode(const uint8* code, base64_uint32 code_len, char* str) { - uint8 plain[1024]; - assert((code_len & 0x03) == 0); //如果它的条件返回错误,则终止程序执行。4的倍数。 + uint8 plain[1024]; + assert((code_len & 0x03) == 0); // 如果它的条件返回错误,则终止程序执行。4的倍数。 - uint32 i, j = 0; - uint8 quad[4]; - for (i = 0; i < code_len; i += 4) - { - for (uint32 k = 0; k < 4; k++) - { - quad[k] = reverse_map[code[i + k]];//分组,每组四个分别依次转换为base64表内的十进制数 - } + base64_uint32 i, j = 0; + uint8 quad[4]; + for (i = 0; i < code_len; i += 4) { + for (base64_uint32 k = 0; k < 4; k++) { + quad[k] = reverse_map[code[i + k]]; // 分组,每组四个分别依次转换为base64表内的十进制数 + } - assert(quad[0] < 64 && quad[1] < 64); + assert(quad[0] < 64 && quad[1] < 64); - plain[j++] = (quad[0] << 2) | (quad[1] >> 4); //取出第一个字符对应base64表的十进制数的前6位与第二个字符对应base64表的十进制数的前2位进行组合 + plain[j++] = + (quad[0] << 2) | + (quad[1] >> + 4); // 取出第一个字符对应base64表的十进制数的前6位与第二个字符对应base64表的十进制数的前2位进行组合 - if (quad[2] >= 64) - break; - else if (quad[3] >= 64) - { - plain[j++] = (quad[1] << 4) | (quad[2] >> 2); //取出第二个字符对应base64表的十进制数的后4位与第三个字符对应base64表的十进制数的前4位进行组合 - break; - } - else - { - plain[j++] = (quad[1] << 4) | (quad[2] >> 2); - plain[j++] = (quad[2] << 6) | quad[3];//取出第三个字符对应base64表的十进制数的后2位与第4个字符进行组合 - } - } - plain[j] = 0; - // char str[1024] = ""; - strcpy(str, (char*)plain); - // strcpy_s(str, sizeof(plain), U2G(str)); - return j; + if (quad[2] >= 64) + break; + else if (quad[3] >= 64) { + plain[j++] = + (quad[1] << 4) | + (quad[2] >> + 2); // 取出第二个字符对应base64表的十进制数的后4位与第三个字符对应base64表的十进制数的前4位进行组合 + break; + } else { + plain[j++] = (quad[1] << 4) | (quad[2] >> 2); + plain[j++] = (quad[2] << 6) | quad[3]; // 取出第三个字符对应base64表的十进制数的后2位与第4个字符进行组合 + } + } + plain[j] = 0; + // char str[1024] = ""; + strcpy(str, (char*)plain); + // strcpy_s(str, sizeof(plain), U2G(str)); + return j; } \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/src/runner/base64.h b/projects/llm_framework/main_melotts/src/runner/base64.h index 8e3dcb6c..f7a01c7e 100644 --- a/projects/llm_framework/main_melotts/src/runner/base64.h +++ b/projects/llm_framework/main_melotts/src/runner/base64.h @@ -2,10 +2,10 @@ #include #include -#include +#include #include #include -typedef unsigned char uint8; -typedef unsigned long uint32; -// uint32 base64_encode(char* input, uint8* encode); -int base64_decode(const uint8* code, uint32 code_len, char* str); \ No newline at end of file +typedef unsigned char uint8; +typedef unsigned long base64_uint32; +// base64_uint32 base64_encode(char* input, uint8* encode); +int base64_decode(const uint8* code, base64_uint32 code_len, char* str); \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt b/projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt new file mode 100644 index 00000000..a2e8d97c --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/CMakeLists.txt @@ -0,0 +1,13 @@ +add_library(wetext_processor STATIC + wetext_processor.cc + wetext_token_parser.cc +) +if(ANDROID) + target_link_libraries(wetext_processor PUBLIC fst wetext_utils) +else() + if(MSVC) + target_link_libraries(wetext_processor PUBLIC fst wetext_utils) + else() + target_link_libraries(wetext_processor PUBLIC dl fst wetext_utils) + endif() +endif() diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc new file mode 100644 index 00000000..eec45a24 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "processor/wetext_processor.h" + +using fst::StringTokenType; + +namespace wetext { +Processor::Processor(const std::string& tagger_path, const std::string& verbalizer_path) +{ + tagger_.reset(StdVectorFst::Read(tagger_path)); + verbalizer_.reset(StdVectorFst::Read(verbalizer_path)); + compiler_ = std::make_shared>(StringTokenType::BYTE); + printer_ = std::make_shared>(StringTokenType::BYTE); + + if (tagger_path.find("zh_tn_") != tagger_path.npos) { + parse_type_ = ParseType::kZH_TN; + } else if (tagger_path.find("zh_itn_") != tagger_path.npos) { + parse_type_ = ParseType::kZH_ITN; + } else if (tagger_path.find("en_tn_") != tagger_path.npos) { + parse_type_ = ParseType::kEN_TN; + } else if (tagger_path.find("ja_tn_") != tagger_path.npos) { + parse_type_ = ParseType::kZH_TN; // 如果是日语的文件开始,也使用中文的规则进行转换 + } else { + LOG(FATAL) << "Invalid fst prefix, prefix should contain" << " either \"_tn_\" or \"_itn_\"."; + } +} + +std::string Processor::ShortestPath(const StdVectorFst& lattice) +{ + StdVectorFst shortest_path; + fst::ShortestPath(lattice, &shortest_path, 1, true); + + std::string output; + printer_->operator()(shortest_path, &output); + return output; +} + +std::string Processor::Compose(const std::string& input, const StdVectorFst* fst) +{ + StdVectorFst input_fst; + compiler_->operator()(input, &input_fst); + + StdVectorFst lattice; + fst::Compose(input_fst, *fst, &lattice); + return ShortestPath(lattice); +} + +std::string Processor::Tag(const std::string& input) +{ + if (input.empty()) { + return ""; + } + return Compose(input, tagger_.get()); +} + +std::string Processor::Verbalize(const std::string& input) +{ + if (input.empty()) { + return ""; + } + TokenParser parser(parse_type_); + std::string output = parser.Reorder(input); + + output = Compose(output, verbalizer_.get()); + output.erase(std::remove(output.begin(), output.end(), '\0'), output.end()); + return output; +} + +std::string Processor::Normalize(const std::string& input) +{ + return Verbalize(Tag(input)); +} + +} // namespace wetext diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h new file mode 100644 index 00000000..e11d307e --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_processor.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROCESSOR_WETEXT_PROCESSOR_H_ +#define PROCESSOR_WETEXT_PROCESSOR_H_ + +#include +#include + +#include "fst/fstlib.h" + +#include "processor/wetext_token_parser.h" + +using fst::StdArc; +using fst::StdVectorFst; +using fst::StringCompiler; +using fst::StringPrinter; + +namespace wetext { +class Processor { + public: + Processor(const std::string& tagger_path, const std::string& verbalizer_path); + std::string Tag(const std::string& input); + std::string Verbalize(const std::string& input); + std::string Normalize(const std::string& input); + + private: + std::string ShortestPath(const StdVectorFst& lattice); + std::string Compose(const std::string& input, const StdVectorFst* fst); + + ParseType parse_type_; + std::shared_ptr tagger_ = nullptr; + std::shared_ptr verbalizer_ = nullptr; + std::shared_ptr> compiler_ = nullptr; + std::shared_ptr> printer_ = nullptr; +}; + +} // namespace wetext + +#endif // PROCESSOR_WETEXT_PROCESSOR_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc new file mode 100644 index 00000000..a600eead --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.cc @@ -0,0 +1,161 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "processor/wetext_token_parser.h" + +#include "utils/wetext_log.h" +#include "utils/wetext_string.h" + +namespace wetext { +const char EOS[] = ""; +const std::set UTF8_WHITESPACE = {" ", "\t", "\n", "\r", + "\x0b\x0c"}; +const std::set ASCII_LETTERS = { + "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", + "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", "A", "B", + "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", + "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "_"}; +const std::unordered_map> ZH_TN_ORDERS = { + {"date", {"year", "month", "day"}}, + {"fraction", {"denominator", "numerator"}}, + {"measure", {"denominator", "numerator", "value"}}, + {"money", {"value", "currency"}}, + {"time", {"noon", "hour", "minute", "second"}}}; +const std::unordered_map> EN_TN_ORDERS = { + {"date", {"preserve_order", "text", "day", "month", "year"}}, + {"money", {"integer_part", "fractional_part", "quantity", "currency_maj"}}}; +const std::unordered_map> ZH_ITN_ORDERS = + {{"date", {"year", "month", "day"}}, + {"fraction", {"sign", "numerator", "denominator"}}, + {"measure", {"numerator", "denominator", "value"}}, + {"money", {"currency", "value", "decimal"}}, + {"time", {"hour", "minute", "second", "noon"}}}; + +TokenParser::TokenParser(ParseType type) { + if (type == ParseType::kZH_TN) { + orders_ = ZH_TN_ORDERS; + } else if (type == ParseType::kZH_ITN) { + orders_ = ZH_ITN_ORDERS; + } else if (type == ParseType::kEN_TN) { + orders_ = EN_TN_ORDERS; + } else { + LOG(FATAL) << "Invalid order"; + } +} + +void TokenParser::Load(const std::string& input) { + wetext::SplitUTF8StringToChars(input, &text_); + CHECK_GT(text_.size(), 0); + index_ = 0; + ch_ = text_[0]; +} + +bool TokenParser::Read() { + if (index_ < text_.size() - 1) { + index_ += 1; + ch_ = text_[index_]; + return true; + } + ch_ = EOS; + return false; +} + +bool TokenParser::ParseWs() { + bool not_eos = ch_ != EOS; + while (not_eos && ch_ == " ") { + not_eos = Read(); + } + return not_eos; +} + +bool TokenParser::ParseChar(const std::string& exp) { + if (ch_ == exp) { + Read(); + return true; + } + return false; +} + +bool TokenParser::ParseChars(const std::string& exp) { + bool ok = false; + std::vector chars; + wetext::SplitUTF8StringToChars(exp, &chars); + for (const auto& x : chars) { + ok |= ParseChar(x); + } + return ok; +} + +std::string TokenParser::ParseKey() { + CHECK_NE(ch_, EOS); + CHECK_EQ(UTF8_WHITESPACE.count(ch_), 0); + + std::string key = ""; + while (ASCII_LETTERS.count(ch_) > 0) { + key += ch_; + Read(); + } + return key; +} + +std::string TokenParser::ParseValue() { + CHECK_NE(ch_, EOS); + bool escape = false; + + std::string value = ""; + while (ch_ != "\"") { + value += ch_; + escape = ch_ == "\\"; + Read(); + if (escape) { + escape = false; + value += ch_; + Read(); + } + } + return value; +} + +void TokenParser::Parse(const std::string& input) { + Load(input); + while (ParseWs()) { + std::string name = ParseKey(); + ParseChars(" { "); + + Token token(name); + while (ParseWs()) { + if (ch_ == "}") { + ParseChar("}"); + break; + } + std::string key = ParseKey(); + ParseChars(": \""); + std::string value = ParseValue(); + ParseChar("\""); + token.Append(key, value); + } + tokens_.emplace_back(token); + } +} + +std::string TokenParser::Reorder(const std::string& input) { + Parse(input); + std::string output = ""; + for (auto& token : tokens_) { + output += token.String(orders_) + " "; + } + return Trim(output); +} + +} // namespace wetext diff --git a/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h new file mode 100644 index 00000000..34aba979 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/processor/wetext_token_parser.h @@ -0,0 +1,94 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PROCESSOR_WETEXT_TOKEN_PARSER_H_ +#define PROCESSOR_WETEXT_TOKEN_PARSER_H_ + +#include +#include +#include +#include + +namespace wetext { + +extern const char EOS[]; +extern const std::set UTF8_WHITESPACE; +extern const std::set ASCII_LETTERS; +extern const std::unordered_map> + ZH_TN_ORDERS; +extern const std::unordered_map> + ZH_ITN_ORDERS; +extern const std::unordered_map> + EN_TN_ORDERS; + +struct Token { + std::string name; + std::vector order; + std::unordered_map members; + + explicit Token(const std::string& name) : name(name) {} + + void Append(const std::string& key, const std::string& value) { + order.emplace_back(key); + members[key] = value; + } + + std::string String( + const std::unordered_map>& orders) { + std::string output = name + " {"; + if (orders.count(name) > 0) { + order = orders.at(name); + } + + for (const auto& key : order) { + if (members.count(key) == 0) { + continue; + } + output += " " + key + ": \"" + members[key] + "\""; + } + return output + " }"; + } +}; + +enum ParseType { + kZH_TN = 0x00, // Chinese Text Normalization + kZH_ITN = 0x01, // Chinese Inverse Text Normalization + kEN_TN = 0x02 // English Text Normalization +}; + +class TokenParser { + public: + explicit TokenParser(ParseType type); + std::string Reorder(const std::string& input); + + private: + void Load(const std::string& input); + bool Read(); + bool ParseWs(); + bool ParseChar(const std::string& exp); + bool ParseChars(const std::string& exp); + std::string ParseKey(); + std::string ParseValue(); + void Parse(const std::string& input); + + int index_; + std::string ch_; + std::vector text_; + std::vector tokens_; + std::unordered_map> orders_; +}; + +} // namespace wetext + +#endif // PROCESSOR_WETEXT_TOKEN_PARSER_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt b/projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt new file mode 100644 index 00000000..30071f4c --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/CMakeLists.txt @@ -0,0 +1,3 @@ +add_library(wetext_utils STATIC wetext_string.cc) + +target_link_libraries(wetext_utils PUBLIC glog) diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h b/projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h new file mode 100644 index 00000000..c1d30df3 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_flags.h @@ -0,0 +1,23 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WETEXT_FLAGS_H_ +#define UTILS_WETEXT_FLAGS_H_ + +// Because openfst is a dynamic library compiled with gflags/glog, we must use +// the gflags/glog from openfst to avoid them linked both statically and +// dynamically into the executable. +#include "fst/flags.h" + +#endif // UTILS_WETEXT_FLAGS_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h b/projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h new file mode 100644 index 00000000..b47a6a48 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_log.h @@ -0,0 +1,23 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WETEXT_LOG_H_ +#define UTILS_WETEXT_LOG_H_ + +// Because openfst is a dynamic library compiled with gflags/glog, we must use +// the gflags/glog from openfst to avoid them linked both statically and +// dynamically into the executable. +#include "fst/log.h" + +#endif // UTILS_WETEXT_LOG_H_ diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc new file mode 100644 index 00000000..4df9ec91 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "utils/wetext_string.h" + +#include "utils/wetext_log.h" + +namespace wetext { +const char* WHITESPACE = " \n\r\t\f\v"; + +int UTF8CharLength(char ch) { + int num_bytes = 1; + CHECK_LE((ch & 0xF8), 0xF0); + if ((ch & 0x80) == 0x00) { + // The first 128 characters (US-ASCII) in UTF-8 format only need one byte. + num_bytes = 1; + } else if ((ch & 0xE0) == 0xC0) { + // The next 1,920 characters need two bytes to encode, + // which covers the remainder of almost all Latin-script alphabets. + num_bytes = 2; + } else if ((ch & 0xF0) == 0xE0) { + // Three bytes are needed for characters in the rest of + // the Basic Multilingual Plane, which contains virtually all characters + // in common use, including most Chinese, Japanese and Korean characters. + num_bytes = 3; + } else if ((ch & 0xF8) == 0xF0) { + // Four bytes are needed for characters in the other planes of Unicode, + // which include less common CJK characters, various historic scripts, + // mathematical symbols, and emoji (pictographic symbols). + num_bytes = 4; + } + return num_bytes; +} + +int UTF8StringLength(const std::string& str) { + int len = 0; + int num_bytes = 1; + for (size_t i = 0; i < str.length(); i += num_bytes) { + num_bytes = UTF8CharLength(str[i]); + ++len; + } + return len; +} + +void SplitUTF8StringToChars(const std::string& str, + std::vector* chars) { + chars->clear(); + int num_bytes = 1; + for (size_t i = 0; i < str.length(); i += num_bytes) { + num_bytes = UTF8CharLength(str[i]); + chars->push_back(str.substr(i, num_bytes)); + } +} + +std::string Ltrim(const std::string& str) { + size_t start = str.find_first_not_of(WHITESPACE); + return (start == std::string::npos) ? "" : str.substr(start); +} + +std::string Rtrim(const std::string& str) { + size_t end = str.find_last_not_of(WHITESPACE); + return end == std::string::npos ? "" : str.substr(0, end + 1); +} + +std::string Trim(const std::string& str) { return Rtrim(Ltrim(str)); } + +void Split(const std::string& str, const std::string& delim, + std::vector* output) { + std::string s = str; + size_t pos = 0; + while ((pos = s.find(delim)) != std::string::npos) { + output->emplace_back(s.substr(0, pos)); + s.erase(0, pos + delim.length()); + } + output->emplace_back(s); +} + +} // namespace wetext diff --git a/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h new file mode 100644 index 00000000..ae890d60 --- /dev/null +++ b/projects/llm_framework/main_melotts/src/runner/utils/wetext_string.h @@ -0,0 +1,42 @@ +// Copyright (c) 2022 Zhendong Peng (pzd17@tsinghua.org.cn) +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef UTILS_WETEXT_STRING_H_ +#define UTILS_WETEXT_STRING_H_ + +#include +#include + +namespace wetext { +extern const char* WHITESPACE; + +int UTF8CharLength(char ch); + +int UTF8StringLength(const std::string& str); + +void SplitUTF8StringToChars(const std::string& str, + std::vector* chars); + +std::string Ltrim(const std::string& str); + +std::string Rtrim(const std::string& str); + +std::string Trim(const std::string& str); + +void Split(const std::string& str, const std::string& delim, + std::vector* output); + +} // namespace wetext + +#endif // UTILS_WETEXT_STRING_H_ From 9a20dd028de0dd4d98da4ab14af576e9f6d68b3f Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 16 May 2025 14:13:49 +0800 Subject: [PATCH 04/79] [update] update melotts, update static_lib verison --- projects/llm_framework/SConstruct | 2 +- projects/llm_framework/main_melotts/SConstruct | 2 +- projects/llm_framework/tools/llm_pack.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/llm_framework/SConstruct b/projects/llm_framework/SConstruct index 5d014ef6..7282f87e 100644 --- a/projects/llm_framework/SConstruct +++ b/projects/llm_framework/SConstruct @@ -5,7 +5,7 @@ import shutil os.environ['SDK_PATH'] = os.path.normpath(str(Path(os.getcwd())/'..'/'..'/'SDK')) os.environ['EXT_COMPONENTS_PATH'] = os.path.normpath(str(Path(os.getcwd())/'..'/'..'/'ext_components')) -version = 'v0.0.7' +version = 'v0.0.8' static_lib = 'static_lib' update = False diff --git a/projects/llm_framework/main_melotts/SConstruct b/projects/llm_framework/main_melotts/SConstruct index 87886e09..b9993031 100644 --- a/projects/llm_framework/main_melotts/SConstruct +++ b/projects/llm_framework/main_melotts/SConstruct @@ -34,7 +34,7 @@ LDFLAGS += [] STATIC_FILES += Glob('mode_*.json') -env['COMPONENTS'].append({'target':'llm_melotts-1.7', +env['COMPONENTS'].append({'target':'llm_melotts-1.8', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index 9fd6c580..3db22980 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -359,7 +359,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-asr':[create_bin_deb,'llm-asr', '1.6', src_folder, revision], 'llm-llm':[create_bin_deb,'llm-llm', '1.8', src_folder, revision], 'llm-tts':[create_bin_deb,'llm-tts', '1.6', src_folder, revision], - 'llm-melotts':[create_bin_deb,'llm-melotts', '1.7', src_folder, revision], + 'llm-melotts':[create_bin_deb,'llm-melotts', '1.8', src_folder, revision], 'llm-camera':[create_bin_deb,'llm-camera', '1.8', src_folder, revision, 'lib-llm'], 'llm-vlm':[create_bin_deb,'llm-vlm', '1.7', src_folder, revision], 'llm-yolo':[create_bin_deb,'llm-yolo', '1.8', src_folder, revision], From daeaf4b830d0d60e5ed9456f53e38de38d4ae3d0 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 16 May 2025 16:20:19 +0800 Subject: [PATCH 05/79] [update] update lib-llm version, update melotts model version. --- projects/llm_framework/main/SConstruct | 2 ++ projects/llm_framework/main_melotts/SConstruct | 6 ++---- projects/llm_framework/tools/llm_pack.py | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/projects/llm_framework/main/SConstruct b/projects/llm_framework/main/SConstruct index 77237295..a95f5367 100644 --- a/projects/llm_framework/main/SConstruct +++ b/projects/llm_framework/main/SConstruct @@ -24,6 +24,8 @@ STATIC_FILES += [AFile('../static_lib/sherpa/ncnn/libsherpa-ncnn-core.so'), AFile('../static_lib/sherpa/ncnn/libncnn.so'), AFile('../static_lib/libtts.so'), AFile('../static_lib/sherpa/ncnn/libkaldi-native-fbank-core.so'), + AFile('../static_lib/wetext/libglog.so.0'), + AFile('../static_lib/wetext/libfst.so.16'), ] env['COMPONENTS'].append({'target':'static_file-1.0', diff --git a/projects/llm_framework/main_melotts/SConstruct b/projects/llm_framework/main_melotts/SConstruct index b9993031..8faa8d86 100644 --- a/projects/llm_framework/main_melotts/SConstruct +++ b/projects/llm_framework/main_melotts/SConstruct @@ -27,10 +27,8 @@ INCLUDE += [ADir('../include')] INCLUDE += [ADir('src/runner'), ADir('../include/onnxruntime/core/session')] LINK_SEARCH_PATH += [ADir('../static_lib/wetext')] LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')] -LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a','-l:libglog.so','-l:libfst.so'] - - -LDFLAGS += [] +LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a'] +REQUIREMENTS += ['glog', 'fst'] STATIC_FILES += Glob('mode_*.json') diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index 3db22980..20a04248 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -225,7 +225,7 @@ def create_data_deb(package_name, version, src_folder, revision = 'm5stack1', de shutil.rmtree(deb_folder) return package_name + " creat success!" -def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', depends = 'lib-llm (>= 1.7)'): +def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', depends = 'lib-llm (>= 1.8)'): bin_files = glob.glob(os.path.join(src_folder, package_name.replace("-", "_") + "-*")) version_info = 0.0 print(os.path.join(src_folder, package_name + "-*")) @@ -352,7 +352,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep cpu_count = cpu_count - 2 # cpu_count = 50 Tasks = { - 'lib-llm':[create_lib_deb,'lib-llm', '1.7', src_folder, revision], + 'lib-llm':[create_lib_deb,'lib-llm', '1.8', src_folder, revision], 'llm-sys':[create_bin_deb,'llm-sys', '1.6', src_folder, revision], 'llm-audio':[create_bin_deb,'llm-audio', '1.6', src_folder, revision], 'llm-kws':[create_bin_deb,'llm-kws', '1.7', src_folder, revision], @@ -376,10 +376,10 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-model-sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01':[create_data_deb,'llm-model-sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01', '0.3', src_folder, revision], 'llm-model-single-speaker-english-fast':[create_data_deb,'llm-model-single-speaker-english-fast', '0.3', src_folder, revision], 'llm-model-single-speaker-fast':[create_data_deb,'llm-model-single-speaker-fast', '0.3', src_folder, revision], - 'llm-model-melotts-zh-cn':[create_data_deb,'llm-model-melotts-zh-cn', '0.5', src_folder, revision], + 'llm-model-melotts-zh-cn':[create_data_deb,'llm-model-melotts-zh-cn', '0.6', src_folder, revision], 'llm-model-melotts-en-us':[create_data_deb,'llm-model-melotts-en-us', '0.5', src_folder, revision], - 'llm-model-melotts-en-default':[create_data_deb,'llm-model-melotts-en-default', '0.5', src_folder, revision], - 'llm-model-melotts-ja-jp':[create_data_deb,'llm-model-melotts-ja-jp', '0.5', src_folder, revision], + 'llm-model-melotts-en-default':[create_data_deb,'llm-model-melotts-en-default', '0.6', src_folder, revision], + 'llm-model-melotts-ja-jp':[create_data_deb,'llm-model-melotts-ja-jp', '0.6', src_folder, revision], 'llm-model-yolo11n':[create_data_deb,'llm-model-yolo11n', data_version, src_folder, revision], 'llm-model-yolo11n-pose':[create_data_deb,'llm-model-yolo11n-pose', '0.3', src_folder, revision], 'llm-model-yolo11n-hand-pose':[create_data_deb,'llm-model-yolo11n-hand-pose', '0.3', src_folder, revision], From 00c0533b86a15dacf7a5446cfd1c153b4fd32856 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 16 May 2025 18:48:49 +0800 Subject: [PATCH 06/79] [update] update libonnxruntime.so --- projects/llm_framework/main/SConstruct | 1 + projects/llm_framework/main_kws/SConstruct | 6 ++++-- projects/llm_framework/main_melotts/SConstruct | 4 ++-- projects/llm_framework/main_vad/SConstruct | 7 ++++--- projects/llm_framework/tools/llm_pack.py | 4 ++-- 5 files changed, 13 insertions(+), 9 deletions(-) diff --git a/projects/llm_framework/main/SConstruct b/projects/llm_framework/main/SConstruct index a95f5367..72d44864 100644 --- a/projects/llm_framework/main/SConstruct +++ b/projects/llm_framework/main/SConstruct @@ -26,6 +26,7 @@ STATIC_FILES += [AFile('../static_lib/sherpa/ncnn/libsherpa-ncnn-core.so'), AFile('../static_lib/sherpa/ncnn/libkaldi-native-fbank-core.so'), AFile('../static_lib/wetext/libglog.so.0'), AFile('../static_lib/wetext/libfst.so.16'), + AFile('../static_lib/libonnxruntime.so.1'), ] env['COMPONENTS'].append({'target':'static_file-1.0', diff --git a/projects/llm_framework/main_kws/SConstruct b/projects/llm_framework/main_kws/SConstruct index c09ca41a..b417fc2b 100644 --- a/projects/llm_framework/main_kws/SConstruct +++ b/projects/llm_framework/main_kws/SConstruct @@ -29,10 +29,12 @@ INCLUDE += [ADir('../include/sherpa'), ] LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')] -LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a', +LDFLAGS += ['-l:libcargs.a', '-l:libsherpa-onnx-core.a', '-l:libkaldi-native-fbank-core.a', '-l:libkaldi-decoder-core.a', '-l:libssentencepiece_core.a'] +REQUIREMENTS += ['onnxruntime'] + STATIC_FILES += [os.path.join(python_venv, 'sherpa-onnx')] STATIC_FILES += Glob('llm-kws_text2token.py') STATIC_FILES += Glob('mode_*.json') @@ -55,7 +57,7 @@ ignore['ignore'] = list(set(ignore['ignore'])) with open('../dist/fileignore', 'w') as f: json.dump(ignore, f, indent=4) -env['COMPONENTS'].append({'target':'llm_kws-1.7', +env['COMPONENTS'].append({'target':'llm_kws-1.8', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/main_melotts/SConstruct b/projects/llm_framework/main_melotts/SConstruct index 8faa8d86..69d9d464 100644 --- a/projects/llm_framework/main_melotts/SConstruct +++ b/projects/llm_framework/main_melotts/SConstruct @@ -26,10 +26,10 @@ REQUIREMENTS += ['samplerate'] INCLUDE += [ADir('../include')] INCLUDE += [ADir('src/runner'), ADir('../include/onnxruntime/core/session')] LINK_SEARCH_PATH += [ADir('../static_lib/wetext')] -LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')] -LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a'] REQUIREMENTS += ['glog', 'fst'] +REQUIREMENTS += ['onnxruntime'] + STATIC_FILES += Glob('mode_*.json') env['COMPONENTS'].append({'target':'llm_melotts-1.8', diff --git a/projects/llm_framework/main_vad/SConstruct b/projects/llm_framework/main_vad/SConstruct index 2493e159..bbc289d5 100644 --- a/projects/llm_framework/main_vad/SConstruct +++ b/projects/llm_framework/main_vad/SConstruct @@ -23,12 +23,13 @@ LINK_SEARCH_PATH += [ADir('../static_lib')] INCLUDE += [ADir('../include/sherpa')] LINK_SEARCH_PATH += [ADir('../static_lib/sherpa/onnx')] -LDFLAGS += ['-l:libsherpa-onnx-core.a', - '-l:libonnxruntime.a'] +LDFLAGS += ['-l:libsherpa-onnx-core.a'] + +REQUIREMENTS += ['onnxruntime'] STATIC_FILES += Glob('mode_*.json') -env['COMPONENTS'].append({'target':'llm_vad-1.6', +env['COMPONENTS'].append({'target':'llm_vad-1.7', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index 20a04248..e593f873 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -355,7 +355,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'lib-llm':[create_lib_deb,'lib-llm', '1.8', src_folder, revision], 'llm-sys':[create_bin_deb,'llm-sys', '1.6', src_folder, revision], 'llm-audio':[create_bin_deb,'llm-audio', '1.6', src_folder, revision], - 'llm-kws':[create_bin_deb,'llm-kws', '1.7', src_folder, revision], + 'llm-kws':[create_bin_deb,'llm-kws', '1.8', src_folder, revision], 'llm-asr':[create_bin_deb,'llm-asr', '1.6', src_folder, revision], 'llm-llm':[create_bin_deb,'llm-llm', '1.8', src_folder, revision], 'llm-tts':[create_bin_deb,'llm-tts', '1.6', src_folder, revision], @@ -365,7 +365,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-yolo':[create_bin_deb,'llm-yolo', '1.8', src_folder, revision], 'llm-skel':[create_bin_deb,'llm-skel', version, src_folder, revision], 'llm-depth-anything':[create_bin_deb,'llm-depth-anything', '1.6', src_folder, revision], - 'llm-vad':[create_bin_deb,'llm-vad', '1.6', src_folder, revision], + 'llm-vad':[create_bin_deb,'llm-vad', '1.7', src_folder, revision], 'llm-whisper':[create_bin_deb,'llm-whisper', '1.7', src_folder, revision], 'llm-openai-api':[create_bin_deb,'llm-openai-api', '1.7', src_folder, revision], 'llm-model-audio-en-us':[create_data_deb,'llm-model-audio-en-us', data_version, src_folder, revision], From f775786033855e254d0710651b2d073b21a339f7 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 20 May 2025 14:19:44 +0800 Subject: [PATCH 07/79] [update] add en-au, en-br, en-india, en-us model. Format code. --- .../main_melotts/mode_melotts-en-au.json | 31 +++++++ .../main_melotts/mode_melotts-en-br.json | 31 +++++++ .../main_melotts/mode_melotts-en-india.json | 31 +++++++ .../main_melotts/mode_melotts-en-us.json | 16 ++-- .../llm_framework/main_melotts/src/main.cpp | 9 +- .../main_melotts/src/runner/Lexicon.hpp | 85 ++++++++----------- projects/llm_framework/tools/llm_pack.py | 5 +- 7 files changed, 145 insertions(+), 63 deletions(-) create mode 100644 projects/llm_framework/main_melotts/mode_melotts-en-au.json create mode 100644 projects/llm_framework/main_melotts/mode_melotts-en-br.json create mode 100644 projects/llm_framework/main_melotts/mode_melotts-en-india.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-au.json b/projects/llm_framework/main_melotts/mode_melotts-en-au.json new file mode 100644 index 00000000..17b71a94 --- /dev/null +++ b/projects/llm_framework/main_melotts/mode_melotts-en-au.json @@ -0,0 +1,31 @@ +{ + "mode": "melotts-en-au", + "type": "tts", + "homepage": "https://github.com/ml-inory/melotts.axera/tree/main/model_convert", + "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder_en-au --output_name decoder-en-au.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", + "pulsar_version": "4.0-64a0e58f", + "capabilities": [ + "tts", + "English" + ], + "input_type": [ + "tts.utf-8" + ], + "output_type": [ + "tts.wav", + "sys.play.0_1" + ], + "mode_param": { + "encoder": "encoder-en-au.ort", + "decoder": "decoder-en-au.axmodel", + "gbin": "g-en-au.bin", + "tokens": "tokens-en.txt", + "lexicon": "lexicon-en.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", + "spacker_speed": 1.2, + "mode_rate": 44100, + "audio_rate": 16000, + "awake_delay": 1000 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-br.json b/projects/llm_framework/main_melotts/mode_melotts-en-br.json new file mode 100644 index 00000000..d5e68979 --- /dev/null +++ b/projects/llm_framework/main_melotts/mode_melotts-en-br.json @@ -0,0 +1,31 @@ +{ + "mode": "melotts-en-br", + "type": "tts", + "homepage": "https://github.com/ml-inory/melotts.axera/tree/main/model_convert", + "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder_en-br --output_name decoder-en-br.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", + "pulsar_version": "4.0-64a0e58f", + "capabilities": [ + "tts", + "English" + ], + "input_type": [ + "tts.utf-8" + ], + "output_type": [ + "tts.wav", + "sys.play.0_1" + ], + "mode_param": { + "encoder": "encoder-en-br.ort", + "decoder": "decoder-en-br.axmodel", + "gbin": "g-en-br.bin", + "tokens": "tokens-en.txt", + "lexicon": "lexicon-en.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", + "spacker_speed": 1.2, + "mode_rate": 44100, + "audio_rate": 16000, + "awake_delay": 1000 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-india.json b/projects/llm_framework/main_melotts/mode_melotts-en-india.json new file mode 100644 index 00000000..e39c1318 --- /dev/null +++ b/projects/llm_framework/main_melotts/mode_melotts-en-india.json @@ -0,0 +1,31 @@ +{ + "mode": "melotts-en-india", + "type": "tts", + "homepage": "https://github.com/ml-inory/melotts.axera/tree/main/model_convert", + "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder_en-india --output_name decoder-en-india.axmodel --target_hardware AX620E", + "pulsar_version": "4.0-64a0e58f", + "capabilities": [ + "tts", + "English" + ], + "input_type": [ + "tts.utf-8" + ], + "output_type": [ + "tts.wav", + "sys.play.0_1" + ], + "mode_param": { + "encoder": "encoder-en-india.ort", + "decoder": "decoder-en-india.axmodel", + "gbin": "g-en-india.bin", + "tokens": "tokens-en.txt", + "lexicon": "lexicon-en.txt", + "tagger": "en_tn_tagger.fst", + "verbalizer": "en_tn_verbalizer.fst", + "spacker_speed": 1.2, + "mode_rate": 44100, + "audio_rate": 16000, + "awake_delay": 1000 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-us.json b/projects/llm_framework/main_melotts/mode_melotts-en-us.json index d6320873..52eed57e 100644 --- a/projects/llm_framework/main_melotts/mode_melotts-en-us.json +++ b/projects/llm_framework/main_melotts/mode_melotts-en-us.json @@ -2,8 +2,8 @@ "mode": "melotts-en-us", "type": "tts", "homepage": "https://huggingface.co/myshell-ai/MeloTTS-English", - "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder-en --output_name decoder-en.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", - "pulsar_version": "3.4-3dfd5692", + "compile_flage": "pulsar2 build --input decoder-en.onnx --config config_decoder_u16.json --output_dir decoder_en-us --output_name decoder-en-us.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", + "pulsar_version": "4.0-64a0e58f", "capabilities": [ "tts", "English" @@ -16,14 +16,14 @@ "sys.play.0_1" ], "mode_param": { - "encoder": "encoder-en.ort", - "decoder": "decoder-en.axmodel", - "gbin": "g-en.bin", - "tokens": "tokens.txt", - "lexicon": "lexicon.txt", + "encoder": "encoder-en-us.ort", + "decoder": "decoder-en-us.axmodel", + "gbin": "g-en-us.bin", + "tokens": "tokens-en.txt", + "lexicon": "lexicon-en.txt", "tagger": "en_tn_tagger.fst", "verbalizer": "en_tn_verbalizer.fst", - "spacker_speed": 1.0, + "spacker_speed": 1.2, "mode_rate": 44100, "audio_rate": 16000, "awake_delay": 1000 diff --git a/projects/llm_framework/main_melotts/src/main.cpp b/projects/llm_framework/main_melotts/src/main.cpp index 0875be97..f80d9074 100644 --- a/projects/llm_framework/main_melotts/src/main.cpp +++ b/projects/llm_framework/main_melotts/src/main.cpp @@ -190,7 +190,7 @@ class llm_task { g_matrix.resize(256, 0); FILE *fp = fopen(mode_config_.gbin.c_str(), "rb"); if (!fp) { - printf("Open %s failed!\n", mode_config_.gbin.c_str()); + SLOGE("Open %s failed!", mode_config_.gbin.c_str()); return -3; } fread(g_matrix.data(), sizeof(float), g_matrix.size(), fp); @@ -198,11 +198,11 @@ class llm_task { encoder_ = std::make_unique(); decoder_ = std::make_unique(); if (0 != encoder_->Init(mode_config_.encoder)) { - printf("encoder init failed!\n"); + SLOGE("encoder init failed!"); return -4; } if (0 != decoder_->Init(mode_config_.decoder.c_str())) { - printf("Init decoder model failed!\n"); + SLOGE("Init decoder model failed!"); return -5; } } catch (...) { @@ -398,7 +398,6 @@ class llm_task { } } - int aligned_start = audio_start + best_offset; std::vector crossfade_region(sola_buffer_frame); @@ -457,7 +456,6 @@ class llm_task { pcmlist.resize(audio_len); } - double src_ratio = static_cast(mode_config_.audio_rate) / static_cast(mode_config_.mode_rate); std::vector tmp_pcm((pcmlist.size() * src_ratio + 1)); @@ -465,7 +463,6 @@ class llm_task { resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio); - wav_pcm_data.reserve(len); std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data), [](const auto val) { return static_cast(val * INT16_MAX); }); diff --git a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp index 134f64c4..de28da8d 100644 --- a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp +++ b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp @@ -10,15 +10,6 @@ #include "../../../../../SDK/components/utilities/include/sample_log.h" #include "processor/wetext_processor.h" -// Debug logging switch - set to true to enable debug logs -static bool DEBUG_LOGGING = false; -// Macro for debug logging -#define DEBUG_LOG(fmt, ...) \ - do { \ - if (DEBUG_LOGGING) { \ - SLOGI(fmt, ##__VA_ARGS__); \ - } \ - } while (0) std::vector split(const std::string& s, char delim) { std::vector result; @@ -31,6 +22,7 @@ std::vector split(const std::string& s, char delim) } return result; } + class Lexicon { private: std::unordered_map, std::vector>> lexicon; @@ -41,18 +33,12 @@ class Lexicon { wetext::Processor* m_processor; public: - // Setter for debug logging - static void setDebugLogging(bool enable) - { - DEBUG_LOGGING = enable; - } Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename, const std::string& tagger_filename, const std::string& verbalizer_filename) : max_phrase_length(0) { - DEBUG_LOG("Dictionary loading: %s Pronunciation table loading: %s tagger_filename: %s verbalizer_filename: %s", - tokens_filename.c_str(), lexicon_filename.c_str(), tagger_filename.c_str(), - verbalizer_filename.c_str()); + SLOGD("Dictionary loading: %s Pronunciation table loading: %s tagger_filename: %s verbalizer_filename: %s", + tokens_filename.c_str(), lexicon_filename.c_str(), tagger_filename.c_str(), verbalizer_filename.c_str()); m_processor = new wetext::Processor(tagger_filename, verbalizer_filename); @@ -106,8 +92,8 @@ class Lexicon { lexicon["。"] = lexicon["."]; lexicon["!"] = lexicon["!"]; lexicon["?"] = lexicon["?"]; - DEBUG_LOG("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), - max_phrase_length); + SLOGD("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), + max_phrase_length); } std::vector splitEachChar(const std::string& text) @@ -136,15 +122,17 @@ class Lexicon { { return s.size() == 1 && ((s[0] >= 'A' && s[0] <= 'Z') || (s[0] >= 'a' && s[0] <= 'z')); } + bool is_english_token_char(const std::string& s) { if (s.size() != 1) return false; char c = s[0]; return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || c == '-' || c == '_'; } + void process_unknown_english(const std::string& word, std::vector& phones, std::vector& tones) { - DEBUG_LOG("Processing unknown term: %s", word.c_str()); + SLOGD("Processing unknown term: %s", word.c_str()); std::string orig_word = word; std::vector parts; std::vector phonetic_parts; @@ -163,7 +151,7 @@ class Lexicon { tones.insert(tones.end(), sub_tones.begin(), sub_tones.end()); parts.push_back(sub_word); phonetic_parts.push_back(phonesToString(sub_phones)); - DEBUG_LOG(" Matched: '%s' -> %s", sub_word.c_str(), phonesToString(sub_phones).c_str()); + SLOGD(" Matched: '%s' -> %s", sub_word.c_str(), phonesToString(sub_phones).c_str()); start += len; matched = true; break; @@ -180,13 +168,13 @@ class Lexicon { tones.insert(tones.end(), char_tones.begin(), char_tones.end()); parts.push_back(single_char); phonetic_parts.push_back(phonesToString(char_phones)); - DEBUG_LOG(" Single char: '%s' -> %s", single_char.c_str(), phonesToString(char_phones).c_str()); + SLOGD(" Single char: '%s' -> %s", single_char.c_str(), phonesToString(char_phones).c_str()); } else { phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end()); tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); parts.push_back(single_char); phonetic_parts.push_back("_unknown_"); - DEBUG_LOG(" Unknown: '%s'", single_char.c_str()); + SLOGD(" Unknown: '%s'", single_char.c_str()); } start++; } @@ -200,26 +188,25 @@ class Lexicon { parts_str += parts[i]; phonetic_str += phonetic_parts[i]; } - DEBUG_LOG("%s\t|\tDecomposed: %s\t|\tPhonetics: %s", orig_word.c_str(), parts_str.c_str(), - phonetic_str.c_str()); + SLOGD("%s\t|\tDecomposed: %s\t|\tPhonetics: %s", orig_word.c_str(), parts_str.c_str(), phonetic_str.c_str()); } void convert(const std::string& text, std::vector& phones, std::vector& tones) { - DEBUG_LOG("\nStarting text processing: \"%s\"", text.c_str()); + SLOGD("\nStarting text processing: \"%s\"", text.c_str()); std::string taggedText = m_processor->Tag(text); - DEBUG_LOG("\taggedText processing: \"%s\"", taggedText.c_str()); + SLOGD("\taggedText processing: \"%s\"", taggedText.c_str()); std::string normalizedText = m_processor->Verbalize(taggedText); - DEBUG_LOG("\normalizedText processing: \"%s\"", normalizedText.c_str()); + SLOGD("\normalizedText processing: \"%s\"", normalizedText.c_str()); - DEBUG_LOG("=======Matching Results======="); - DEBUG_LOG("Unit\t|\tPhonemes\t|\tTones"); - DEBUG_LOG("-----------------------------"); + SLOGD("=======Matching Results======="); + SLOGD("Unit\t|\tPhonemes\t|\tTones"); + SLOGD("-----------------------------"); phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end()); tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); - DEBUG_LOG("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), - tonesToString(unknown_token.second).c_str()); + SLOGD("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), + tonesToString(unknown_token.second).c_str()); auto chars = splitEachChar(normalizedText); int i = 0; while (i < chars.size()) { @@ -236,8 +223,8 @@ class Lexicon { auto& [eng_phones, eng_tones] = lexicon[eng_word]; phones.insert(phones.end(), eng_phones.begin(), eng_phones.end()); tones.insert(tones.end(), eng_tones.begin(), eng_tones.end()); - DEBUG_LOG("%s\t|\t%s\t|\t%s", orig_word.c_str(), phonesToString(eng_phones).c_str(), - tonesToString(eng_tones).c_str()); + SLOGD("%s\t|\t%s\t|\t%s", orig_word.c_str(), phonesToString(eng_phones).c_str(), + tonesToString(eng_tones).c_str()); } else { process_unknown_english(orig_word, phones, tones); } @@ -256,8 +243,8 @@ class Lexicon { auto& [phrase_phones, phrase_tones] = lexicon[phrase]; phones.insert(phones.end(), phrase_phones.begin(), phrase_phones.end()); tones.insert(tones.end(), phrase_tones.begin(), phrase_tones.end()); - DEBUG_LOG("%s\t|\t%s\t|\t%s", phrase.c_str(), phonesToString(phrase_phones).c_str(), - tonesToString(phrase_tones).c_str()); + SLOGD("%s\t|\t%s\t|\t%s", phrase.c_str(), phonesToString(phrase_phones).c_str(), + tonesToString(phrase_tones).c_str()); i += len; matched = true; break; @@ -279,25 +266,25 @@ class Lexicon { auto& [char_phones, char_tones] = lexicon[s]; phones.insert(phones.end(), char_phones.begin(), char_phones.end()); tones.insert(tones.end(), char_tones.begin(), char_tones.end()); - DEBUG_LOG("%s\t|\t%s\t|\t%s", orig_char.c_str(), phonesToString(char_phones).c_str(), - tonesToString(char_tones).c_str()); + SLOGD("%s\t|\t%s\t|\t%s", orig_char.c_str(), phonesToString(char_phones).c_str(), + tonesToString(char_tones).c_str()); } else { phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end()); tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); - DEBUG_LOG("%s\t|\t%s (Not matched)\t|\t%s", orig_char.c_str(), - phonesToString(unknown_token.first).c_str(), tonesToString(unknown_token.second).c_str()); + SLOGD("%s\t|\t%s (Not matched)\t|\t%s", orig_char.c_str(), + phonesToString(unknown_token.first).c_str(), tonesToString(unknown_token.second).c_str()); } } } phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end()); tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); - DEBUG_LOG("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), - tonesToString(unknown_token.second).c_str()); - DEBUG_LOG("\nProcessing Summary:"); - DEBUG_LOG("Original text: %s", text.c_str()); - DEBUG_LOG("Phonemes: %s", phonesToString(phones).c_str()); - DEBUG_LOG("Tones: %s", tonesToString(tones).c_str()); - DEBUG_LOG("===================="); + SLOGD("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), + tonesToString(unknown_token.second).c_str()); + SLOGD("\nProcessing Summary:"); + SLOGD("Original text: %s", text.c_str()); + SLOGD("Phonemes: %s", phonesToString(phones).c_str()); + SLOGD("Tones: %s", tonesToString(tones).c_str()); + SLOGD("===================="); } private: @@ -316,6 +303,7 @@ class Lexicon { phones.insert(phones.end(), phones_and_tones.first.begin(), phones_and_tones.first.end()); tones.insert(tones.end(), phones_and_tones.second.begin(), phones_and_tones.second.end()); } + std::string phonesToString(const std::vector& phones) { std::string result; @@ -329,6 +317,7 @@ class Lexicon { } return result; } + std::string tonesToString(const std::vector& tones) { std::string result; diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index e593f873..c188bfe8 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -377,8 +377,11 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-model-single-speaker-english-fast':[create_data_deb,'llm-model-single-speaker-english-fast', '0.3', src_folder, revision], 'llm-model-single-speaker-fast':[create_data_deb,'llm-model-single-speaker-fast', '0.3', src_folder, revision], 'llm-model-melotts-zh-cn':[create_data_deb,'llm-model-melotts-zh-cn', '0.6', src_folder, revision], - 'llm-model-melotts-en-us':[create_data_deb,'llm-model-melotts-en-us', '0.5', src_folder, revision], + 'llm-model-melotts-en-au':[create_data_deb,'llm-model-melotts-en-au', '0.6', src_folder, revision], + 'llm-model-melotts-en-br':[create_data_deb,'llm-model-melotts-en-br', '0.6', src_folder, revision], 'llm-model-melotts-en-default':[create_data_deb,'llm-model-melotts-en-default', '0.6', src_folder, revision], + 'llm-model-melotts-en-india':[create_data_deb,'llm-model-melotts-en-india', '0.6', src_folder, revision], + 'llm-model-melotts-en-us':[create_data_deb,'llm-model-melotts-en-us', '0.6', src_folder, revision], 'llm-model-melotts-ja-jp':[create_data_deb,'llm-model-melotts-ja-jp', '0.6', src_folder, revision], 'llm-model-yolo11n':[create_data_deb,'llm-model-yolo11n', data_version, src_folder, revision], 'llm-model-yolo11n-pose':[create_data_deb,'llm-model-yolo11n-pose', '0.3', src_folder, revision], From 8acb179b634afc26c1a723c8dd948524918e65bb Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 20 May 2025 15:06:11 +0800 Subject: [PATCH 08/79] [fix] Handles the situation where Either tagger or verbalizer file does not exist. --- .../llm_framework/main_melotts/src/main.cpp | 12 ++- .../main_melotts/src/runner/Lexicon.hpp | 78 +++++++++++++++++-- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/projects/llm_framework/main_melotts/src/main.cpp b/projects/llm_framework/main_melotts/src/main.cpp index f80d9074..6b8e1c15 100644 --- a/projects/llm_framework/main_melotts/src/main.cpp +++ b/projects/llm_framework/main_melotts/src/main.cpp @@ -183,10 +183,14 @@ class llm_task { awake_delay_ = config_body["awake_delay"].get(); else if (file_body["mode_param"].contains("awake_delay")) awake_delay_ = file_body["mode_param"]["awake_delay"]; - // Load lexicon - lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens, mode_config_.tagger, - mode_config_.verbalizer); - // Read g.bin + + if (!std::filesystem::exists(mode_config_.tagger) || !std::filesystem::exists(mode_config_.verbalizer)) { + SLOGW("Either tagger or verbalizer file does not exist, using alternative lexicon."); + lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens); + } else { + lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens, mode_config_.tagger, + mode_config_.verbalizer); + } g_matrix.resize(256, 0); FILE *fp = fopen(mode_config_.gbin.c_str(), "rb"); if (!fp) { diff --git a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp index de28da8d..1884c929 100644 --- a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp +++ b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp @@ -30,7 +30,7 @@ class Lexicon { std::pair, std::vector> unknown_token; std::unordered_map reverse_tokens; - wetext::Processor* m_processor; + wetext::Processor* m_processor = nullptr; public: Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename, const std::string& tagger_filename, @@ -96,6 +96,65 @@ class Lexicon { max_phrase_length); } + Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) : max_phrase_length(0) + { + SLOGD("Dictionary loading: %s Pronunciation table loading: %s", tokens_filename.c_str(), + lexicon_filename.c_str()); + + std::unordered_map tokens; + std::ifstream ifs(tokens_filename); + assert(ifs.is_open()); + std::string line; + while (std::getline(ifs, line)) { + auto splitted_line = split(line, ' '); + if (splitted_line.size() >= 2) { + int token_id = std::stoi(splitted_line[1]); + tokens.insert({splitted_line[0], token_id}); + reverse_tokens[token_id] = splitted_line[0]; + } + } + ifs.close(); + ifs.open(lexicon_filename); + assert(ifs.is_open()); + while (std::getline(ifs, line)) { + auto splitted_line = split(line, ' '); + if (splitted_line.empty()) continue; + std::string word_or_phrase = splitted_line[0]; + auto chars = splitEachChar(word_or_phrase); + max_phrase_length = std::max(max_phrase_length, chars.size()); + size_t phone_tone_len = splitted_line.size() - 1; + size_t half_len = phone_tone_len / 2; + std::vector phones, tones; + for (size_t i = 0; i < phone_tone_len; i++) { + auto phone_or_tone = splitted_line[i + 1]; + if (i < half_len) { + if (tokens.find(phone_or_tone) != tokens.end()) { + phones.push_back(tokens[phone_or_tone]); + } + } else { + tones.push_back(std::stoi(phone_or_tone)); + } + } + lexicon[word_or_phrase] = std::make_pair(phones, tones); + } + const std::vector punctuation{"!", "?", "…", ",", ".", "'", "-"}; + for (const auto& p : punctuation) { + if (tokens.find(p) != tokens.end()) { + int i = tokens[p]; + lexicon[p] = std::make_pair(std::vector{i}, std::vector{0}); + } + } + assert(tokens.find("_") != tokens.end()); + unknown_token = std::make_pair(std::vector{tokens["_"]}, std::vector{0}); + lexicon[" "] = unknown_token; + lexicon[","] = lexicon[","]; + lexicon["。"] = lexicon["."]; + lexicon["!"] = lexicon["!"]; + lexicon["?"] = lexicon["?"]; + SLOGD("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), + max_phrase_length); + } + std::vector splitEachChar(const std::string& text) { std::vector words; @@ -195,14 +254,17 @@ class Lexicon { { SLOGD("\nStarting text processing: \"%s\"", text.c_str()); - std::string taggedText = m_processor->Tag(text); - SLOGD("\taggedText processing: \"%s\"", taggedText.c_str()); - std::string normalizedText = m_processor->Verbalize(taggedText); - SLOGD("\normalizedText processing: \"%s\"", normalizedText.c_str()); + std::string normalizedText; + if (m_processor) { + std::string taggedText = m_processor->Tag(text); + SLOGD("\taggedText processing: \"%s\"", taggedText.c_str()); + normalizedText = m_processor->Verbalize(taggedText); + SLOGD("\tnormalizedText processing: \"%s\"", normalizedText.c_str()); + } else { + SLOGD("m_processor is not initialized, skipping tag and verbalize steps."); + normalizedText = text; + } - SLOGD("=======Matching Results======="); - SLOGD("Unit\t|\tPhonemes\t|\tTones"); - SLOGD("-----------------------------"); phones.insert(phones.end(), unknown_token.first.begin(), unknown_token.first.end()); tones.insert(tones.end(), unknown_token.second.begin(), unknown_token.second.end()); SLOGD("\t|\t%s\t|\t%s", phonesToString(unknown_token.first).c_str(), From f67506c66238d74278c38cc1bc61961207f29211 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 20 May 2025 16:16:32 +0800 Subject: [PATCH 09/79] [update] update melotts-es-es model --- .../main_melotts/mode_melotts-es-es.json | 29 +++++++++++++++++++ projects/llm_framework/tools/llm_pack.py | 1 + 2 files changed, 30 insertions(+) create mode 100644 projects/llm_framework/main_melotts/mode_melotts-es-es.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-es-es.json b/projects/llm_framework/main_melotts/mode_melotts-es-es.json new file mode 100644 index 00000000..4a65a57e --- /dev/null +++ b/projects/llm_framework/main_melotts/mode_melotts-es-es.json @@ -0,0 +1,29 @@ +{ + "mode": "melotts-es-es", + "type": "tts", + "homepage": "https://github.com/ml-inory/melotts.axera/tree/main/model_convert", + "compile_flage": "pulsar2 build --input decoder-es.onnx --config config_decoder_u16.json --output_dir decoder_es --output_name decoder-es.axmodel --target_hardware AX620E --npu_mode NPU2 --compiler.check 0", + "pulsar_version": "4.0-64a0e58f", + "capabilities": [ + "tts", + "Spanish" + ], + "input_type": [ + "tts.utf-8" + ], + "output_type": [ + "tts.wav", + "sys.play.0_1" + ], + "mode_param": { + "encoder": "encoder-es.ort", + "decoder": "decoder-es.axmodel", + "gbin": "g-es.bin", + "tokens": "tokens-es.txt", + "lexicon": "lexicon-es.txt", + "spacker_speed": 1.2, + "mode_rate": 44100, + "audio_rate": 16000, + "awake_delay": 1000 + } +} \ No newline at end of file diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index c188bfe8..b27cd597 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -383,6 +383,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-model-melotts-en-india':[create_data_deb,'llm-model-melotts-en-india', '0.6', src_folder, revision], 'llm-model-melotts-en-us':[create_data_deb,'llm-model-melotts-en-us', '0.6', src_folder, revision], 'llm-model-melotts-ja-jp':[create_data_deb,'llm-model-melotts-ja-jp', '0.6', src_folder, revision], + 'llm-model-melotts-es-es':[create_data_deb,'llm-model-melotts-es-es', '0.5', src_folder, revision], 'llm-model-yolo11n':[create_data_deb,'llm-model-yolo11n', data_version, src_folder, revision], 'llm-model-yolo11n-pose':[create_data_deb,'llm-model-yolo11n-pose', '0.3', src_folder, revision], 'llm-model-yolo11n-hand-pose':[create_data_deb,'llm-model-yolo11n-hand-pose', '0.3', src_folder, revision], From 764bca125836471f12cd5f274a668066e6bafbac Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 20 May 2025 17:57:37 +0800 Subject: [PATCH 10/79] [update] update model list --- README_zh.md | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/README_zh.md b/README_zh.md index 4216b0aa..e80293e6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -58,7 +58,44 @@ StackFlow 语音助手的主要工作模式: ## 模型列表 | 模型名 | 模型类型 | 模型大小 | 模型能力 | 模型配置文件 | 计算单元 | | :----: | :----: | :----: | :----: | :----: | :----: | +| [silero-vad](https://github.com/snakers4/silero-vad) | VAD | 3.3M | 语音活动检测 | [mode_silero-vad.json](projects/llm_framework/main_vad/mode_silero-vad.json) | CPU | | [sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01](https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.tar.bz2) | KWS | 6.4M | 关键词识别 | [mode_sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.json](projects/llm_framework/main_kws/mode_sherpa-onnx-kws-zipformer-gigaspeech-3.3M-2024-01-01.json) | CPU | +| [sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01](https://github.com/k2-fsa/sherpa-onnx/releases/download/kws-models/sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.tar.bz2) | KWS | 5.7M | 关键词识别 | [mode_sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.json](projects/llm_framework/main_kws/mode_sherpa-onnx-kws-zipformer-wenetspeech-3.3M-2024-01-01.json) | CPU | +| [sherpa-ncnn-streaming-zipformer-20M-2023-02-17](https://huggingface.co/desh2608/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-small) | ASR | 40M | 语音识别 | [mode_sherpa-ncnn-streaming-zipformer-20M-2023-02-17.json](projects/llm_framework/main_asr/mode_sherpa-ncnn-streaming-zipformer-20M-2023-02-17.json) | CPU | +| [sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23](https://github.com/k2-fsa/icefall/tree/master/egs/librispeech/ASR/pruned_transducer_stateless7_streaming) | ASR | 24M | 语音识别 | [mode_sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23.json](projects/llm_framework/main_asr/mode_sherpa-ncnn-streaming-zipformer-zh-14M-2023-02-23.json) | CPU | +| [whisper-tiny](https://huggingface.co/openai/whisper-tiny) | ASR | 201M | 语音识别 | [mode_whisper-tiny.json](projects/llm_framework/main_whisper/mode_whisper-tiny.json) | NPU | +| [whisper-base](https://huggingface.co/openai/whisper-base) | ASR | 309M | 语音识别 | [mode_whisper-base.json](projects/llm_framework/main_whisper/mode_whisper-base.json) | NPU | +| [whisper-small](https://huggingface.co/openai/whisper-small) | ASR | 725M | 语音识别 | [mode_whisper-small.json](projects/llm_framework/main_whisper/mode_whisper-small.json) | NPU | +| [single-speaker-fast](https://github.com/huakunyang/SummerTTS) | TTS | 77M | 语音生成 | [mode_whisper-tiny.json](projects/llm_framework/main_tts/mode_single-speaker-fast.json) | CPU | +| [single-speaker-english-fast](https://github.com/huakunyang/SummerTTS) | TTS | 60M | 语音生成 | [mode_whisper-tiny.json](projects/llm_framework/main_tts/mode_single-speaker-english-fast.json) | CPU | +| [melotts-en-au](https://huggingface.co/myshell-ai/MeloTTS-English) | TTS | 102M | 语音生成 | [mode_melotts-en-au.json](projects/llm_framework/main_melotts/mode_melotts-en-au.json) | NPU | +| [melotts-en-br](https://huggingface.co/myshell-ai/MeloTTS-English) | TTS | 102M | 语音生成 | [mode_melotts-en-au.json](projects/llm_framework/main_melotts/mode_melotts-en-br.json) | NPU | +| [melotts-en-default](https://huggingface.co/myshell-ai/MeloTTS-English) | TTS | 102M | 语音生成 | [mode_melotts-en-india.json](projects/llm_framework/main_melotts/mode_melotts-en-default.json) | NPU | +| [melotts-en-us](https://huggingface.co/myshell-ai/MeloTTS-English) | TTS | 102M | 语音生成 | [mode_melotts-en-au.json](projects/llm_framework/main_melotts/mode_melotts-en-us.json) | NPU | +| [melotts-es-es](https://huggingface.co/myshell-ai/MeloTTS-Spanish) | TTS | 83M | 语音生成 | [mode_melotts-es-es.json](projects/llm_framework/main_melotts/mode_melotts-es-es.json) | NPU | +| [melotts-ja-jp](https://huggingface.co/myshell-ai/MeloTTS-Japanese) | TTS | 83M | 语音生成 | [mode_melotts-ja-jp.json](projects/llm_framework/main_melotts/mode_melotts-ja-jp.json) | NPU | +| [melotts-zh-cn](https://huggingface.co/myshell-ai/MeloTTS-Chinese) | TTS | 86M | 语音生成 | [mode_melotts-zh-cn.json](projects/llm_framework/main_melotts/mode_melotts-zh-cn.json) | NPU | +| [deepseek-r1-1.5B-ax630c](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) | LLM | 2.0G | 文本生成 | [mode_deepseek-r1-1.5B-ax630c.json](projects/llm_framework/main_llm/models/mode_deepseek-r1-1.5B-ax630c.json) | NPU | +| [deepseek-r1-1.5B-p256-ax630c](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B) | LLM | 2.0G | 文本生成 | [mode_deepseek-r1-1.5B-p256-ax630c.json](projects/llm_framework/main_llm/models/mode_deepseek-r1-1.5B-p256-ax630c.json) | NPU | +| [llama3.2-1B-p256-ax630c](https://huggingface.co/meta-llama/Llama-3.2-1B) | LLM | 1.7G | 文本生成 | [mode_llama3.2-1B-p256-ax630c.json](projects/llm_framework/main_llm/models/mode_llama3.2-1B-p256-ax630c.json) | NPU | +| [llama3.2-1B-prefill-ax630c](https://huggingface.co/meta-llama/Llama-3.2-1B) | LLM | 1.7G | 文本生成 | [mode_llama3.2-1B-prefill-ax630c.json](projects/llm_framework/main_llm/models/mode_llama3.2-1B-prefill-ax630c.json) | NPU | +| [openbuddy-llama3.2-1B-ax630c](https://huggingface.co/OpenBuddy/openbuddy-llama3.2-1b-v23.1-131k) | LLM | 1.7G | 文本生成 | [mode_openbuddy-llama3.2-1B-ax630c.json](projects/llm_framework/main_llm/models/mode_openbuddy-llama3.2-1B-ax630c.json) | NPU | +| [qwen2.5-0.5B-Int4-ax630c](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct-GPTQ-Int4) | LLM | 626M | 文本生成 | [mode_qwen2.5-0.5B-Int4-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-Int4-ax630c.json) | NPU | +| [qwen2.5-0.5B-p256-ax630c](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | LLM | 760M | 文本生成 | [mode_qwen2.5-0.5B-p256-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-p256-ax630c.json) | NPU | +| [qwen2.5-0.5B-prefill-20e](https://huggingface.co/Qwen/Qwen2.5-0.5B-Instruct) | LLM | 758M | 文本生成 | [mode_qwen2.5-0.5B-prefill-20e.json](projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-prefill-20e.json) | NPU | +| [qwen2.5-1.5B-Int4-ax630c](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GPTQ-Int4) | LLM | 1.5G | 文本生成 | [mode_qwen2.5-1.5B-Int4-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-Int4-ax630c.json) | NPU | +| [qwen2.5-1.5B-p256-ax630c](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | LLM | 2.0G | 文本生成 | [mode_qwen2.5-1.5B-p256-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-p256-ax630c.json) | NPU | +| [qwen2.5-1.5B-ax630c](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) | LLM | 2.0G | 文本生成 | [mode_qwen2.5-1.5B-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-ax630c.json) | NPU | +| [qwen2.5-coder-0.5B-ax630c](https://huggingface.co/Qwen/Qwen2.5-Coder-0.5B-Instruct) | LLM | 756M | 文本生成 | [mode_qwen2.5-coder-0.5B-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen2.5-coder-0.5B-ax630c.json) | NPU | +| [qwen3-0.6B-ax630c](https://huggingface.co/AXERA-TECH/InternVL2_5-1B) | LLM | 917M | 文本生成 | [mode_qwen3-0.6B-ax630c.json](projects/llm_framework/main_llm/models/mode_qwen3-0.6B-ax630c.json) | NPU | +| [mode_internvl2.5-1B-364-ax630c](https://huggingface.co/Qwen/Qwen3-0.6B) | VLM | 1.2G | 多模态文本生成 | [mode_internvl2.5-1B-364-ax630c.json](projects/llm_framework/main_vlm/models/mode_internvl2.5-1B-364-ax630c.json) | NPU | +| [smolvlm-256M-ax630c](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | VLM | 330M | 多模态文本生成 | [mode_smolvlm-256M-ax630c.json](projects/llm_framework/main_vlm/models/mode_smolvlm-256M-ax630c.json) | NPU | +| [smolvlm-500M-ax630c](https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct) | VLM | 605M | 多模态文本生成 | [mode_smolvlm-256M-ax630c.json](projects/llm_framework/main_vlm/models/mode_smolvlm-500M-ax630c.json) | NPU | +| [yolo11n](https://github.com/ultralytics/ultralytics) | CV | 2.8M | 目标检测 | [mode_yolo11n.json](projects/llm_framework/main_yolo/mode_yolo11n.json) | NPU | +| [yolo11n-seg](https://github.com/ultralytics/ultralytics) | CV | 3.0M | 实例分割 | [mode_yolo11n-seg.json](projects/llm_framework/main_yolo/mode_yolo11n-seg.json) | NPU | +| [yolo11n-pose](https://github.com/ultralytics/ultralytics) | CV | 3.1M | 姿态检测 | [mode_yolo11n-pose.json](projects/llm_framework/main_yolo/mode_yolo11n-pose.json) | NPU | +| [yolo11n-hand-pose](https://github.com/ultralytics/ultralytics) | CV | 3.2M | 姿态检测 | [mode_yolo11n-hand-pose.json](projects/llm_framework/main_yolo/mode_yolo11n-hand-pose.json) | NPU | +| [depth-anything-ax630c](https://github.com/DepthAnything/Depth-Anything-V2) | CV | 29M | 单目深度估计 | [mode_depth-anything-ax630c.json](projects/llm_framework/main_depth_anything/mode_depth-anything-ax630c.json) | NPU | ## 环境要求 ## 当前 StackFlow 的 AI 单元是建立在 AXERA 加速平台之上的,主要的芯片平台为 ax630c、ax650n。系统要求为 ubuntu。 From aa10381ad85c4b6c8b98152cecc4465529a05074 Mon Sep 17 00:00:00 2001 From: nyasu3w <125797829+nyasu3w@users.noreply.github.com> Date: Tue, 3 Jun 2025 23:26:19 +0900 Subject: [PATCH 11/79] add trigger method to llm_kws --- projects/llm_framework/main_kws/src/main.cpp | 42 ++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/projects/llm_framework/main_kws/src/main.cpp b/projects/llm_framework/main_kws/src/main.cpp index 9c05ce82..38e030c7 100644 --- a/projects/llm_framework/main_kws/src/main.cpp +++ b/projects/llm_framework/main_kws/src/main.cpp @@ -247,6 +247,13 @@ class llm_task { } } + void trigger() + { + if (out_callback_) { + out_callback_("True"); + } + } + bool delete_model() { spotter_.reset(); @@ -284,6 +291,9 @@ class llm_kws : public StackFlow { llm_kws() : StackFlow("kws") { task_count_ = 1; + rpc_ctx_->register_rpc_action("trigger", + std::bind(&llm_kws::trigger, this, std::placeholders::_1, std::placeholders::_2)); + } void play_awake_wav(const std::string &wav_file) @@ -529,6 +539,38 @@ class llm_kws : public StackFlow { return 0; } + std::string trigger(pzmq *_pzmq, const std::shared_ptr& rawdata0) + { + const std::string rawdata = rawdata0->string(); + int pos = rawdata.find("{"); +// SLOGI("llm_kws::trigger:json:%s", rawdata.substr(pos).c_str()); + + nlohmann::json error_body; + nlohmann::json data; + try { + data = nlohmann::json::parse(rawdata.substr(pos)); + } catch (...) { + SLOGE("setup json format error."); + error_body["code"] = -2; + error_body["message"] = "json format error."; + send("None", "None", error_body, "kws"); + return LLM_NONE; + } + auto work_id = data["work_id"].get(); + + int work_id_num = sample_get_work_id_num(work_id); + if (llm_task_.find(work_id_num) == llm_task_.end()) { + error_body["code"] = -6; + error_body["message"] = "Unit Does Not Exist"; + send("None", "None", error_body, work_id); + return LLM_NONE; + } + + llm_task_[work_id_num]->trigger(); + return LLM_NONE; + } + + ~llm_kws() { while (1) { From b9401b2817a1874c79cc012302ad4b362e90aea2 Mon Sep 17 00:00:00 2001 From: dianjixz <18637716021@163.com> Date: Thu, 5 Jun 2025 19:28:50 +0800 Subject: [PATCH 12/79] [update] llm trigger Standardization. --- projects/llm_framework/main_kws/src/main.cpp | 88 ++++++++++++-------- 1 file changed, 55 insertions(+), 33 deletions(-) diff --git a/projects/llm_framework/main_kws/src/main.cpp b/projects/llm_framework/main_kws/src/main.cpp index 38e030c7..4ea6abcd 100644 --- a/projects/llm_framework/main_kws/src/main.cpp +++ b/projects/llm_framework/main_kws/src/main.cpp @@ -283,6 +283,7 @@ class llm_task { class llm_kws : public StackFlow { private: + enum { EVENT_TRIGGER = EVENT_EXPORT + 1 }; int task_count_; std::string audio_url_; std::unordered_map> llm_task_; @@ -291,9 +292,13 @@ class llm_kws : public StackFlow { llm_kws() : StackFlow("kws") { task_count_ = 1; - rpc_ctx_->register_rpc_action("trigger", - std::bind(&llm_kws::trigger, this, std::placeholders::_1, std::placeholders::_2)); - + event_queue_.appendListener(EVENT_TRIGGER, std::bind(&llm_kws::trigger, this, std::placeholders::_1)); + rpc_ctx_->register_rpc_action( + "trigger", [this](pzmq *_pzmq, const std::shared_ptr &data) -> std::string { + this->event_queue_.enqueue(EVENT_TRIGGER, + std::make_shared(data->get_param(0), data->get_param(1))); + return LLM_NONE; + }); } void play_awake_wav(const std::string &wav_file) @@ -440,9 +445,11 @@ class llm_kws : public StackFlow { return -1; } - int work_id_num = sample_get_work_id_num(work_id); - auto llm_channel = get_channel(work_id); - auto llm_task_obj = std::make_shared(work_id); + int work_id_num = sample_get_work_id_num(work_id); + auto llm_channel = get_channel(work_id); + auto llm_task_obj = std::make_shared(work_id); + std::weak_ptr _llm_task_obj = llm_task_obj; + std::weak_ptr _llm_channel = llm_channel; nlohmann::json config_body; try { config_body = nlohmann::json::parse(data); @@ -458,17 +465,22 @@ class llm_kws : public StackFlow { llm_channel->set_output(llm_task_obj->enoutput_); llm_channel->set_stream(llm_task_obj->enstream_); llm_task_obj->play_awake_wav = std::bind(&llm_kws::play_awake_wav, this, std::placeholders::_1); - llm_task_obj->set_output([llm_task_obj, llm_channel](const std::string &data) { - llm_channel->send(llm_task_obj->response_format_, true, LLM_NO_ERROR); + llm_task_obj->set_output([_llm_task_obj, _llm_channel](const std::string &data) { + auto llm_task_obj = _llm_task_obj.lock(); + auto llm_channel = _llm_channel.lock(); + if (llm_task_obj && llm_channel) { + llm_channel->send(llm_task_obj->response_format_, true, LLM_NO_ERROR); + } }); for (const auto input : llm_task_obj->inputs_) { if (input.find("sys") != std::string::npos) { - audio_url_ = unit_call("audio", "cap", "None"); - std::weak_ptr _llm_task_obj = llm_task_obj; - llm_channel->subscriber(audio_url_, [_llm_task_obj](pzmq *_pzmq, const std::shared_ptr &raw) { - _llm_task_obj.lock()->sys_pcm_on_data(raw->string()); - }); + audio_url_ = unit_call("audio", "cap", "None"); + llm_channel->subscriber(audio_url_, + [_llm_task_obj](pzmq *_pzmq, const std::shared_ptr &raw) { + auto llm_task_obj = _llm_task_obj.lock(); + if (llm_task_obj) llm_task_obj->sys_pcm_on_data(raw->string()); + }); llm_task_obj->audio_flage_ = true; } else if (input.find("kws") != std::string::npos) { llm_channel->subscriber_work_id( @@ -539,38 +551,48 @@ class llm_kws : public StackFlow { return 0; } - std::string trigger(pzmq *_pzmq, const std::shared_ptr& rawdata0) + std::string trigger(const std::shared_ptr &arg) { - const std::string rawdata = rawdata0->string(); - int pos = rawdata.find("{"); -// SLOGI("llm_kws::trigger:json:%s", rawdata.substr(pos).c_str()); - - nlohmann::json error_body; - nlohmann::json data; - try { - data = nlohmann::json::parse(rawdata.substr(pos)); - } catch (...) { - SLOGE("setup json format error."); - error_body["code"] = -2; - error_body["message"] = "json format error."; - send("None", "None", error_body, "kws"); + std::shared_ptr originalPtr = std::static_pointer_cast(arg); + std::string zmq_url = originalPtr->string(0); + std::string data = originalPtr->string(1); + std::string work_id = sample_json_str_get(data, "work_id"); + if (work_id.length() == 0) { + nlohmann::json out_body; + out_body["request_id"] = sample_json_str_get(data, "request_id"); + out_body["work_id"] = "kws"; + out_body["created"] = time(NULL); + out_body["object"] = ""; + out_body["data"] = ""; + out_body["error"]["code"] = -2; + out_body["error"]["message"] = "json format error."; + pzmq _zmq(zmq_url, ZMQ_PUSH); + std::string out = out_body.dump(); + out += "\n"; + _zmq.send_data(out); return LLM_NONE; } - auto work_id = data["work_id"].get(); int work_id_num = sample_get_work_id_num(work_id); if (llm_task_.find(work_id_num) == llm_task_.end()) { - error_body["code"] = -6; - error_body["message"] = "Unit Does Not Exist"; - send("None", "None", error_body, work_id); + nlohmann::json out_body; + out_body["request_id"] = sample_json_str_get(data, "request_id"); + out_body["work_id"] = "kws"; + out_body["created"] = time(NULL); + out_body["object"] = ""; + out_body["data"] = ""; + out_body["error"]["code"] = -6; + out_body["error"]["message"] = "Unit Does Not Exist"; + pzmq _zmq(zmq_url, ZMQ_PUSH); + std::string out = out_body.dump(); + out += "\n"; + _zmq.send_data(out); return LLM_NONE; } - llm_task_[work_id_num]->trigger(); return LLM_NONE; } - ~llm_kws() { while (1) { From 61c69a33623081885e6fb990120f03a4c86d3c46 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 10 Jun 2025 09:21:08 +0800 Subject: [PATCH 13/79] [update] update docs --- projects/llm_framework/README.md | 2 +- projects/llm_framework/tools/test_tools/test-melo.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/llm_framework/README.md b/projects/llm_framework/README.md index e99981c3..10699700 100644 --- a/projects/llm_framework/README.md +++ b/projects/llm_framework/README.md @@ -98,7 +98,7 @@ send : "action": "setup", "object": "melotts.setup", "data": { - "model": "melotts_zh-cn", + "model": "melotts-zh-cn", "response_format": "sys.pcm", "input": "tts.utf-8", "enoutput": false diff --git a/projects/llm_framework/tools/test_tools/test-melo.py b/projects/llm_framework/tools/test_tools/test-melo.py index b257f2fc..02135bfa 100644 --- a/projects/llm_framework/tools/test_tools/test-melo.py +++ b/projects/llm_framework/tools/test_tools/test-melo.py @@ -45,7 +45,7 @@ def create_melotts_setup_data(request_id="melotts_setup"): "action": "setup", "object": "melotts.setup", "data": { - "model": "melotts_zh-cn", + "model": "melotts-zh-cn", "response_format": "sys.pcm", "input": "tts.utf-8", "enoutput": False From 2d0cd697a9005e7075109b0a6c54475c6cbbd152 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 18 Jun 2025 11:41:23 +0800 Subject: [PATCH 14/79] [update] vlm add task_camera_data --- projects/llm_framework/main_vlm/src/main.cpp | 118 ++++++++++++++++--- 1 file changed, 102 insertions(+), 16 deletions(-) diff --git a/projects/llm_framework/main_vlm/src/main.cpp b/projects/llm_framework/main_vlm/src/main.cpp index b625b5f9..e7aeb36f 100644 --- a/projects/llm_framework/main_vlm/src/main.cpp +++ b/projects/llm_framework/main_vlm/src/main.cpp @@ -15,19 +15,25 @@ #include #include #include "../../../../SDK/components/utilities/include/sample_log.h" +#include "thread_safe_list.h" using namespace StackFlows; int main_exit_flage = 0; static void __sigint(int iSigNo) { - SLOGW("llm_sys will be exit!"); + SLOGW("llm_vlm will be exit!"); main_exit_flage = 1; } static std::string base_model_path_; static std::string base_model_config_path_; +typedef struct { + cv::Mat inference_src; + bool inference_bgr2rgb; +} inference_async_par; + typedef std::function task_callback_t; #define CONFIG_AUTO_SET(obj, key) \ @@ -56,6 +62,8 @@ class llm_task { task_callback_t out_callback_; bool enoutput_; bool enstream_; + bool encamera_; + thread_safe::list async_list_; void set_output(task_callback_t out_callback) { @@ -222,10 +230,46 @@ class llm_task { return oss_prompt.str(); } + int inference_async(cv::Mat &src, bool bgr2rgb = true) + { + if (async_list_.size() < 1) { + inference_async_par par; + par.inference_src = src.clone(); + par.inference_bgr2rgb = bgr2rgb; + async_list_.put(par); + } + return async_list_.size(); + } + + bool inference_raw_yuv(const std::string &msg) + { + if (msg.size() != 320 * 320 * 2) { + throw std::string("img size error"); + } + cv::Mat camera_data(320, 320, CV_8UC2, (void *)msg.data()); + cv::Mat rgb; + cv::cvtColor(camera_data, rgb, cv::COLOR_YUV2RGB_YUYV); + return inference_async(rgb, true) ? false : true; + } + void inference(const std::string &msg) { try { - if (image_data_.empty()) { + if (encamera_) { + inference_async_par par; + async_list_.get(); // discard buffered frames + par = async_list_.get(); + if (par.inference_src.empty()) return; + if (par.inference_bgr2rgb) { + cv::Mat rgb; + cv::cvtColor(par.inference_src, rgb, cv::COLOR_BGR2RGB); + par.inference_src = rgb; + } + lLaMa_->Encode(par.inference_src, img_embed); + lLaMa_->Encode(img_embed, prompt_data_, prompt_complete(msg)); + std::string out = lLaMa_->Run(prompt_data_); + if (out_callback_) out_callback_(out, true); + } else if (image_data_.empty()) { lLaMa_->Encode(prompt_data_, prompt_complete(msg)); std::string out = lLaMa_->Run(prompt_data_); if (out_callback_) out_callback_(out, true); @@ -302,13 +346,13 @@ std::atomic llm_task::next_port_{8090}; #undef CONFIG_AUTO_SET -class llm_llm : public StackFlow { +class llm_vlm : public StackFlow { private: int task_count_; std::unordered_map> llm_task_; public: - llm_llm() : StackFlow("vlm") + llm_vlm() : StackFlow("vlm") { task_count_ = 2; } @@ -447,6 +491,23 @@ class llm_llm : public StackFlow { llm_task_obj->lLaMa_->Stop(); } + void task_camera_data(const std::weak_ptr llm_task_obj_weak, + const std::weak_ptr llm_channel_weak, const std::string &data) + { + nlohmann::json error_body; + auto llm_task_obj = llm_task_obj_weak.lock(); + auto llm_channel = llm_channel_weak.lock(); + if (!(llm_task_obj && llm_channel)) { + SLOGE("Model run failed."); + return; + } + try { + llm_task_obj->inference_raw_yuv(data); + } catch (...) { + SLOGE("data format error"); + } + } + int setup(const std::string &work_id, const std::string &object, const std::string &data) override { nlohmann::json error_body; @@ -476,26 +537,38 @@ class llm_llm : public StackFlow { llm_channel->set_output(llm_task_obj->enoutput_); llm_channel->set_stream(llm_task_obj->enstream_); - llm_task_obj->set_output(std::bind(&llm_llm::task_output, this, std::weak_ptr(llm_task_obj), + llm_task_obj->set_output(std::bind(&llm_vlm::task_output, this, std::weak_ptr(llm_task_obj), std::weak_ptr(llm_channel), std::placeholders::_1, std::placeholders::_2)); for (const auto input : llm_task_obj->inputs_) { if (input.find("vlm") != std::string::npos) { llm_channel->subscriber_work_id( - "", std::bind(&llm_llm::task_user_data, this, std::weak_ptr(llm_task_obj), + "", std::bind(&llm_vlm::task_user_data, this, std::weak_ptr(llm_task_obj), std::weak_ptr(llm_channel), std::placeholders::_1, std::placeholders::_2)); } else if (input.find("asr") != std::string::npos) { llm_channel->subscriber_work_id( - input, std::bind(&llm_llm::task_asr_data, this, std::weak_ptr(llm_task_obj), + input, std::bind(&llm_vlm::task_asr_data, this, std::weak_ptr(llm_task_obj), std::weak_ptr(llm_channel), std::placeholders::_1, std::placeholders::_2)); } else if (input.find("kws") != std::string::npos) { llm_channel->subscriber_work_id( - input, std::bind(&llm_llm::kws_awake, this, std::weak_ptr(llm_task_obj), + input, std::bind(&llm_vlm::kws_awake, this, std::weak_ptr(llm_task_obj), std::weak_ptr(llm_channel), std::placeholders::_1, std::placeholders::_2)); + } else if (input.find("camera") != std::string::npos) { + llm_task_obj->encamera_ = true; + std::string input_url_name = input + ".out_port"; + std::string input_url = unit_call("sys", "sql_select", input_url_name); + if (!input_url.empty()) { + std::weak_ptr _llm_task_obj = llm_task_obj; + std::weak_ptr _llm_channel = llm_channel; + llm_channel->subscriber(input_url, [this, _llm_task_obj, _llm_channel]( + pzmq *_pzmq, const std::shared_ptr &raw) { + this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); + }); + } } } llm_task_[work_id_num] = llm_task_obj; @@ -513,7 +586,7 @@ class llm_llm : public StackFlow { void link(const std::string &work_id, const std::string &object, const std::string &data) override { - SLOGI("llm_llm::link:%s", data.c_str()); + SLOGI("llm_vlm::link:%s", data.c_str()); int ret = 1; nlohmann::json error_body; int work_id_num = sample_get_work_id_num(work_id); @@ -528,15 +601,28 @@ class llm_llm : public StackFlow { if (data.find("asr") != std::string::npos) { ret = llm_channel->subscriber_work_id( data, - std::bind(&llm_llm::task_asr_data, this, std::weak_ptr(llm_task_obj), + std::bind(&llm_vlm::task_asr_data, this, std::weak_ptr(llm_task_obj), std::weak_ptr(llm_channel), std::placeholders::_1, std::placeholders::_2)); llm_task_obj->inputs_.push_back(data); } else if (data.find("kws") != std::string::npos) { ret = llm_channel->subscriber_work_id( data, - std::bind(&llm_llm::kws_awake, this, std::weak_ptr(llm_task_obj), + std::bind(&llm_vlm::kws_awake, this, std::weak_ptr(llm_task_obj), std::weak_ptr(llm_channel), std::placeholders::_1, std::placeholders::_2)); llm_task_obj->inputs_.push_back(data); + } else if (data.find("camera") != std::string::npos) { + llm_task_obj->encamera_ = true; + std::string input_url_name = data + ".out_port"; + std::string input_url = unit_call("sys", "sql_select", input_url_name); + if (!input_url.empty()) { + std::weak_ptr _llm_task_obj = llm_task_obj; + std::weak_ptr _llm_channel = llm_channel; + llm_channel->subscriber( + input_url, [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { + this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); + }); + } + llm_task_obj->inputs_.push_back(data); } if (ret) { error_body["code"] = -20; @@ -550,7 +636,7 @@ class llm_llm : public StackFlow { void unlink(const std::string &work_id, const std::string &object, const std::string &data) override { - SLOGI("llm_llm::unlink:%s", data.c_str()); + SLOGI("llm_vlm::unlink:%s", data.c_str()); int ret = 0; nlohmann::json error_body; int work_id_num = sample_get_work_id_num(work_id); @@ -575,7 +661,7 @@ class llm_llm : public StackFlow { void taskinfo(const std::string &work_id, const std::string &object, const std::string &data) override { - SLOGI("llm_llm::taskinfo:%s", data.c_str()); + SLOGI("llm_vlm::taskinfo:%s", data.c_str()); nlohmann::json req_body; int work_id_num = sample_get_work_id_num(work_id); if (WORK_ID_NONE == work_id_num) { @@ -602,7 +688,7 @@ class llm_llm : public StackFlow { int exit(const std::string &work_id, const std::string &object, const std::string &data) override { - SLOGI("llm_llm::exit:%s", data.c_str()); + SLOGI("llm_vlm::exit:%s", data.c_str()); nlohmann::json error_body; int work_id_num = sample_get_work_id_num(work_id); @@ -621,7 +707,7 @@ class llm_llm : public StackFlow { return 0; } - ~llm_llm() + ~llm_vlm() { while (1) { auto iteam = llm_task_.begin(); @@ -641,7 +727,7 @@ int main(int argc, char *argv[]) signal(SIGTERM, __sigint); signal(SIGINT, __sigint); mkdir("/tmp/llm", 0777); - llm_llm llm; + llm_vlm llm; while (!main_exit_flage) { sleep(1); } From 357a6f11574c27b15ead26b35a8529dab55c833e Mon Sep 17 00:00:00 2001 From: dianjixz <18637716021@163.com> Date: Mon, 23 Jun 2025 11:54:54 +0800 Subject: [PATCH 15/79] [update] llm-camera axera camera add custom_config --- .../llm_framework/main_camera/camera.json | 4 +- .../main_camera/src/axera_camera.c | 55 +- .../main_camera/src/axera_camera.h | 41 +- .../llm_framework/main_camera/src/camera.h | 3 +- .../llm_framework/main_camera/src/main.cpp | 690 +++++++++--------- .../main_camera/src/v4l2_camera.c | 2 +- 6 files changed, 419 insertions(+), 376 deletions(-) diff --git a/projects/llm_framework/main_camera/camera.json b/projects/llm_framework/main_camera/camera.json index 22fbd880..081fe9ba 100644 --- a/projects/llm_framework/main_camera/camera.json +++ b/projects/llm_framework/main_camera/camera.json @@ -13,7 +13,9 @@ "image.jpeg.base64" ], "cap_param": { - "None": "None" + "VinParam.eSysMode": 1, + "VinParam.eHdrMode": 1, + "VinParam.bAiispEnable": 1 }, "jpeg_config_param": { "stVencAttr.enType": 26, diff --git a/projects/llm_framework/main_camera/src/axera_camera.c b/projects/llm_framework/main_camera/src/axera_camera.c index 841a6b56..00172883 100644 --- a/projects/llm_framework/main_camera/src/axera_camera.c +++ b/projects/llm_framework/main_camera/src/axera_camera.c @@ -115,27 +115,7 @@ AX_VIN_CHN_ATTR_T gSc850slChn0Attr = { .tFrameRateCtrl = {AX_INVALID_FRMRATE, AX_INVALID_FRMRATE}, }; -typedef enum { - SAMPLE_VIN_NONE = -1, - SAMPLE_VIN_SINGLE_DUMMY = 0, - SAMPLE_VIN_SINGLE_OS04A10 = 1, - SAMPLE_VIN_DOUBLE_OS04A10 = 2, - SAMPLE_VIN_SINGLE_SC450AI = 3, - SAMPLE_VIN_DOUBLE_SC450AI = 4, - SAMPLE_VIN_DOUBLE_OS04A10_AND_BT656 = 5, - SAMPLE_VIN_SINGLE_S5KJN1SQ03 = 6, - SAMPLE_VIN_SINGLE_OS04A10_DCG_HDR = 7, - SAMPLE_VIN_SINGLE_OS04A10_DCG_VS_HDR = 8, - SYS_CASE_SINGLE_DVP = 20, - SYS_CASE_SINGLE_BT601 = 21, - SYS_CASE_SINGLE_BT656 = 22, - SYS_CASE_SINGLE_BT1120 = 23, - SYS_CASE_SINGLE_LVDS = 24, - SYS_CASE_SINGLE_OS04A10_ONLINE = 25, - SMARTSENS_SC850SL = 13, - SAMPLE_VIN_SINGLE_SC850SL = 26, - SAMPLE_VIN_BUTT -} SAMPLE_VIN_CASE_E; + struct axera_camera_index_t { char name[48]; @@ -157,14 +137,7 @@ struct axera_camera_index_t { {"axera_single_os04a10_online", SYS_CASE_SINGLE_OS04A10_ONLINE}, {"axera_single_sc850sl", SAMPLE_VIN_SINGLE_SC850SL}}; -typedef struct { - SAMPLE_VIN_CASE_E eSysCase; - COMMON_VIN_MODE_E eSysMode; - AX_SNS_HDR_MODE_E eHdrMode; - SAMPLE_LOAD_RAW_NODE_E eLoadRawNode; - AX_BOOL bAiispEnable; - AX_S32 nDumpFrameNum; -} SAMPLE_VIN_PARAM_T; + /* comm pool */ COMMON_SYS_POOL_CFG_T gtSysCommPoolSingleDummySdr[] = { @@ -1194,11 +1167,13 @@ struct axera_camera_t { AX_VIDEO_FRAME_T out_img; int Chn; AX_VENC_CHN_ATTR_T stVencChnAttr; + AX_VENC_CHN_ATTR_T stJPEGVencChnAttr; AX_IVPS_PIPELINE_ATTR_T stPipelineAttr; AX_RTSP_HANDLE pRtspHandle; AX_RTSP_ATTR_T stRtspAttr[MAX_RTSP_MAX_CHANNEL_NUM]; pthread_t venc_thread_id_; int venc_run_; + int venc_jpeg_run_; } axera_obj = {0}; static int camera_capture_callback_set(struct camera_t *camera, vcamera_frame_get pcallback) @@ -1448,6 +1423,12 @@ void init_rtsp(AX_VENC_CHN_ATTR_T *stVencChnAttr) axera_obj.venc_run_ = 1; } +void init_jpeg(AX_VENC_CHN_ATTR_T *stVencChnAttr) +{ + axera_obj.stJPEGVencChnAttr = *stVencChnAttr; + axera_obj.venc_jpeg_run_ = 1; +} + static int SAMPLE_IVPS_Init(AX_S32 nGrpId, camera_t *camera) { AX_S32 s32Ret = 0, nChn; @@ -1611,6 +1592,7 @@ int axera_camera_open_from(camera_t *camera) int Ret = -1; AX_S32 axRet; if (camera == NULL) return -1; + axera_config_t *axera_config = (axera_config_t *)camera->custom_config_; /* Check whether the camera is already open or in an error state */ SLOGI("Open camera %s...", camera->dev_name_); if (camera->state_ & AX_SENSOR_CAM_OPEN) { @@ -1629,9 +1611,9 @@ int axera_camera_open_from(camera_t *camera) return -10; } - axera_obj.VinParam.eSysMode = COMMON_VIN_SENSOR; - axera_obj.VinParam.eHdrMode = AX_SNS_LINEAR_MODE; - axera_obj.VinParam.bAiispEnable = AX_TRUE; + axera_obj.VinParam.eSysMode = axera_config->VinParam.eSysMode; // COMMON_VIN_SENSOR; + axera_obj.VinParam.eHdrMode = axera_config->VinParam.eHdrMode; // AX_SNS_LINEAR_MODE; + axera_obj.VinParam.bAiispEnable = axera_config->VinParam.bAiispEnable; // AX_TRUE; // axera_obj.gCams.tChnAttr __sample_case_config(&axera_obj.gCams, &axera_obj.VinParam, &axera_obj.tCommonArgs, &axera_obj.tPrivArgs); COMMON_SYS_Init(&axera_obj.tCommonArgs); @@ -1682,7 +1664,7 @@ int axera_camera_open_from(camera_t *camera) return -1; } -camera_t *axera_camera_open(const char *pdev_name, int width, int height, int fps) +camera_t *axera_camera_open(const char *pdev_name, int width, int height, int fps, void *config) { int Ret = -1; camera_t *camera = (camera_t *)malloc(sizeof(camera_t)); @@ -1698,9 +1680,10 @@ camera_t *axera_camera_open(const char *pdev_name, int width, int height, int fp memset(camera->dev_name_, 0, CONFIG_DEVNAME_LEN); memcpy(camera->dev_name_, pdev_name, CopyLen); - camera->width_ = width; - camera->height_ = height; - camera->capture_fps_ = fps; + camera->width_ = width; + camera->height_ = height; + camera->capture_fps_ = fps; + camera->custom_config_ = config; Ret = axera_camera_open_from(camera); if (Ret) { diff --git a/projects/llm_framework/main_camera/src/axera_camera.h b/projects/llm_framework/main_camera/src/axera_camera.h index 6b8644d3..124693e2 100644 --- a/projects/llm_framework/main_camera/src/axera_camera.h +++ b/projects/llm_framework/main_camera/src/axera_camera.h @@ -6,15 +6,52 @@ #ifndef AXERA_CAMERA_H #define AXERA_CAMERA_H #include "common_venc.h" +#include "common_vin.h" #if __cplusplus extern "C" { #endif + +typedef enum { + SAMPLE_VIN_NONE = -1, + SAMPLE_VIN_SINGLE_DUMMY = 0, + SAMPLE_VIN_SINGLE_OS04A10 = 1, + SAMPLE_VIN_DOUBLE_OS04A10 = 2, + SAMPLE_VIN_SINGLE_SC450AI = 3, + SAMPLE_VIN_DOUBLE_SC450AI = 4, + SAMPLE_VIN_DOUBLE_OS04A10_AND_BT656 = 5, + SAMPLE_VIN_SINGLE_S5KJN1SQ03 = 6, + SAMPLE_VIN_SINGLE_OS04A10_DCG_HDR = 7, + SAMPLE_VIN_SINGLE_OS04A10_DCG_VS_HDR = 8, + SYS_CASE_SINGLE_DVP = 20, + SYS_CASE_SINGLE_BT601 = 21, + SYS_CASE_SINGLE_BT656 = 22, + SYS_CASE_SINGLE_BT1120 = 23, + SYS_CASE_SINGLE_LVDS = 24, + SYS_CASE_SINGLE_OS04A10_ONLINE = 25, + SMARTSENS_SC850SL = 13, + SAMPLE_VIN_SINGLE_SC850SL = 26, + SAMPLE_VIN_BUTT +} SAMPLE_VIN_CASE_E; + +typedef struct { + SAMPLE_VIN_CASE_E eSysCase; + COMMON_VIN_MODE_E eSysMode; + AX_SNS_HDR_MODE_E eHdrMode; + SAMPLE_LOAD_RAW_NODE_E eLoadRawNode; + AX_BOOL bAiispEnable; + AX_S32 nDumpFrameNum; +} SAMPLE_VIN_PARAM_T; + +typedef struct axera_config_t { + SAMPLE_VIN_PARAM_T VinParam; +} axera_config_t; + /** * Open the axera_camera * @pdev_name Device node * Return value: NULL for failure */ -camera_t* axera_camera_open(const char* pdev_name, int width, int height, int fps); +camera_t* axera_camera_open(const char* pdev_name, int width, int height, int fps, void* config); /** * Open the axera_camera from config @@ -30,7 +67,7 @@ int axera_camera_open_from(camera_t* camera); int axera_camera_close(camera_t* camera); void init_rtsp(AX_VENC_CHN_ATTR_T *stVencChnAttr); -void init_jpeg(); +void init_jpeg(AX_VENC_CHN_ATTR_T *stVencChnAttr); #if __cplusplus } diff --git a/projects/llm_framework/main_camera/src/camera.h b/projects/llm_framework/main_camera/src/camera.h index 8aab5310..607c9ffa 100644 --- a/projects/llm_framework/main_camera/src/camera.h +++ b/projects/llm_framework/main_camera/src/camera.h @@ -36,6 +36,7 @@ typedef struct camera_t { 3 Error */ pthread_t capture_thread_id_; void* ctx_; + void* custom_config_; /** * Set capture frame callback * Return value: 0 for success, -1 for failure @@ -59,7 +60,7 @@ typedef struct camera_t { * @pdev_name Device node * Return value: NULL for failure */ -camera_t* camera_open(const char* pdev_name, int width, int height, int fps); +camera_t* camera_open(const char* pdev_name, int width, int height, int fps, void* config); /** * Open the camera from config diff --git a/projects/llm_framework/main_camera/src/main.cpp b/projects/llm_framework/main_camera/src/main.cpp index dad26993..3565b01d 100644 --- a/projects/llm_framework/main_camera/src/main.cpp +++ b/projects/llm_framework/main_camera/src/main.cpp @@ -61,20 +61,34 @@ static void __sigint(int iSigNo) typedef std::function task_callback_t; -typedef camera_t *(*hal_camera_open_fun)(const char *pdev_name, int width, int height, int fps); +typedef camera_t *(*hal_camera_open_fun)(const char *pdev_name, int width, int height, int fps, void *config); typedef int (*hal_camera_close_fun)(camera_t *camera); +typedef bool (*hal_parse_config_fun)(const nlohmann::json &config_body, const nlohmann::json &file_body, + void **custom_config); + +#define CONFIG_OBJECT stVencChnAttr #define CONFIG_AUTO_SET(obj, key) \ if (config_body.contains(#key)) \ - stVencChnAttr.key = config_body[#key]; \ + CONFIG_OBJECT.key = config_body[#key]; \ else if (obj.contains(#key)) \ - stVencChnAttr.key = obj[#key]; + CONFIG_OBJECT.key = obj[#key]; + +#define CONFIG_AUTO_SET_DEFAULT(obj, key, default_value) \ + if (config_body.contains(#key)) \ + CONFIG_OBJECT.key = config_body[#key]; \ + else if (obj.contains(#key)) \ + CONFIG_OBJECT.key = obj[#key]; \ + else \ + CONFIG_OBJECT.key = default_value; class llm_task { private: camera_t *cam; + void *cam_config; hal_camera_open_fun hal_camera_open; hal_camera_close_fun hal_camera_close; + hal_parse_config_fun hal_parse_config; public: std::string response_format_; @@ -159,6 +173,311 @@ class llm_task { out_callback_ = out_callback; } + static bool parse_axera_config(const nlohmann::json &config_body, const nlohmann::json &file_body, + void **custom_config) + { + std::string rtsp_config; + static axera_config_t axera_config; + memset(&axera_config, 0, sizeof(axera_config_t)); + if (config_body.contains("rtsp")) { + rtsp_config = config_body.at("rtsp"); + } + int frame_width = config_body.at("frame_width"); + int frame_height = config_body.at("frame_height"); + *custom_config = (void *)&axera_config; + + { +#undef CONFIG_OBJECT +#define CONFIG_OBJECT axera_config + CONFIG_AUTO_SET_DEFAULT(file_body["cap_param"], VinParam.eSysMode, COMMON_VIN_SENSOR); + CONFIG_AUTO_SET_DEFAULT(file_body["cap_param"], VinParam.eHdrMode, AX_SNS_LINEAR_MODE); + CONFIG_AUTO_SET_DEFAULT(file_body["cap_param"], VinParam.bAiispEnable, AX_TRUE); + } + + if (rtsp_config.empty() == false) { + AX_VENC_CHN_ATTR_T stVencChnAttr; + memset(&stVencChnAttr, 0, sizeof(AX_VENC_CHN_ATTR_T)); +#undef CONFIG_OBJECT +#define CONFIG_OBJECT stVencChnAttr + if (rtsp_config.find("h264") != std::string::npos) { + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32MaxPicWidth); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32MaxPicHeight); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enMemSource); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32BufSize); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enProfile); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enLevel); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enTier); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32PicWidthSrc); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32PicHeightSrc); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.bEnable); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.s32X); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.s32Y); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.u32Width); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.u32Height); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enRotation); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enLinkMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.bDeBreathEffect); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.bRefRingbuf); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.s32StopWaitTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u8InFifoDepth); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u8OutFifoDepth); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32SliceNum); + CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stAttrH265e.bRcnRefShareBuf); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.enRcMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.s32FirstFrameStartQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stFrameRate.fSrcFrameRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stFrameRate.fDstFrameRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32BitRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MaxIprop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MinIprop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MaxBitRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.enVQ); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MaxBitRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QVbr.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QVbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QVbr.u32TargetBitRate); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxIprop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinIprop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxBitRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32ShortTermStatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermStatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermMaxBitrate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermMinBitrate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32ExtraBitPercent); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermStatTimeUnit); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32IQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32PQp); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32BQp); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.u32Gop); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.u32StatTime); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.u32TargetBitRate); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.enGopMode); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stNormalP.stPicConfig.s32QpOffset); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stNormalP.stPicConfig.f32QpFactor); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicConfig.s32QpOffset); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicConfig.f32QpFactor); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicSpecialConfig.s32QpOffset); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicSpecialConfig.f32QpFactor); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicSpecialConfig.s32Interval); + CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stSvcT.u32GopSize); + } else if (rtsp_config.find("h265") != std::string::npos) { + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32MaxPicWidth); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32MaxPicHeight); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enMemSource); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32BufSize); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enProfile); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enLevel); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enTier); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32PicWidthSrc); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32PicHeightSrc); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.bEnable); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.s32X); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.s32Y); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.u32Width); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.u32Height); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enRotation); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enLinkMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.bDeBreathEffect); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.bRefRingbuf); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.s32StopWaitTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u8InFifoDepth); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u8OutFifoDepth); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32SliceNum); + CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stAttrH265e.bRcnRefShareBuf); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.enRcMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.s32FirstFrameStartQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stFrameRate.fSrcFrameRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stFrameRate.fDstFrameRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32BitRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MaxIprop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MinIprop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MaxBitRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.enVQ); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MaxBitRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QVbr.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QVbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QVbr.u32TargetBitRate); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32StatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinIQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.s32DeBreathQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32IdrQpDeltaRange); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxIprop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinIprop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxBitRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32ShortTermStatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermStatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermMaxBitrate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermMinBitrate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32ExtraBitPercent); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermStatTimeUnit); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.s32IntraQpDelta); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32IQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32PQp); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32BQp); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.u32Gop); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.u32StatTime); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.u32TargetBitRate); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.stQpmapInfo.enCtbRcMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.stQpmapInfo.enQpmapQpType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.stQpmapInfo.enQpmapBlockType); + CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.stQpmapInfo.enQpmapBlockUnit); + + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.enGopMode); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stNormalP.stPicConfig.s32QpOffset); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stNormalP.stPicConfig.f32QpFactor); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicConfig.s32QpOffset); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicConfig.f32QpFactor); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicSpecialConfig.s32QpOffset); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicSpecialConfig.f32QpFactor); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicSpecialConfig.s32Interval); + CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stSvcT.u32GopSize); + } + + try { + std::regex pattern(R"(rtsp\.(\d+)[xX-](\d+)\.h(264|265))"); + std::smatch matches; + if (std::regex_search(rtsp_config, matches, pattern)) { + if (matches.size() >= 3) { + stVencChnAttr.stVencAttr.u32PicWidthSrc = std::stoi(matches[1].str()); + stVencChnAttr.stVencAttr.u32PicHeightSrc = std::stoi(matches[2].str()); + } + } + } catch (...) { + return true; + } + if ((stVencChnAttr.stVencAttr.u32PicWidthSrc < frame_width) || + (stVencChnAttr.stVencAttr.u32PicHeightSrc < frame_height)) { + return true; + } + init_rtsp(&stVencChnAttr); + } + return false; + } + bool parse_config(const nlohmann::json &config_body) { try { @@ -167,9 +486,6 @@ class llm_task { devname_ = config_body.at("input"); frame_width_ = config_body.at("frame_width"); frame_height_ = config_body.at("frame_height"); - if (config_body.contains("rtsp")) { - rtsp_config_ = config_body.at("rtsp"); - } if (config_body.contains("enable_webstream")) { enable_webstream_ = config_body.at("enable_webstream"); } else { @@ -185,340 +501,43 @@ class llm_task { if (devname_.find("/dev/video") != std::string::npos) { hal_camera_open = camera_open; hal_camera_close = camera_close; + hal_parse_config = NULL; } else if (devname_.find("axera_") != std::string::npos) { hal_camera_open = axera_camera_open; hal_camera_close = axera_camera_close; - if (!rtsp_config_.empty()) { - nlohmann::json error_body; - nlohmann::json file_body; - std::string base_model_path; - std::string base_model_config_path; - std::list config_file_paths = - get_config_file_paths(base_model_path, base_model_config_path, "camera"); - try { - for (auto file_name : config_file_paths) { - std::ifstream config_file(file_name); - if (!config_file.is_open()) { - SLOGW("config file :%s miss", file_name.c_str()); - continue; - } - SLOGI("config file :%s read", file_name.c_str()); - config_file >> file_body; - config_file.close(); - break; - } - if (file_body.empty()) { - SLOGE("all config file miss"); - return true; - } - AX_VENC_CHN_ATTR_T stVencChnAttr; - memset(&stVencChnAttr, 0, sizeof(AX_VENC_CHN_ATTR_T)); - if (rtsp_config_.find("h264") != std::string::npos) { - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enType); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32MaxPicWidth); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32MaxPicHeight); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enMemSource); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32BufSize); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enProfile); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enLevel); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enTier); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32PicWidthSrc); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32PicHeightSrc); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.bEnable); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.s32X); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.s32Y); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.u32Width); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stCropCfg.stRect.u32Height); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enRotation); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.enLinkMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.bDeBreathEffect); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.bRefRingbuf); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.s32StopWaitTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u8InFifoDepth); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u8OutFifoDepth); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.u32SliceNum); - CONFIG_AUTO_SET(file_body["h264_config_param"], stVencAttr.stAttrH265e.bRcnRefShareBuf); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.enRcMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.s32FirstFrameStartQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stFrameRate.fSrcFrameRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stFrameRate.fDstFrameRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32BitRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MaxIprop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32MinIprop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Cbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264Cbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264Cbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MaxBitRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.enVQ); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264Vbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264Vbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264Vbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MaxBitRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264AVbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264AVbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264AVbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QVbr.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QVbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QVbr.u32TargetBitRate); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxIprop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MinIprop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32MaxBitRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32ShortTermStatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermStatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermMaxBitrate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermMinBitrate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32ExtraBitPercent); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.u32LongTermStatTimeUnit); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264CVbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264CVbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264CVbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32IQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32PQp); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264FixQp.u32BQp); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.u32Gop); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.u32StatTime); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.u32TargetBitRate); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stRcAttr.stH264QpMap.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264QpMap.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stRcAttr.stH264QpMap.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.enGopMode); - CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stNormalP.stPicConfig.s32QpOffset); - CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stNormalP.stPicConfig.f32QpFactor); - CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicConfig.s32QpOffset); - CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stOneLTR.stPicConfig.f32QpFactor); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stGopAttr.stOneLTR.stPicSpecialConfig.s32QpOffset); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stGopAttr.stOneLTR.stPicSpecialConfig.f32QpFactor); - CONFIG_AUTO_SET(file_body["h264_config_param"], - stGopAttr.stOneLTR.stPicSpecialConfig.s32Interval); - CONFIG_AUTO_SET(file_body["h264_config_param"], stGopAttr.stSvcT.u32GopSize); - } else if (rtsp_config_.find("h265") != std::string::npos) { - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enType); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32MaxPicWidth); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32MaxPicHeight); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enMemSource); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32BufSize); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enProfile); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enLevel); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enTier); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32PicWidthSrc); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32PicHeightSrc); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.bEnable); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.s32X); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.s32Y); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.u32Width); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stCropCfg.stRect.u32Height); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enRotation); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.enLinkMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.bDeBreathEffect); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.bRefRingbuf); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.s32StopWaitTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u8InFifoDepth); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u8OutFifoDepth); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.u32SliceNum); - CONFIG_AUTO_SET(file_body["h265_config_param"], stVencAttr.stAttrH265e.bRcnRefShareBuf); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.enRcMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.s32FirstFrameStartQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stFrameRate.fSrcFrameRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stFrameRate.fDstFrameRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32BitRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MaxIprop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32MinIprop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Cbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265Cbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265Cbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MaxBitRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.enVQ); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265Vbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265Vbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265Vbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MaxBitRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265AVbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265AVbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265AVbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QVbr.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QVbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QVbr.u32TargetBitRate); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32StatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinIQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.s32DeBreathQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32IdrQpDeltaRange); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxIprop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MinIprop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32MaxBitRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32ShortTermStatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermStatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermMaxBitrate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermMinBitrate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32ExtraBitPercent); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.u32LongTermStatTimeUnit); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.s32IntraQpDelta); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265CVbr.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265CVbr.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265CVbr.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32IQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32PQp); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265FixQp.u32BQp); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.u32Gop); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.u32StatTime); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.u32TargetBitRate); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.stQpmapInfo.enCtbRcMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stRcAttr.stH265QpMap.stQpmapInfo.enQpmapQpType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265QpMap.stQpmapInfo.enQpmapBlockType); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stRcAttr.stH265QpMap.stQpmapInfo.enQpmapBlockUnit); - - CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.enGopMode); - CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stNormalP.stPicConfig.s32QpOffset); - CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stNormalP.stPicConfig.f32QpFactor); - CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicConfig.s32QpOffset); - CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stOneLTR.stPicConfig.f32QpFactor); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stGopAttr.stOneLTR.stPicSpecialConfig.s32QpOffset); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stGopAttr.stOneLTR.stPicSpecialConfig.f32QpFactor); - CONFIG_AUTO_SET(file_body["h265_config_param"], - stGopAttr.stOneLTR.stPicSpecialConfig.s32Interval); - CONFIG_AUTO_SET(file_body["h265_config_param"], stGopAttr.stSvcT.u32GopSize); - } - try { - std::regex pattern(R"(rtsp\.(\d+)[xX-](\d+)\.h(264|265))"); - std::smatch matches; - if (std::regex_search(rtsp_config_, matches, pattern)) { - if (matches.size() >= 3) { - stVencChnAttr.stVencAttr.u32PicWidthSrc = std::stoi(matches[1].str()); - stVencChnAttr.stVencAttr.u32PicHeightSrc = std::stoi(matches[2].str()); - } - } - } catch (...) { - return true; - } - if ((stVencChnAttr.stVencAttr.u32PicWidthSrc < frame_width_) || - (stVencChnAttr.stVencAttr.u32PicHeightSrc < frame_height_)) { - return true; + hal_parse_config = llm_task::parse_axera_config; + } else { + return true; + } + { + nlohmann::json error_body; + nlohmann::json file_body; + std::string base_model_path; + std::string base_model_config_path; + std::list config_file_paths = + get_config_file_paths(base_model_path, base_model_config_path, "camera"); + try { + for (auto file_name : config_file_paths) { + std::ifstream config_file(file_name); + if (!config_file.is_open()) { + SLOGW("config file :%s miss", file_name.c_str()); + continue; } - init_rtsp(&stVencChnAttr); - } catch (...) { + SLOGI("config file :%s read", file_name.c_str()); + config_file >> file_body; + config_file.close(); + break; + } + if (file_body.empty()) { + SLOGE("all config file miss"); return true; } + if (hal_parse_config) { + if (hal_parse_config(config_body, file_body, &cam_config)) return true; + } + } catch (...) { + return true; } - } else { - return true; } return false; @@ -530,7 +549,7 @@ class llm_task { return -1; } try { - cam = hal_camera_open(devname_.c_str(), frame_width_, frame_height_, 30); + cam = hal_camera_open(devname_.c_str(), frame_width_, frame_height_, 30, cam_config); if (cam == NULL) { printf("Camera open failed \n"); return -1; @@ -553,7 +572,8 @@ class llm_task { llm_task(const std::string &workid) { - cam = NULL; + cam = NULL; + cam_config = NULL; } void start() diff --git a/projects/llm_framework/main_camera/src/v4l2_camera.c b/projects/llm_framework/main_camera/src/v4l2_camera.c index c9efbf71..33f3fe46 100644 --- a/projects/llm_framework/main_camera/src/v4l2_camera.c +++ b/projects/llm_framework/main_camera/src/v4l2_camera.c @@ -317,7 +317,7 @@ int camera_open_from(camera_t* camera) return -1; } -camera_t* camera_open(const char* pdev_name, int width, int height, int fps) +camera_t* camera_open(const char* pdev_name, int width, int height, int fps, void* config) { int Ret = -1; camera_t* camera = (camera_t*)malloc(sizeof(camera_t)); From 57b143756f24fa2818997a826dc72108160ce46e Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 24 Jun 2025 19:23:49 +0800 Subject: [PATCH 16/79] [update] depth_anything use async inference, move ax_engine init.fix link return value. --- .../main_depth_anything/src/EngineWrapper.cpp | 11 +- .../main_depth_anything/src/EngineWrapper.hpp | 2 +- .../main_depth_anything/src/main.cpp | 115 ++++++++++-------- projects/llm_framework/main_vlm/src/main.cpp | 1 + .../main_yolo/src/EngineWrapper.cpp | 11 +- .../main_yolo/src/EngineWrapper.hpp | 2 +- projects/llm_framework/main_yolo/src/main.cpp | 30 ++--- 7 files changed, 102 insertions(+), 70 deletions(-) diff --git a/projects/llm_framework/main_depth_anything/src/EngineWrapper.cpp b/projects/llm_framework/main_depth_anything/src/EngineWrapper.cpp index 7a7ec61f..21c51e49 100644 --- a/projects/llm_framework/main_depth_anything/src/EngineWrapper.cpp +++ b/projects/llm_framework/main_depth_anything/src/EngineWrapper.cpp @@ -178,10 +178,19 @@ static AX_S32 CheckModelVNpu(const std::string &strModel, const AX_ENGINE_MODEL_ } #endif -int EngineWrapper::Init(const char *strModelPath, uint32_t nNpuType) +int EngineWrapper::Init(const char *strModelPath, uint32_t nNpuType, uint32_t npuMode) { AX_S32 ret = 0; + // 0. Init AX_ENGINE + AX_ENGINE_NPU_ATTR_T npu_attr; + memset(&npu_attr, 0, sizeof(npu_attr)); + npu_attr.eHardMode = static_cast(npuMode); + ret = AX_ENGINE_Init(&npu_attr); + if (0 != ret) { + fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); + } + // 1. load model AX_BOOL bLoadModelUseCmm = AX_TRUE; AX_CHAR *pModelBufferVirAddr = nullptr; diff --git a/projects/llm_framework/main_depth_anything/src/EngineWrapper.hpp b/projects/llm_framework/main_depth_anything/src/EngineWrapper.hpp index c3f848ea..07eb43d6 100644 --- a/projects/llm_framework/main_depth_anything/src/EngineWrapper.hpp +++ b/projects/llm_framework/main_depth_anything/src/EngineWrapper.hpp @@ -30,7 +30,7 @@ class EngineWrapper { Release(); } - int Init(const char* strModelPath, uint32_t nNpuType = 0); + int Init(const char* strModelPath, uint32_t nNpuType = 0, uint32_t npuMode = 0); int SetInput(void* pInput, int index); diff --git a/projects/llm_framework/main_depth_anything/src/main.cpp b/projects/llm_framework/main_depth_anything/src/main.cpp index 7685bc83..ed82a61a 100644 --- a/projects/llm_framework/main_depth_anything/src/main.cpp +++ b/projects/llm_framework/main_depth_anything/src/main.cpp @@ -9,8 +9,8 @@ #include #include #include - #include "../../../../SDK/components/utilities/include/sample_log.h" +#include "thread_safe_list.h" using namespace StackFlows; @@ -27,13 +27,15 @@ static std::string base_model_config_path_; typedef struct { std::string depth_anything_model; std::string model_type = "detect"; - std::vector cls_name; - int img_h = 640; - int img_w = 640; - int cls_num = 80; - float pron_threshold = 0.45f; - float nms_threshold = 0.45; -} yolo_config; + int img_h = 640; + int img_w = 640; + uint32_t npu_type = 0; +} depth_anything_config; + +typedef struct { + cv::Mat inference_src; + bool inference_bgr2rgb; +} inference_async_par; typedef std::function task_callback_t; @@ -46,7 +48,7 @@ typedef std::function task_callback_ class llm_task { private: public: - yolo_config mode_config_; + depth_anything_config mode_config_; std::string model_; std::unique_ptr depth_anything_; std::string response_format_; @@ -57,7 +59,8 @@ class llm_task { static int ax_init_flage_; task_callback_t out_callback_; std::atomic_bool camera_flage_; - std::mutex inference_mtx_; + std::unique_ptr inference_run_; + thread_safe::list async_list_; bool parse_config(const nlohmann::json &config_body) { @@ -113,10 +116,11 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], img_h); CONFIG_AUTO_SET(file_body["mode_param"], img_w); CONFIG_AUTO_SET(file_body["mode_param"], model_type); + CONFIG_AUTO_SET(file_body["mode_param"], npu_type); mode_config_.depth_anything_model = base_model + mode_config_.depth_anything_model; depth_anything_ = std::make_unique(); - if (0 != depth_anything_->Init(mode_config_.depth_anything_model.c_str())) { - SLOGE("Init yolo_model model failed!\n"); + if (0 != depth_anything_->Init(mode_config_.depth_anything_model.c_str(), 0, mode_config_.npu_type)) { + SLOGE("Init depth_anything_model model failed!\n"); return -5; } } catch (...) { @@ -140,54 +144,63 @@ class llm_task { bool inference_decode(const std::string &msg) { - if (inference_mtx_.try_lock()) - std::lock_guard guard(inference_mtx_, std::adopt_lock); - else - return true; cv::Mat src = cv::imdecode(std::vector(msg.begin(), msg.end()), cv::IMREAD_COLOR); if (src.empty()) return true; - return inference(src); + return inference_async(src) ? false : true; } bool inference_raw_yuv(const std::string &msg) { - if (inference_mtx_.try_lock()) - std::lock_guard guard(inference_mtx_, std::adopt_lock); - else - return true; if (msg.size() != mode_config_.img_w * mode_config_.img_h * 2) { throw std::string("img size error"); } cv::Mat camera_data(mode_config_.img_h, mode_config_.img_w, CV_8UC2, (void *)msg.data()); cv::Mat rgb; cv::cvtColor(camera_data, rgb, cv::COLOR_YUV2RGB_YUYV); - return inference(rgb, true); + return inference_async(rgb, true) ? false : true; } bool inference_raw_rgb(const std::string &msg) { - if (inference_mtx_.try_lock()) - std::lock_guard guard(inference_mtx_, std::adopt_lock); - else - return true; if (msg.size() != mode_config_.img_w * mode_config_.img_h * 3) { throw std::string("img size error"); } cv::Mat camera_data(mode_config_.img_h, mode_config_.img_w, CV_8UC3, (void *)msg.data()); - return inference(camera_data, false); + return inference_async(camera_data, false) ? false : true; } bool inference_raw_bgr(const std::string &msg) { - if (inference_mtx_.try_lock()) - std::lock_guard guard(inference_mtx_, std::adopt_lock); - else - return true; if (msg.size() != mode_config_.img_w * mode_config_.img_h * 3) { throw std::string("img size error"); } cv::Mat camera_data(mode_config_.img_h, mode_config_.img_w, CV_8UC3, (void *)msg.data()); - return inference(camera_data); + return inference_async(camera_data) ? false : true; + } + + void run() + { + inference_async_par par; + for (;;) { + { + par = async_list_.get(); + if (par.inference_src.empty()) break; + inference(par.inference_src, par.inference_bgr2rgb); + } + } + } + + int inference_async(cv::Mat &src, bool bgr2rgb = true) + { + if (async_list_.size() < 3) { + inference_async_par par; + par.inference_src = src.clone(); + par.inference_bgr2rgb = bgr2rgb; + async_list_.put(par); + } else { + SLOGE("inference list is full\n"); + } + return async_list_.size(); } bool inference(cv::Mat &src, bool bgr2rgb = true) @@ -195,7 +208,7 @@ class llm_task { try { int ret = -1; std::vector image(mode_config_.img_w * mode_config_.img_h * 3, 0); - common::get_input_data_no_letterbox(src, image, mode_config_.img_h, mode_config_.img_w, bgr2rgb); + common::get_input_data_letterbox(src, image, mode_config_.img_h, mode_config_.img_w, bgr2rgb); cv::Mat img_mat(mode_config_.img_h, mode_config_.img_w, CV_8UC3, image.data()); depth_anything_->SetInput((void *)image.data(), 0); if (0 != depth_anything_->Run()) { @@ -206,7 +219,7 @@ class llm_task { depth_anything_->Post_Process(img_mat, mode_config_.model_type, depth_anything_output); if (out_callback_) out_callback_(depth_anything_output, true); } catch (...) { - SLOGW("yolo_->Run have error!"); + SLOGW("depth_anything_->Run have error!"); return true; } return false; @@ -219,12 +232,6 @@ class llm_task { if (0 != ret) { fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret); } - AX_ENGINE_NPU_ATTR_T npu_attr; - memset(&npu_attr, 0, sizeof(npu_attr)); - ret = AX_ENGINE_Init(&npu_attr); - if (0 != ret) { - fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); - } } ax_init_flage_++; } @@ -234,7 +241,6 @@ class llm_task { if (ax_init_flage_ > 0) { --ax_init_flage_; if (!ax_init_flage_) { - AX_ENGINE_Deinit(); AX_SYS_Deinit(); } } @@ -243,15 +249,25 @@ class llm_task { llm_task(const std::string &workid) { _ax_init(); + inference_run_ = std::make_unique(std::bind(&llm_task::run, this)); } void start() { + if (!inference_run_) { + inference_run_ = std::make_unique(std::bind(&llm_task::run, this)); + } } void stop() { + if (inference_run_) { + inference_async_par par; + async_list_.put(par); + inference_run_->join(); + inference_run_.reset(); } + } ~llm_task() { @@ -418,10 +434,10 @@ class llm_depth_anything : public StackFlow { if (!input_url.empty()) { std::weak_ptr _llm_task_obj = llm_task_obj; std::weak_ptr _llm_channel = llm_channel; - llm_channel->subscriber( - input_url, [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { - this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); - }); + llm_channel->subscriber(input_url, [this, _llm_task_obj, _llm_channel]( + pzmq *_pzmq, const std::shared_ptr &raw) { + this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); + }); } } llm_task_[work_id_num] = llm_task_obj; @@ -465,10 +481,11 @@ class llm_depth_anything : public StackFlow { if (!input_url.empty()) { std::weak_ptr _llm_task_obj = llm_task_obj; std::weak_ptr _llm_channel = llm_channel; - llm_channel->subscriber(input_url, - [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { - this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); - }); + llm_channel->subscriber( + input_url, [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { + this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); + }); + ret = 0; } llm_task_obj->inputs_.push_back(data); } diff --git a/projects/llm_framework/main_vlm/src/main.cpp b/projects/llm_framework/main_vlm/src/main.cpp index e7aeb36f..6ee6dd67 100644 --- a/projects/llm_framework/main_vlm/src/main.cpp +++ b/projects/llm_framework/main_vlm/src/main.cpp @@ -621,6 +621,7 @@ class llm_vlm : public StackFlow { input_url, [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); }); + ret = 0; } llm_task_obj->inputs_.push_back(data); } diff --git a/projects/llm_framework/main_yolo/src/EngineWrapper.cpp b/projects/llm_framework/main_yolo/src/EngineWrapper.cpp index 84085db0..99e448d1 100644 --- a/projects/llm_framework/main_yolo/src/EngineWrapper.cpp +++ b/projects/llm_framework/main_yolo/src/EngineWrapper.cpp @@ -179,10 +179,19 @@ static AX_S32 CheckModelVNpu(const std::string& strModel, const AX_ENGINE_MODEL_ } #endif -int EngineWrapper::Init(const char* strModelPath, uint32_t nNpuType) +int EngineWrapper::Init(const char* strModelPath, uint32_t nNpuType, uint32_t npuMode) { AX_S32 ret = 0; + // 0. Init AX_ENGINE + AX_ENGINE_NPU_ATTR_T npu_attr; + memset(&npu_attr, 0, sizeof(npu_attr)); + npu_attr.eHardMode = static_cast(npuMode); + ret = AX_ENGINE_Init(&npu_attr); + if (0 != ret) { + fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); + } + // 1. load model AX_BOOL bLoadModelUseCmm = AX_TRUE; AX_CHAR* pModelBufferVirAddr = nullptr; diff --git a/projects/llm_framework/main_yolo/src/EngineWrapper.hpp b/projects/llm_framework/main_yolo/src/EngineWrapper.hpp index 5d42a07e..b049a6fe 100644 --- a/projects/llm_framework/main_yolo/src/EngineWrapper.hpp +++ b/projects/llm_framework/main_yolo/src/EngineWrapper.hpp @@ -43,7 +43,7 @@ class EngineWrapper { Release(); } - int Init(const char* strModelPath, uint32_t nNpuType = 0); + int Init(const char* strModelPath, uint32_t nNpuType = 0, uint32_t npuMode = 0); int SetInput(void* pInput, int index); diff --git a/projects/llm_framework/main_yolo/src/main.cpp b/projects/llm_framework/main_yolo/src/main.cpp index 208a3757..bf259999 100644 --- a/projects/llm_framework/main_yolo/src/main.cpp +++ b/projects/llm_framework/main_yolo/src/main.cpp @@ -40,6 +40,7 @@ typedef struct { int point_num = 17; float pron_threshold = 0.45f; float nms_threshold = 0.45; + uint32_t npu_type = 0; } yolo_config; typedef struct { @@ -131,9 +132,10 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], cls_num); CONFIG_AUTO_SET(file_body["mode_param"], point_num); CONFIG_AUTO_SET(file_body["mode_param"], model_type); + CONFIG_AUTO_SET(file_body["mode_param"], npu_type); mode_config_.yolo_model = base_model + mode_config_.yolo_model; yolo_ = std::make_unique(); - if (0 != yolo_->Init(mode_config_.yolo_model.c_str())) { + if (0 != yolo_->Init(mode_config_.yolo_model.c_str(), 0, mode_config_.npu_type)) { SLOGE("Init yolo_model model failed!\n"); return -5; } @@ -293,12 +295,6 @@ class llm_task { if (0 != ret) { fprintf(stderr, "AX_SYS_Init failed! ret = 0x%x\n", ret); } - AX_ENGINE_NPU_ATTR_T npu_attr; - memset(&npu_attr, 0, sizeof(npu_attr)); - ret = AX_ENGINE_Init(&npu_attr); - if (0 != ret) { - fprintf(stderr, "Init ax-engine failed{0x%8x}.\n", ret); - } } ax_init_flage_++; } @@ -308,7 +304,6 @@ class llm_task { if (ax_init_flage_ > 0) { --ax_init_flage_; if (!ax_init_flage_) { - AX_ENGINE_Deinit(); AX_SYS_Deinit(); } } @@ -355,7 +350,7 @@ class llm_yolo : public StackFlow { public: llm_yolo() : StackFlow("yolo") { - task_count_ = 1; + task_count_ = 2; } void task_output(const std::weak_ptr llm_task_obj_weak, @@ -503,10 +498,10 @@ class llm_yolo : public StackFlow { if (!input_url.empty()) { std::weak_ptr _llm_task_obj = llm_task_obj; std::weak_ptr _llm_channel = llm_channel; - llm_channel->subscriber( - input_url, [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { - this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); - }); + llm_channel->subscriber(input_url, [this, _llm_task_obj, _llm_channel]( + pzmq *_pzmq, const std::shared_ptr &raw) { + this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); + }); } } llm_task_[work_id_num] = llm_task_obj; @@ -550,10 +545,11 @@ class llm_yolo : public StackFlow { if (!input_url.empty()) { std::weak_ptr _llm_task_obj = llm_task_obj; std::weak_ptr _llm_channel = llm_channel; - llm_channel->subscriber(input_url, - [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { - this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); - }); + llm_channel->subscriber( + input_url, [this, _llm_task_obj, _llm_channel](pzmq *_pzmq, const std::shared_ptr &raw) { + this->task_camera_data(_llm_task_obj, _llm_channel, raw->string()); + }); + ret = 0; } llm_task_obj->inputs_.push_back(data); } From b7e62dd6d61848e9d77f6d374c97f38ff072ed3f Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 25 Jun 2025 17:09:35 +0800 Subject: [PATCH 17/79] [update] update llm-depth-anything version, llm-yolo version. fix libonnxruntime.so Missing --- projects/llm_framework/SConstruct | 2 +- projects/llm_framework/main/SConstruct | 1 + projects/llm_framework/main_depth_anything/SConstruct | 2 +- projects/llm_framework/main_yolo/SConstruct | 2 +- projects/llm_framework/tools/llm_pack.py | 9 +++++++-- 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/projects/llm_framework/SConstruct b/projects/llm_framework/SConstruct index 7282f87e..6c4575b5 100644 --- a/projects/llm_framework/SConstruct +++ b/projects/llm_framework/SConstruct @@ -5,7 +5,7 @@ import shutil os.environ['SDK_PATH'] = os.path.normpath(str(Path(os.getcwd())/'..'/'..'/'SDK')) os.environ['EXT_COMPONENTS_PATH'] = os.path.normpath(str(Path(os.getcwd())/'..'/'..'/'ext_components')) -version = 'v0.0.8' +version = 'v0.0.9' static_lib = 'static_lib' update = False diff --git a/projects/llm_framework/main/SConstruct b/projects/llm_framework/main/SConstruct index 72d44864..0c82732f 100644 --- a/projects/llm_framework/main/SConstruct +++ b/projects/llm_framework/main/SConstruct @@ -27,6 +27,7 @@ STATIC_FILES += [AFile('../static_lib/sherpa/ncnn/libsherpa-ncnn-core.so'), AFile('../static_lib/wetext/libglog.so.0'), AFile('../static_lib/wetext/libfst.so.16'), AFile('../static_lib/libonnxruntime.so.1'), + AFile('../static_lib/libonnxruntime.so.1.14.0') ] env['COMPONENTS'].append({'target':'static_file-1.0', diff --git a/projects/llm_framework/main_depth_anything/SConstruct b/projects/llm_framework/main_depth_anything/SConstruct index a90887f6..66caf78a 100644 --- a/projects/llm_framework/main_depth_anything/SConstruct +++ b/projects/llm_framework/main_depth_anything/SConstruct @@ -30,7 +30,7 @@ STATIC_LIB += static_file * 2 STATIC_FILES += Glob('mode_*.json') -env['COMPONENTS'].append({'target':'llm_depth_anything-1.6', +env['COMPONENTS'].append({'target':'llm_depth_anything-1.7', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/main_yolo/SConstruct b/projects/llm_framework/main_yolo/SConstruct index 99f1e1ef..7c588e53 100644 --- a/projects/llm_framework/main_yolo/SConstruct +++ b/projects/llm_framework/main_yolo/SConstruct @@ -39,7 +39,7 @@ STATIC_FILES += Glob('mode_*.json') # AFile('../static_lib/libbz2.so.1.0')] # DEFINITIONS += ["-DENABLE_BACKWARD"] -env['COMPONENTS'].append({'target':'llm_yolo-1.8', +env['COMPONENTS'].append({'target':'llm_yolo-1.9', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index b27cd597..1e954272 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -362,9 +362,9 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-melotts':[create_bin_deb,'llm-melotts', '1.8', src_folder, revision], 'llm-camera':[create_bin_deb,'llm-camera', '1.8', src_folder, revision, 'lib-llm'], 'llm-vlm':[create_bin_deb,'llm-vlm', '1.7', src_folder, revision], - 'llm-yolo':[create_bin_deb,'llm-yolo', '1.8', src_folder, revision], + 'llm-yolo':[create_bin_deb,'llm-yolo', '1.9', src_folder, revision], 'llm-skel':[create_bin_deb,'llm-skel', version, src_folder, revision], - 'llm-depth-anything':[create_bin_deb,'llm-depth-anything', '1.6', src_folder, revision], + 'llm-depth-anything':[create_bin_deb,'llm-depth-anything', '1.7', src_folder, revision], 'llm-vad':[create_bin_deb,'llm-vad', '1.7', src_folder, revision], 'llm-whisper':[create_bin_deb,'llm-whisper', '1.7', src_folder, revision], 'llm-openai-api':[create_bin_deb,'llm-openai-api', '1.7', src_folder, revision], @@ -385,10 +385,15 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-model-melotts-ja-jp':[create_data_deb,'llm-model-melotts-ja-jp', '0.6', src_folder, revision], 'llm-model-melotts-es-es':[create_data_deb,'llm-model-melotts-es-es', '0.5', src_folder, revision], 'llm-model-yolo11n':[create_data_deb,'llm-model-yolo11n', data_version, src_folder, revision], + 'llm-model-yolo11n-npu1':[create_data_deb,'llm-model-yolo11n-npu1', '0.4', src_folder, revision], 'llm-model-yolo11n-pose':[create_data_deb,'llm-model-yolo11n-pose', '0.3', src_folder, revision], + 'llm-model-yolo11n-pose-npu1':[create_data_deb,'llm-model-yolo11n-pose-npu1', '0.4', src_folder, revision], 'llm-model-yolo11n-hand-pose':[create_data_deb,'llm-model-yolo11n-hand-pose', '0.3', src_folder, revision], + 'llm-model-yolo11n-hand-pose-npu1':[create_data_deb,'llm-model-yolo11n-hand-pose-npu1', '0.4', src_folder, revision], 'llm-model-yolo11n-seg':[create_data_deb,'llm-model-yolo11n-seg', '0.3', src_folder, revision], + 'llm-model-yolo11n-seg-npu1':[create_data_deb,'llm-model-yolo11n-seg-npu1', '0.4', src_folder, revision], 'llm-model-depth-anything-ax630c':[create_data_deb,'llm-model-depth-anything-ax630c', '0.4', src_folder, revision], + 'llm-model-depth-anything-npu1-ax630c':[create_data_deb,'llm-model-depth-anything-npu1-ax630c', '0.4', src_folder, revision], 'llm-model-whisper-tiny':[create_data_deb,'llm-model-whisper-tiny', '0.4', src_folder, revision], 'llm-model-whisper-base':[create_data_deb,'llm-model-whisper-base', '0.4', src_folder, revision], 'llm-model-whisper-small':[create_data_deb,'llm-model-whisper-small', '0.4', src_folder, revision], From 592fd9e16c564cd14b349b3236fbd143fe31c63d Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 25 Jun 2025 18:04:17 +0800 Subject: [PATCH 18/79] [update] update llm-camera version --- projects/llm_framework/main_camera/SConstruct | 2 +- projects/llm_framework/tools/llm_pack.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/llm_framework/main_camera/SConstruct b/projects/llm_framework/main_camera/SConstruct index eb5190c7..92f525e9 100644 --- a/projects/llm_framework/main_camera/SConstruct +++ b/projects/llm_framework/main_camera/SConstruct @@ -68,7 +68,7 @@ STATIC_LIB += static_file * 4 STATIC_FILES += [AFile('camera.json')] STATIC_FILES += Glob('mode_*.json') -env['COMPONENTS'].append({'target':'llm_camera-1.8', +env['COMPONENTS'].append({'target':'llm_camera-1.9', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index 1e954272..599676c9 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -360,7 +360,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-llm':[create_bin_deb,'llm-llm', '1.8', src_folder, revision], 'llm-tts':[create_bin_deb,'llm-tts', '1.6', src_folder, revision], 'llm-melotts':[create_bin_deb,'llm-melotts', '1.8', src_folder, revision], - 'llm-camera':[create_bin_deb,'llm-camera', '1.8', src_folder, revision, 'lib-llm'], + 'llm-camera':[create_bin_deb,'llm-camera', '1.9', src_folder, revision, 'lib-llm'], 'llm-vlm':[create_bin_deb,'llm-vlm', '1.7', src_folder, revision], 'llm-yolo':[create_bin_deb,'llm-yolo', '1.9', src_folder, revision], 'llm-skel':[create_bin_deb,'llm-skel', version, src_folder, revision], From cccddd2273017a8d6e94046dc01abf0bcc5231c2 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Thu, 26 Jun 2025 10:44:43 +0800 Subject: [PATCH 19/79] [update] update llm-vlm version --- projects/llm_framework/main_vlm/SConstruct | 2 +- projects/llm_framework/tools/llm_pack.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/llm_framework/main_vlm/SConstruct b/projects/llm_framework/main_vlm/SConstruct index d1046f23..bba83f8f 100644 --- a/projects/llm_framework/main_vlm/SConstruct +++ b/projects/llm_framework/main_vlm/SConstruct @@ -73,7 +73,7 @@ ignore['ignore'] = list(set(ignore['ignore'])) with open('../dist/fileignore', 'w') as f: json.dump(ignore, f, indent=4) -env['COMPONENTS'].append({'target':'llm_vlm-1.7', +env['COMPONENTS'].append({'target':'llm_vlm-1.8', 'SRCS':SRCS, 'INCLUDE':INCLUDE, 'PRIVATE_INCLUDE':PRIVATE_INCLUDE, diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index 599676c9..1d2a89f0 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -361,7 +361,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-tts':[create_bin_deb,'llm-tts', '1.6', src_folder, revision], 'llm-melotts':[create_bin_deb,'llm-melotts', '1.8', src_folder, revision], 'llm-camera':[create_bin_deb,'llm-camera', '1.9', src_folder, revision, 'lib-llm'], - 'llm-vlm':[create_bin_deb,'llm-vlm', '1.7', src_folder, revision], + 'llm-vlm':[create_bin_deb,'llm-vlm', '1.8', src_folder, revision], 'llm-yolo':[create_bin_deb,'llm-yolo', '1.9', src_folder, revision], 'llm-skel':[create_bin_deb,'llm-skel', version, src_folder, revision], 'llm-depth-anything':[create_bin_deb,'llm-depth-anything', '1.7', src_folder, revision], From b0743f068a6b86dff84e3ac685b679c92870a99f Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 27 Jun 2025 15:19:53 +0800 Subject: [PATCH 20/79] [update] update model list & add npu1 model. --- README_zh.md | 5 + .../mode_depth-anything-npu1-ax630c.json | 25 ++++ projects/llm_framework/main_yolo/SConstruct | 2 +- .../models/mode_yolo11n-hand-pose-npu1.json | 33 ++++++ .../{ => models}/mode_yolo11n-hand-pose.json | 0 .../main_yolo/models/mode_yolo11n-npu1.json | 111 ++++++++++++++++++ .../models/mode_yolo11n-pose-npu1.json | 33 ++++++ .../{ => models}/mode_yolo11n-pose.json | 0 .../models/mode_yolo11n-seg-npu1.json | 111 ++++++++++++++++++ .../{ => models}/mode_yolo11n-seg.json | 0 .../main_yolo/{ => models}/mode_yolo11n.json | 0 11 files changed, 319 insertions(+), 1 deletion(-) create mode 100644 projects/llm_framework/main_depth_anything/mode_depth-anything-npu1-ax630c.json create mode 100644 projects/llm_framework/main_yolo/models/mode_yolo11n-hand-pose-npu1.json rename projects/llm_framework/main_yolo/{ => models}/mode_yolo11n-hand-pose.json (100%) create mode 100644 projects/llm_framework/main_yolo/models/mode_yolo11n-npu1.json create mode 100644 projects/llm_framework/main_yolo/models/mode_yolo11n-pose-npu1.json rename projects/llm_framework/main_yolo/{ => models}/mode_yolo11n-pose.json (100%) create mode 100644 projects/llm_framework/main_yolo/models/mode_yolo11n-seg-npu1.json rename projects/llm_framework/main_yolo/{ => models}/mode_yolo11n-seg.json (100%) rename projects/llm_framework/main_yolo/{ => models}/mode_yolo11n.json (100%) diff --git a/README_zh.md b/README_zh.md index e80293e6..a2f7e4a6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -92,10 +92,15 @@ StackFlow 语音助手的主要工作模式: | [smolvlm-256M-ax630c](https://huggingface.co/HuggingFaceTB/SmolVLM-256M-Instruct) | VLM | 330M | 多模态文本生成 | [mode_smolvlm-256M-ax630c.json](projects/llm_framework/main_vlm/models/mode_smolvlm-256M-ax630c.json) | NPU | | [smolvlm-500M-ax630c](https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct) | VLM | 605M | 多模态文本生成 | [mode_smolvlm-256M-ax630c.json](projects/llm_framework/main_vlm/models/mode_smolvlm-500M-ax630c.json) | NPU | | [yolo11n](https://github.com/ultralytics/ultralytics) | CV | 2.8M | 目标检测 | [mode_yolo11n.json](projects/llm_framework/main_yolo/mode_yolo11n.json) | NPU | +| [yolo11n-npu1](https://github.com/ultralytics/ultralytics) | CV | 2.8M | 目标检测 | [mode_yolo11n-npu1.json](projects/llm_framework/main_yolo/mode_yolo11n-npu1.json) | NPU | | [yolo11n-seg](https://github.com/ultralytics/ultralytics) | CV | 3.0M | 实例分割 | [mode_yolo11n-seg.json](projects/llm_framework/main_yolo/mode_yolo11n-seg.json) | NPU | +| [yolo11n-seg-npu1](https://github.com/ultralytics/ultralytics) | CV | 3.0M | 实例分割 | [mode_yolo11n-seg-npu1.json](projects/llm_framework/main_yolo/mode_yolo11n-seg-npu1.json) | NPU | | [yolo11n-pose](https://github.com/ultralytics/ultralytics) | CV | 3.1M | 姿态检测 | [mode_yolo11n-pose.json](projects/llm_framework/main_yolo/mode_yolo11n-pose.json) | NPU | +| [yolo11n-pose-npu1](https://github.com/ultralytics/ultralytics) | CV | 3.1M | 姿态检测 | [mode_yolo11n-pose-npu1.json](projects/llm_framework/main_yolo/mode_yolo11n-pose-npu1.json) | NPU | | [yolo11n-hand-pose](https://github.com/ultralytics/ultralytics) | CV | 3.2M | 姿态检测 | [mode_yolo11n-hand-pose.json](projects/llm_framework/main_yolo/mode_yolo11n-hand-pose.json) | NPU | +| [yolo11n-hand-pose-npu1](https://github.com/ultralytics/ultralytics) | CV | 3.2M | 姿态检测 | [mode_yolo11n-hand-pose-npu1.json](projects/llm_framework/main_yolo/mode_yolo11n-hand-pose-npu1.json) | NPU | | [depth-anything-ax630c](https://github.com/DepthAnything/Depth-Anything-V2) | CV | 29M | 单目深度估计 | [mode_depth-anything-ax630c.json](projects/llm_framework/main_depth_anything/mode_depth-anything-ax630c.json) | NPU | +| [depth-anything-npu1-ax630c](https://github.com/DepthAnything/Depth-Anything-V2) | CV | 29M | 单目深度估计 | [mode_depth-anything-npu1-ax630c.json](projects/llm_framework/main_depth_anything/mode_depth-anything-npu1-ax630c.json) | NPU | ## 环境要求 ## 当前 StackFlow 的 AI 单元是建立在 AXERA 加速平台之上的,主要的芯片平台为 ax630c、ax650n。系统要求为 ubuntu。 diff --git a/projects/llm_framework/main_depth_anything/mode_depth-anything-npu1-ax630c.json b/projects/llm_framework/main_depth_anything/mode_depth-anything-npu1-ax630c.json new file mode 100644 index 00000000..079ff56d --- /dev/null +++ b/projects/llm_framework/main_depth_anything/mode_depth-anything-npu1-ax630c.json @@ -0,0 +1,25 @@ +{ + "mode":"depth-anything-npu1-ax630c", + "type":"cv", + "homepage":"https://github.com/DepthAnything/Depth-Anything-V2", + "compile_flage":"pulsar2 build --input cv/model/depth_anything_vits14-sim-384x256.onnx --config cv/config/depth_anything_u8.json --output_dir cv/axmodel/depth-anything-u8-npu1 --output_name depth_anything.axmodel --target_hardware AX620E --compiler.check 0 --npu_mode NPU1", + "pulsar_version":"3.3-f0b32d03", + "capabilities":[ + "Segmentation" + ], + "input_type":[ + "cv.jpeg.base64" + ], + "output_type":[ + "cv.jpeg.base64" + ], + "mode_param":{ + "depth_anything_model":"depth-anything-npu1-ax630c.axmodel", + "model_type":"segment", + "img_h":256, + "img_w":384, + "npu_type":1 + }, + "mode_param_bak":{ + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_yolo/SConstruct b/projects/llm_framework/main_yolo/SConstruct index 7c588e53..54502a54 100644 --- a/projects/llm_framework/main_yolo/SConstruct +++ b/projects/llm_framework/main_yolo/SConstruct @@ -27,7 +27,7 @@ static_file = Glob('../static_lib/module-llm/libabsl_*') static_file = Glob('../static_lib/libopencv-4.6-aarch64-none/lib/lib*') STATIC_LIB += static_file * 2 -STATIC_FILES += Glob('mode_*.json') +STATIC_FILES += Glob('models/mode_*.json') diff --git a/projects/llm_framework/main_yolo/models/mode_yolo11n-hand-pose-npu1.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-hand-pose-npu1.json new file mode 100644 index 00000000..877e8f33 --- /dev/null +++ b/projects/llm_framework/main_yolo/models/mode_yolo11n-hand-pose-npu1.json @@ -0,0 +1,33 @@ +{ + "mode":"yolo11n-hand-pose-npu1", + "type":"cv", + "homepage":"https://github.com/ultralytics/ultralytics", + "compile_flage":"pulsar2 build --input cv/model/yolo11n-hand-pose-cut.onnx --config cv/config/yolo11_hand_pose_config.json --output_dir cv/axmodel/yolo11n-hand-pose-npu1 --output_name yolo11n-hand-pose-npu1.axmodel --target_hardware AX620E --compiler.check 0 --npu_mode NPU1", + "pulsar_version":"3.3-f0b32d03", + "capabilities":[ + "Pose" + ], + "input_type":[ + "yolo.jpeg.base64" + ], + "output_type":[ + "yolo.box", + "yolo.boxV2" + ], + "mode_param":{ + "yolo_model":"yolo11n-hand-pose-npu1.axmodel", + "model_type":"pose", + "img_h":320, + "img_w":320, + "cls_num":1, + "point_num":21, + "pron_threshold":0.45, + "nms_threshold":0.45, + "npu_type":1, + "cls_name":[ + "hand" + ] + }, + "mode_param_bak":{ + } +} diff --git a/projects/llm_framework/main_yolo/mode_yolo11n-hand-pose.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-hand-pose.json similarity index 100% rename from projects/llm_framework/main_yolo/mode_yolo11n-hand-pose.json rename to projects/llm_framework/main_yolo/models/mode_yolo11n-hand-pose.json diff --git a/projects/llm_framework/main_yolo/models/mode_yolo11n-npu1.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-npu1.json new file mode 100644 index 00000000..76fee735 --- /dev/null +++ b/projects/llm_framework/main_yolo/models/mode_yolo11n-npu1.json @@ -0,0 +1,111 @@ +{ + "mode":"yolo11n-npu1", + "type":"cv", + "homepage":"https://github.com/ultralytics/ultralytics", + "compile_flage":"pulsar2 build --input cv/model/yolo11n.onnx --config cv/config/yolo11_config.json --output_dir cv/axmodel/yolo11n-npu1 --output_name yolo11n-npu1.axmodel --target_hardware AX620E --compiler.check 0 --npu_mode NPU1", + "pulsar_version":"3.3-f0b32d03", + "capabilities":[ + "Detection" + ], + "input_type":[ + "yolo.jpeg.base64" + ], + "output_type":[ + "yolo.box", + "yolo.boxV2" + ], + "mode_param":{ + "yolo_model":"yolo11n-npu1.axmodel", + "model_type":"detect", + "img_h":320, + "img_w":320, + "cls_num":80, + "pron_threshold":0.45, + "nms_threshold":0.45, + "npu_type":1, + "cls_name":[ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush" + ] + }, + "mode_param_bak":{ + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_yolo/models/mode_yolo11n-pose-npu1.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-pose-npu1.json new file mode 100644 index 00000000..8a7cbeaf --- /dev/null +++ b/projects/llm_framework/main_yolo/models/mode_yolo11n-pose-npu1.json @@ -0,0 +1,33 @@ +{ + "mode":"yolo11n-pose-npu1", + "type":"cv", + "homepage":"https://github.com/ultralytics/ultralytics", + "compile_flage":"pulsar2 build --input cv/model/yolo11n-pose-cut.onnx --config cv/config/yolo11_pose_config.json --output_dir cv/axmodel/yolo11n-pose-npu1 --output_name yolo11n-pose-npu1.axmodel --target_hardware AX620E --compiler.check 0 --npu_mode NPU1", + "pulsar_version":"3.3-f0b32d03", + "capabilities":[ + "Pose" + ], + "input_type":[ + "yolo.jpeg.base64" + ], + "output_type":[ + "yolo.box", + "yolo.boxV2" + ], + "mode_param":{ + "yolo_model":"yolo11n-pose-npu1.axmodel", + "model_type":"pose", + "img_h":320, + "img_w":320, + "cls_num":1, + "point_num":17, + "pron_threshold":0.45, + "nms_threshold":0.45, + "npu_type":1, + "cls_name":[ + "person" + ] + }, + "mode_param_bak":{ + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_yolo/mode_yolo11n-pose.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-pose.json similarity index 100% rename from projects/llm_framework/main_yolo/mode_yolo11n-pose.json rename to projects/llm_framework/main_yolo/models/mode_yolo11n-pose.json diff --git a/projects/llm_framework/main_yolo/models/mode_yolo11n-seg-npu1.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-seg-npu1.json new file mode 100644 index 00000000..a0b8df2d --- /dev/null +++ b/projects/llm_framework/main_yolo/models/mode_yolo11n-seg-npu1.json @@ -0,0 +1,111 @@ +{ + "mode":"yolo11n-seg-npu1", + "type":"cv", + "homepage":"https://github.com/ultralytics/ultralytics", + "compile_flage":"pulsar2 build --input cv/model/yolo11n-seg-cut.onnx --config cv/config/yolo11_seg_config.json --output_dir cv/axmodel/yolo11n-seg-npu1 --output_name yolo11n-seg-npu1.axmodel --target_hardware AX620E --compiler.check 0 --npu_mode NPU1", + "pulsar_version":"3.3-f0b32d03", + "capabilities":[ + "Segmentation" + ], + "input_type":[ + "yolo.jpeg.base64" + ], + "output_type":[ + "yolo.box", + "yolo.boxV2" + ], + "mode_param":{ + "yolo_model":"yolo11n-seg-npu1.axmodel", + "model_type":"segment", + "img_h":320, + "img_w":320, + "cls_num":80, + "pron_threshold":0.45, + "nms_threshold":0.45, + "npu_type":1, + "cls_name":[ + "person", + "bicycle", + "car", + "motorcycle", + "airplane", + "bus", + "train", + "truck", + "boat", + "traffic light", + "fire hydrant", + "stop sign", + "parking meter", + "bench", + "bird", + "cat", + "dog", + "horse", + "sheep", + "cow", + "elephant", + "bear", + "zebra", + "giraffe", + "backpack", + "umbrella", + "handbag", + "tie", + "suitcase", + "frisbee", + "skis", + "snowboard", + "sports ball", + "kite", + "baseball bat", + "baseball glove", + "skateboard", + "surfboard", + "tennis racket", + "bottle", + "wine glass", + "cup", + "fork", + "knife", + "spoon", + "bowl", + "banana", + "apple", + "sandwich", + "orange", + "broccoli", + "carrot", + "hot dog", + "pizza", + "donut", + "cake", + "chair", + "couch", + "potted plant", + "bed", + "dining table", + "toilet", + "tv", + "laptop", + "mouse", + "remote", + "keyboard", + "cell phone", + "microwave", + "oven", + "toaster", + "sink", + "refrigerator", + "book", + "clock", + "vase", + "scissors", + "teddy bear", + "hair drier", + "toothbrush" + ] + }, + "mode_param_bak":{ + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_yolo/mode_yolo11n-seg.json b/projects/llm_framework/main_yolo/models/mode_yolo11n-seg.json similarity index 100% rename from projects/llm_framework/main_yolo/mode_yolo11n-seg.json rename to projects/llm_framework/main_yolo/models/mode_yolo11n-seg.json diff --git a/projects/llm_framework/main_yolo/mode_yolo11n.json b/projects/llm_framework/main_yolo/models/mode_yolo11n.json similarity index 100% rename from projects/llm_framework/main_yolo/mode_yolo11n.json rename to projects/llm_framework/main_yolo/models/mode_yolo11n.json From 629e822e0aad37af3c130b042959f8cb6dd4512d Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 27 Jun 2025 18:25:09 +0800 Subject: [PATCH 21/79] [update] update docs --- .../llm_camera_en.md | 1 + .../llm_camera_zh.md | 1 + doc/projects_llm_framework_doc/llm_vlm_en.md | 40 +++++++++++++++++-- doc/projects_llm_framework_doc/llm_vlm_zh.md | 40 +++++++++++++++++-- 4 files changed, 74 insertions(+), 8 deletions(-) diff --git a/doc/projects_llm_framework_doc/llm_camera_en.md b/doc/projects_llm_framework_doc/llm_camera_en.md index 25a0f05c..ddac27e0 100644 --- a/doc/projects_llm_framework_doc/llm_camera_en.md +++ b/doc/projects_llm_framework_doc/llm_camera_en.md @@ -37,6 +37,7 @@ Send JSON: - enoutput: Whether to enable user result output. If you do not need to obtain camera images, do not enable this parameter, as the video stream will increase the communication pressure on the channel. - enable_webstream: Whether to enable webstream output, webstream will listen on tcp:8989 port, and once a client connection is received, it will push jpeg images in HTTP protocol multipart/x-mixed-replace type. - rtsp: Whether to enable rtsp stream output, rtsp will establish an RTSP TCP server at rtsp://{DevIp}:8554/axstream0, and you can pull the video stream from this port using the RTSP protocol. The video stream format is 1280x720 H265. Note that this video stream is only valid on the AX630C MIPI camera, and the UVC camera cannot use RTSP. +- VinParam.bAiispEnable: Whether to enable AI-ISP, enabled by default. Set to 0 to disable, only valid when using AX630C MIPI camera. Response JSON: diff --git a/doc/projects_llm_framework_doc/llm_camera_zh.md b/doc/projects_llm_framework_doc/llm_camera_zh.md index 5675df39..4bb3cb6c 100644 --- a/doc/projects_llm_framework_doc/llm_camera_zh.md +++ b/doc/projects_llm_framework_doc/llm_camera_zh.md @@ -37,6 +37,7 @@ - enoutput:是否起用用户结果输出。如果不需要获取摄像头图片,请不要开启该参数,视频流会增加信道的通信压力。 - enable_webstream:是否启用 webstream 流输出,webstream 会监听 tcp:8989 端口,一但收到客户端连接,将会以 HTTP 协议 multipart/x-mixed-replace 类型推送 jpeg 图片。 - rtsp:是否启用 rtsp 流输出,rtsp 会建立一个 rtsp://{DevIp}:8554/axstream0 RTSP TCP 服务端,可使用RTSP 协议向该端口拉取视频流。视频流的格式为 1280x720 H265。注意,该视频流只在 AX630C MIPI 摄像头上有效,UVC 摄像头无法使用 RTSP。 +- VinParam.bAiispEnable:是否开启 AI-ISP,默认开启。关闭为 0,仅在使用 AX630C MIPI 摄像头时有效。 响应 json: diff --git a/doc/projects_llm_framework_doc/llm_vlm_en.md b/doc/projects_llm_framework_doc/llm_vlm_en.md index 76196696..bbaa0d9c 100644 --- a/doc/projects_llm_framework_doc/llm_vlm_en.md +++ b/doc/projects_llm_framework_doc/llm_vlm_en.md @@ -15,7 +15,7 @@ Send the following JSON: "action": "setup", "object": "vlm.setup", "data": { - "model": "internvl2.5-1B-ax630c", + "model": "internvl2.5-1B-364-ax630c", "response_format": "vlm.utf-8.stream", "input": "vlm.utf-8", "enoutput": true, @@ -29,7 +29,7 @@ Send the following JSON: - work_id: Set to `vlm` when configuring the unit. - action: The method being called is `setup`. - object: Data type being transferred is `vlm.setup`. -- model: The model used is `internvl2.5-1B-ax630c`, a multimodal model. +- model: The model used is `internvl2.5-1B-364-ax630c`, a multimodal model. - response_format: The output is in `vlm.utf-8.stream`, a UTF-8 stream format. - input: The input is `vlm.utf-8`, representing user input. - enoutput: Specifies whether to enable user output. @@ -250,7 +250,7 @@ Example: "action": "setup", "object": "vlm.setup", "data": { - "model": "internvl2.5-1B-ax630c", + "model": "internvl2.5-1B-364-ax630c", "response_format": "vlm.utf-8.stream", "input": [ "vlm.utf-8", @@ -264,6 +264,38 @@ Example: } ``` +Linking the Output of the llm-camera Unit. + +Sending JSON: + +```json +{ + "request_id": "3", + "work_id": "vlm.1003", + "action": "link", + "object": "work_id", + "data": "camera.1000" +} +``` + +Response JSON: + +```json +{ + "created": 1750992545, + "data": "None", + "error": { + "code": 0, + "message": "" + }, + "object": "None", + "request_id": "3", + "work_id": "vlm.1003" +} +``` + +> **Ensure that the camera is properly configured and ready for operation when performing the link action. If using the AX630C MIPI camera, configure it in AI-ISP disabled mode during the initialization of llm-camera.** + ## unlink Unlink units. @@ -447,7 +479,7 @@ Response JSON: "vlm.utf-8", "kws.1000" ], - "model": "internvl2.5-1B-ax630c", + "model": "internvl2.5-1B-364-ax630c", "response_format": "vlm.utf-8.stream" }, "error": { diff --git a/doc/projects_llm_framework_doc/llm_vlm_zh.md b/doc/projects_llm_framework_doc/llm_vlm_zh.md index 161f48c1..ce797e5e 100644 --- a/doc/projects_llm_framework_doc/llm_vlm_zh.md +++ b/doc/projects_llm_framework_doc/llm_vlm_zh.md @@ -15,7 +15,7 @@ "action": "setup", "object": "vlm.setup", "data": { - "model": "internvl2.5-1B-ax630c", + "model": "internvl2.5-1B-364-ax630c", "response_format": "vlm.utf-8.stream", "input": "vlm.utf-8", "enoutput": true, @@ -29,7 +29,7 @@ - work_id:配置单元时,为 `vlm`。 - action:调用的方法为 `setup`。 - object:传输的数据类型为 `vlm.setup`。 -- model:使用的模型为 `internvl2.5-1B-ax630c` 多模态模型。 +- model:使用的模型为 `internvl2.5-1B-364-ax630c` 多模态模型。 - response_format:返回结果为 `vlm.utf-8.stream`, utf-8 的流式输出。 - input:输入的为 `vlm.utf-8`,代表的是从用户输入。 - enoutput:是否起用用户结果输出。 @@ -248,7 +248,7 @@ error::code 为 0 表示执行成功。 "action": "setup", "object": "vlm.setup", "data": { - "model": "internvl2.5-1B-ax630c", + "model": "internvl2.5-1B-364-ax630c", "response_format": "vlm.utf-8.stream", "input": [ "vlm.utf-8", @@ -262,6 +262,38 @@ error::code 为 0 表示执行成功。 } ``` +链接 llm-camera 单元的输出。 + +发送 json: + +```json +{ + "request_id": "3", + "work_id": "vlm.1003", + "action": "link", + "object": "work_id", + "data": "camera.1000" +} +``` + +响应 json: + +```json +{ + "created": 1750992545, + "data": "None", + "error": { + "code": 0, + "message": "" + }, + "object": "None", + "request_id": "3", + "work_id": "vlm.1003" +} +``` + +> **link 时必须保证 camera 此时已经配置好进入工作状态。当同时使用 AX630C MIPI 摄像头,需要在 llm-camera 初始化的时候配置为 AI-ISP 关闭模式。** + ## unlink 取消链接。 @@ -445,7 +477,7 @@ error::code 为 0 表示执行成功。 "vlm.utf-8", "kws.1000" ], - "model": "internvl2.5-1B-ax630c", + "model": "internvl2.5-1B-364-ax630c", "response_format": "vlm.utf-8.stream" }, "error": { From d29e074afcac048fbdeb7c34b300c27fe6ca2a12 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 27 Jun 2025 18:30:33 +0800 Subject: [PATCH 22/79] [update] update ax650 model config, melotts model. --- .../mode_deepseek-r1-1.5B-Int4-ax650.json | 35 +++ .../mode_deepseek-r1-7B-Int4-ax650.json | 35 +++ .../models/mode_qwen2.5-0.5B-Int4-ax650.json | 35 +++ .../models/mode_qwen2.5-1.5B-Int4-ax650.json | 35 +++ .../models/mode_qwen2.5-3B-Int4-ax650.json | 35 +++ .../models/mode_qwen2.5-7B-Int4-ax650.json | 35 +++ .../tokenizer_deepseek-r1-1.5B-Int4-ax650.py | 131 +++++++++ .../tokenizer_deepseek-r1-7B-Int4-ax650.py | 131 +++++++++ .../tokenizer_qwen2.5-0.5B-Int4-ax650.py | 131 +++++++++ .../tokenizer_qwen2.5-1.5B-Int4-ax650.py | 131 +++++++++ .../tokenizer_qwen2.5-3B-Int4-ax650.py | 131 +++++++++ .../tokenizer_qwen2.5-7B-Int4-ax650.py | 131 +++++++++ .../llm_framework/main_melotts/SConstruct | 2 +- .../{ => models}/mode_melotts-en-au.json | 0 .../{ => models}/mode_melotts-en-br.json | 0 .../{ => models}/mode_melotts-en-default.json | 0 .../{ => models}/mode_melotts-en-india.json | 0 .../{ => models}/mode_melotts-en-us.json | 0 .../{ => models}/mode_melotts-es-es.json | 0 .../{ => models}/mode_melotts-ja-jp.json | 0 .../{ => models}/mode_melotts-zh-cn.json | 0 .../models/mode_internvl2.5-1B-448-ax650.json | 35 +++ .../models/mode_smolvlm-256M-ax650.json | 35 +++ .../models/mode_smolvlm-500M-ax650.json | 37 +++ .../tokenizer_internvl2.5-1B-448-ax650.py | 138 ++++++++++ .../scripts/tokenizer_smolvlm-256M-ax650.py | 248 ++++++++++++++++++ .../llm_framework/main_whisper/SConstruct | 2 +- .../models/mode_whisper-base-ax650.json | 42 +++ .../{ => models}/mode_whisper-base.json | 0 .../models/mode_whisper-small-ax650.json | 42 +++ .../{ => models}/mode_whisper-small.json | 0 .../models/mode_whisper-tiny-ax650.json | 42 +++ .../{ => models}/mode_whisper-tiny.json | 0 33 files changed, 1617 insertions(+), 2 deletions(-) create mode 100644 projects/llm_framework/main_llm/models/mode_deepseek-r1-1.5B-Int4-ax650.json create mode 100644 projects/llm_framework/main_llm/models/mode_deepseek-r1-7B-Int4-ax650.json create mode 100644 projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-Int4-ax650.json create mode 100644 projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-Int4-ax650.json create mode 100644 projects/llm_framework/main_llm/models/mode_qwen2.5-3B-Int4-ax650.json create mode 100644 projects/llm_framework/main_llm/models/mode_qwen2.5-7B-Int4-ax650.json create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-1.5B-Int4-ax650.py create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-7B-Int4-ax650.py create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-0.5B-Int4-ax650.py create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-1.5B-Int4-ax650.py create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-3B-Int4-ax650.py create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-7B-Int4-ax650.py rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-en-au.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-en-br.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-en-default.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-en-india.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-en-us.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-es-es.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-ja-jp.json (100%) rename projects/llm_framework/main_melotts/{ => models}/mode_melotts-zh-cn.json (100%) create mode 100644 projects/llm_framework/main_vlm/models/mode_internvl2.5-1B-448-ax650.json create mode 100644 projects/llm_framework/main_vlm/models/mode_smolvlm-256M-ax650.json create mode 100644 projects/llm_framework/main_vlm/models/mode_smolvlm-500M-ax650.json create mode 100644 projects/llm_framework/main_vlm/scripts/tokenizer_internvl2.5-1B-448-ax650.py create mode 100644 projects/llm_framework/main_vlm/scripts/tokenizer_smolvlm-256M-ax650.py create mode 100644 projects/llm_framework/main_whisper/models/mode_whisper-base-ax650.json rename projects/llm_framework/main_whisper/{ => models}/mode_whisper-base.json (100%) create mode 100644 projects/llm_framework/main_whisper/models/mode_whisper-small-ax650.json rename projects/llm_framework/main_whisper/{ => models}/mode_whisper-small.json (100%) create mode 100644 projects/llm_framework/main_whisper/models/mode_whisper-tiny-ax650.json rename projects/llm_framework/main_whisper/{ => models}/mode_whisper-tiny.json (100%) diff --git a/projects/llm_framework/main_llm/models/mode_deepseek-r1-1.5B-Int4-ax650.json b/projects/llm_framework/main_llm/models/mode_deepseek-r1-1.5B-Int4-ax650.json new file mode 100644 index 00000000..8eeeb580 --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_deepseek-r1-1.5B-Int4-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"deepseek-r1-1.5B-Int4-ax650", + "type":"llm", + "homepage":"https://huggingface.co/AXERA-TECH/DeepSeek-R1-Distill-Qwen-1.5B-GPTQ-Int4", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":28, + "tokens_embed_num":151936, + "tokens_embed_size":1536, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_deepseek-r1-1.5B-Int4-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/models/mode_deepseek-r1-7B-Int4-ax650.json b/projects/llm_framework/main_llm/models/mode_deepseek-r1-7B-Int4-ax650.json new file mode 100644 index 00000000..757683c0 --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_deepseek-r1-7B-Int4-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"deepseek-r1-7B-Int4-ax650", + "type":"llm", + "homepage":"https://huggingface.co/AXERA-TECH/DeepSeek-R1-Distill-Qwen-7B-GPTQ-Int4", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":28, + "tokens_embed_num":152064, + "tokens_embed_size":3584, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_deepseek-r1-7B-Int4-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-Int4-ax650.json b/projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-Int4-ax650.json new file mode 100644 index 00000000..528ae793 --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_qwen2.5-0.5B-Int4-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"qwen2.5-0.5B-Int4-ax650", + "type":"llm", + "homepage":"https://huggingface.co/AXERA-TECH/Qwen2.5-0.5B-Instruct-GPTQ-Int4", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":24, + "tokens_embed_num":151936, + "tokens_embed_size":896, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_qwen2.5-0.5B-Int4-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-Int4-ax650.json b/projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-Int4-ax650.json new file mode 100644 index 00000000..6e50ca2c --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_qwen2.5-1.5B-Int4-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"qwen2.5-1.5B-Int4-ax650", + "type":"llm", + "homepage":"https://huggingface.co/AXERA-TECH/Qwen2.5-1.5B-Instruct-GPTQ-Int4", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":28, + "tokens_embed_num":151936, + "tokens_embed_size":1536, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_qwen2.5-1.5B-Int4-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/models/mode_qwen2.5-3B-Int4-ax650.json b/projects/llm_framework/main_llm/models/mode_qwen2.5-3B-Int4-ax650.json new file mode 100644 index 00000000..e11b7d82 --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_qwen2.5-3B-Int4-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"qwen2.5-3B-Int4-ax650", + "type":"llm", + "homepage":"https://huggingface.co/AXERA-TECH/Qwen2.5-3B-Instruct-GPTQ-Int4", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":36, + "tokens_embed_num":151936, + "tokens_embed_size":2048, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_qwen2.5-3B-Int4-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/models/mode_qwen2.5-7B-Int4-ax650.json b/projects/llm_framework/main_llm/models/mode_qwen2.5-7B-Int4-ax650.json new file mode 100644 index 00000000..8da02c1b --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_qwen2.5-7B-Int4-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"qwen2.5-7B-Int4-ax650", + "type":"llm", + "homepage":"https://huggingface.co/AXERA-TECH/Qwen2.5-7B-Instruct-GPTQ-Int4", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":28, + "tokens_embed_num":152064, + "tokens_embed_size":3584, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_qwen2.5-7B-Int4-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-1.5B-Int4-ax650.py b/projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-1.5B-Int4-ax650.py new file mode 100644 index 00000000..3af58240 --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-1.5B-Int4-ax650.py @@ -0,0 +1,131 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +class Tokenizer_Http(): + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def encode(self, prompt, content): + messages = [ + {"role": "system", "content": content}, + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(text) + token_ids = self.tokenizer.encode(text) + return token_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + #通过类继承,新定义类 + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print(self.path) + #在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/bos_id': + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({'bos_id': -1}) + else: + msg = json.dumps({'bos_id': bos_id}) + elif self.path == '/eos_id': + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({'eos_id': -1}) + else: + msg = json.dumps({'eos_id': eos_id}) + else: + msg = 'error' + + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + def do_POST(self): + #在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read(int( + self.headers['content-length'])) #获取从客户端传入的参数(byte格式) + data = data.decode() #将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/encode': + req = json.loads(data) + prompt = req['text'] + + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({'token_ids': -1}) + else: + msg = json.dumps({'token_ids': token_ids}) + + elif self.path == '/decode': + req = json.loads(data) + token_ids = req['token_ids'] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({'text': ""}) + else: + msg = json.dumps({'text': text}) + else: + msg = 'error' + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument('--host', type=str, default='localhost') + args.add_argument('--port', type=int, default=8080) + args.add_argument('--model_id', type=str, default='deepseek_tokenizer') + args.add_argument('--content', type=str, default='You are a helpful assistant.') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' + print('http://%s:%s' % host) + server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() #开启服务 diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-7B-Int4-ax650.py b/projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-7B-Int4-ax650.py new file mode 100644 index 00000000..3af58240 --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_deepseek-r1-7B-Int4-ax650.py @@ -0,0 +1,131 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +class Tokenizer_Http(): + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def encode(self, prompt, content): + messages = [ + {"role": "system", "content": content}, + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(text) + token_ids = self.tokenizer.encode(text) + return token_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + #通过类继承,新定义类 + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print(self.path) + #在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/bos_id': + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({'bos_id': -1}) + else: + msg = json.dumps({'bos_id': bos_id}) + elif self.path == '/eos_id': + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({'eos_id': -1}) + else: + msg = json.dumps({'eos_id': eos_id}) + else: + msg = 'error' + + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + def do_POST(self): + #在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read(int( + self.headers['content-length'])) #获取从客户端传入的参数(byte格式) + data = data.decode() #将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/encode': + req = json.loads(data) + prompt = req['text'] + + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({'token_ids': -1}) + else: + msg = json.dumps({'token_ids': token_ids}) + + elif self.path == '/decode': + req = json.loads(data) + token_ids = req['token_ids'] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({'text': ""}) + else: + msg = json.dumps({'text': text}) + else: + msg = 'error' + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument('--host', type=str, default='localhost') + args.add_argument('--port', type=int, default=8080) + args.add_argument('--model_id', type=str, default='deepseek_tokenizer') + args.add_argument('--content', type=str, default='You are a helpful assistant.') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' + print('http://%s:%s' % host) + server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() #开启服务 diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-0.5B-Int4-ax650.py b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-0.5B-Int4-ax650.py new file mode 100644 index 00000000..4fded69c --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-0.5B-Int4-ax650.py @@ -0,0 +1,131 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +class Tokenizer_Http(): + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def encode(self, prompt, content): + messages = [ + {"role": "system", "content": content}, + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(text) + token_ids = self.tokenizer.encode(text) + return token_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + #通过类继承,新定义类 + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print(self.path) + #在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/bos_id': + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({'bos_id': -1}) + else: + msg = json.dumps({'bos_id': bos_id}) + elif self.path == '/eos_id': + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({'eos_id': -1}) + else: + msg = json.dumps({'eos_id': eos_id}) + else: + msg = 'error' + + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + def do_POST(self): + #在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read(int( + self.headers['content-length'])) #获取从客户端传入的参数(byte格式) + data = data.decode() #将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/encode': + req = json.loads(data) + prompt = req['text'] + + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({'token_ids': -1}) + else: + msg = json.dumps({'token_ids': token_ids}) + + elif self.path == '/decode': + req = json.loads(data) + token_ids = req['token_ids'] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({'text': ""}) + else: + msg = json.dumps({'text': text}) + else: + msg = 'error' + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument('--host', type=str, default='localhost') + args.add_argument('--port', type=int, default=8080) + args.add_argument('--model_id', type=str, default='qwen2.5_coder_tokenizer') + args.add_argument('--content', type=str, default='You are Qwen, created by Alibaba Cloud. You are a helpful assistant.') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' + print('http://%s:%s' % host) + server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() #开启服务 diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-1.5B-Int4-ax650.py b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-1.5B-Int4-ax650.py new file mode 100644 index 00000000..4fded69c --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-1.5B-Int4-ax650.py @@ -0,0 +1,131 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +class Tokenizer_Http(): + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def encode(self, prompt, content): + messages = [ + {"role": "system", "content": content}, + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(text) + token_ids = self.tokenizer.encode(text) + return token_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + #通过类继承,新定义类 + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print(self.path) + #在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/bos_id': + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({'bos_id': -1}) + else: + msg = json.dumps({'bos_id': bos_id}) + elif self.path == '/eos_id': + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({'eos_id': -1}) + else: + msg = json.dumps({'eos_id': eos_id}) + else: + msg = 'error' + + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + def do_POST(self): + #在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read(int( + self.headers['content-length'])) #获取从客户端传入的参数(byte格式) + data = data.decode() #将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/encode': + req = json.loads(data) + prompt = req['text'] + + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({'token_ids': -1}) + else: + msg = json.dumps({'token_ids': token_ids}) + + elif self.path == '/decode': + req = json.loads(data) + token_ids = req['token_ids'] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({'text': ""}) + else: + msg = json.dumps({'text': text}) + else: + msg = 'error' + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument('--host', type=str, default='localhost') + args.add_argument('--port', type=int, default=8080) + args.add_argument('--model_id', type=str, default='qwen2.5_coder_tokenizer') + args.add_argument('--content', type=str, default='You are Qwen, created by Alibaba Cloud. You are a helpful assistant.') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' + print('http://%s:%s' % host) + server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() #开启服务 diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-3B-Int4-ax650.py b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-3B-Int4-ax650.py new file mode 100644 index 00000000..4fded69c --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-3B-Int4-ax650.py @@ -0,0 +1,131 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +class Tokenizer_Http(): + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def encode(self, prompt, content): + messages = [ + {"role": "system", "content": content}, + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(text) + token_ids = self.tokenizer.encode(text) + return token_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + #通过类继承,新定义类 + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print(self.path) + #在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/bos_id': + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({'bos_id': -1}) + else: + msg = json.dumps({'bos_id': bos_id}) + elif self.path == '/eos_id': + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({'eos_id': -1}) + else: + msg = json.dumps({'eos_id': eos_id}) + else: + msg = 'error' + + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + def do_POST(self): + #在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read(int( + self.headers['content-length'])) #获取从客户端传入的参数(byte格式) + data = data.decode() #将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/encode': + req = json.loads(data) + prompt = req['text'] + + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({'token_ids': -1}) + else: + msg = json.dumps({'token_ids': token_ids}) + + elif self.path == '/decode': + req = json.loads(data) + token_ids = req['token_ids'] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({'text': ""}) + else: + msg = json.dumps({'text': text}) + else: + msg = 'error' + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument('--host', type=str, default='localhost') + args.add_argument('--port', type=int, default=8080) + args.add_argument('--model_id', type=str, default='qwen2.5_coder_tokenizer') + args.add_argument('--content', type=str, default='You are Qwen, created by Alibaba Cloud. You are a helpful assistant.') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' + print('http://%s:%s' % host) + server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() #开启服务 diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-7B-Int4-ax650.py b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-7B-Int4-ax650.py new file mode 100644 index 00000000..4fded69c --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-7B-Int4-ax650.py @@ -0,0 +1,131 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +class Tokenizer_Http(): + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + + def encode(self, prompt, content): + messages = [ + {"role": "system", "content": content}, + {"role": "user", "content": prompt} + ] + text = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + print(text) + token_ids = self.tokenizer.encode(text) + return token_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + #通过类继承,新定义类 + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print(self.path) + #在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/bos_id': + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({'bos_id': -1}) + else: + msg = json.dumps({'bos_id': bos_id}) + elif self.path == '/eos_id': + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({'eos_id': -1}) + else: + msg = json.dumps({'eos_id': eos_id}) + else: + msg = 'error' + + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + def do_POST(self): + #在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read(int( + self.headers['content-length'])) #获取从客户端传入的参数(byte格式) + data = data.decode() #将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") #设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == '/encode': + req = json.loads(data) + prompt = req['text'] + + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({'token_ids': -1}) + else: + msg = json.dumps({'token_ids': token_ids}) + + elif self.path == '/decode': + req = json.loads(data) + token_ids = req['token_ids'] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({'text': ""}) + else: + msg = json.dumps({'text': text}) + else: + msg = 'error' + print(msg) + msg = str(msg).encode() #转为str再转为byte格式 + + self.wfile.write(msg) #将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument('--host', type=str, default='localhost') + args.add_argument('--port', type=int, default=8080) + args.add_argument('--model_id', type=str, default='qwen2.5_coder_tokenizer') + args.add_argument('--content', type=str, default='You are Qwen, created by Alibaba Cloud. You are a helpful assistant.') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) #设定地址与端口号,'localhost'等价于'127.0.0.1' + print('http://%s:%s' % host) + server = HTTPServer(host, Request) #根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() #开启服务 diff --git a/projects/llm_framework/main_melotts/SConstruct b/projects/llm_framework/main_melotts/SConstruct index 69d9d464..1701cb76 100644 --- a/projects/llm_framework/main_melotts/SConstruct +++ b/projects/llm_framework/main_melotts/SConstruct @@ -30,7 +30,7 @@ REQUIREMENTS += ['glog', 'fst'] REQUIREMENTS += ['onnxruntime'] -STATIC_FILES += Glob('mode_*.json') +STATIC_FILES += Glob('models/mode_*.json') env['COMPONENTS'].append({'target':'llm_melotts-1.8', 'SRCS':SRCS, diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-au.json b/projects/llm_framework/main_melotts/models/mode_melotts-en-au.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-en-au.json rename to projects/llm_framework/main_melotts/models/mode_melotts-en-au.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-br.json b/projects/llm_framework/main_melotts/models/mode_melotts-en-br.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-en-br.json rename to projects/llm_framework/main_melotts/models/mode_melotts-en-br.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-default.json b/projects/llm_framework/main_melotts/models/mode_melotts-en-default.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-en-default.json rename to projects/llm_framework/main_melotts/models/mode_melotts-en-default.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-india.json b/projects/llm_framework/main_melotts/models/mode_melotts-en-india.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-en-india.json rename to projects/llm_framework/main_melotts/models/mode_melotts-en-india.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-en-us.json b/projects/llm_framework/main_melotts/models/mode_melotts-en-us.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-en-us.json rename to projects/llm_framework/main_melotts/models/mode_melotts-en-us.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-es-es.json b/projects/llm_framework/main_melotts/models/mode_melotts-es-es.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-es-es.json rename to projects/llm_framework/main_melotts/models/mode_melotts-es-es.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-ja-jp.json b/projects/llm_framework/main_melotts/models/mode_melotts-ja-jp.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-ja-jp.json rename to projects/llm_framework/main_melotts/models/mode_melotts-ja-jp.json diff --git a/projects/llm_framework/main_melotts/mode_melotts-zh-cn.json b/projects/llm_framework/main_melotts/models/mode_melotts-zh-cn.json similarity index 100% rename from projects/llm_framework/main_melotts/mode_melotts-zh-cn.json rename to projects/llm_framework/main_melotts/models/mode_melotts-zh-cn.json diff --git a/projects/llm_framework/main_vlm/models/mode_internvl2.5-1B-448-ax650.json b/projects/llm_framework/main_vlm/models/mode_internvl2.5-1B-448-ax650.json new file mode 100644 index 00000000..ecc6c071 --- /dev/null +++ b/projects/llm_framework/main_vlm/models/mode_internvl2.5-1B-448-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"internvl2.5-1B-448-ax650", + "type":"vlm", + "homepage":"https://huggingface.co/AXERA-TECH/InternVL2_5-1B", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "vlm.chat_completion", + "vlm.chat_completion.stream" + ], + "output_type":[ + "vlm.utf-8", + "vlm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p320_l%d_together.axmodel", + "filename_vpm_resampler_axmodedl":"vit_intern_2_5_sim_space2depth_nhwc.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":24, + "tokens_embed_num":151674, + "img_token_id":151667, + "tokens_embed_size":896, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_internvl2.5-1B-448-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/models/mode_smolvlm-256M-ax650.json b/projects/llm_framework/main_vlm/models/mode_smolvlm-256M-ax650.json new file mode 100644 index 00000000..179fa6b6 --- /dev/null +++ b/projects/llm_framework/main_vlm/models/mode_smolvlm-256M-ax650.json @@ -0,0 +1,35 @@ +{ + "mode":"smolvlm-256M-ax650", + "type":"vlm", + "homepage":"https://huggingface.co/AXERA-TECH/SmolVLM-256M-Instruct", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "vlm.chat_completion", + "vlm.chat_completion.stream" + ], + "output_type":[ + "vlm.utf-8", + "vlm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"llama_post.axmodel", + "template_filename_axmodel":"llama_p128_l%d_together.axmodel", + "filename_vpm_resampler_axmodedl":"SmolVLM-256M-Instruct_vision_nhwc.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":30, + "tokens_embed_num":49280, + "img_token_id":49190, + "tokens_embed_size":576, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_smolvlm-256M-ax650.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/models/mode_smolvlm-500M-ax650.json b/projects/llm_framework/main_vlm/models/mode_smolvlm-500M-ax650.json new file mode 100644 index 00000000..3ce09a08 --- /dev/null +++ b/projects/llm_framework/main_vlm/models/mode_smolvlm-500M-ax650.json @@ -0,0 +1,37 @@ +{ + "mode":"smolvlm-500M-ax630c", + "type":"vlm", + "homepage":"https://huggingface.co/HuggingFaceTB/SmolVLM-500M-Instruct", + "compile_flage":"pulsar2 build --input HuggingFaceTB/SmolVLM-500M-w8a16/SmolVLM-500M-Instruct_vision.onnx --config AXERA/SmolVLM-256M-Instruct.axera/model_convert/config.json --output_dir HuggingFaceTB/SmolVLM-500M-w8a16/build-output --output_name SmolVLM-500M-Instruct_vision.axmodel --target_hardware AX620E --compiler.check 0 --npu_mode NPU2", + "pulsar_version":"3.4-983bb35e", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "vlm.chat_completion", + "vlm.chat_completion.stream" + ], + "output_type":[ + "vlm.utf-8", + "vlm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "filename_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"llama_post.axmodel", + "template_filename_axmodel":"llama_p128_l%d_together.axmodel", + "filename_vpm_resampler_axmodedl":"SmolVLM-500M-Instruct_vision.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":32, + "tokens_embed_num":49280, + "img_token_id":49190, + "tokens_embed_size":960, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "ext_scripts":["tokenizer_smolvlm-500M-ax630c.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/scripts/tokenizer_internvl2.5-1B-448-ax650.py b/projects/llm_framework/main_vlm/scripts/tokenizer_internvl2.5-1B-448-ax650.py new file mode 100644 index 00000000..bb27674b --- /dev/null +++ b/projects/llm_framework/main_vlm/scripts/tokenizer_internvl2.5-1B-448-ax650.py @@ -0,0 +1,138 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + + +class Tokenizer_Http: + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=True, use_fast=False + ) + + def encode(self, prompt, content): + prompt = f"<|im_start|>system\n{content}<|im_end|><|im_start|>user\n{prompt}<|im_end|><|im_start|>assistant\n" + input_ids = self.tokenizer.encode(prompt) + return input_ids + + def encode_vpm(self, prompt, content="Please describe the image shortly."): + prompt = f"<|im_start|>system\n{content}<|im_end|><|im_start|>user\n" + "" * 256 + f"\n{prompt}<|im_end|><|im_start|>assistant\n" + input_ids = self.tokenizer.encode(prompt) + return input_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + # 通过类继承,新定义类 + timeout = 5 + server_version = "Apache" + + def do_GET(self): + print(self.path) + # 在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") # 设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == "/bos_id": + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({"bos_id": -1}) + else: + msg = json.dumps({"bos_id": bos_id}) + elif self.path == "/eos_id": + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({"eos_id": -1}) + else: + msg = json.dumps({"eos_id": eos_id}) + else: + msg = "error" + + print(msg) + msg = str(msg).encode() # 转为str再转为byte格式 + + self.wfile.write(msg) # 将byte格式的信息返回给客户端 + + def do_POST(self): + # 在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read( + int(self.headers["content-length"]) + ) # 获取从客户端传入的参数(byte格式) + data = data.decode() # 将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") # 设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == "/encode": + req = json.loads(data) + print(req) + prompt = req["text"] + b_img_prompt = False + if "img_prompt" in req: + b_img_prompt = req["img_prompt"] + if b_img_prompt: + token_ids = tokenizer.encode_vpm(prompt) + else: + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({"token_ids": -1}) + else: + msg = json.dumps({"token_ids": token_ids}) + + elif self.path == "/decode": + req = json.loads(data) + token_ids = req["token_ids"] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({"text": ""}) + else: + msg = json.dumps({"text": text}) + else: + msg = "error" + print(msg) + msg = str(msg).encode() # 转为str再转为byte格式 + + self.wfile.write(msg) # 将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument("--host", type=str, default="localhost") + args.add_argument("--port", type=int, default=8080) + args.add_argument('--model_id', type=str, default='internvl2_tokenizer') + args.add_argument('--content', type=str, default='你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) # 设定地址与端口号,'localhost'等价于'127.0.0.1' + print("http://%s:%s" % host) + server = HTTPServer(host, Request) # 根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() # 开启服务 diff --git a/projects/llm_framework/main_vlm/scripts/tokenizer_smolvlm-256M-ax650.py b/projects/llm_framework/main_vlm/scripts/tokenizer_smolvlm-256M-ax650.py new file mode 100644 index 00000000..560a71f3 --- /dev/null +++ b/projects/llm_framework/main_vlm/scripts/tokenizer_smolvlm-256M-ax650.py @@ -0,0 +1,248 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from transformers.tokenization_utils_base import AddedToken +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse + +def _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, +): + """Prompt with expanded image tokens for when the image is split into patches.""" + text_split_images = "" + for n_h in range(image_rows): + for n_w in range(image_cols): + text_split_images += ( + f"{fake_token_around_image}" + + f"" + + f"{image_token}" * image_seq_len + ) + text_split_images += "\n" + + text_split_images += ( + f"\n{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + return text_split_images + + +def _prompt_single_image( + image_seq_len, fake_token_around_image, image_token, global_img_token +): + """Prompt with expanded image tokens for a single image.""" + return ( + f"{fake_token_around_image}" + + f"{global_img_token}" + + f"{image_token}" * image_seq_len + + f"{fake_token_around_image}" + ) + + +def get_image_prompt_string( + image_rows, + image_cols, + image_seq_len, + fake_token_around_image, + image_token, + global_img_token, +): + if image_rows == 0 and image_cols == 0: + return _prompt_single_image( + image_seq_len, + fake_token_around_image=fake_token_around_image, + image_token=image_token, + global_img_token=global_img_token, + ) + return _prompt_split_image( + image_seq_len, + image_rows, + image_cols, + fake_token_around_image, + image_token, + global_img_token, + ) + +class Tokenizer_Http: + + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained( + model_id, trust_remote_code=True, use_fast=False + ) + + def encode(self, prompt, content): + prompt = f"<|im_start|>User:{content}\nAssistant:" + input_ids = self.tokenizer(prompt) + return input_ids["input_ids"] + + def encode_vpm(self, prompt, content="Please describe the image shortly."): + prompt = f"<|im_start|>User:{prompt}\nAssistant:" + text = [prompt] + image_rows = [[0]] + image_cols = [[0]] + image_seq_len = 64 + image_token = "" + fake_image_token = "" + global_img_token = "" + prompt_strings = [] + for sample, sample_rows, sample_cols in zip(text, image_rows, image_cols): + # Replace the image token with fake tokens around the expanded image token sequence of length `image_seq_len` + image_prompt_strings = [] + for n_rows, n_cols in zip(sample_rows, sample_cols): + image_prompt_string = get_image_prompt_string( + n_rows, + n_cols, + image_seq_len, + image_token=image_token, + fake_token_around_image=fake_image_token, + global_img_token=global_img_token, + ) + image_prompt_strings.append(image_prompt_string) + + split_sample = sample.split(image_token) + if len(split_sample) == 0: + raise ValueError("The image token should be present in the text.") + + # Place in the image prompt strings where the image tokens are + sample = split_sample[0] + for i, image_prompt_string in enumerate(image_prompt_strings): + sample += image_prompt_string + split_sample[i + 1] + prompt_strings.append(sample) + + fake_image_token = AddedToken(fake_image_token, normalized=False, special=True) + image_token = AddedToken(image_token, normalized=False, special=True) + end_of_utterance_token = AddedToken( + "", normalized=False, special=True + ) + tokens_to_add = { + "additional_special_tokens": [ + fake_image_token, + image_token, + end_of_utterance_token, + ] + } + self.tokenizer.add_special_tokens(tokens_to_add) + + input_ids = self.tokenizer(prompt_strings)["input_ids"][0] + return input_ids + + def decode(self, token_ids): + return self.tokenizer.decode(token_ids, clean_up_tokenization_spaces=False) + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + +class Request(BaseHTTPRequestHandler): + # 通过类继承,新定义类 + timeout = 5 + server_version = "Apache" + + def do_GET(self): + print(self.path) + # 在新类中定义get的内容(当客户端向该服务端使用get请求时,本服务端将如下运行) + self.send_response(200) + self.send_header("type", "get") # 设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == "/bos_id": + bos_id = tokenizer.bos_id + # print(bos_id) + # to json + if bos_id is None: + msg = json.dumps({"bos_id": -1}) + else: + msg = json.dumps({"bos_id": bos_id}) + elif self.path == "/eos_id": + eos_id = tokenizer.eos_id + if eos_id is None: + msg = json.dumps({"eos_id": -1}) + else: + msg = json.dumps({"eos_id": eos_id}) + else: + msg = "error" + + print(msg) + msg = str(msg).encode() # 转为str再转为byte格式 + + self.wfile.write(msg) # 将byte格式的信息返回给客户端 + + def do_POST(self): + # 在新类中定义post的内容(当客户端向该服务端使用post请求时,本服务端将如下运行) + data = self.rfile.read( + int(self.headers["content-length"]) + ) # 获取从客户端传入的参数(byte格式) + data = data.decode() # 将byte格式转为str格式 + + self.send_response(200) + self.send_header("type", "post") # 设置响应头,可省略或设置多个 + self.end_headers() + + if self.path == "/encode": + req = json.loads(data) + print(req) + prompt = req["text"] + b_img_prompt = False + if "img_prompt" in req: + b_img_prompt = req["img_prompt"] + if b_img_prompt: + token_ids = tokenizer.encode_vpm(prompt) + else: + token_ids = tokenizer.encode(prompt, args.content) + if token_ids is None: + msg = json.dumps({"token_ids": -1}) + else: + msg = json.dumps({"token_ids": token_ids}) + + elif self.path == "/decode": + req = json.loads(data) + token_ids = req["token_ids"] + text = tokenizer.decode(token_ids) + if text is None: + msg = json.dumps({"text": ""}) + else: + msg = json.dumps({"text": text}) + else: + msg = "error" + print(msg) + msg = str(msg).encode() # 转为str再转为byte格式 + + self.wfile.write(msg) # 将byte格式的信息返回给客户端 + + +if __name__ == "__main__": + + args = argparse.ArgumentParser() + args.add_argument("--host", type=str, default="localhost") + args.add_argument("--port", type=int, default=8080) + args.add_argument('--model_id', type=str, default='internvl2_tokenizer') + args.add_argument('--content', type=str, default='') + args = args.parse_args() + + tokenizer = Tokenizer_Http(args.model_id) + + + # print(tokenizer.bos_id, tokenizer.bos_token, tokenizer.eos_id, tokenizer.eos_token) + # print(tokenizer.encode("hello world", args.content)) + + host = (args.host, args.port) # 设定地址与端口号,'localhost'等价于'127.0.0.1' + print("http://%s:%s" % host) + server = HTTPServer(host, Request) # 根据地址端口号和新定义的类,创建服务器实例 + server.serve_forever() # 开启服务 diff --git a/projects/llm_framework/main_whisper/SConstruct b/projects/llm_framework/main_whisper/SConstruct index 4c61edce..6395e923 100644 --- a/projects/llm_framework/main_whisper/SConstruct +++ b/projects/llm_framework/main_whisper/SConstruct @@ -31,7 +31,7 @@ LINK_SEARCH_PATH += [ADir('../static_lib/opencc/lib')] # LDFLAGS += ['-l:libcargs.a', '-l:libonnxruntime.a'] LDFLAGS += ['-l:libopencc.a', '-l:libmarisa.a'] -STATIC_FILES += Glob('mode_*.json') +STATIC_FILES += Glob('models/mode_*.json') env['COMPONENTS'].append({'target':'llm_whisper-1.7', 'SRCS':SRCS, diff --git a/projects/llm_framework/main_whisper/models/mode_whisper-base-ax650.json b/projects/llm_framework/main_whisper/models/mode_whisper-base-ax650.json new file mode 100644 index 00000000..a9c22be6 --- /dev/null +++ b/projects/llm_framework/main_whisper/models/mode_whisper-base-ax650.json @@ -0,0 +1,42 @@ +{ + "mode": "whisper-base-ax650", + "type": "asr", + "homepage":"https://huggingface.co/openai/whisper-base", + "capabilities": [ + "Automatic_Speech_Recognition", + "English", + "Chinese", + "Japanese" + ], + "input_type": [ + "sys.pcm" + ], + "output_type": [ + "asr.utf-8" + ], + "mode_param": { + "model_type": "base", + "language": "en", + "encoder": "base-encoder.axmodel", + "decoder_main": "base-decoder-main.axmodel", + "decoder_loop": "base-decoder-loop.axmodel", + "positional_embedding": "base-positional_embedding.bin", + "tokens": "base-tokens.txt", + "t2s": "t2s.json", + "whisper_sample_rate": 16000, + "whisper_n_fft": 400, + "awake_delay": 1000, + "whisper_hop_length": 160, + "whisper_chunk_size": 30, + "whisper_n_mels": 80, + "whisper_sot": 50258, + "whisper_eot": 50257, + "whisper_blank": 220, + "whisper_no_timestamps": 50363, + "whisper_no_speech": 50362, + "whisper_translate": 50358, + "whisper_transcribe": 50359, + "whisper_vocab_size": 51865, + "whisper_n_text_ctx": 448 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_whisper/mode_whisper-base.json b/projects/llm_framework/main_whisper/models/mode_whisper-base.json similarity index 100% rename from projects/llm_framework/main_whisper/mode_whisper-base.json rename to projects/llm_framework/main_whisper/models/mode_whisper-base.json diff --git a/projects/llm_framework/main_whisper/models/mode_whisper-small-ax650.json b/projects/llm_framework/main_whisper/models/mode_whisper-small-ax650.json new file mode 100644 index 00000000..8ef8f1c7 --- /dev/null +++ b/projects/llm_framework/main_whisper/models/mode_whisper-small-ax650.json @@ -0,0 +1,42 @@ +{ + "mode": "whisper-small-ax650", + "type": "asr", + "homepage":"https://huggingface.co/openai/whisper-small", + "capabilities": [ + "Automatic_Speech_Recognition", + "English", + "Chinese", + "Japanese" + ], + "input_type": [ + "sys.pcm" + ], + "output_type": [ + "asr.utf-8" + ], + "mode_param": { + "model_type": "small", + "language": "en", + "encoder": "small-encoder.axmodel", + "decoder_main": "small-decoder-main.axmodel", + "decoder_loop": "small-decoder-loop.axmodel", + "positional_embedding": "small-positional_embedding.bin", + "tokens": "small-tokens.txt", + "t2s": "t2s.json", + "whisper_sample_rate": 16000, + "whisper_n_fft": 400, + "awake_delay": 1000, + "whisper_hop_length": 160, + "whisper_chunk_size": 30, + "whisper_n_mels": 80, + "whisper_sot": 50258, + "whisper_eot": 50257, + "whisper_blank": 220, + "whisper_no_timestamps": 50363, + "whisper_no_speech": 50362, + "whisper_translate": 50358, + "whisper_transcribe": 50359, + "whisper_vocab_size": 51865, + "whisper_n_text_ctx": 448 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_whisper/mode_whisper-small.json b/projects/llm_framework/main_whisper/models/mode_whisper-small.json similarity index 100% rename from projects/llm_framework/main_whisper/mode_whisper-small.json rename to projects/llm_framework/main_whisper/models/mode_whisper-small.json diff --git a/projects/llm_framework/main_whisper/models/mode_whisper-tiny-ax650.json b/projects/llm_framework/main_whisper/models/mode_whisper-tiny-ax650.json new file mode 100644 index 00000000..b44e7b47 --- /dev/null +++ b/projects/llm_framework/main_whisper/models/mode_whisper-tiny-ax650.json @@ -0,0 +1,42 @@ +{ + "mode": "whisper-tiny-ax650", + "type": "asr", + "homepage":"https://huggingface.co/openai/whisper-tiny", + "capabilities": [ + "Automatic_Speech_Recognition", + "English", + "Chinese", + "Japanese" + ], + "input_type": [ + "sys.pcm" + ], + "output_type": [ + "asr.utf-8" + ], + "mode_param": { + "model_type": "tiny", + "language": "en", + "encoder": "tiny-encoder.axmodel", + "decoder_main": "tiny-decoder-main.axmodel", + "decoder_loop": "tiny-decoder-loop.axmodel", + "positional_embedding": "tiny-positional_embedding.bin", + "tokens": "tiny-tokens.txt", + "t2s": "t2s.json", + "whisper_sample_rate": 16000, + "whisper_n_fft": 400, + "awake_delay": 1000, + "whisper_hop_length": 160, + "whisper_chunk_size": 30, + "whisper_n_mels": 80, + "whisper_sot": 50258, + "whisper_eot": 50257, + "whisper_blank": 220, + "whisper_no_timestamps": 50363, + "whisper_no_speech": 50362, + "whisper_translate": 50358, + "whisper_transcribe": 50359, + "whisper_vocab_size": 51865, + "whisper_n_text_ctx": 448 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_whisper/mode_whisper-tiny.json b/projects/llm_framework/main_whisper/models/mode_whisper-tiny.json similarity index 100% rename from projects/llm_framework/main_whisper/mode_whisper-tiny.json rename to projects/llm_framework/main_whisper/models/mode_whisper-tiny.json From 90fae78817446612fb174087191d095333ad0dea Mon Sep 17 00:00:00 2001 From: dianjixz <18637716021@163.com> Date: Tue, 1 Jul 2025 12:17:26 +0800 Subject: [PATCH 23/79] [update] main_audio add 630c kit default param && StackFlow add send_raw_to_pub --- .../StackFlow/stackflow/StackFlow.cpp | 5 + .../StackFlow/stackflow/StackFlow.h | 1 + projects/llm_framework/main_audio/SConstruct | 2 +- .../llm_framework/main_audio/audio_kit.json | 115 ++++++++++++++++++ .../llm_framework/main_audio/src/main.cpp | 22 ++-- 5 files changed, 130 insertions(+), 15 deletions(-) create mode 100644 projects/llm_framework/main_audio/audio_kit.json diff --git a/ext_components/StackFlow/stackflow/StackFlow.cpp b/ext_components/StackFlow/stackflow/StackFlow.cpp index 2991162e..561e20f8 100644 --- a/ext_components/StackFlow/stackflow/StackFlow.cpp +++ b/ext_components/StackFlow/stackflow/StackFlow.cpp @@ -117,6 +117,11 @@ int llm_channel_obj::send_raw_to_pub(const std::string &raw) return zmq_[-1]->send_data(raw); } +int llm_channel_obj::send_raw_to_pub(const char *data, int size) +{ + return zmq_[-1]->send_data(data, size); +} + int llm_channel_obj::send_raw_to_usr(const std::string &raw) { if (zmq_[-2]) { diff --git a/ext_components/StackFlow/stackflow/StackFlow.h b/ext_components/StackFlow/stackflow/StackFlow.h index 25ac6743..c3424fa2 100644 --- a/ext_components/StackFlow/stackflow/StackFlow.h +++ b/ext_components/StackFlow/stackflow/StackFlow.h @@ -106,6 +106,7 @@ class llm_channel_obj { void stop_subscriber(const std::string &zmq_url); int check_zmq_errno(void *ctx, void *com, int code); int send_raw_to_pub(const std::string &raw); + int send_raw_to_pub(const char *data, int size); int send_raw_to_usr(const std::string &raw); template int output_data(const std::string &object, const T &data, const U &error_msg) diff --git a/projects/llm_framework/main_audio/SConstruct b/projects/llm_framework/main_audio/SConstruct index 27a04b68..18d7ca05 100644 --- a/projects/llm_framework/main_audio/SConstruct +++ b/projects/llm_framework/main_audio/SConstruct @@ -26,7 +26,7 @@ LINK_SEARCH_PATH += [ADir('../static_lib')] REQUIREMENTS += ['ax_sys', 'ax_interpreter', 'ax_audio', 'ax_audio_3a', 'ax_fdk', 'ax_opus'] REQUIREMENTS += ['tinyalsa', 'opus', 'samplerate', 'fdk-aac'] -STATIC_FILES += [AFile('audio.json')] +STATIC_FILES += [AFile('audio.json'), AFile('audio_kit.json')] STATIC_FILES += Glob('mode_*.json') env['COMPONENTS'].append({'target':'llm_audio-1.6', diff --git a/projects/llm_framework/main_audio/audio_kit.json b/projects/llm_framework/main_audio/audio_kit.json new file mode 100644 index 00000000..1c7a85fd --- /dev/null +++ b/projects/llm_framework/main_audio/audio_kit.json @@ -0,0 +1,115 @@ +{ + "mode": "None", + "type": "audio", + "capabilities": [ + "play", + "cap" + ], + "input_type": [ + "rpc.audio.wav.base64", + "rpc.audio.pcm.base64" + ], + "output_type": [ + "audio.pcm.stream" + ], + "play_param": { + "card": 0, + "device": 1, + "volume": 0.5, + "channel": 1, + "rate": 16000, + "bit": 16, + "stPoolConfig.MetaSize": 8192, + "stPoolConfig.BlkSize": 32768, + "stPoolConfig.BlkCnt": 37, + "stPoolConfig.IsMergeMode": 0, + "stPoolConfig.CacheMode": 0, + "stPoolConfig.PartitionName": "anonymous", + "stAttr.enBitwidth": 1, + "stAttr.enSoundmode": 0, + "stAttr.u32ChnCnt": 2, + "stAttr.enLinkMode": 0, + "stAttr.enSamplerate": 16000, + "stAttr.U32Depth": 30, + "stAttr.u32PeriodSize": 160, + "stAttr.u32PeriodCount": 8, + "stAttr.bInsertSilence": 0, + "stVqeAttr.s32SampleRate": 16000, + "stVqeAttr.u32FrameSamples": 160, + "stVqeAttr.stNsCfg.bNsEnable": 0, + "stVqeAttr.stNsCfg.enAggressivenessLevel": 2, + "stVqeAttr.stAgcCfg.bAgcEnable": 0, + "stVqeAttr.stAgcCfg.enAgcMode": 2, + "stVqeAttr.stAgcCfg.s16TargetLevel": -3, + "stVqeAttr.stAgcCfg.s16Gain": 9, + "stHpfAttr.bEnable": 0, + "stHpfAttr.s32GainDb": -3, + "stHpfAttr.s32Samplerate": 16000, + "stHpfAttr.s32Freq": 200, + "stLpfAttr.bEnable": 0, + "stLpfAttr.s32GainDb": 0, + "stLpfAttr.s32Samplerate": 16000, + "stLpfAttr.s32Freq": 3000, + "stEqAttr.bEnable": 0, + "stEqAttr.s32GainDb[0]": -10, + "stEqAttr.s32GainDb[1]": -3, + "stEqAttr.s32GainDb[2]": 3, + "stEqAttr.s32GainDb[3]": 5, + "stEqAttr.s32GainDb[4]": 10, + "stEqAttr.s32Samplerate": 16000, + "gResample": 0, + "enInSampleRate": 16000, + "gInstant": 0, + "gInsertSilence": 0 + }, + "cap_param": { + "sys_pcm_cap_channel": "ipc:///tmp/llm/pcm.cap.socket", + "card": 0, + "device": 0, + "volume": 1.0, + "channel": 1, + "rate": 16000, + "bit": 16, + "stPoolConfig.MetaSize": 8192, + "stPoolConfig.BlkSize": 7680, + "stPoolConfig.BlkCnt": 33, + "stPoolConfig.IsMergeMode": 0, + "stPoolConfig.CacheMode": 0, + "stPoolConfig.PartitionName": "anonymous", + "aistAttr.enBitwidth": 1, + "aistAttr.enLinkMode": 0, + "aistAttr.enSamplerate": 16000, + "aistAttr.enLayoutMode": 1, + "aistAttr.U32Depth": 30, + "aistAttr.u32PeriodSize": 160, + "aistAttr.u32PeriodCount": 8, + "aistAttr.u32ChnCnt": 2, + "aistVqeAttr.s32SampleRate": 16000, + "aistVqeAttr.u32FrameSamples": 160, + "aistVqeAttr.stNsCfg.bNsEnable": 1, + "aistVqeAttr.stNsCfg.enAggressivenessLevel": 2, + "aistVqeAttr.stAgcCfg.bAgcEnable": 0, + "aistVqeAttr.stAgcCfg.enAgcMode": 2, + "aistVqeAttr.stAgcCfg.s16TargetLevel": -3, + "aistVqeAttr.stAgcCfg.s16Gain": 9, + "aistVqeAttr.stAecCfg.enAecMode": 2, + "stHpfAttr.bEnable": 0, + "stHpfAttr.s32GainDb": -3, + "stHpfAttr.s32Samplerate": 16000, + "stHpfAttr.s32Freq": 200, + "stLpfAttr.bEnable": 0, + "stLpfAttr.s32GainDb": 0, + "stLpfAttr.s32Samplerate": 16000, + "stLpfAttr.s32Freq": 3000, + "stEqAttr.bEnable": 0, + "stEqAttr.s32GainDb[0]": -10, + "stEqAttr.s32GainDb[1]": -3, + "stEqAttr.s32GainDb[2]": 3, + "stEqAttr.s32GainDb[3]": 5, + "stEqAttr.s32GainDb[4]": 10, + "stEqAttr.s32Samplerate": 16000, + "gResample": 0, + "enOutSampleRate": 16000, + "gDbDetection": 0 + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_audio/src/main.cpp b/projects/llm_framework/main_audio/src/main.cpp index 80da841c..e398e8ad 100644 --- a/projects/llm_framework/main_audio/src/main.cpp +++ b/projects/llm_framework/main_audio/src/main.cpp @@ -142,8 +142,13 @@ class llm_audio : public StackFlow { nlohmann::json error_body; std::string base_model_path; std::string base_model_config_path; - std::list config_file_paths = - get_config_file_paths(base_model_path, base_model_config_path, "audio"); + std::list config_file_paths; + if (access("/sys/devices/platform/soc/4851000.i2c/i2c-1/1-0043", F_OK) == 0) { + config_file_paths = get_config_file_paths(base_model_path, base_model_config_path, "audio_kit"); + } else { + config_file_paths = get_config_file_paths(base_model_path, base_model_config_path, "audio"); + } + try { config_body = nlohmann::json::parse(data); for (auto file_name : config_file_paths) { @@ -241,18 +246,7 @@ class llm_audio : public StackFlow { CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.enBitwidth); CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.enLinkMode); CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.enSamplerate); - - if (config_body.contains("aistAttr.enLayoutMode")) - mode_config_.aistAttr.enLayoutMode = config_body["aistAttr.enLayoutMode"]; - else if (file_body["cap_param"].contains("aistAttr.enLayoutMode")) { - mode_config_.aistAttr.enLayoutMode = file_body["cap_param"]["aistAttr.enLayoutMode"]; - if (access("/sys/devices/platform/soc/4851000.i2c/i2c-1/1-0043", F_OK) == 0) { - if (mode_config_.aistAttr.enLayoutMode == AX_AI_REF_MIC) { - mode_config_.aistAttr.enLayoutMode = AX_AI_MIC_REF; - } - } - } - + CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.enLayoutMode); CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.U32Depth); CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.u32PeriodSize); CONFIG_AUTO_SET(file_body["cap_param"], aistAttr.u32PeriodCount); From 59958869a0adec852067acce4a6a61b941f680b4 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 1 Jul 2025 15:40:30 +0800 Subject: [PATCH 24/79] [update] main_audio add tinyalsa API cap function. --- ext_components/ax_msp/SConstruct | 1 + .../llm_framework/main_audio/src/alsa_audio.c | 136 ++++++++++++++++++ .../llm_framework/main_audio/src/alsa_audio.h | 15 ++ .../llm_framework/main_audio/src/main.cpp | 1 + 4 files changed, 153 insertions(+) create mode 100644 projects/llm_framework/main_audio/src/alsa_audio.c create mode 100644 projects/llm_framework/main_audio/src/alsa_audio.h diff --git a/ext_components/ax_msp/SConstruct b/ext_components/ax_msp/SConstruct index a77c7d46..3c5ecbe3 100644 --- a/ext_components/ax_msp/SConstruct +++ b/ext_components/ax_msp/SConstruct @@ -29,6 +29,7 @@ if 'CONFIG_AX_620E_MSP_ENABLED' in os.environ: for dirn in third_party: INCLUDE.append(os.path.join(MSP_PATH,'third-party',dirn,'include')) LINK_SEARCH_PATH.append(os.path.join(MSP_PATH,'third-party',dirn,'lib/arm64/glibc')) + INCLUDE.append(os.path.join(MSP_PATH,'third-party/tinyalsa/include/tinyalsa')) INCLUDE.append(os.path.join(MSP_PATH,'third-party/live/out/arm64/glibc/include')) LINK_SEARCH_PATH.append(os.path.join(MSP_PATH,'third-party/live/out/arm64/glibc/lib')) INCLUDE.append(os.path.join(MSP_PATH,'third-party/openssl/arm64/include')) diff --git a/projects/llm_framework/main_audio/src/alsa_audio.c b/projects/llm_framework/main_audio/src/alsa_audio.c new file mode 100644 index 00000000..73aca0f4 --- /dev/null +++ b/projects/llm_framework/main_audio/src/alsa_audio.c @@ -0,0 +1,136 @@ +#include "alsa_audio.h" +#include "samplerate.h" +#include +#include +#include +#include +#include + +static int gcapLoopExit = 0; + +void alsa_cap_start(unsigned int card, unsigned int device, float Volume, int channel, int rate, int bit, + AUDIOCallback callback) +{ + struct pcm_config config; + unsigned int pcm_open_flags; + struct pcm *pcm; + char *buffer; + unsigned int size; + unsigned int frames_read; + unsigned int total_frames_read; + unsigned int bytes_per_frame; + + memset(&config, 0, sizeof(config)); + config.channels = channel; + config.rate = 48000; // TODO: 部分USB MIC仅支持48k,暂时固定采集为48k + config.period_size = 512; + config.period_count = 4; + config.format = PCM_FORMAT_S16_LE; + config.start_threshold = 0; + config.stop_threshold = 0; + config.silence_threshold = 0; + + pcm_open_flags = PCM_IN; + if (1) pcm_open_flags |= PCM_MMAP; + + pcm = pcm_open(card, device, pcm_open_flags, &config); + if (!pcm || !pcm_is_ready(pcm)) { + fprintf(stderr, "Unable to open PCM device (%s)\n", pcm_get_error(pcm)); + return; + } + + size = pcm_frames_to_bytes(pcm, pcm_get_buffer_size(pcm)); + buffer = malloc(size); + if (!buffer) { + fprintf(stderr, "Unable to allocate %u bytes\n", size); + pcm_close(pcm); + return; + } + + if (1) { + printf("Capturing sample: %u ch, %u hz, %u bit\n", channel, rate, pcm_format_to_bits(PCM_FORMAT_S16_LE)); + } + + bytes_per_frame = pcm_frames_to_bytes(pcm, 1); + total_frames_read = 0; + + SRC_STATE *src_state = NULL; + float *in_float = NULL, *out_float = NULL; + int in_frames = pcm_get_buffer_size(pcm); + int out_frames = (int)((float)in_frames * ((float)rate / 48000.0f) + 1); + int out_bytes = out_frames * channel * sizeof(short); + + if (rate != 48000) { + src_state = src_new(SRC_SINC_FASTEST, channel, NULL); + in_float = malloc(in_frames * channel * sizeof(float)); + out_float = malloc(out_frames * channel * sizeof(float)); + if (!src_state || !in_float || !out_float) { + fprintf(stderr, "Unable to allocate resample buffers\n"); + free(buffer); + if (in_float) free(in_float); + if (out_float) free(out_float); + if (src_state) src_delete(src_state); + pcm_close(pcm); + return; + } + } + + while (!gcapLoopExit) { + int ret = pcm_readi(pcm, buffer, in_frames); + if (ret < 0) { + fprintf(stderr, "Error capturing samples - %d (%s)\n", errno, strerror(errno)); + break; + } + frames_read = ret; + total_frames_read += frames_read; + + if (rate == 48000) { + callback(buffer, frames_read * bytes_per_frame); + } else { + short *in_short = (short *)buffer; + for (int i = 0; i < frames_read * channel; ++i) { + in_float[i] = in_short[i] / 32768.0f; + } + SRC_DATA src_data; + src_data.data_in = in_float; + src_data.input_frames = frames_read; + src_data.data_out = out_float; + src_data.output_frames = out_frames; + src_data.src_ratio = (double)rate / 48000.0; + src_data.end_of_input = 0; + int error = src_process(src_state, &src_data); + if (error) { + fprintf(stderr, "SRC error: %s\n", src_strerror(error)); + break; + } + // float转short + short *out_short = malloc(src_data.output_frames_gen * channel * sizeof(short)); + for (int i = 0; i < src_data.output_frames_gen * channel; ++i) { + float sample = out_float[i]; + if (sample > 1.0f) sample = 1.0f; + if (sample < -1.0f) sample = -1.0f; + out_short[i] = (short)(sample * 32767.0f); + } + callback((const char *)out_short, src_data.output_frames_gen * channel * sizeof(short)); + free(out_short); + } + } + + if (rate != 48000) { + free(in_float); + free(out_float); + src_delete(src_state); + } + free(buffer); + pcm_close(pcm); +} + +void alsa_close_cap() +{ + gcapLoopExit = 1; +} + +int alsa_cap_status() +{ + return gcapLoopExit; +} \ No newline at end of file diff --git a/projects/llm_framework/main_audio/src/alsa_audio.h b/projects/llm_framework/main_audio/src/alsa_audio.h new file mode 100644 index 00000000..95c469ab --- /dev/null +++ b/projects/llm_framework/main_audio/src/alsa_audio.h @@ -0,0 +1,15 @@ +#pragma once + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void (*AUDIOCallback)(const char *data, int size); + +void alsa_cap_start(unsigned int card, unsigned int device, float Volume, int channel, int rate, int bit, + AUDIOCallback callback); +void alsa_close_cap(); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/projects/llm_framework/main_audio/src/main.cpp b/projects/llm_framework/main_audio/src/main.cpp index e398e8ad..e7007538 100644 --- a/projects/llm_framework/main_audio/src/main.cpp +++ b/projects/llm_framework/main_audio/src/main.cpp @@ -22,6 +22,7 @@ static void __sigint(int iSigNo) } #include "sample_audio.h" +#include "alsa_audio.h" #define CONFIG_AUTO_SET(obj, key) \ if (config_body.contains(#key)) \ From 9abe069db6c7d1ff29e46ccba6f5b2e5e19bd78a Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 18 Jul 2025 14:12:12 +0800 Subject: [PATCH 25/79] [update] KWS sets multiple keywords, fix melotts File missing, Optimize generated audio --- projects/llm_framework/main_kws/src/main.cpp | 76 +++++-- .../llm_framework/main_melotts/src/main.cpp | 212 +++++++----------- .../main_melotts/src/runner/Lexicon.hpp | 8 +- 3 files changed, 141 insertions(+), 155 deletions(-) diff --git a/projects/llm_framework/main_kws/src/main.cpp b/projects/llm_framework/main_kws/src/main.cpp index 4ea6abcd..5b8a21bb 100644 --- a/projects/llm_framework/main_kws/src/main.cpp +++ b/projects/llm_framework/main_kws/src/main.cpp @@ -34,6 +34,8 @@ static void __sigint(int iSigNo) static std::string base_model_path_; static std::string base_model_config_path_; +typedef std::function task_callback_t; + #define CONFIG_AUTO_SET(obj, key) \ if (config_body.contains(#key)) \ mode_config_.key = config_body[#key]; \ @@ -50,16 +52,17 @@ class llm_task { std::string model_; std::string response_format_; std::vector inputs_; - std::string kws_; + std::vector kws_; bool enoutput_; + bool enoutput_json_; bool enstream_; bool enwake_audio_; std::atomic_bool audio_flage_; + task_callback_t out_callback_; int delay_audio_frame_ = 100; buffer_t *pcmdata; std::string wake_wav_file_; - std::function out_callback_; std::function play_awake_wav; bool parse_config(const nlohmann::json &config_body) @@ -68,7 +71,6 @@ class llm_task { model_ = config_body.at("model"); response_format_ = config_body.at("response_format"); enoutput_ = config_body.at("enoutput"); - kws_ = config_body.at("kws"); if (config_body.contains("enwake_audio")) { enwake_audio_ = config_body["enwake_audio"]; } else { @@ -83,6 +85,16 @@ class llm_task { } } } + if (config_body.contains("kws")) { + if (config_body["kws"].is_string()) { + kws_.push_back(config_body["kws"].get()); + } else if (config_body["kws"].is_array()) { + for (auto _in : config_body["kws"]) { + kws_.push_back(_in.get()); + } + } + } + enoutput_json_ = response_format_.find("json") == std::string::npos ? false : true; } catch (...) { SLOGE("setup config_body error"); return true; @@ -173,7 +185,9 @@ class llm_task { mode_config_.keywords_file = base_model + mode_config_.keywords_file; std::ofstream temp_awake_key("/tmp/kws_awake.txt.tmp"); - temp_awake_key << kws_; + for (const auto &keyword : kws_) { + temp_awake_key << keyword << std::endl; + } temp_awake_key.close(); std::ostringstream awake_key_compile_cmd; if (file_exists("/opt/m5stack/scripts/text2token.py")) @@ -206,7 +220,7 @@ class llm_task { return 0; } - void set_output(std::function out_callback) + void set_output(task_callback_t out_callback) { out_callback_ = out_callback; } @@ -242,16 +256,17 @@ class llm_task { play_awake_wav(wake_wav_file_); } if (out_callback_) { - out_callback_("True"); + if (enoutput_json_) + out_callback_(r.AsJsonString(), true); + else + out_callback_("", true); } } } void trigger() { - if (out_callback_) { - out_callback_("True"); - } + if (out_callback_) out_callback_("", true); } bool delete_model() @@ -301,6 +316,39 @@ class llm_kws : public StackFlow { }); } + void task_output(const std::weak_ptr llm_task_obj_weak, + const std::weak_ptr llm_channel_weak, const std::string &data, bool finish) + { + auto llm_task_obj = llm_task_obj_weak.lock(); + auto llm_channel = llm_channel_weak.lock(); + if (!(llm_task_obj && llm_channel)) { + return; + } + std::string tmp_msg1; + const std::string *next_data = &data; + if (data.empty()) { + llm_channel->send(llm_task_obj->response_format_, true, LLM_NO_ERROR); + return; + } + if (finish) { + tmp_msg1 = data + "."; + next_data = &tmp_msg1; + } + if (llm_channel->enstream_) { + static int count = 0; + nlohmann::json data_body; + data_body["index"] = count++; + data_body["delta"] = (*next_data); + data_body["finish"] = finish; + if (finish) count = 0; + SLOGI("send stream:%s", next_data->c_str()); + llm_channel->send(llm_task_obj->response_format_, data_body, LLM_NO_ERROR); + } else if (finish) { + SLOGI("send utf-8:%s", next_data->c_str()); + llm_channel->send(llm_task_obj->response_format_, (*next_data), LLM_NO_ERROR); + } + } + void play_awake_wav(const std::string &wav_file) { FILE *fp = fopen(wav_file.c_str(), "rb"); @@ -465,13 +513,9 @@ class llm_kws : public StackFlow { llm_channel->set_output(llm_task_obj->enoutput_); llm_channel->set_stream(llm_task_obj->enstream_); llm_task_obj->play_awake_wav = std::bind(&llm_kws::play_awake_wav, this, std::placeholders::_1); - llm_task_obj->set_output([_llm_task_obj, _llm_channel](const std::string &data) { - auto llm_task_obj = _llm_task_obj.lock(); - auto llm_channel = _llm_channel.lock(); - if (llm_task_obj && llm_channel) { - llm_channel->send(llm_task_obj->response_format_, true, LLM_NO_ERROR); - } - }); + llm_task_obj->set_output(std::bind(&llm_kws::task_output, this, std::weak_ptr(llm_task_obj), + std::weak_ptr(llm_channel), std::placeholders::_1, + std::placeholders::_2)); for (const auto input : llm_task_obj->inputs_) { if (input.find("sys") != std::string::npos) { diff --git a/projects/llm_framework/main_melotts/src/main.cpp b/projects/llm_framework/main_melotts/src/main.cpp index 6b8e1c15..c4e2122c 100644 --- a/projects/llm_framework/main_melotts/src/main.cpp +++ b/projects/llm_framework/main_melotts/src/main.cpp @@ -184,7 +184,10 @@ class llm_task { else if (file_body["mode_param"].contains("awake_delay")) awake_delay_ = file_body["mode_param"]["awake_delay"]; - if (!std::filesystem::exists(mode_config_.tagger) || !std::filesystem::exists(mode_config_.verbalizer)) { + if (!std::filesystem::exists(mode_config_.tagger) || + !std::filesystem::is_regular_file(mode_config_.tagger) || + !std::filesystem::exists(mode_config_.verbalizer) || + !std::filesystem::is_regular_file(mode_config_.verbalizer)) { SLOGW("Either tagger or verbalizer file does not exist, using alternative lexicon."); lexicon_ = std::make_unique(mode_config_.lexicon, mode_config_.tokens); } else { @@ -281,195 +284,132 @@ class llm_task { int dec_len = zp_size / zp_shape[1]; int audio_slice_len = decoder_->GetOutputSize(0) / sizeof(float); - const int pad_frames = 24; - const int samples_per_frame = 512; + const int overlap_size = 1024; + const int fade_size = 512; - const int effective_frames = dec_len - 2 * pad_frames; + int dec_slice_num = static_cast(std::ceil(static_cast(zp_shape[2]) / dec_len)); - int dec_slice_num = - static_cast(std::ceil(static_cast(zp_shape[2]) / static_cast(effective_frames))); - - const int sola_buffer_frame = pad_frames * samples_per_frame; - const int sola_search_frame = pad_frames * samples_per_frame; - const int block_frame = (dec_len - 2 * pad_frames) * samples_per_frame; - - std::vector fade_in_window(sola_buffer_frame); - std::vector fade_out_window(sola_buffer_frame); - - for (int i = 0; i < sola_buffer_frame; i++) { - fade_in_window[i] = static_cast(i) / sola_buffer_frame; - fade_out_window[i] = 1.0f - fade_in_window[i]; + std::vector fade_in(fade_size); + std::vector fade_out(fade_size); + for (int i = 0; i < fade_size; i++) { + float t = static_cast(i) / fade_size; + fade_in[i] = t; + fade_out[i] = 1.0f - t; } - std::vector sola_buffer(sola_buffer_frame, 0.0f); - bool first_frame = true; - std::vector pcmlist; + std::vector previous_tail; for (int i = 0; i < dec_slice_num; i++) { - int input_start = i * effective_frames; - if (i > 0) { - input_start -= pad_frames; - } - input_start = std::max(0, input_start); - - int actual_len = std::min(dec_len, static_cast(zp_shape[2] - input_start)); - - int output_start_frame, output_end_frame; - - if (i == 0) { - output_start_frame = 0; - output_end_frame = effective_frames - 1; - } else if (i == dec_slice_num - 1) { - output_start_frame = i * effective_frames; - output_end_frame = static_cast(zp_shape[2]) - 1; - } else { - output_start_frame = i * effective_frames; - output_end_frame = (i + 1) * effective_frames - 1; - } + int input_start = i * dec_len; + int actual_size = std::min(dec_len, static_cast(zp_shape[2] - input_start)); std::vector zp(zp_size, 0); - for (int n = 0; n < zp_shape[1]; n++) { - int copy_size = std::min(actual_len, static_cast(zp_shape[2] - input_start)); - if (copy_size > 0) { + if (actual_size > 0) { memcpy(zp.data() + n * dec_len, zp_data + n * zp_shape[2] + input_start, - sizeof(float) * copy_size); + sizeof(float) * actual_size); } } - std::vector decoder_output(audio_slice_len); decoder_->SetInput(zp.data(), 0); decoder_->SetInput(g_matrix.data(), 1); if (0 != decoder_->Run()) { + SLOGE("Decoder run failed at slice %d", i); throw std::string("decoder_ RunSync error"); } + std::vector decoder_output(audio_slice_len); decoder_->GetOutput(decoder_output.data(), 0); - if (first_frame) { - int audio_start = 0; - int audio_len = decoder_output.size() - sola_buffer_frame; - audio_len = std::max(0, audio_len); - - if (audio_len > 0) { - pcmlist.insert(pcmlist.end(), decoder_output.begin() + audio_start, - decoder_output.begin() + audio_start + audio_len); - } + if (i == 0) { + int main_part_size = static_cast(decoder_output.size()) - overlap_size; + main_part_size = std::max(0, main_part_size); - int buffer_start = audio_len; + pcmlist.insert(pcmlist.end(), decoder_output.begin(), decoder_output.begin() + main_part_size); - if (buffer_start + sola_buffer_frame <= decoder_output.size()) { - std::copy(decoder_output.begin() + buffer_start, - decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin()); - } else { - int available = static_cast(decoder_output.size() - buffer_start); - if (available > 0) { - std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), sola_buffer.begin()); - std::fill(sola_buffer.begin() + available, sola_buffer.end(), 0.0f); - } else { - std::fill(sola_buffer.begin(), sola_buffer.end(), 0.0f); - } + if (decoder_output.size() > main_part_size) { + previous_tail.assign(decoder_output.begin() + main_part_size, decoder_output.end()); } - first_frame = false; - } else { - int audio_start = pad_frames * samples_per_frame; - - std::vector search_window(sola_buffer_frame + sola_search_frame); - std::copy(decoder_output.begin() + audio_start, - decoder_output.begin() + audio_start + search_window.size(), search_window.begin()); - - int best_offset = 0; - float best_correlation = -1.0; - - for (int offset = 0; offset <= sola_search_frame; offset++) { - float correlation = 0.0; - float energy = 0.0; - - for (int j = 0; j < sola_buffer_frame; j++) { - correlation += sola_buffer[j] * search_window[j + offset]; - energy += search_window[j + offset] * search_window[j + offset]; - } + if (previous_tail.empty()) { + pcmlist.insert(pcmlist.end(), decoder_output.begin(), decoder_output.end()); + continue; + } - float normalized_correlation = (energy > 1e-8) ? correlation / std::sqrt(energy) : 0.0f; + int blend_size = std::min( + {fade_size, static_cast(previous_tail.size()), static_cast(decoder_output.size())}); - if (normalized_correlation > best_correlation) { - best_correlation = normalized_correlation; - best_offset = offset; - } + std::vector blended_region(blend_size); + for (int j = 0; j < blend_size; j++) { + blended_region[j] = previous_tail[j] * fade_out[j * fade_size / blend_size] + + decoder_output[j] * fade_in[j * fade_size / blend_size]; } - int aligned_start = audio_start + best_offset; - - std::vector crossfade_region(sola_buffer_frame); + pcmlist.insert(pcmlist.end(), blended_region.begin(), blended_region.end()); - for (int j = 0; j < sola_buffer_frame; j++) { - crossfade_region[j] = - decoder_output[aligned_start + j] * fade_in_window[j] + sola_buffer[j] * fade_out_window[j]; + if (static_cast(previous_tail.size()) > blend_size) { + pcmlist.insert(pcmlist.end(), previous_tail.begin() + blend_size, previous_tail.end()); } - pcmlist.insert(pcmlist.end(), crossfade_region.begin(), crossfade_region.end()); - - int remaining_start = aligned_start + sola_buffer_frame; + int current_remaining_start = blend_size; + int current_remaining_size = static_cast(decoder_output.size()) - current_remaining_start; if (i == dec_slice_num - 1) { - int total_expected_samples = audio_len * samples_per_frame / 512; - int processed_samples = static_cast(pcmlist.size()); - int remaining_needed = total_expected_samples - processed_samples; - remaining_needed = std::max(0, remaining_needed); + int total_expected = audio_len; + int current_total = static_cast(pcmlist.size()); + current_remaining_size = std::min(current_remaining_size, total_expected - current_total); + } - int remaining_len = - std::min(remaining_needed, static_cast(decoder_output.size() - remaining_start)); + if (current_remaining_size > overlap_size && i < dec_slice_num - 1) { + int main_part_size = current_remaining_size - overlap_size; - if (remaining_len > 0) { - pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start, - decoder_output.begin() + remaining_start + remaining_len); - } + pcmlist.insert(pcmlist.end(), decoder_output.begin() + current_remaining_start, + decoder_output.begin() + current_remaining_start + main_part_size); + previous_tail.assign(decoder_output.begin() + current_remaining_start + main_part_size, + decoder_output.begin() + current_remaining_start + current_remaining_size); } else { - int remaining_len = (dec_len - 2 * pad_frames) * samples_per_frame - sola_buffer_frame; - remaining_len = - std::min(remaining_len, static_cast(decoder_output.size() - remaining_start)); - - if (remaining_len > 0) { - pcmlist.insert(pcmlist.end(), decoder_output.begin() + remaining_start, - decoder_output.begin() + remaining_start + remaining_len); - } - - int buffer_start = remaining_start + remaining_len; - - if (buffer_start + sola_buffer_frame <= decoder_output.size()) { - std::copy(decoder_output.begin() + buffer_start, - decoder_output.begin() + buffer_start + sola_buffer_frame, sola_buffer.begin()); - } else { - int avail = static_cast(decoder_output.size() - buffer_start); - if (avail > 0) { - std::copy(decoder_output.begin() + buffer_start, decoder_output.end(), - sola_buffer.begin()); - } - std::fill(sola_buffer.begin() + avail, sola_buffer.end(), 0.0f); + if (current_remaining_size > 0) { + pcmlist.insert(pcmlist.end(), decoder_output.begin() + current_remaining_start, + decoder_output.begin() + current_remaining_start + current_remaining_size); } + previous_tail.clear(); } } + + if (static_cast(pcmlist.size()) >= audio_len) { + break; + } } - if (pcmlist.size() > audio_len) { + if (static_cast(pcmlist.size()) > audio_len) { pcmlist.resize(audio_len); } + float max_val = 0.0f; + int clip_count = 0; + for (float sample : pcmlist) { + max_val = std::max(max_val, std::abs(sample)); + if (std::abs(sample) > 0.95f) clip_count++; + } + double src_ratio = static_cast(mode_config_.audio_rate) / static_cast(mode_config_.mode_rate); - std::vector tmp_pcm((pcmlist.size() * src_ratio + 1)); + std::vector tmp_pcm(static_cast(pcmlist.size() * src_ratio + 1)); int len; - resample_audio(pcmlist.data(), pcmlist.size(), tmp_pcm.data(), &len, src_ratio); wav_pcm_data.reserve(len); - std::transform(tmp_pcm.begin(), tmp_pcm.begin() + len, std::back_inserter(wav_pcm_data), - [](const auto val) { return static_cast(val * INT16_MAX); }); + for (int i = 0; i < len; i++) { + float val = tmp_pcm[i]; + if (std::abs(val) > 0.95f) { + val = val > 0 ? 0.95f : -0.95f; + } + wav_pcm_data.push_back(static_cast(val * INT16_MAX)); + } if (out_callback_) { out_callback_( @@ -478,8 +418,10 @@ class llm_task { } } catch (const std::exception &e) { + SLOGE("Exception: %s", e.what()); return true; } catch (...) { + SLOGE("Unknown exception occurred"); return true; } return false; diff --git a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp index 1884c929..cf2c1fcb 100644 --- a/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp +++ b/projects/llm_framework/main_melotts/src/runner/Lexicon.hpp @@ -37,7 +37,7 @@ class Lexicon { const std::string& verbalizer_filename) : max_phrase_length(0) { - SLOGD("Dictionary loading: %s Pronunciation table loading: %s tagger_filename: %s verbalizer_filename: %s", + SLOGI("Dictionary loading: %s Pronunciation table loading: %s tagger_filename: %s verbalizer_filename: %s", tokens_filename.c_str(), lexicon_filename.c_str(), tagger_filename.c_str(), verbalizer_filename.c_str()); m_processor = new wetext::Processor(tagger_filename, verbalizer_filename); @@ -92,13 +92,13 @@ class Lexicon { lexicon["。"] = lexicon["."]; lexicon["!"] = lexicon["!"]; lexicon["?"] = lexicon["?"]; - SLOGD("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), + SLOGI("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), max_phrase_length); } Lexicon(const std::string& lexicon_filename, const std::string& tokens_filename) : max_phrase_length(0) { - SLOGD("Dictionary loading: %s Pronunciation table loading: %s", tokens_filename.c_str(), + SLOGI("Dictionary loading: %s Pronunciation table loading: %s", tokens_filename.c_str(), lexicon_filename.c_str()); std::unordered_map tokens; @@ -151,7 +151,7 @@ class Lexicon { lexicon["。"] = lexicon["."]; lexicon["!"] = lexicon["!"]; lexicon["?"] = lexicon["?"]; - SLOGD("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), + SLOGI("Dictionary loading complete, containing %zu entries, longest phrase length: %zu", lexicon.size(), max_phrase_length); } From c25b4f0a093fcca647b2851b3472c36e9510384f Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Wed, 23 Jul 2025 14:26:22 +0800 Subject: [PATCH 26/79] [fix] Fix caching causing audio issues --- .../llm_framework/main_melotts/src/runner/EngineWrapper.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/llm_framework/main_melotts/src/runner/EngineWrapper.cpp b/projects/llm_framework/main_melotts/src/runner/EngineWrapper.cpp index 0dda3e7b..0394ee48 100644 --- a/projects/llm_framework/main_melotts/src/runner/EngineWrapper.cpp +++ b/projects/llm_framework/main_melotts/src/runner/EngineWrapper.cpp @@ -292,7 +292,7 @@ int EngineWrapper::Init(const char* strModelPath, uint32_t nNpuType) // 6. prepare io // AX_U32 nIoDepth = (stCtx.vecOutputBufferFlag.size() == 0) ? 1 : stCtx.vecOutputBufferFlag.size(); - ret = utils::prepare_io(strModelPath, m_io_info, m_io, utils::IO_BUFFER_STRATEGY_CACHED); + ret = utils::prepare_io(strModelPath, m_io_info, m_io, utils::IO_BUFFER_STRATEGY_DEFAULT); if (0 != ret) { printf("prepare io failed!\n"); utils::free_io(m_io); From cfbfd62c4acf4f915b4498a173fb6c7ac0394a41 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Thu, 14 Aug 2025 09:15:49 +0800 Subject: [PATCH 27/79] [update] update docs --- doc/projects_llm_framework_doc/llm_kws_en.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/projects_llm_framework_doc/llm_kws_en.md b/doc/projects_llm_framework_doc/llm_kws_en.md index 4ab85b47..84b2a666 100644 --- a/doc/projects_llm_framework_doc/llm_kws_en.md +++ b/doc/projects_llm_framework_doc/llm_kws_en.md @@ -34,7 +34,7 @@ Send JSON: - response_format: The result returned is in `kws.bool` format. - input: The input is `sys.pcm`, representing system audio. - enoutput: Whether to enable user result output. -- kws: The Chinese wake-up word is `"你好你好"`. +- kws: The English wake-up word is `"HELLO"`. It must be capital letters. - enwake_audio: Whether to enable wake-up audio output. Default is true. Response JSON: From a916ca04d64746153da363dc966ebc6730468a1d Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Thu, 21 Aug 2025 19:08:20 +0800 Subject: [PATCH 28/79] [update] Reduce buffer frames --- projects/llm_framework/main_asr/src/main.cpp | 6 +++--- projects/llm_framework/main_kws/src/main.cpp | 4 ++-- projects/llm_framework/main_vad/src/main.cpp | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/projects/llm_framework/main_asr/src/main.cpp b/projects/llm_framework/main_asr/src/main.cpp index c3bd64f2..4a5f9703 100644 --- a/projects/llm_framework/main_asr/src/main.cpp +++ b/projects/llm_framework/main_asr/src/main.cpp @@ -58,7 +58,7 @@ class llm_task { std::atomic_bool audio_flage_; std::atomic_bool awake_flage_; int awake_delay_ = 50; - int delay_audio_frame_ = 100; + int delay_audio_frame_ = 10; buffer_t *pcmdata; std::function pause; @@ -183,11 +183,11 @@ class llm_task { { static int count = 0; if (count < delay_audio_frame_) { - buffer_write_char(pcmdata, raw.c_str(), raw.length()); + buffer_write_char(pcmdata, raw.data(), raw.length()); count++; return; } - buffer_write_char(pcmdata, raw.c_str(), raw.length()); + buffer_write_char(pcmdata, raw.data(), raw.length()); buffer_position_set(pcmdata, 0); count = 0; std::vector floatSamples; diff --git a/projects/llm_framework/main_kws/src/main.cpp b/projects/llm_framework/main_kws/src/main.cpp index 5b8a21bb..18e4dfd2 100644 --- a/projects/llm_framework/main_kws/src/main.cpp +++ b/projects/llm_framework/main_kws/src/main.cpp @@ -59,7 +59,7 @@ class llm_task { bool enwake_audio_; std::atomic_bool audio_flage_; task_callback_t out_callback_; - int delay_audio_frame_ = 100; + int delay_audio_frame_ = 10; buffer_t *pcmdata; std::string wake_wav_file_; @@ -233,7 +233,7 @@ class llm_task { count++; return; } - buffer_write_char(pcmdata, raw.c_str(), raw.length()); + buffer_write_char(pcmdata, raw.data(), raw.length()); buffer_position_set(pcmdata, 0); count = 0; std::vector floatSamples; diff --git a/projects/llm_framework/main_vad/src/main.cpp b/projects/llm_framework/main_vad/src/main.cpp index 3bd53f3f..20a67eb3 100644 --- a/projects/llm_framework/main_vad/src/main.cpp +++ b/projects/llm_framework/main_vad/src/main.cpp @@ -60,7 +60,7 @@ class llm_task { std::string superior_id_; task_callback_t out_callback_; int awake_delay_ = 50; - int delay_audio_frame_ = 100; + int delay_audio_frame_ = 3; buffer_t *pcmdata; std::string wake_wav_file_; @@ -158,7 +158,7 @@ class llm_task { count++; return; } - buffer_write_char(pcmdata, raw.c_str(), raw.length()); + buffer_write_char(pcmdata, raw.data(), raw.length()); buffer_position_set(pcmdata, 0); count = 0; std::vector floatSamples; From 73c4a49b05125c2484940bc15ce2e1948366f7c6 Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Fri, 22 Aug 2025 17:38:39 +0800 Subject: [PATCH 29/79] [update] ModuleLLM support ctx model, add HomeAssistant model, add model post process config. --- .../mode_qwen2.5-HA-0.5B-ctx-ax630c.json | 38 + .../tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py | 203 ++++ projects/llm_framework/main_llm/src/main.cpp | 148 ++- .../llm_framework/main_llm/src/runner/LLM.hpp | 1066 +++++++++++++++-- .../main_llm/src/runner/LLMPostprocess.hpp | 46 +- .../src/runner/Tokenizer/Tokenizer.cpp | 618 +++++----- .../src/runner/Tokenizer/Tokenizer.hpp | 90 +- .../ax_model_runner/ax_model_runner.hpp | 3 + .../main_llm/src/runner/utils/http_utils.hpp | 102 ++ projects/llm_framework/tools/llm_pack.py | 1 + 10 files changed, 1826 insertions(+), 489 deletions(-) create mode 100644 projects/llm_framework/main_llm/models/mode_qwen2.5-HA-0.5B-ctx-ax630c.json create mode 100644 projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py create mode 100644 projects/llm_framework/main_llm/src/runner/utils/http_utils.hpp diff --git a/projects/llm_framework/main_llm/models/mode_qwen2.5-HA-0.5B-ctx-ax630c.json b/projects/llm_framework/main_llm/models/mode_qwen2.5-HA-0.5B-ctx-ax630c.json new file mode 100644 index 00000000..1481d454 --- /dev/null +++ b/projects/llm_framework/main_llm/models/mode_qwen2.5-HA-0.5B-ctx-ax630c.json @@ -0,0 +1,38 @@ +{ + "mode":"qwen2.5-HA-0.5B-ctx-ax630c", + "type":"llm", + "homepage":"https://huggingface.co/yunyu1258/qwen2.5-0.5b-ha", + "compile_flage":"pulsar2 llm_build --input_path Qwen/qwen2.5-0.5b-ha --output_path Qwen/qwen2.5-0.5B-p1024-ha-ax630c --hidden_state_type bf16 --prefill_len 128 --kv_cache_len 1280 --last_kv_cache_len 128 --last_kv_cache_len 512 --last_kv_cache_len 1024 --chip AX620E --parallel 24", + "pulsar_version":"4.1-patch1-c37957c7", + "capabilities":[ + "text_generation", + "chat" + ], + "input_type":[ + "llm.utf-8", + "llm.utf-8.stream", + "llm.chat_completion", + "llm.chat_completion.stream" + ], + "output_type":[ + "llm.utf-8", + "llm.utf-8.stream" + ], + "mode_param":{ + "tokenizer_type":2, + "url_tokenizer_model":"http://localhost:8080", + "filename_tokens_embed":"model.embed_tokens.weight.bfloat16.bin", + "filename_post_axmodel":"qwen2_post.axmodel", + "template_filename_axmodel":"qwen2_p128_l%d_together.axmodel", + "b_use_topk":false, + "b_bos":false, + "b_eos":false, + "axmodel_num":24, + "tokens_embed_num":151936, + "tokens_embed_size":896, + "b_use_mmap_load_embed":true, + "b_dynamic_load_axmodel_layer":false, + "precompute_len":1202, + "ext_scripts":["tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py"] + } +} \ No newline at end of file diff --git a/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py new file mode 100644 index 00000000..4a63e88a --- /dev/null +++ b/projects/llm_framework/main_llm/scripts/tokenizer_qwen2.5-HA-0.5B-ctx-ax630c.py @@ -0,0 +1,203 @@ +from transformers import AutoTokenizer, PreTrainedTokenizerFast +from http.server import HTTPServer, BaseHTTPRequestHandler +import json +import argparse +import uuid + +# 全局字典:存储 uid 到 Tokenizer_Http 实例的映射 +tokenizers = {} + +class Tokenizer_Http(): + def __init__(self, model_id): + self.tokenizer = AutoTokenizer.from_pretrained(model_id) + self.messages = [ + {"role": "system", "content": "You are Qwen, created by Alibaba Cloud. You are a helpful assistant."}, + ] + self.token_ids = [] + + self.token_ids_cache = [] + + def encode(self, prompt, last_reply=None): + if last_reply is not None: + self.messages.append({"role": "assistant", "content": last_reply}) + text = self.tokenizer.apply_chat_template( + self.messages, + tokenize=False, + add_generation_prompt=True + ) + # print("生成的文本:\n============\n", text, "============\n") + self.token_ids = self.tokenizer.encode(text)[:-3] + self.messages.append({"role": "user", "content": prompt}) + + text = self.tokenizer.apply_chat_template( + self.messages, + tokenize=False, + add_generation_prompt=True + ) + print("生成的文本:\n============\n", text, "============\n") + token_ids = self.tokenizer.encode(text) + # 找出新增部分 + diff = token_ids[len(self.token_ids):] + self.token_ids = token_ids + print(self.decode(diff)) + return token_ids, diff + + def decode(self, token_ids): + self.token_ids_cache += token_ids + text = self.tokenizer.decode(self.token_ids_cache) + if "\ufffd" in text: + print("text 中包含非法字符") + return "" + else: + self.token_ids_cache.clear() + return text + + + @property + def bos_id(self): + return self.tokenizer.bos_token_id + + @property + def eos_id(self): + return self.tokenizer.eos_token_id + + @property + def bos_token(self): + return self.tokenizer.bos_token + + @property + def eos_token(self): + return self.tokenizer.eos_token + + def reset(self, system_prompt=None): + if system_prompt is None: + system_prompt = args.content + self.messages = [ + {"role": "system", "content": system_prompt}, + ] + text = self.tokenizer.apply_chat_template( + self.messages, + tokenize=False, + add_generation_prompt=True + ) + token_ids = self.tokenizer.encode(text)[:-3] + self.token_ids = token_ids + print(self.decode(token_ids)) + return token_ids + + +class Request(BaseHTTPRequestHandler): + timeout = 5 + server_version = 'Apache' + + def do_GET(self): + print("GET 请求路径:", self.path) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + + # 新增接口:获取 uid + if '/get_uid' in self.path: + new_uid = str(uuid.uuid4()) + print("新 uid:", new_uid) + # 为该 uid 创建一个新的 Tokenizer_Http 实例 + tokenizers[new_uid] = Tokenizer_Http(args.model_id) + msg = json.dumps({'uid': new_uid}) + elif '/bos_id' in self.path: + # 获取 uid 参数(例如 ?uid=xxx) + uid = self.get_query_param("uid") + instance: Tokenizer_Http = tokenizers.get(uid) + if instance is None: + msg = json.dumps({'error': 'Invalid uid'}) + else: + bos_id = instance.bos_id + msg = json.dumps({'bos_id': bos_id if bos_id is not None else -1}) + elif '/eos_id' in self.path: + uid = self.get_query_param("uid") + instance: Tokenizer_Http = tokenizers.get(uid) + if instance is None: + msg = json.dumps({'error': 'Invalid uid'}) + else: + eos_id = instance.eos_id + msg = json.dumps({'eos_id': eos_id if eos_id is not None else -1}) + else: + msg = json.dumps({'error': 'Invalid GET endpoint'}) + + print("响应消息:", msg) + self.wfile.write(msg.encode()) + + def do_POST(self): + content_length = int(self.headers.get('content-length', 0)) + data = self.rfile.read(content_length).decode() + print("POST 请求路径:", self.path) + print("接收到的数据:", data) + req = json.loads(data) + + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + + if '/encode' in self.path: + # 请求数据中必须包含 uid, text, 和可选的 last_reply + uid = req.get('uid') + prompt = req.get('text') + last_reply = req.get('last_reply') + instance: Tokenizer_Http = tokenizers.get(uid) + if instance is None: + msg = json.dumps({'error': 'Invalid uid'}) + else: + token_ids, diff = instance.encode(prompt, last_reply) + msg = json.dumps({'token_ids': token_ids, 'diff': diff}) + elif '/decode' in self.path: + uid = req.get('uid') + token_ids = req.get('token_ids') + instance: Tokenizer_Http = tokenizers.get(uid) + if instance is None: + msg = json.dumps({'error': 'Invalid uid'}) + else: + text = instance.decode(token_ids) + msg = json.dumps({'text': text}) + elif '/reset' in self.path: + uid = req.get("uid") + system_prompt = req.get("system_prompt") + instance: Tokenizer_Http = tokenizers.get(uid) + if instance is None: + msg = json.dumps({'error': 'Invalid uid'}) + else: + if system_prompt is not None: + print("system_prompt:", system_prompt) + token_ids = instance.reset(system_prompt) + msg = json.dumps({'token_ids': token_ids}) + else: + token_ids = instance.reset() + msg = json.dumps({'token_ids': token_ids}) + else: + msg = json.dumps({'error': 'Invalid POST endpoint'}) + + print("响应消息:", msg) + self.wfile.write(msg.encode()) + + def get_query_param(self, key): + """ + 辅助函数:从 GET 请求的 URL 中获取查询参数的值 + 例如:/bos_id?uid=xxx + """ + from urllib.parse import urlparse, parse_qs + query = urlparse(self.path).query + params = parse_qs(query) + values = params.get(key) + return values[0] if values else None + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--host', type=str, default='0.0.0.0') + parser.add_argument('--port', type=int, default=12345) + parser.add_argument('--model_id', type=str, default='qwen3_1.7B_tokenizer') + parser.add_argument('--content', type=str, default='You are Qwen, created by Alibaba Cloud. You are a helpful assistant.') + + args = parser.parse_args() + + host = (args.host, args.port) + print('Server running at http://%s:%s' % host) + server = HTTPServer(host, Request) + server.serve_forever() diff --git a/projects/llm_framework/main_llm/src/main.cpp b/projects/llm_framework/main_llm/src/main.cpp index 7cbb9fbe..e198cc9c 100644 --- a/projects/llm_framework/main_llm/src/main.cpp +++ b/projects/llm_framework/main_llm/src/main.cpp @@ -55,10 +55,17 @@ class llm_task { enum inference_status { INFERENCE_NONE = 0, INFERENCE_RUNNING }; LLMAttrType mode_config_; std::unique_ptr lLaMa_; + std::unique_ptr lLaMa_ctx_; std::string model_; std::string response_format_; std::vector inputs_; std::string prompt_; + std::string last_reply; + std::vector prompt_data; + std::vector tokens_ids, tokens_diff; + std::vector> k_caches, v_caches; + int precompute_len = 0; + std::vector _token_ids; task_callback_t out_callback_; bool enoutput_; bool enstream_; @@ -125,10 +132,10 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], tokenizer_type); CONFIG_AUTO_SET(file_body["mode_param"], filename_tokenizer_model); + CONFIG_AUTO_SET(file_body["mode_param"], url_tokenizer_model); CONFIG_AUTO_SET(file_body["mode_param"], filename_tokens_embed); CONFIG_AUTO_SET(file_body["mode_param"], filename_post_axmodel); CONFIG_AUTO_SET(file_body["mode_param"], template_filename_axmodel); - CONFIG_AUTO_SET(file_body["mode_param"], b_use_topk); CONFIG_AUTO_SET(file_body["mode_param"], b_bos); CONFIG_AUTO_SET(file_body["mode_param"], b_eos); CONFIG_AUTO_SET(file_body["mode_param"], axmodel_num); @@ -137,61 +144,119 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], b_use_mmap_load_embed); CONFIG_AUTO_SET(file_body["mode_param"], b_dynamic_load_axmodel_layer); CONFIG_AUTO_SET(file_body["mode_param"], max_token_len); + CONFIG_AUTO_SET(file_body["mode_param"], enable_temperature); CONFIG_AUTO_SET(file_body["mode_param"], temperature); + CONFIG_AUTO_SET(file_body["mode_param"], enable_top_p_sampling); CONFIG_AUTO_SET(file_body["mode_param"], top_p); + CONFIG_AUTO_SET(file_body["mode_param"], enable_top_k_sampling); + CONFIG_AUTO_SET(file_body["mode_param"], top_k); + CONFIG_AUTO_SET(file_body["mode_param"], enable_repetition_penalty); + CONFIG_AUTO_SET(file_body["mode_param"], repetition_penalty); + CONFIG_AUTO_SET(file_body["mode_param"], penalty_window); + CONFIG_AUTO_SET(file_body["mode_param"], precompute_len); + { + auto has_http = [](const std::string &s) { return s.find("http") != std::string::npos; }; + + auto find_tokenizer_file = [this]() -> std::string { + const std::string base = "/opt/m5stack/scripts/"; + const std::string a = base + model_ + "_tokenizer.py"; + if (file_exists(a)) return a; + const std::string b = base + "tokenizer_" + model_ + ".py"; + if (file_exists(b)) return b; + SLOGE("%s or %s not found!", a.c_str(), b.c_str()); + return {}; + }; + + auto start_tokenizer_server = [&](const std::string &tokenizer_file) { + if (tokenizer_file.empty()) return; + if (tokenizer_server_flage_.load()) return; - if (mode_config_.filename_tokenizer_model.find("http:") != std::string::npos) { - mode_config_.filename_tokenizer_model = "http://localhost:" + std::to_string(port_); - std::string tokenizer_file; - if (file_exists(std::string("/opt/m5stack/scripts/") + model_ + std::string("_tokenizer.py"))) { - tokenizer_file = std::string("/opt/m5stack/scripts/") + model_ + std::string("_tokenizer.py"); - } else if (file_exists(std::string("/opt/m5stack/scripts/") + std::string("tokenizer_") + model_ + - std::string(".py"))) { - tokenizer_file = - std::string("/opt/m5stack/scripts/") + std::string("tokenizer_") + model_ + std::string(".py"); - } else { - std::string __log = model_ + std::string("_tokenizer.py"); - __log += " or "; - __log += std::string("tokenizer_") + model_ + std::string(".py"); - __log += " not found!"; - SLOGE("%s", __log.c_str()); - } - if (!tokenizer_server_flage_.load()) { tokenizer_pid_ = fork(); if (tokenizer_pid_ == 0) { setenv("PYTHONPATH", "/opt/m5stack/lib/llm/site-packages", 1); + const std::string port_str = std::to_string(port_); + const std::string model_id = base_model + "tokenizer"; + execl("/usr/bin/python3", "python3", tokenizer_file.c_str(), "--host", "localhost", "--port", - std::to_string(port_).c_str(), "--model_id", (base_model + "tokenizer").c_str(), - "--content", ("'" + prompt_ + "'").c_str(), nullptr); + port_str.c_str(), "--model_id", model_id.c_str(), "--content", prompt_.c_str(), + (char *)nullptr); + perror("execl failed"); - exit(1); + _exit(1); } + tokenizer_server_flage_.store(true); SLOGI("port_=%s model_id=%s content=%s", std::to_string(port_).c_str(), - (base_model + "tokenizer").c_str(), ("'" + prompt_ + "'").c_str()); + (base_model + std::string("tokenizer")).c_str(), prompt_.c_str()); + std::this_thread::sleep_for(std::chrono::seconds(15)); + }; + + auto process_field = [&](std::string &field, const char *name_for_log) -> bool { + if (!has_http(field)) return false; + + field = "http://localhost:" + std::to_string(port_); + const std::string tokenizer_file = find_tokenizer_file(); + start_tokenizer_server(tokenizer_file); + SLOGI("%s: %s", name_for_log, field.c_str()); + return true; + }; + + if (!process_field(mode_config_.filename_tokenizer_model, "filename_tokenizer_model") && + !process_field(mode_config_.url_tokenizer_model, "url_tokenizer_model")) { + mode_config_.filename_tokenizer_model = base_model + mode_config_.filename_tokenizer_model; + SLOGE("filename_tokenizer_model: %s", mode_config_.filename_tokenizer_model.c_str()); } - } else { - mode_config_.filename_tokenizer_model = base_model + mode_config_.filename_tokenizer_model; } - SLOGI("filename_tokenizer_model: %s", mode_config_.filename_tokenizer_model.c_str()); mode_config_.filename_tokens_embed = base_model + mode_config_.filename_tokens_embed; mode_config_.filename_post_axmodel = base_model + mode_config_.filename_post_axmodel; mode_config_.template_filename_axmodel = base_model + mode_config_.template_filename_axmodel; - mode_config_.runing_callback = [this](int *p_token, int n_token, const char *p_str, float token_per_sec, void *reserve) { if (this->out_callback_) { this->out_callback_(std::string(p_str), false); } }; - lLaMa_ = std::make_unique(); - if (!lLaMa_->Init(mode_config_)) { - lLaMa_->Deinit(); - lLaMa_.reset(); - return -2; + + if (mode_config_.precompute_len > 0) { + lLaMa_ctx_ = std::make_unique(); + if (!lLaMa_ctx_->Init(mode_config_)) { + lLaMa_ctx_->Deinit(); + lLaMa_ctx_.reset(); + return -2; + } + } else { + lLaMa_ = std::make_unique(); + if (!lLaMa_->Init(mode_config_)) { + lLaMa_->Deinit(); + lLaMa_.reset(); + return -2; + } } + if (lLaMa_ctx_) { + lLaMa_ctx_->SetSystemPrompt(mode_config_.system_prompt, _token_ids); + std::string kvcache_path = "/tmp/.llm/"; + if (!kvcache_path.empty() && kvcache_path != "") { + if (lLaMa_ctx_->load_kvcache(kvcache_path, mode_config_.axmodel_num, k_caches, v_caches, + mode_config_.system_prompt, precompute_len)) { + ALOGI("load kvcache from path: %s success,precompute_len: %d", kvcache_path.c_str(), + precompute_len); + } else { + ALOGW("load kvcache from path: %s failed,generate kvcache", kvcache_path.c_str()); + lLaMa_ctx_->GenerateKVCachePrefill(_token_ids, k_caches, v_caches, precompute_len); + if (!lLaMa_ctx_->save_kvcache(kvcache_path, mode_config_.system_prompt, precompute_len, + k_caches, v_caches)) { + ALOGE("save kvcache failed"); + } + ALOGI("generate kvcache to path: %s", kvcache_path.c_str()); + } + } else { + lLaMa_ctx_->GenerateKVCachePrefill(_token_ids, k_caches, v_caches, precompute_len); + } + ALOGI("precompute_len: %d", precompute_len); + ALOGI("system_prompt: %s", mode_config_.system_prompt.c_str()); + } } catch (...) { SLOGE("config false"); return -3; @@ -253,8 +318,25 @@ class llm_task { { #if 1 try { - std::string out = lLaMa_->Run(prompt_complete(msg)); - if (out_callback_) out_callback_(out, true); + if (lLaMa_) { + std::string out = lLaMa_->Run(prompt_complete(msg)); + if (out_callback_) out_callback_(out, true); + } + + if (lLaMa_ctx_) { + lLaMa_ctx_->Encode(prompt_data, prompt_complete(msg), last_reply, tokens_ids, tokens_diff); + if (auto ret = lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size()); + ret != 0) { + ALOGE("SetKVCache failed: %d,the context may be full,input \"reset\" to reset context", ret); + // raise; + lLaMa_ctx_->SetSystemPrompt(mode_config_.system_prompt, _token_ids); + lLaMa_ctx_->GenerateKVCachePrefill(_token_ids, k_caches, v_caches, precompute_len); + lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size()); + } + last_reply = lLaMa_ctx_->Run(prompt_data); + lLaMa_ctx_->GetKVCache(k_caches, v_caches, precompute_len); + if (out_callback_) out_callback_(last_reply, true); + } } catch (...) { SLOGW("lLaMa_->Run have error!"); } diff --git a/projects/llm_framework/main_llm/src/runner/LLM.hpp b/projects/llm_framework/main_llm/src/runner/LLM.hpp index 7d700a53..5639fb89 100644 --- a/projects/llm_framework/main_llm/src/runner/LLM.hpp +++ b/projects/llm_framework/main_llm/src/runner/LLM.hpp @@ -15,48 +15,57 @@ #include #include +#define ALIGN_DOWN(x, a) ((x) & ~((a) - 1)) // typedef void (*LLMRuningCallback)(int *p_token, int n_token, const char *p_str, float token_per_sec, void *reserve); typedef std::function LLMRuningCallback; struct LLMAttrType { + std::string system_prompt; + std::string template_filename_axmodel = "tinyllama-int8/tinyllama_l%d.axmodel"; + std::string post_config_path = "post_config.json"; int axmodel_num = 22; - // std::string template_prefill_filename_axmodel = "minicpmv/prefill_axmodel/minicpm_p96_l%d.axmodel"; - // int prefill_axmodel_num = 40; - int prefill_token_num = 96; // auto calc + int prefill_token_num = 96; + int prefill_max_token_num = 512; std::string filename_post_axmodel = "tinyllama-int8/tinyllama_post.axmodel"; - bool b_use_topk = false; - - // std::string filename_vpm_resampler_axmodedl = "minicpmv/vpm_resampler_version0_fp16.axmodel"; - // int vpm_width = 280; - // int vpm_height = 280; - TokenizerType tokenizer_type = TKT_LLaMa; std::string filename_tokenizer_model = "tokenizer.model"; - bool b_bos = true, b_eos = false; + std::string url_tokenizer_model; + bool b_bos = true; + bool b_eos = false; std::string filename_tokens_embed = "tinyllama.model.embed_tokens.weight.bfloat16.bin"; int tokens_embed_num = 32000; int tokens_embed_size = 2048; - int max_token_len = 127; // auto calc + int max_token_len = 127; + int kv_cache_num = 1024; + int kv_cache_size = 256; - int kv_cache_num = 1024; // auto calc - int kv_cache_size = 256; // auto calc + int precompute_len = 0; + std::vector prefill_max_kv_cache_num_grp; + int prefill_grpid = -1; - float temperature = 0.7f; - float top_p = 0.9f; - bool b_use_mmap_load_embed = false; - bool b_dynamic_load_axmodel_layer = false; + bool enable_temperature = false; + float temperature = 0.7f; + + bool enable_top_p_sampling = false; + float top_p = 0.7f; - bool b_use_mmap_load_layer = true; + bool enable_top_k_sampling = false; + int top_k = 50; - std::string post_config_path = "post_config.json"; + bool enable_repetition_penalty = false; + float repetition_penalty = 1.2f; + int penalty_window = 50; + + bool b_use_mmap_load_embed = false; + bool b_dynamic_load_axmodel_layer = false; + bool b_use_mmap_load_layer = true; - // bool b_live_print = true; LLMRuningCallback runing_callback = nullptr; void *reserve = nullptr; }; @@ -97,28 +106,7 @@ class LLM { logits[i] = *reinterpret_cast(&proc); } - // postprocess.set_temperature(true, 0.9f); - // // postprocess.set_repetition_penalty(true, 1.1f); - // postprocess.set_top_k_sampling(true, 10); - // // postprocess.set_top_p_sampling(true, 0.9f); - return postprocess.apply(logits, history); - - // float max_val = -MAXFLOAT; - // int max_index = 0; - // for (int i = 0; i < n; i++) - // { - // unsigned int proc = p[i] << 16; - // float tmp = *reinterpret_cast(&proc); - // if (tmp > max_val) - // { - // max_val = tmp; - // max_index = i; - // } - // } - // if (val) - // *val = max_val; - // return max_index; } public: @@ -133,17 +121,6 @@ class LLM { return false; } update_cqdm(&cqdm, 0, "count", "tokenizer init ok"); - // test code - // { - // std::vector output; - // tokenizer.Encode("Today is National", output); - // // print output - // for (size_t i = 0; i < output.size(); i++) - // { - // printf("%d ", output[i]); - // } - // printf("\n"); - // } if (!embed_selector.Init(attr.filename_tokens_embed, attr.tokens_embed_num, attr.tokens_embed_size, attr.b_use_mmap_load_embed)) { @@ -152,20 +129,7 @@ class LLM { return false; } update_cqdm(&cqdm, 1, "count", "embed_selector init ok"); - // test code - // { - // std::vector embed = embed_selector.getByIndex(123); - // printf("embed size: %d\n", embed.size()); - // for (int i = 0; i < embed.size(); i++) - // { - // bfloat16 bf16 = bfloat16(embed[i]); - // float val = bf16; - // printf("%d %0.22f\n", embed[i], val); - // } - // } - llama_layers.resize(attr.axmodel_num); - // prefill_layers.resize(attr.prefill_axmodel_num); char axmodel_path[1024]; for (int i = 0; i < attr.axmodel_num; i++) { @@ -227,8 +191,6 @@ class LLM { int max_token_len = llama_layers[0].layer.get_input("mask").nSize / sizeof(unsigned short) - 1; _attr.max_token_len = max_token_len > _attr.max_token_len ? _attr.max_token_len : max_token_len; ALOGI("max_token_len : %d", _attr.max_token_len); - // auto &input_k_cache = llama_layers[0].layer.get_input("K_cache"); - // auto &output_k_cache_out = llama_layers[0].layer.get_output("K_cache_out"); _attr.kv_cache_size = llama_layers[0].layer.get_output("K_cache_out").nSize / sizeof(unsigned short); _attr.kv_cache_num = llama_layers[0].layer.get_input("K_cache").nSize / _attr.kv_cache_size / sizeof(unsigned short); @@ -245,11 +207,29 @@ class LLM { auto &layer = llama_layers[0]; layer.layer.deinit(); } + nlohmann::json dynamic_config; + + dynamic_config["enable_temperature"] = _attr.enable_temperature; + dynamic_config["temperature"] = _attr.temperature; + + dynamic_config["enable_repetition_penalty"] = _attr.enable_repetition_penalty; + dynamic_config["repetition_penalty"] = _attr.repetition_penalty; + dynamic_config["penalty_window"] = _attr.penalty_window; + + dynamic_config["enable_top_p_sampling"] = _attr.enable_top_p_sampling; + dynamic_config["top_p"] = _attr.top_p; + + dynamic_config["enable_top_k_sampling"] = _attr.enable_top_k_sampling; + dynamic_config["top_k"] = _attr.top_k; if (!postprocess.load_config(attr.post_config_path)) { ALOGW("load postprocess config(%s) failed", attr.post_config_path.c_str()); } + if (!postprocess.load_config(dynamic_config)) { + ALOGW("load postprocess config(%s) failed", dynamic_config.dump(4).c_str()); + } + // Reset(); ALOGI("LLM init ok"); return true; @@ -269,12 +249,6 @@ class LLM { embed_selector.Deinit(); } - // void Reset() - // { - // k_caches.resize(_attr.axmodel_num, std::vector(_attr.kv_cache_num * _attr.kv_cache_size, 0)); - // v_caches.resize(_attr.axmodel_num, std::vector(_attr.kv_cache_num * _attr.kv_cache_size, 0)); - // } - void Stop() { b_stop = true; @@ -282,9 +256,11 @@ class LLM { int Encode(std::vector &out_embed, std::string prompt = "What is in the image?") { - std::vector input_ids = tokenizer->Encode(prompt, true); + ImageInfo img_info; + img_info.img_prompt = false; + std::vector input_ids = tokenizer->Encode(prompt, img_info); if (input_ids.size() > _attr.prefill_token_num) { - ALOGE("input_ids(%d) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num); + ALOGE("input_ids(%ld) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num); return -1; } out_embed.resize(input_ids.size() * _attr.tokens_embed_size); @@ -306,7 +282,7 @@ class LLM { return Run(test_embed); } - std::string Run(std::vector test_embed) + std::string Run(std::vector &test_embed) { b_stop = false; std::string final_out; @@ -391,20 +367,8 @@ class LLM { if (_attr.b_dynamic_load_axmodel_layer) { layer.layer.deinit(); } - // ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), - // bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32()); } - // ALOGI("prefill time cost: %.2f s", t_cost.cost() / 1000); - - // print token_ids - // printf("%s\n", input_str.c_str()); - // for (size_t i = 0; i < token_ids.size(); i++) - // { - // printf("%d ", token_ids[i]); - // } - // printf("\n"); - int next_token = -1; t_cqdm cqdm = create_cqdm(_attr.max_token_len, 32); std::vector embed(_attr.tokens_embed_size, 0); @@ -417,19 +381,15 @@ class LLM { auto &input = llama_post.get_input("input"); memcpy(input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); llama_post.inference(); + int max_index; - if (_attr.b_use_topk) { - AX_SYS_MinvalidateCache(llama_post.get_output("indices").phyAddr, - llama_post.get_output("indices").pVirAddr, - llama_post.get_output("indices").nSize); - max_index = *(int *)llama_post.get_output("indices").pVirAddr; - } else { - auto &output_post = llama_post.get_output("output"); - AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); - unsigned short *post_out = (unsigned short *)output_post.pVirAddr; - float max_val = -MAXFLOAT; - max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, &max_val); - } + + auto &output_post = llama_post.get_output("output"); + AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); + unsigned short *post_out = (unsigned short *)output_post.pVirAddr; + float max_val = -MAXFLOAT; + max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, &max_val); + next_token = max_index; token_ids.push_back(max_index); @@ -513,18 +473,13 @@ class LLM { memcpy(input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); llama_post.inference(); int max_index; - if (_attr.b_use_topk) { - AX_SYS_MinvalidateCache(llama_post.get_output("indices").phyAddr, - llama_post.get_output("indices").pVirAddr, - llama_post.get_output("indices").nSize); - max_index = *(int *)llama_post.get_output("indices").pVirAddr; - } else { - auto &output_post = llama_post.get_output("output"); - AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); - unsigned short *post_out = (unsigned short *)output_post.pVirAddr; - float max_val = -MAXFLOAT; - max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, &max_val); - } + + auto &output_post = llama_post.get_output("output"); + AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); + unsigned short *post_out = (unsigned short *)output_post.pVirAddr; + float max_val = -MAXFLOAT; + max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, &max_val); + next_token = max_index; if (tokenizer->isEnd(max_index)) { @@ -574,3 +529,890 @@ class LLM { return final_out; } }; + +class LLM_CTX { +private: + std::shared_ptr tokenizer; + LLaMaEmbedSelector embed_selector; + + LLMAttrType _attr; + + struct LLMLayer { + ax_runner_ax650 layer; + std::string filename; + MMap layer_buffer; + std::vector layer_buffer_vec; + }; + + std::vector llama_layers; + ax_runner_ax650 llama_post; + + // + int decode_grpid = 0; + + // ax_runner_ax650 vpm_resampler; + + // std::vector> k_caches, v_caches; + + bool b_stop = false; + + LLMPostprocess postprocess; + static int post_process(LLMPostprocess &postprocess, unsigned short *p, int n, std::vector &history, + float *val = 0) + { + std::vector logits(n); + for (int i = 0; i < n; i++) { + unsigned int proc = p[i] << 16; + logits[i] = *reinterpret_cast(&proc); + } + + return postprocess.apply(logits, history); + } + +public: + bool Init(LLMAttrType attr) + { + ALOGI("LLM init start"); + t_cqdm cqdm = create_cqdm(attr.axmodel_num + 3, 32); + this->_attr = attr; + tokenizer = CreateTokenizer(attr.tokenizer_type); + if (!tokenizer->Init(attr.url_tokenizer_model)) { + ALOGE("tokenizer.Init(%s) failed", attr.url_tokenizer_model.c_str()); + return false; + } + std::vector _token_ids; + tokenizer->Reset(attr.system_prompt, _token_ids); + update_cqdm(&cqdm, 0, "count", "tokenizer init ok"); + + if (!embed_selector.Init(attr.filename_tokens_embed, attr.tokens_embed_num, attr.tokens_embed_size, + attr.b_use_mmap_load_embed)) { + ALOGE("embed_selector.Init(%s, %d, %d) failed", attr.filename_tokens_embed.c_str(), attr.tokens_embed_num, + attr.tokens_embed_size); + return false; + } + update_cqdm(&cqdm, 1, "count", "embed_selector init ok"); + + llama_layers.resize(attr.axmodel_num); + // prefill_layers.resize(attr.prefill_axmodel_num); + + char axmodel_path[1024]; + for (int i = 0; i < attr.axmodel_num; i++) { + sprintf(axmodel_path, attr.template_filename_axmodel.c_str(), i); + llama_layers[i].filename = axmodel_path; + + int ret = llama_layers[i].layer.init(llama_layers[i].filename.c_str(), false); + if (ret != 0) { + ALOGE("init axmodel(%s) failed", llama_layers[i].filename.c_str()); + return false; + } + int remain_cmm = get_remaining_cmm_size(); + sprintf(axmodel_path, "init %d axmodel ok,remain_cmm(%d MB)", i, remain_cmm); + update_cqdm(&cqdm, i + 2, "count", axmodel_path); + } + + int ret = llama_post.init(attr.filename_post_axmodel.c_str(), false); + if (ret != 0) { + ALOGE("init post axmodel(%s) failed", attr.filename_post_axmodel.c_str()); + return false; + } + int remain_cmm = get_remaining_cmm_size(); + sprintf(axmodel_path, "init post axmodel ok,remain_cmm(%d MB)", remain_cmm); + update_cqdm(&cqdm, attr.axmodel_num + 2, "count", axmodel_path); + + { + _attr.max_token_len = llama_layers[0].layer.get_input("mask").nSize / sizeof(unsigned short) - 1; + ALOGI("max_token_len : %d", _attr.max_token_len); + // auto &input_k_cache = llama_layers[0].layer.get_input("K_cache"); + // auto &output_k_cache_out = llama_layers[0].layer.get_output("K_cache_out"); + _attr.kv_cache_size = llama_layers[0].layer.get_output("K_cache_out").nSize / sizeof(unsigned short); + _attr.kv_cache_num = + llama_layers[0].layer.get_input("K_cache").nSize / _attr.kv_cache_size / sizeof(unsigned short); + ALOGI("kv_cache_size : %d, kv_cache_num: %d", _attr.kv_cache_size, _attr.kv_cache_num); + if (_attr.max_token_len > _attr.kv_cache_num) { + ALOGE("max_token_len(%d) > kv_cache_num(%d)", _attr.max_token_len, _attr.kv_cache_num); + return false; + } + + _attr.prefill_token_num = llama_layers[0].layer.get_input(1, "indices").vShape[1]; + ALOGI("prefill_token_num : %d", _attr.prefill_token_num); + for (size_t i = 0; i < llama_layers[0].layer.get_num_input_groups() - 1; i++) { + int prefill_max_kv_cache_num = llama_layers[0].layer.get_input(i + 1, "K_cache").vShape[1]; + ALOGI("grp: %d, prefill_max_token_num : %d", i + 1, prefill_max_kv_cache_num); + _attr.prefill_max_kv_cache_num_grp.push_back(prefill_max_kv_cache_num); + } + _attr.prefill_max_token_num = + _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1]; + ALOGI("prefill_max_token_num : %d", _attr.prefill_max_token_num); + } + + nlohmann::json dynamic_config; + + dynamic_config["enable_temperature"] = _attr.enable_temperature; + dynamic_config["temperature"] = _attr.temperature; + + dynamic_config["enable_repetition_penalty"] = _attr.enable_repetition_penalty; + dynamic_config["repetition_penalty"] = _attr.repetition_penalty; + dynamic_config["penalty_window"] = _attr.penalty_window; + + dynamic_config["enable_top_p_sampling"] = _attr.enable_top_p_sampling; + dynamic_config["top_p"] = _attr.top_p; + + dynamic_config["enable_top_k_sampling"] = _attr.enable_top_k_sampling; + dynamic_config["top_k"] = _attr.top_k; + + if (!postprocess.load_config(attr.post_config_path)) { + ALOGW("load postprocess config(%s) failed", attr.post_config_path.c_str()); + } + + if (!postprocess.load_config(dynamic_config)) { + ALOGW("load postprocess config(%s) failed", dynamic_config.dump(4).c_str()); + } + + // Reset(); + ALOGI("LLM init ok"); + return true; + } + + LLMAttrType *getAttr() + { + return &_attr; + } + + LLMPostprocess *getPostprocess() + { + return &postprocess; + } + + void Deinit() + { + for (int i = 0; i < _attr.axmodel_num; i++) { + llama_layers[i].layer.release(); + } + llama_post.release(); + embed_selector.Deinit(); + } + + void Stop() + { + b_stop = true; + } + + int SetSystemPrompt(std::string system_prompt, std::vector &_token_ids) + { + tokenizer->Reset(system_prompt, _token_ids); + _attr.system_prompt = system_prompt; + _attr.prefill_max_token_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1]; + return 0; + } + + int GenerateKVCachePrefill(std::vector &_token_ids, std::vector> &k_caches, + std::vector> &v_caches, int &precompute_len) + { + bfloat16 bf16 = -65536.f; + int input_embed_num = _token_ids.size(); + precompute_len = _token_ids.size(); + + k_caches.resize(_attr.axmodel_num); + v_caches.resize(_attr.axmodel_num); + int prefill_split_num = ceil((double)input_embed_num / _attr.prefill_token_num); + + int prefill_grpid = _attr.prefill_max_kv_cache_num_grp.size(); + + for (size_t i = 0; i < _attr.prefill_max_kv_cache_num_grp.size(); i++) { + if (input_embed_num <= _attr.prefill_max_kv_cache_num_grp[i]) { + prefill_grpid = i + 1; + break; + } + } + ALOGI("input token num : %d, prefill_split_num : %d prefill_grpid : %d", input_embed_num, prefill_split_num, + prefill_grpid); + + // clear kv cache + for (size_t i = 0; i < _attr.axmodel_num; i++) { + memset((void *)llama_layers[i].layer.get_input(prefill_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(prefill_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(prefill_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(prefill_grpid, "V_cache").nSize); + } + + if (input_embed_num == 0) { + for (size_t i = 0; i < _attr.axmodel_num; i++) { + k_caches[i].resize(precompute_len * _attr.kv_cache_size); + v_caches[i].resize(precompute_len * _attr.kv_cache_size); + } + ALOGI("input token num is 0, skip"); + return 0; + } + + int kv_cache_num = _attr.prefill_max_kv_cache_num_grp[prefill_grpid - 1]; + + std::vector test_embed; + test_embed.resize(_token_ids.size() * _attr.tokens_embed_size); + + for (size_t i = 0; i < _token_ids.size(); i++) { + embed_selector.getByIndex(_token_ids[i], test_embed.data() + i * _attr.tokens_embed_size); + } + + for (size_t p = 0; p < prefill_split_num; p++) { + std::vector mask_tmp; + mask_tmp.resize(1 * _attr.prefill_token_num * (kv_cache_num + _attr.prefill_token_num), bf16.data); + int input_num_token = _attr.prefill_token_num; + if (p == prefill_split_num - 1) { + input_num_token = input_embed_num - p * _attr.prefill_token_num; + } + + ALOGI("input_num_token:%d", input_num_token); + for (size_t i = 0; i < _attr.prefill_token_num; i++) { + if (i < input_num_token) { + int mask_current_start = kv_cache_num; + auto mask_ptr = mask_tmp.data() + i * (kv_cache_num + _attr.prefill_token_num); + + for (int j = 0; j < p * _attr.prefill_token_num; j++) { + mask_ptr[j] = 0; + } + + for (int j = mask_current_start; j < mask_current_start + i + 1; j++) { + mask_ptr[j] = 0; + } + } + } + + std::vector embed_tmp(_attr.prefill_token_num * _attr.tokens_embed_size, 0); + if (p == (prefill_split_num - 1)) { + memcpy( + embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + (input_embed_num - p * _attr.prefill_token_num) * _attr.tokens_embed_size * sizeof(unsigned short)); + } else { + memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + _attr.prefill_token_num * _attr.tokens_embed_size * sizeof(unsigned short)); + } + + for (unsigned int m = 0; m < _attr.axmodel_num; m++) { + auto &layer = llama_layers[m]; + // set indices + auto &input_indices = layer.layer.get_input(prefill_grpid, "indices"); + unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr; + memset(input_indices_ptr, 0, input_indices.nSize); + int idx = 0; + for (unsigned int i = p * _attr.prefill_token_num; i < (p + 1) * _attr.prefill_token_num; i++) { + input_indices_ptr[idx] = i; + idx++; + } + + // set mask + auto &input_mask = layer.layer.get_input(prefill_grpid, "mask"); + memcpy((void *)input_mask.pVirAddr, (void *)mask_tmp.data(), mask_tmp.size() * sizeof(unsigned short)); + + auto &input_input = layer.layer.get_input(prefill_grpid, "input"); + memcpy((void *)input_input.pVirAddr, embed_tmp.data(), embed_tmp.size() * sizeof(unsigned short)); + + layer.layer.inference(prefill_grpid); + + auto &input_decoder_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + auto &input_decoder_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + + auto &input_prefill_k_cache = layer.layer.get_input(prefill_grpid, "K_cache"); + auto &input_prefill_v_cache = layer.layer.get_input(prefill_grpid, "V_cache"); + + auto &output_k_cache = layer.layer.get_output(prefill_grpid, "K_cache_out"); + auto &output_v_cache = layer.layer.get_output(prefill_grpid, "V_cache_out"); + + int kv_offset = (p * _attr.prefill_token_num) * _attr.kv_cache_size; + + memcpy((unsigned short *)input_decoder_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + memcpy((unsigned short *)input_decoder_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(prefill_grpid, "output"); + memcpy(embed_tmp.data(), (void *)output.pVirAddr, embed_tmp.size() * sizeof(unsigned short)); + } + } + + for (size_t i = 0; i < _attr.axmodel_num; i++) { + auto &layer = llama_layers[i]; + k_caches[i].resize(precompute_len * _attr.kv_cache_size); + v_caches[i].resize(precompute_len * _attr.kv_cache_size); + auto &input_k_cache = layer.layer.get_input(prefill_grpid, "K_cache"); + auto &input_v_cache = layer.layer.get_input(prefill_grpid, "V_cache"); + memcpy((void *)k_caches[i].data(), (void *)input_k_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy((void *)v_caches[i].data(), (void *)input_v_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + + return 0; + } + + int GenerateKVCache(std::vector &_token_ids) + { + // clear kv cache + for (size_t i = 0; i < _attr.axmodel_num; i++) { + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "V_cache").nSize); + } + + bfloat16 bf16 = -65536.f; + std::vector mask(_attr.kv_cache_num + 1, bf16.data); + mask[_attr.kv_cache_num] = 0; + std::vector embed; + + int next_token = _token_ids[0]; + + t_cqdm cqdm = create_cqdm(_token_ids.size(), 32); + + for (unsigned int indices = 0; indices < _token_ids.size(); indices++) { + // ALOGI("out %d %d", indices, next_token); + embed_selector.getByIndex(next_token, embed); + + for (int m = 0; m < _attr.axmodel_num; m++) { + if (b_stop) { + break; + } + + auto &layer = llama_layers[m]; + + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + auto &input_indices = layer.layer.get_input(decode_grpid, "indices"); + memcpy(input_indices.pVirAddr, &indices, sizeof(indices)); + + auto &input_mask = layer.layer.get_input(decode_grpid, "mask"); + memcpy(input_mask.pVirAddr, mask.data(), mask.size() * sizeof(unsigned short)); + + auto &input_input = layer.layer.get_input(decode_grpid, "input"); + memcpy(input_input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + + layer.layer.inference(decode_grpid); + + auto &output_k_cache = layer.layer.get_output(decode_grpid, "K_cache_out"); + memcpy(input_k_cache_ptr + indices * _attr.kv_cache_size, output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output_v_cache = layer.layer.get_output(decode_grpid, "V_cache_out"); + memcpy(input_v_cache_ptr + indices * _attr.kv_cache_size, output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(decode_grpid, "output"); + memcpy(embed.data(), output.pVirAddr, embed.size() * sizeof(unsigned short)); + } + mask[indices] = 0; + next_token = _token_ids[indices + 1]; + update_cqdm(&cqdm, indices, "token", ""); + // ALOGI(""); + } + return 0; + } + + int GetKVCache(std::vector> &k_caches, + std::vector> &v_caches, int &precompute_len) + { + bfloat16 bf16 = -65536.f; + std::vector mask(_attr.kv_cache_num + 1, bf16.data); + auto &input_mask = llama_layers[0].layer.get_input(decode_grpid, "mask"); + memcpy(mask.data(), (void *)input_mask.pVirAddr, input_mask.nSize); + for (size_t i = 0; i < mask.size(); i++) { + if (mask[i] == bf16.data) { + precompute_len = i + 1; + break; + } + } + ALOGI("precompute_len:%d, remaining:%d", precompute_len, + _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1] - precompute_len); + k_caches.resize(_attr.axmodel_num); + v_caches.resize(_attr.axmodel_num); + for (size_t i = 0; i < _attr.axmodel_num; i++) { + auto &layer = llama_layers[i]; + k_caches[i].resize(precompute_len * _attr.kv_cache_size); + v_caches[i].resize(precompute_len * _attr.kv_cache_size); + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + memcpy((void *)k_caches[i].data(), (void *)input_k_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy((void *)v_caches[i].data(), (void *)input_v_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + + _attr.prefill_max_token_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1]; + + return 0; + } + + int SetKVCache(std::vector> &k_caches, + std::vector> &v_caches, int precompute_len, int input_num_token) + { + _attr.precompute_len = precompute_len; + for (size_t i = 0; i < _attr.prefill_max_kv_cache_num_grp.size(); i++) { + if (_attr.precompute_len + input_num_token <= _attr.prefill_max_kv_cache_num_grp[i]) { + _attr.prefill_grpid = i + 1; + break; + } + } + int kv_cache_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_grpid - 1]; + ALOGI("prefill_grpid:%d kv_cache_num:%d precompute_len:%d input_num_token:%d", _attr.prefill_grpid, + kv_cache_num, precompute_len, input_num_token); + + _attr.prefill_max_token_num = + ALIGN_DOWN(_attr.prefill_max_token_num - _attr.precompute_len, _attr.prefill_token_num); + ALOGI("current prefill_max_token_num:%d", _attr.prefill_max_token_num); + + if (precompute_len == 0) { + ALOGI("first run"); + return 0; + } + + if (precompute_len + input_num_token > kv_cache_num) { + ALOGE("precompute_len(%d) + input_num_token(%d) > _attr.prefill_max_kv_cache_num_grp[%d]", precompute_len, + input_num_token, _attr.prefill_grpid - 1); + return -1; + } + + if (input_num_token > _attr.prefill_max_token_num) { + ALOGE("input_num_token(%d) > _attr.prefill_max_token_num(%d)", input_num_token, + _attr.prefill_max_token_num); + return -1; + } + + if (k_caches.size() != v_caches.size()) { + ALOGE("k_caches.size(%d) != v_caches.size(%d)", k_caches.size(), v_caches.size()); + return -1; + } + + if (k_caches.size() != _attr.axmodel_num) { + ALOGE("k_caches.size(%d) != _attr.axmodel_num(%d)", k_caches.size(), _attr.axmodel_num); + return -1; + } + + // clear kv cache + for (size_t i = 0; i < _attr.axmodel_num; i++) { + memset((void *)llama_layers[i].layer.get_input(_attr.prefill_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(_attr.prefill_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(_attr.prefill_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(_attr.prefill_grpid, "V_cache").nSize); + + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "V_cache").nSize); + } + + // int prefill_grpid = llama_layers[0].layer.get_num_input_groups() - 1; + + for (unsigned int m = 0; m < _attr.axmodel_num; m++) { + auto &layer = llama_layers[m]; + + auto &k_cache = k_caches[m]; + auto &v_cache = v_caches[m]; + + if (k_cache.size() != _attr.precompute_len * _attr.kv_cache_size) { + ALOGE("k_cache.size(%d) != precompute_len(%d) * _attr.kv_cache_size(%d)", k_cache.size(), + _attr.precompute_len, _attr.kv_cache_size); + return -1; + } + if (v_cache.size() < _attr.precompute_len * _attr.kv_cache_size) { + ALOGE("v_cache.size(%d) < precompute_len(%d) * _attr.kv_cache_size(%d)", v_cache.size(), + _attr.precompute_len, _attr.kv_cache_size); + return -1; + } + + // set kv cache inputs + { + auto &input_k_cache = layer.layer.get_input(_attr.prefill_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(_attr.prefill_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + memcpy(input_k_cache_ptr, k_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy(input_v_cache_ptr, v_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + + { + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + memcpy(input_k_cache_ptr, k_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy(input_v_cache_ptr, v_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + } + + return 0; + } + + bool save_kvcache(std::string target_path, std::string system_prompt, int precompute_len, + std::vector> &k_caches, + std::vector> &v_caches) + { + for (size_t i = 0; i < k_caches.size(); i++) { + std::string k_cache_path = target_path + "/k_cache_" + std::to_string(i) + ".bin"; + std::string v_cache_path = target_path + "/v_cache_" + std::to_string(i) + ".bin"; + std::ofstream k_cache_file(k_cache_path); + std::ofstream v_cache_file(v_cache_path); + if (!k_cache_file.is_open() || !v_cache_file.is_open()) { + ALOGE("save kvcache failed"); + return false; + } + k_cache_file.write((char *)k_caches[i].data(), k_caches[i].size() * sizeof(unsigned short)); + v_cache_file.write((char *)v_caches[i].data(), v_caches[i].size() * sizeof(unsigned short)); + k_cache_file.close(); + v_cache_file.close(); + } + nlohmann::json j; + j["system_prompt"] = system_prompt; + j["precompute_len"] = precompute_len; + std::string config_path = target_path + "/config.json"; + std::ofstream config_file(config_path); + config_file << j.dump(); + config_file.close(); + return true; + } + + bool load_kvcache(std::string target_path, int axmodel_num, std::vector> &k_caches, + std::vector> &v_caches, std::string &system_prompt, + int &precompute_len) + { + k_caches.resize(axmodel_num); + v_caches.resize(axmodel_num); + for (size_t i = 0; i < k_caches.size(); i++) { + std::string k_cache_path = target_path + "/k_cache_" + std::to_string(i) + ".bin"; + std::string v_cache_path = target_path + "/v_cache_" + std::to_string(i) + ".bin"; + if (file_exist(k_cache_path) && file_exist(v_cache_path)) { + std::vector k_cache; + std::vector v_cache; + std::ifstream k_cache_file(k_cache_path); + std::ifstream v_cache_file(v_cache_path); + + k_cache_file.seekg(0, std::ios::end); + k_cache.resize(k_cache_file.tellg() / sizeof(unsigned short)); + k_cache_file.seekg(0, std::ios::beg); + + v_cache_file.seekg(0, std::ios::end); + v_cache.resize(v_cache_file.tellg() / sizeof(unsigned short)); + v_cache_file.seekg(0, std::ios::beg); + + k_cache_file.read((char *)k_cache.data(), k_cache.size() * sizeof(unsigned short)); + v_cache_file.read((char *)v_cache.data(), v_cache.size() * sizeof(unsigned short)); + + k_cache_file.close(); + v_cache_file.close(); + k_caches[i] = k_cache; + v_caches[i] = v_cache; + } else { + ALOGE("k_cache %s or v_cache %s not exist", k_cache_path.c_str(), v_cache_path.c_str()); + return false; + } + } + + std::string config_path = target_path + "/config.json"; + if (file_exist(config_path)) { + std::ifstream config_file(config_path); + nlohmann::json j; + config_file >> j; + system_prompt = j["system_prompt"].get(); + precompute_len = j["precompute_len"].get(); + config_file.close(); + } else { + ALOGE("config %s not exist", config_path.c_str()); + return false; + } + return true; + } + + int Encode(std::vector &out_embed, std::string prompt, std::string last_reply, + std::vector &tokens_ids, std::vector &tokens_diff) + { + ImageInfo img_info; + img_info.img_prompt = false; + if (!tokenizer->Encode(prompt, last_reply, tokens_ids, tokens_diff, img_info)) { + ALOGE("encode failed"); + return -1; + } + + out_embed.resize(tokens_diff.size() * _attr.tokens_embed_size); + + for (size_t i = 0; i < tokens_diff.size(); i++) { + embed_selector.getByIndex(tokens_diff[i], out_embed.data() + i * _attr.tokens_embed_size); + } + + return 0; + } + + std::string Run(std::vector test_embed) + { + b_stop = false; + std::string final_out; + + bfloat16 bf16 = -65536.f; + std::vector mask(_attr.kv_cache_num + 1, bf16.data); + std::vector embed(_attr.tokens_embed_size, 0); + int kv_cache_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_grpid - 1]; + + std::vector cached_token; + std::vector token_ids; + + int input_embed_num = test_embed.size() / _attr.tokens_embed_size; + int prefill_split_num = ceil((double)input_embed_num / _attr.prefill_token_num); + ALOGI("input token num : %d, prefill_split_num : %d", input_embed_num, prefill_split_num); + + mask[_attr.kv_cache_num] = 0; + for (size_t i = 0; i < _attr.precompute_len + input_embed_num; i++) { + mask[i] = 0; + } + timer t_cost; + timer ttft_timer; + ttft_timer.start(); + + for (size_t p = 0; p < prefill_split_num; p++) { + if (b_stop) { + break; + } + + std::vector mask_tmp; + mask_tmp.resize(1 * _attr.prefill_token_num * (kv_cache_num + _attr.prefill_token_num), bf16.data); + int input_num_token = _attr.prefill_token_num; + if (p == prefill_split_num - 1) { + input_num_token = input_embed_num - p * _attr.prefill_token_num; + } + + ALOGI("input_num_token:%d", input_num_token); + for (size_t i = 0; i < _attr.prefill_token_num; i++) { + if (i < input_num_token) { + int mask_current_start = kv_cache_num; + auto mask_ptr = mask_tmp.data() + i * (kv_cache_num + _attr.prefill_token_num); + + for (int j = 0; j < _attr.precompute_len + p * _attr.prefill_token_num; j++) { + mask_ptr[j] = 0; + } + + for (int j = mask_current_start; j < mask_current_start + i + 1; j++) { + mask_ptr[j] = 0; + } + } + } + + std::vector embed_tmp(_attr.prefill_token_num * _attr.tokens_embed_size, 0); + if (p == (prefill_split_num - 1)) { + memcpy( + embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + (input_embed_num - p * _attr.prefill_token_num) * _attr.tokens_embed_size * sizeof(unsigned short)); + } else { + memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + _attr.prefill_token_num * _attr.tokens_embed_size * sizeof(unsigned short)); + } + + for (unsigned int m = 0; m < _attr.axmodel_num; m++) { + if (b_stop) { + break; + } + + auto &layer = llama_layers[m]; + + // set indices + auto &input_indices = layer.layer.get_input(_attr.prefill_grpid, "indices"); + unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr; + memset(input_indices_ptr, 0, input_indices.nSize); + int idx = 0; + for (unsigned int i = _attr.precompute_len + p * _attr.prefill_token_num; + i < _attr.precompute_len + (p + 1) * _attr.prefill_token_num; i++) { + input_indices_ptr[idx] = i; + idx++; + } + + // set mask + auto &input_mask = layer.layer.get_input(_attr.prefill_grpid, "mask"); + memcpy((void *)input_mask.pVirAddr, (void *)mask_tmp.data(), mask_tmp.size() * sizeof(unsigned short)); + + // set input + auto &input_input = layer.layer.get_input(_attr.prefill_grpid, "input"); + memcpy((void *)input_input.pVirAddr, embed_tmp.data(), embed_tmp.size() * sizeof(unsigned short)); + + layer.layer.inference(_attr.prefill_grpid); + + auto &input_decoder_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + auto &input_decoder_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + + auto &input_prefill_k_cache = layer.layer.get_input(_attr.prefill_grpid, "K_cache"); + auto &input_prefill_v_cache = layer.layer.get_input(_attr.prefill_grpid, "V_cache"); + + auto &output_k_cache = layer.layer.get_output(_attr.prefill_grpid, "K_cache_out"); + auto &output_v_cache = layer.layer.get_output(_attr.prefill_grpid, "V_cache_out"); + + int kv_offset = (_attr.precompute_len + p * _attr.prefill_token_num) * _attr.kv_cache_size; + + memcpy((unsigned short *)input_decoder_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + memcpy((unsigned short *)input_decoder_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(_attr.prefill_grpid, "output"); + memcpy(embed_tmp.data(), (void *)output.pVirAddr, embed_tmp.size() * sizeof(unsigned short)); + } + if (p == (prefill_split_num - 1)) { + memcpy(embed.data(), + embed_tmp.data() + (input_embed_num - p * _attr.prefill_token_num - 1) * _attr.tokens_embed_size, + _attr.tokens_embed_size * sizeof(unsigned short)); + } + } + + int next_token = -1; + t_cqdm cqdm = create_cqdm(_attr.max_token_len, 32); + + { + // post process + auto &input = llama_post.get_input("input"); + memcpy(input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + llama_post.inference(); + int max_index; + + auto &output_post = llama_post.get_output("output"); + // AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); + unsigned short *post_out = (unsigned short *)output_post.pVirAddr; + float max_val = -MAXFLOAT; + // max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val); + max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, nullptr); + + next_token = max_index; + + token_ids.push_back(max_index); + cached_token.push_back(max_index); + ALOGI("ttft: %.2f ms", ttft_timer.cost()); + } + t_cost.start(); + + bool b_hit_eos = false; + for (unsigned int indices = _attr.precompute_len + input_embed_num; indices < _attr.max_token_len; indices++) { + if (b_stop) { + break; + } + + // ALOGI("out %d %d", indices, next_token); + embed_selector.getByIndex(next_token, embed); + // ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), bfloat16(embed[2]).fp32(), + // bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32()); + + for (int m = 0; m < _attr.axmodel_num; m++) { + if (b_stop) { + break; + } + + auto &layer = llama_layers[m]; + + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + // memcpy(input_k_cache.pVirAddr, k_caches[m].data(), sizeof(unsigned short) * k_caches[m].size()); + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + // memcpy(input_v_cache.pVirAddr, v_caches[m].data(), sizeof(unsigned short) * v_caches[m].size()); + + auto &input_indices = layer.layer.get_input(decode_grpid, "indices"); + memcpy(input_indices.pVirAddr, &indices, sizeof(indices)); + + auto &input_mask = layer.layer.get_input(decode_grpid, "mask"); + memcpy(input_mask.pVirAddr, mask.data(), mask.size() * sizeof(unsigned short)); + + auto &input_input = layer.layer.get_input(decode_grpid, "input"); + memcpy(input_input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + + layer.layer.inference(decode_grpid); + + auto &output_k_cache = layer.layer.get_output(decode_grpid, "K_cache_out"); + // AX_SYS_MinvalidateCache(output_k_cache.phyAddr, output_k_cache.pVirAddr, output_k_cache.nSize); + memcpy(input_k_cache_ptr + indices * _attr.kv_cache_size, output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output_v_cache = layer.layer.get_output(decode_grpid, "V_cache_out"); + // AX_SYS_MinvalidateCache(output_v_cache.phyAddr, output_v_cache.pVirAddr, output_v_cache.nSize); + memcpy(input_v_cache_ptr + indices * _attr.kv_cache_size, output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(decode_grpid, "output"); + // AX_SYS_MinvalidateCache(output.phyAddr, output.pVirAddr, output.nSize); + memcpy(embed.data(), output.pVirAddr, embed.size() * sizeof(unsigned short)); + + // ALOGI("%f %f %f %f %f", bfloat16(embed[0]).fp32(), bfloat16(embed[1]).fp32(), + // bfloat16(embed[2]).fp32(), bfloat16(embed[3]).fp32(), bfloat16(embed[4]).fp32()); + } + // ALOGI(""); + mask[indices] = 0; + { + // post process + auto &input = llama_post.get_input("input"); + memcpy(input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + llama_post.inference(); + int max_index; + + auto &output_post = llama_post.get_output("output"); + // AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); + unsigned short *post_out = (unsigned short *)output_post.pVirAddr; + float max_val = -MAXFLOAT; + // max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val); + max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, nullptr); + + next_token = max_index; + + if (tokenizer->isEnd(max_index)) { + if (cached_token.size() && _attr.runing_callback) { + float t_cost_ms = t_cost.cost(); + float token_per_sec = token_ids.size() / (t_cost_ms / 1000); + auto tmp_out = tokenizer->Decode(cached_token); + _attr.runing_callback(cached_token.data(), cached_token.size(), tmp_out.c_str(), token_per_sec, + _attr.reserve); + cached_token.clear(); + } + b_hit_eos = true; + break; + } + token_ids.push_back(max_index); + + if (_attr.runing_callback) { + cached_token.push_back(max_index); + if (cached_token.size() >= 3) { + float t_cost_ms = t_cost.cost(); + float token_per_sec = token_ids.size() / (t_cost_ms / 1000); + auto tmp_out = tokenizer->Decode(cached_token); + _attr.runing_callback(cached_token.data(), cached_token.size(), tmp_out.c_str(), token_per_sec, + _attr.reserve); + cached_token.clear(); + } + } + } + + if (_attr.runing_callback == nullptr) update_cqdm(&cqdm, indices, "token", ""); + if (b_hit_eos) { + break; + } + } + printf("\n\n"); + fflush(stdout); + float t_cost_ms = t_cost.cost(); + ALOGN("hit eos,avg %.2f token/s\n", token_ids.size() / (t_cost_ms / 1000)); + + final_out = tokenizer->Decode(token_ids); + + return final_out; + } +}; \ No newline at end of file diff --git a/projects/llm_framework/main_llm/src/runner/LLMPostprocess.hpp b/projects/llm_framework/main_llm/src/runner/LLMPostprocess.hpp index c98205e4..b7d50156 100644 --- a/projects/llm_framework/main_llm/src/runner/LLMPostprocess.hpp +++ b/projects/llm_framework/main_llm/src/runner/LLMPostprocess.hpp @@ -242,10 +242,11 @@ class LLMPostprocess this->temperature = temperature; } - void set_repetition_penalty(bool enable, float penalty) + void set_repetition_penalty(bool enable, float penalty, int penalty_window) { enable_repetition_penalty = enable; this->repetition_penalty = penalty; + this->penalty_window = penalty_window; } void set_diversity_penalty(bool enable, const std::vector &common_phrases, float penalty) @@ -295,6 +296,49 @@ class LLMPostprocess return true; } + bool load_config(const nlohmann::json& config) + { + if (config.is_null()) { + ALOGE("config is null or invalid"); + return false; + } + + ALOGI("load config: \n%s\n", config.dump(4).c_str()); + + if (config.contains("enable_temperature")) { + enable_temperature = config["enable_temperature"].get(); + } + if (config.contains("temperature")) { + temperature = config["temperature"].get(); + } + + if (config.contains("enable_repetition_penalty")) { + enable_repetition_penalty = config["enable_repetition_penalty"].get(); + } + if (config.contains("repetition_penalty")) { + repetition_penalty = config["repetition_penalty"].get(); + } + if (config.contains("penalty_window")) { + penalty_window = config["penalty_window"].get(); + } + + if (config.contains("enable_top_p_sampling")) { + enable_top_p_sampling = config["enable_top_p_sampling"].get(); + } + if (config.contains("top_p")) { + top_p = config["top_p"].get(); + } + + if (config.contains("enable_top_k_sampling")) { + enable_top_k_sampling = config["enable_top_k_sampling"].get(); + } + if (config.contains("top_k")) { + top_k = config["top_k"].get(); + } + + return true; + } + int apply(std::vector &logits, const std::vector &history) { if (enable_temperature) diff --git a/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.cpp b/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.cpp index 31aa272d..f676efdd 100644 --- a/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.cpp +++ b/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.cpp @@ -8,6 +8,7 @@ // #include "chatglm.h" #include "httplib.h" +#include "http_utils.hpp" #include "json.hpp" #include "sample_log.h" @@ -21,8 +22,7 @@ #include #include -class TokenizerLLaMa : public BaseTokenizer -{ +class TokenizerLLaMa : public BaseTokenizer { protected: sentencepiece::SentencePieceProcessor sp; bool _b_bos, _b_eos; @@ -33,8 +33,7 @@ class TokenizerLLaMa : public BaseTokenizer bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override { auto ret = sp.Load(model_path); - if (!ret.ok()) - { + if (!ret.ok()) { ALOGE("%s", ret.error_message()); return false; } @@ -44,43 +43,37 @@ class TokenizerLLaMa : public BaseTokenizer return ret.ok(); } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Encode(std::string input, std::vector &output, ImageInfo img_info) override { auto ret = sp.Encode(input, &output); - if (!ret.ok()) - { + if (!ret.ok()) { ALOGE("%s", ret.error_message()); return false; } - if (_b_bos) - { + if (_b_bos) { output.insert(output.begin(), sp.bos_id()); } - if (_b_eos) - { + if (_b_eos) { output.push_back(sp.eos_id()); } return true; } - std::vector Encode(std::string input, bool b_img_prompt = false) override + std::vector Encode(std::string input, ImageInfo img_info) override { std::vector output; - Encode(input, output, b_img_prompt); + Encode(input, output, img_info); return output; } - std::string Decode(const std::vector input) override + std::string Decode(const std::vector &input) override { sentencepiece::SentencePieceText spt; sp.Decode(input, &spt); std::string out = spt.pieces()[0].piece(); - if (*(unsigned short *)out.data() == 38626) - { + if (*(unsigned short *)out.data() == 38626) { return " " + spt.text(); - } - else - { + } else { return spt.text(); } } @@ -98,40 +91,32 @@ class TokenizerLLaMa : public BaseTokenizer { std::ostringstream oss_prompt; int messages_len = messages_.size(); - for(auto &message : messages_) - { - messages_len --; - switch (message.first) - { - case ROLE_USER: - { - oss_prompt << "<|user|>\n" << message.second << ""; - } - break; - case ROLE_SYSTEM: - break; - case ROLE_ASSISTANT: - break; - case ROLE_ASSISTANT_HELP: - { - if(messages_len == 0) - { - oss_prompt << "<|assistant|>\n"; - } - } - break; - default: - break; + for (auto &message : messages_) { + messages_len--; + switch (message.first) { + case ROLE_USER: { + oss_prompt << "<|user|>\n" << message.second << ""; + } break; + case ROLE_SYSTEM: + break; + case ROLE_ASSISTANT: + break; + case ROLE_ASSISTANT_HELP: { + if (messages_len == 0) { + oss_prompt << "<|assistant|>\n"; + } + } break; + default: + break; } } return oss_prompt.str(); } }; -class TokenizerMINICPM : public TokenizerLLaMa -{ +class TokenizerMINICPM : public TokenizerLLaMa { public: - std::string Decode(const std::vector input) override + std::string Decode(const std::vector &input) override { sentencepiece::SentencePieceText spt; sp.Decode(input, &spt); @@ -141,38 +126,30 @@ class TokenizerMINICPM : public TokenizerLLaMa { std::ostringstream oss_prompt; int messages_len = messages_.size(); - for(auto &message : messages_) - { - messages_len --; - switch (message.first) - { - case ROLE_USER: - { - oss_prompt << "<用户>" << message.second; - } - break; - case ROLE_SYSTEM: - break; - case ROLE_ASSISTANT: - break; - case ROLE_ASSISTANT_HELP: - { - if(messages_len == 0) - { - oss_prompt << ""; - } - } - break; - default: - break; + for (auto &message : messages_) { + messages_len--; + switch (message.first) { + case ROLE_USER: { + oss_prompt << "<用户>" << message.second; + } break; + case ROLE_SYSTEM: + break; + case ROLE_ASSISTANT: + break; + case ROLE_ASSISTANT_HELP: { + if (messages_len == 0) { + oss_prompt << ""; + } + } break; + default: + break; } } return oss_prompt.str(); } }; -class TokenizerPhi3 : public BaseTokenizer -{ +class TokenizerPhi3 : public BaseTokenizer { sentencepiece::SentencePieceProcessor sp; bool _b_bos, _b_eos; @@ -182,8 +159,7 @@ class TokenizerPhi3 : public BaseTokenizer bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override { auto ret = sp.Load(model_path); - if (!ret.ok()) - { + if (!ret.ok()) { ALOGE("%s", ret.error_message()); return false; } @@ -193,46 +169,40 @@ class TokenizerPhi3 : public BaseTokenizer return ret.ok(); } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Encode(std::string input, std::vector &output, ImageInfo img_info) override { auto ret = sp.Encode(input, &output); - if (!ret.ok()) - { + if (!ret.ok()) { ALOGE("%s", ret.error_message()); return false; } - output.insert(output.begin(), 32010); //"<|user|>" - output.push_back(32007); //"<|end|>" - output.push_back(32001); //"<|assistant|>" - if (_b_bos) - { + output.insert(output.begin(), 32010); //"<|user|>" + output.push_back(32007); //"<|end|>" + output.push_back(32001); //"<|assistant|>" + if (_b_bos) { output.insert(output.begin(), sp.bos_id()); } - if (_b_eos) - { + if (_b_eos) { output.push_back(sp.eos_id()); } return true; } - std::vector Encode(std::string input, bool b_img_prompt = false) override + std::vector Encode(std::string input, ImageInfo img_info) override { std::vector output; - Encode(input, output, b_img_prompt); + Encode(input, output, img_info); return output; } - std::string Decode(const std::vector input) override + std::string Decode(const std::vector &input) override { sentencepiece::SentencePieceText spt; sp.Decode(input, &spt); std::string out = spt.pieces()[0].piece(); - if (*(unsigned short *)out.data() == 38626) - { + if (*(unsigned short *)out.data() == 38626) { return " " + spt.text(); - } - else - { + } else { return spt.text(); } } @@ -251,42 +221,35 @@ class TokenizerPhi3 : public BaseTokenizer { return id == GetEosID() || id > 31999; } + std::string apply_chat_template() override { std::ostringstream oss_prompt; int messages_len = messages_.size(); - for(auto &message : messages_) - { - messages_len --; - switch (message.first) - { - case ROLE_USER: - { - oss_prompt << message.second; - } - break; - case ROLE_SYSTEM: - break; - case ROLE_ASSISTANT: - break; - case ROLE_ASSISTANT_HELP: - { - if(messages_len == 0) - { - oss_prompt << " "; - } - } - break; - default: - break; + for (auto &message : messages_) { + messages_len--; + switch (message.first) { + case ROLE_USER: { + oss_prompt << message.second; + } break; + case ROLE_SYSTEM: + break; + case ROLE_ASSISTANT: + break; + case ROLE_ASSISTANT_HELP: { + if (messages_len == 0) { + oss_prompt << " "; + } + } break; + default: + break; } } return oss_prompt.str(); } }; -class TokenizerQwen : public BaseTokenizer -{ +class TokenizerQwen : public BaseTokenizer { std::shared_ptr sp; bool _b_bos, _b_eos; @@ -295,8 +258,7 @@ class TokenizerQwen : public BaseTokenizer public: bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override { - if (!file_exist(model_path)) - { + if (!file_exist(model_path)) { ALOGE("tokenizer model file(%s) not exist", model_path.c_str()); return false; } @@ -308,14 +270,12 @@ class TokenizerQwen : public BaseTokenizer return true; } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Encode(std::string input, std::vector &output, ImageInfo img_info) override { - if (_b_bos) - { + if (_b_bos) { // input += "<|im_start|>"; } - if (_b_eos) - { + if (_b_eos) { input += "<|endoftext|>"; } output = sp->encode(input, 1024); @@ -323,14 +283,14 @@ class TokenizerQwen : public BaseTokenizer return true; } - std::vector Encode(std::string input, bool b_img_prompt = false) override + std::vector Encode(std::string input, ImageInfo img_info) override { std::vector output; - Encode(input, output, b_img_prompt); + Encode(input, output, img_info); return output; } - std::string Decode(const std::vector input) override + std::string Decode(const std::vector &input) override { return sp->decode(input); } @@ -348,108 +308,32 @@ class TokenizerQwen : public BaseTokenizer { std::ostringstream oss_prompt; int messages_len = messages_.size(); - for(auto &message : messages_) - { - messages_len --; - switch (message.first) - { - case ROLE_USER: - { - oss_prompt << "<|im_start|>user\n" << message.second << "<|im_end|>\n"; - } - break; - case ROLE_SYSTEM: - { - oss_prompt << "<|im_start|>system\n" << message.second << ".<|im_end|>\n"; - } - break; - case ROLE_ASSISTANT: - { - oss_prompt << "<|im_start|>assistant\n" << message.second << ".<|im_end|>\n"; - } - break; - case ROLE_ASSISTANT_HELP: - { - if(messages_len == 0) - { - oss_prompt << "<|im_start|>assistant\n"; - } - } - break; - default: - break; + for (auto &message : messages_) { + messages_len--; + switch (message.first) { + case ROLE_USER: { + oss_prompt << "<|im_start|>user\n" << message.second << "<|im_end|>\n"; + } break; + case ROLE_SYSTEM: { + oss_prompt << "<|im_start|>system\n" << message.second << ".<|im_end|>\n"; + } break; + case ROLE_ASSISTANT: { + oss_prompt << "<|im_start|>assistant\n" << message.second << ".<|im_end|>\n"; + } break; + case ROLE_ASSISTANT_HELP: { + if (messages_len == 0) { + oss_prompt << "<|im_start|>assistant\n"; + } + } break; + default: + break; } } return oss_prompt.str(); } }; -// class TokenizerGLM3 : public BaseTokenizer -// { -// std::shared_ptr sp; -// bool _b_bos, _b_eos; - -// private: -// /* data */ -// public: -// bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override -// { -// if (!file_exist(model_path)) -// { -// ALOGE("tokenizer model file(%s) not exist", model_path.c_str()); -// return false; -// } -// // std::vector sp_model_data; -// // read_file(model_path, sp_model_data); -// // std::string_view serialized_model_proto(sp_model_data.data(), sp_model_data.size()); - -// sp.reset(new chatglm::ChatGLM3Tokenizer(model_path)); - -// this->_b_bos = b_bos; -// this->_b_eos = b_eos; -// return true; -// } - -// bool Encode(std::string input, std::vector &output) override -// { -// if (_b_bos) -// { -// // input += "<|im_start|>"; -// } -// if (_b_eos) -// { -// // input += "<|endoftext|>"; -// } -// output = sp->encode(input, 1024); - -// return true; -// } - -// std::vector Encode(std::string input) override -// { -// std::vector output; -// Encode(input, output); -// return output; -// } - -// std::string Decode(const std::vector input) override -// { -// return sp->decode(input); -// } - -// int GetBosID() override -// { -// return sp->sp.bos_id(); -// } - -// int GetEosID() override -// { -// return sp->sp.eos_id(); -// } -// }; - -class Tokenizer_Http : public BaseTokenizer -{ +class Tokenizer_Http : public BaseTokenizer { std::shared_ptr cli; bool _b_bos, _b_eos; @@ -457,14 +341,15 @@ class Tokenizer_Http : public BaseTokenizer int bos_id, eos_id; + std::string uid; + private: /* data */ public: bool Init(std::string model_path = "http://localhost:8080", bool b_bos = true, bool b_eos = false) override { base_url = model_path; - try - { + try { cli = std::make_shared(base_url); cli->set_connection_timeout(1); cli->set_read_timeout(1); @@ -472,30 +357,26 @@ class Tokenizer_Http : public BaseTokenizer { auto ret = cli->Get("/bos_id"); auto rep = ret.value(); - if (rep.status != 200) - { + if (rep.status != 200) { ALOGE("get bos_id failed, status: %d", rep.status); return false; } nlohmann::json j = nlohmann::json::parse(rep.body); - bos_id = j["bos_id"]; + bos_id = j["bos_id"]; } { auto ret = cli->Get("/eos_id"); auto rep = ret.value(); - if (rep.status != 200) - { + if (rep.status != 200) { ALOGE("get eos_id failed, status: %d", rep.status); return false; } nlohmann::json j = nlohmann::json::parse(rep.body); - eos_id = j["eos_id"]; + eos_id = j["eos_id"]; } printf("bos_id: %d, eos_id: %d\n", bos_id, eos_id); - } - catch (const std::exception &e) - { + } catch (const std::exception &e) { std::cerr << e.what() << '\n'; return false; } @@ -505,77 +386,201 @@ class Tokenizer_Http : public BaseTokenizer return true; } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Init(std::string model_path = "http://localhost:8080") override + { + base_url = model_path; + if (!test_connect_http(base_url, 10)) { + ALOGE("connect %s failed", base_url.c_str()); + return false; + } else { + ALOGI("connect %s ok", base_url.c_str()); + } + + cli = std::make_shared(base_url); + cli->set_connection_timeout(10); + cli->set_read_timeout(10); + cli->set_write_timeout(10); + + int try_count = 10; + int count = try_count; + while (count-- > 0) { + try { + auto ret = cli->Get("/get_uid"); + auto rep = ret.value(); + if (rep.status != 200) { + ALOGE("get uid failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + uid = j["uid"]; + ALOGI("uid: %s", uid.c_str()); + break; + } catch (const std::exception &e) { + std::cerr << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get uid failed, try again %d/%d", count, try_count); + } + + count = 10; + while (count-- > 0) { + try { + auto ret = cli->Get("/bos_id?uid=" + uid); + auto rep = ret.value(); + if (rep.status != 200) { + ALOGE("get bos_id failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + bos_id = j["bos_id"]; + break; + } catch (const std::exception &e) { + std::cerr << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get bos_id failed, try again %d/%d", count, try_count); + } + + count = 10; + while (count-- > 0) { + try { + auto ret = cli->Get("/eos_id?uid=" + uid); + auto rep = ret.value(); + if (rep.status != 200) { + ALOGE("get eos_id failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + eos_id = j["eos_id"]; + break; + } catch (const std::exception &e) { + std::cerr << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get eos_id failed, try again %d/%d", count, try_count); + } + + printf("bos_id: %d, eos_id: %d\n", bos_id, eos_id); + + return true; + } + + bool Encode(std::string input, std::string last_reply, std::vector &tokens, std::vector &tokens_diff, + ImageInfo img_info) override { nlohmann::json j; + j["uid"] = uid; j["text"] = input; - j["img_prompt"] = b_img_prompt; - auto ret = cli->Post("/encode", j.dump(), "application/json"); - auto rep = ret.value(); - if (rep.status != 200) - { + if (!last_reply.empty() and last_reply != "") { + j["last_reply"] = last_reply; + } + + j["img_prompt"] = img_info.img_prompt; + auto ret = cli->Post("/encode", j.dump(), "application/json"); + auto rep = ret.value(); + if (rep.status != 200) { ALOGE("encode failed, status: %d", rep.status); return false; } nlohmann::json j2; - try - { + try { j2 = nlohmann::json::parse(rep.body); + } catch (const std::exception &e) { + ALOGE("json parse failed: %s", e.what()); + ALOGE("%s", rep.body.c_str()); + return false; + } + + std::vector _token_ids = j2["token_ids"]; + std::vector _tokens_diff = j2["diff"]; + + tokens = _token_ids; + tokens_diff = _tokens_diff; + + return true; + } + + bool Encode(std::string input, std::vector &output, ImageInfo img_info) override + { + nlohmann::json j; + j["text"] = input; + j["img_prompt"] = img_info.img_prompt; + auto ret = cli->Post("/encode", j.dump(), "application/json"); + auto rep = ret.value(); + if (rep.status != 200) { + ALOGE("encode failed, status: %d", rep.status); + return false; } - catch (const std::exception &e) - { + nlohmann::json j2; + try { + j2 = nlohmann::json::parse(rep.body); + } catch (const std::exception &e) { ALOGE("json parse failed: %s", e.what()); ALOGE("%s", rep.body.c_str()); return false; } std::vector out = j2["token_ids"]; - output = out; - // output = sp->encode(input, 1024); - if (_b_bos) - { + output = out; + + if (_b_bos) { output.insert(output.begin(), bos_id); } - if (_b_eos) - { + if (_b_eos) { output.push_back(eos_id); } return true; } - std::vector Encode(std::string input, bool b_img_prompt = false) override + bool Reset(std::string system_prompt, std::vector &tokens) override + { + nlohmann::json j; + j["uid"] = uid; + if (!system_prompt.empty() and system_prompt != "") { + j["system_prompt"] = system_prompt; + } + + auto ret = cli->Post("/reset", j.dump(), "application/json"); + auto rep = ret.value(); + if (rep.status != 200) { + ALOGE("reset failed, status: %d", rep.status); + return false; + } + nlohmann::json j_rep = nlohmann::json::parse(rep.body); + std::vector _token_ids = j_rep["token_ids"]; + tokens = _token_ids; + return true; + } + + std::vector Encode(std::string input, ImageInfo img_info) override { std::vector output; - Encode(input, output, b_img_prompt); + Encode(input, output, img_info); return output; } - std::string Decode(const std::vector input) override + std::string Decode(const std::vector &input) override { - int cnt = 2; + int cnt = 2; std::string out_str = ""; - while (cnt--) - { + while (cnt--) { nlohmann::json j; j["token_ids"] = input; - auto ret = cli->Post("/decode", j.dump(), "application/json"); - auto rep = ret.value(); - if (rep.status != 200) - { + j["uid"] = uid; + auto ret = cli->Post("/decode", j.dump(), "application/json"); + auto rep = ret.value(); + if (rep.status != 200) { ALOGE("decode failed, status: %d, try again", rep.status); ALOGE("%s", rep.body.c_str()); usleep(1000 * 1000); continue; } - try - { + try { nlohmann::json j2 = nlohmann::json::parse(rep.body); - out_str = j2["text"]; + out_str = j2["text"]; break; - } - catch (const std::exception &e) - { + } catch (const std::exception &e) { ALOGE("json parse failed: %s, try again", e.what()); ALOGE("%s", rep.body.c_str()); usleep(1000 * 1000); @@ -594,34 +599,25 @@ class Tokenizer_Http : public BaseTokenizer { return eos_id; } + std::string apply_chat_template() override { std::ostringstream oss_prompt; int messages_len = messages_.size(); - for(auto &message : messages_) - { - messages_len --; - switch (message.first) - { - case ROLE_USER: - { - oss_prompt << message.second ; - } - break; - case ROLE_SYSTEM: - { - } - break; - case ROLE_ASSISTANT: - { - } - break; - case ROLE_ASSISTANT_HELP: - { - } - break; - default: - break; + for (auto &message : messages_) { + messages_len--; + switch (message.first) { + case ROLE_USER: { + oss_prompt << message.second; + } break; + case ROLE_SYSTEM: { + } break; + case ROLE_ASSISTANT: { + } break; + case ROLE_ASSISTANT_HELP: { + } break; + default: + break; } } return oss_prompt.str(); @@ -801,7 +797,7 @@ class Tokenizer_Auto : public BaseTokenizer { return ret.success; } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Encode(std::string input, std::vector &output, ImageInfo img_info) override { nlohmann::json rpcobj; rpcobj["method"] = "encode"; @@ -815,14 +811,14 @@ class Tokenizer_Auto : public BaseTokenizer { return ret.success; } - std::vector Encode(std::string input, bool b_img_prompt = false) override + std::vector Encode(std::string input, ImageInfo img_info) override { std::vector output; - Encode(input, output, b_img_prompt); + Encode(input, output, img_info); return output; } - std::string Decode(const std::vector input) override + std::string Decode(const std::vector &input) override { nlohmann::json rpcobj; rpcobj["method"] = "decode"; @@ -897,22 +893,20 @@ class Tokenizer_Auto : public BaseTokenizer { std::shared_ptr CreateTokenizer(TokenizerType type) { - switch (type) - { - case TKT_LLaMa: - return std::make_shared(); - case TKT_MINICPM: - return std::make_shared(); - case TKT_HTTP: - return std::make_shared(); - case TKT_Qwen: - return std::make_shared(); - case TKT_Phi3: - return std::make_shared(); - case TKT_AUTO: - return std::make_shared(); - default: - return nullptr; + switch (type) { + case TKT_LLaMa: + return std::make_shared(); + case TKT_MINICPM: + return std::make_shared(); + case TKT_HTTP: + return std::make_shared(); + case TKT_Qwen: + return std::make_shared(); + case TKT_Phi3: + return std::make_shared(); + case TKT_AUTO: + return std::make_shared(); + default: + return nullptr; } } - diff --git a/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.hpp b/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.hpp index ccb8c4a6..ab545999 100644 --- a/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.hpp +++ b/projects/llm_framework/main_llm/src/runner/Tokenizer/Tokenizer.hpp @@ -5,48 +5,76 @@ #include #include #include -enum TokenizerType -{ - TKT_LLaMa, - TKT_Qwen, - TKT_HTTP, - TKT_Phi3, - TKT_MINICPM, - TKT_AUTO, - TKT_END +enum TokenizerType { TKT_LLaMa, TKT_Qwen, TKT_HTTP, TKT_Phi3, TKT_MINICPM, TKT_AUTO, TKT_END }; + +enum TokenizeRole { + ROLE_USER, // 用户输入 + ROLE_SYSTEM, // 提示词 + ROLE_TOOL, // 工具 + ROLE_IPYTHON, // 工具 + ROLE_ASSISTANT, // 助手回复 + ROLE_ASSISTANT_HELP // 询问句 }; -enum TokenizeRole{ - ROLE_USER,//用户输入 - ROLE_SYSTEM,//提示词 - ROLE_TOOL, //工具 - ROLE_IPYTHON, //工具 - ROLE_ASSISTANT,//助手回复 - ROLE_ASSISTANT_HELP// 询问句 +struct ImageInfo { + int imgsz = 448; + int num_img = 1; + bool img_prompt = false; }; -class BaseTokenizer -{ -public: - std::list> messages_; - void messages_clean() {messages_.clear();}; +class BaseTokenizer { public: - virtual bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) = 0; - virtual bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) = 0; - virtual std::vector Encode(std::string input, bool b_img_prompt = false) = 0; - virtual std::string Decode(const std::vector input) = 0; - virtual int GetBosID() = 0; - virtual int GetEosID() = 0; + std::list> messages_; + + void messages_clean() + { + messages_.clear(); + } + + virtual ~BaseTokenizer() = default; + + virtual bool Init(std::string model_path) + { + return false; + }; + + virtual bool Init(std::string model_path, bool b_bos, bool b_eos) + { + return false; + }; + + virtual bool Reset(std::string system_prompt, std::vector &tokens) + { + return false; + }; + + virtual bool Encode(std::string input, std::string last_reply, std::vector &tokens, + std::vector &tokens_diff, ImageInfo img_info) + { + return false; + }; + + virtual bool Encode(std::string input, std::vector &output, ImageInfo img_info) = 0; + virtual std::vector Encode(std::string input, ImageInfo img_info) = 0; + virtual std::string Decode(const std::vector &input) = 0; + virtual int GetBosID() = 0; + virtual int GetEosID() = 0; + virtual std::string apply_chat_template() = 0; - virtual std::string messages_complete(enum TokenizeRole role, const std::string &content = ""){ + + virtual std::string messages_complete(TokenizeRole role, const std::string &content = "") + { messages_.push_back(std::make_pair(role, content)); - // std::cout << "messages_complete role:" << role << "content:" << content << std::endl; - if(ROLE_ASSISTANT_HELP == role) + if (ROLE_ASSISTANT_HELP == role) return apply_chat_template(); else return ""; } - virtual bool isEnd(int id) { return id == GetEosID(); } + + virtual bool isEnd(int id) + { + return id == GetEosID(); + } }; std::shared_ptr CreateTokenizer(TokenizerType type); \ No newline at end of file diff --git a/projects/llm_framework/main_llm/src/runner/ax_model_runner/ax_model_runner.hpp b/projects/llm_framework/main_llm/src/runner/ax_model_runner/ax_model_runner.hpp index 551f65bd..c9ed4fa3 100644 --- a/projects/llm_framework/main_llm/src/runner/ax_model_runner/ax_model_runner.hpp +++ b/projects/llm_framework/main_llm/src/runner/ax_model_runner/ax_model_runner.hpp @@ -61,6 +61,9 @@ class ax_runner_base int get_num_inputs() { return minput_tensors.size(); }; int get_num_outputs() { return moutput_tensors.size(); }; + int get_num_input_groups() { return mgroup_input_tensors.size(); }; + int get_num_output_groups() { return mgroup_output_tensors.size(); }; + const ax_runner_tensor_t &get_input(int idx) { return minput_tensors[idx]; } const ax_runner_tensor_t *get_inputs_ptr() { return minput_tensors.data(); } const ax_runner_tensor_t &get_input(std::string name) diff --git a/projects/llm_framework/main_llm/src/runner/utils/http_utils.hpp b/projects/llm_framework/main_llm/src/runner/utils/http_utils.hpp new file mode 100644 index 00000000..5041b6a1 --- /dev/null +++ b/projects/llm_framework/main_llm/src/runner/utils/http_utils.hpp @@ -0,0 +1,102 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +/** + * @brief Attempts to establish a TCP connection to a specified host and port. + * + * This function creates a socket and tries to connect to a server specified by + * the host and port parameters. It returns true if the connection is successful, + * otherwise it returns false and outputs an error message to standard error. + * + * @param host The IP address of the server to connect to. + * @param port The port number of the server to connect to. + * + * @return true if the connection is successfully established, false otherwise. + */ +static bool test_connect(const std::string &host, int port) +{ + int sock = socket(AF_INET, SOCK_STREAM, 0); + if (sock < 0) + { + // std::cerr << "Socket creation failed\n"; + return false; + } + + sockaddr_in server_addr{}; + server_addr.sin_family = AF_INET; + server_addr.sin_port = htons(port); + + if (inet_pton(AF_INET, host.c_str(), &server_addr.sin_addr) <= 0) + { + // std::cerr << "IP address conversion failed\n"; + close(sock); + return false; + } + + if (connect(sock, (struct sockaddr *)&server_addr, sizeof(server_addr)) < 0) + { + // std::cerr << "Connection failed\n"; + close(sock); + return false; + } + + close(sock); + return true; +} + +/** + * @brief Attempts to establish an HTTP connection to a specified URL with a timeout. + * + * This function parses the provided HTTP URL to extract the host and port information, + * and attempts to establish a TCP connection using the `test_connect` function. + * It retries the connection until the specified timeout is reached. + * + * @param http_url The HTTP URL of the server to connect to. + * @param timeout The maximum number of seconds to keep attempting the connection. + * + * @return true if the connection is successfully established within the timeout period, + * false otherwise. + */ +static bool test_connect_http(const std::string &http_url, int timeout) +{ + size_t pos = http_url.find("://"); + if (pos == std::string::npos) + return false; + std::string host = http_url.substr(pos + 3); + pos = host.find('/'); + if (pos != std::string::npos) + host = host.substr(0, pos); + pos = host.find(':'); + int port = 80; + if (pos != std::string::npos) + { + port = std::stoi(host.substr(pos + 1)); + host = host.substr(0, pos); + } + else + { + return false; + } + + if (host == "localhost") + host = "127.0.0.1"; + int tmp = timeout; + while (timeout--) + { + if (test_connect(host, port)) + return true; + std::this_thread::sleep_for(std::chrono::seconds(1)); + printf("\033[1;30;31m" + "connect failed %s, try again in %2d/%2d \n" + "\033[0m", + http_url.c_str(), timeout, tmp); + } + return false; +} \ No newline at end of file diff --git a/projects/llm_framework/tools/llm_pack.py b/projects/llm_framework/tools/llm_pack.py index 1d2a89f0..1783e536 100755 --- a/projects/llm_framework/tools/llm_pack.py +++ b/projects/llm_framework/tools/llm_pack.py @@ -401,6 +401,7 @@ def create_bin_deb(package_name, version, src_folder, revision = 'm5stack1', dep 'llm-model-qwen2.5-0.5B-prefill-20e':[create_data_deb,'llm-model-qwen2.5-0.5B-prefill-20e', data_version, src_folder, revision], 'llm-model-qwen2.5-0.5B-p256-ax630c':[create_data_deb,'llm-model-qwen2.5-0.5B-p256-ax630c', '0.4', src_folder, revision], 'llm-model-qwen2.5-0.5B-Int4-ax630c':[create_data_deb,'llm-model-qwen2.5-0.5B-Int4-ax630c', '0.4', src_folder, revision], + 'llm-model-qwen2.5-HA-0.5B-ctx-ax630c':[create_data_deb,'llm-model-qwen2.5-HA-0.5B-ctx-ax630c', '0.5', src_folder, revision], 'llm-model-qwen2.5-1.5B-ax630c':[create_data_deb,'llm-model-qwen2.5-1.5B-ax630c', '0.3', src_folder, revision], 'llm-model-qwen2.5-1.5B-p256-ax630c':[create_data_deb,'llm-model-qwen2.5-1.5B-p256-ax630c', '0.4', src_folder, revision], 'llm-model-qwen2.5-1.5B-Int4-ax630c':[create_data_deb,'llm-model-qwen2.5-1.5B-Int4-ax630c', '0.4', src_folder, revision], From 9167b6e546d12353224e26b03c9db042a1a0d05d Mon Sep 17 00:00:00 2001 From: LittleMouse Date: Tue, 26 Aug 2025 17:20:31 +0800 Subject: [PATCH 30/79] [update] update llm_vlm encoder. update audio cache. --- projects/llm_framework/main_asr/src/main.cpp | 14 +- projects/llm_framework/main_kws/src/main.cpp | 16 +- projects/llm_framework/main_llm/src/main.cpp | 1 + projects/llm_framework/main_vad/src/main.cpp | 13 +- projects/llm_framework/main_vlm/src/main.cpp | 226 +- .../llm_framework/main_vlm/src/runner/LLM.hpp | 1167 ++++++++- .../src/runner/Tokenizer/QwenTokenizer.cpp | 127 - .../src/runner/Tokenizer/QwenTokenizer.hpp | 34 - .../src/runner/Tokenizer/Tokenizer.cpp | 576 ++--- .../src/runner/Tokenizer/Tokenizer.hpp | 21 +- .../main_vlm/src/runner/Tokenizer/base64.h | 54 - .../main_vlm/src/runner/Tokenizer/tiktoken.h | 326 --- .../src/runner/Tokenizer/unordered_dense.h | 2240 ----------------- .../ax_model_runner/ax_model_runner.hpp | 3 + .../main_vlm/src/runner/utils/http_utils.hpp | 102 + .../llm_framework/main_whisper/src/main.cpp | 28 +- 16 files changed, 1725 insertions(+), 3223 deletions(-) delete mode 100644 projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.cpp delete mode 100644 projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.hpp delete mode 100644 projects/llm_framework/main_vlm/src/runner/Tokenizer/base64.h delete mode 100644 projects/llm_framework/main_vlm/src/runner/Tokenizer/tiktoken.h delete mode 100644 projects/llm_framework/main_vlm/src/runner/Tokenizer/unordered_dense.h create mode 100644 projects/llm_framework/main_vlm/src/runner/utils/http_utils.hpp diff --git a/projects/llm_framework/main_asr/src/main.cpp b/projects/llm_framework/main_asr/src/main.cpp index 4a5f9703..ae51982b 100644 --- a/projects/llm_framework/main_asr/src/main.cpp +++ b/projects/llm_framework/main_asr/src/main.cpp @@ -58,7 +58,7 @@ class llm_task { std::atomic_bool audio_flage_; std::atomic_bool awake_flage_; int awake_delay_ = 50; - int delay_audio_frame_ = 10; + int delay_audio_frame_ = 11; buffer_t *pcmdata; std::function pause; @@ -187,18 +187,20 @@ class llm_task { count++; return; } - buffer_write_char(pcmdata, raw.data(), raw.length()); buffer_position_set(pcmdata, 0); - count = 0; + std::vector floatSamples; { int16_t audio_val; - while (buffer_read_u16(pcmdata, (unsigned short *)&audio_val, 1)) { - float normalizedSample = (float)audio_val / INT16_MAX; + while (buffer_read_i16(pcmdata, &audio_val, 1)) { + float normalizedSample = static_cast(audio_val) / INT16_MAX; floatSamples.push_back(normalizedSample); } } - buffer_position_set(pcmdata, 0); + + buffer_resize(pcmdata, 0); + count = 0; + if (awake_flage_ && recognizer_stream_) { recognizer_stream_.reset(); awake_flage_ = false; diff --git a/projects/llm_framework/main_kws/src/main.cpp b/projects/llm_framework/main_kws/src/main.cpp index 18e4dfd2..8c724599 100644 --- a/projects/llm_framework/main_kws/src/main.cpp +++ b/projects/llm_framework/main_kws/src/main.cpp @@ -59,7 +59,7 @@ class llm_task { bool enwake_audio_; std::atomic_bool audio_flage_; task_callback_t out_callback_; - int delay_audio_frame_ = 10; + int delay_audio_frame_ = 11; buffer_t *pcmdata; std::string wake_wav_file_; @@ -229,22 +229,24 @@ class llm_task { { static int count = 0; if (count < delay_audio_frame_) { - buffer_write_char(pcmdata, raw.c_str(), raw.length()); + buffer_write_char(pcmdata, raw.data(), raw.length()); count++; return; } - buffer_write_char(pcmdata, raw.data(), raw.length()); buffer_position_set(pcmdata, 0); - count = 0; + std::vector floatSamples; { int16_t audio_val; - while (buffer_read_u16(pcmdata, (unsigned short *)&audio_val, 1)) { - float normalizedSample = (float)audio_val / INT16_MAX; + while (buffer_read_i16(pcmdata, &audio_val, 1)) { + float normalizedSample = static_cast(audio_val) / INT16_MAX; floatSamples.push_back(normalizedSample); } } - buffer_position_set(pcmdata, 0); + + buffer_resize(pcmdata, 0); + count = 0; + spotter_stream_->AcceptWaveform(mode_config_.feat_config.sampling_rate, floatSamples.data(), floatSamples.size()); while (spotter_->IsReady(spotter_stream_.get())) { diff --git a/projects/llm_framework/main_llm/src/main.cpp b/projects/llm_framework/main_llm/src/main.cpp index e198cc9c..dcb64fb8 100644 --- a/projects/llm_framework/main_llm/src/main.cpp +++ b/projects/llm_framework/main_llm/src/main.cpp @@ -130,6 +130,7 @@ class llm_task { std::string base_model = base_model_path_ + model_ + "/"; SLOGI("base_model %s", base_model.c_str()); + CONFIG_AUTO_SET(file_body["mode_param"], system_prompt); CONFIG_AUTO_SET(file_body["mode_param"], tokenizer_type); CONFIG_AUTO_SET(file_body["mode_param"], filename_tokenizer_model); CONFIG_AUTO_SET(file_body["mode_param"], url_tokenizer_model); diff --git a/projects/llm_framework/main_vad/src/main.cpp b/projects/llm_framework/main_vad/src/main.cpp index 20a67eb3..1c49eb33 100644 --- a/projects/llm_framework/main_vad/src/main.cpp +++ b/projects/llm_framework/main_vad/src/main.cpp @@ -60,7 +60,7 @@ class llm_task { std::string superior_id_; task_callback_t out_callback_; int awake_delay_ = 50; - int delay_audio_frame_ = 3; + int delay_audio_frame_ = 4; buffer_t *pcmdata; std::string wake_wav_file_; @@ -158,18 +158,19 @@ class llm_task { count++; return; } - buffer_write_char(pcmdata, raw.data(), raw.length()); buffer_position_set(pcmdata, 0); - count = 0; + std::vector floatSamples; { int16_t audio_val; - while (buffer_read_u16(pcmdata, (unsigned short *)&audio_val, 1)) { - float normalizedSample = (float)audio_val / INT16_MAX; + while (buffer_read_i16(pcmdata, &audio_val, 1)) { + float normalizedSample = static_cast(audio_val) / INT16_MAX; floatSamples.push_back(normalizedSample); } } - buffer_position_set(pcmdata, 0); + buffer_resize(pcmdata, 0); + count = 0; + vad_->AcceptWaveform(floatSamples.data(), floatSamples.size()); if (vad_->IsSpeechDetected() && !printed) { diff --git a/projects/llm_framework/main_vlm/src/main.cpp b/projects/llm_framework/main_vlm/src/main.cpp index 6ee6dd67..83a56e51 100644 --- a/projects/llm_framework/main_vlm/src/main.cpp +++ b/projects/llm_framework/main_vlm/src/main.cpp @@ -52,6 +52,7 @@ class llm_task { public: LLMAttrType mode_config_; std::unique_ptr lLaMa_; + std::unique_ptr lLaMa_ctx_; std::string model_; std::string response_format_; std::vector inputs_; @@ -59,6 +60,11 @@ class llm_task { std::vector image_data_; std::vector img_embed; std::string prompt_; + std::string last_reply; + std::vector tokens_ids, tokens_diff; + std::vector> k_caches, v_caches; + int precompute_len = 0; + std::vector _token_ids; task_callback_t out_callback_; bool enoutput_; bool enstream_; @@ -122,8 +128,10 @@ class llm_task { std::string base_model = base_model_path_ + model_ + "/"; SLOGI("base_model %s", base_model.c_str()); + CONFIG_AUTO_SET(file_body["mode_param"], system_prompt); CONFIG_AUTO_SET(file_body["mode_param"], tokenizer_type); CONFIG_AUTO_SET(file_body["mode_param"], filename_tokenizer_model); + CONFIG_AUTO_SET(file_body["mode_param"], url_tokenizer_model); CONFIG_AUTO_SET(file_body["mode_param"], filename_tokens_embed); CONFIG_AUTO_SET(file_body["mode_param"], filename_post_axmodel); CONFIG_AUTO_SET(file_body["mode_param"], filename_vpm_resampler_axmodedl); @@ -139,64 +147,123 @@ class llm_task { CONFIG_AUTO_SET(file_body["mode_param"], b_use_mmap_load_embed); CONFIG_AUTO_SET(file_body["mode_param"], b_dynamic_load_axmodel_layer); CONFIG_AUTO_SET(file_body["mode_param"], max_token_len); + CONFIG_AUTO_SET(file_body["mode_param"], enable_temperature); CONFIG_AUTO_SET(file_body["mode_param"], temperature); + CONFIG_AUTO_SET(file_body["mode_param"], enable_top_p_sampling); CONFIG_AUTO_SET(file_body["mode_param"], top_p); + CONFIG_AUTO_SET(file_body["mode_param"], enable_top_k_sampling); + CONFIG_AUTO_SET(file_body["mode_param"], top_k); + CONFIG_AUTO_SET(file_body["mode_param"], enable_repetition_penalty); + CONFIG_AUTO_SET(file_body["mode_param"], repetition_penalty); + CONFIG_AUTO_SET(file_body["mode_param"], penalty_window); CONFIG_AUTO_SET(file_body["mode_param"], vpm_width); CONFIG_AUTO_SET(file_body["mode_param"], vpm_height); + CONFIG_AUTO_SET(file_body["mode_param"], precompute_len); + + { + auto has_http = [](const std::string &s) { return s.find("http") != std::string::npos; }; + + auto find_tokenizer_file = [this]() -> std::string { + const std::string base = "/opt/m5stack/scripts/"; + const std::string a = base + model_ + "_tokenizer.py"; + if (file_exists(a)) return a; + const std::string b = base + "tokenizer_" + model_ + ".py"; + if (file_exists(b)) return b; + SLOGE("%s or %s not found!", a.c_str(), b.c_str()); + return {}; + }; + + auto start_tokenizer_server = [&](const std::string &tokenizer_file) { + if (tokenizer_file.empty()) return; + if (tokenizer_server_flage_.load()) return; - if (mode_config_.filename_tokenizer_model.find("http:") != std::string::npos) { - mode_config_.filename_tokenizer_model = "http://localhost:" + std::to_string(port_); - std::string tokenizer_file; - if (file_exists(std::string("/opt/m5stack/scripts/") + model_ + std::string("_tokenizer.py"))) { - tokenizer_file = std::string("/opt/m5stack/scripts/") + model_ + std::string("_tokenizer.py"); - } else if (file_exists(std::string("/opt/m5stack/scripts/") + std::string("tokenizer_") + model_ + - std::string(".py"))) { - tokenizer_file = - std::string("/opt/m5stack/scripts/") + std::string("tokenizer_") + model_ + std::string(".py"); - } else { - std::string __log = model_ + std::string("_tokenizer.py"); - __log += " or "; - __log += std::string("tokenizer_") + model_ + std::string(".py"); - __log += " not found!"; - SLOGE("%s", __log.c_str()); - } - if (!tokenizer_server_flage_.load()) { tokenizer_pid_ = fork(); if (tokenizer_pid_ == 0) { - setenv("PYTHONPATH", "/opt/m5stack/lib/vlm/site-packages", 1); + setenv("PYTHONPATH", "/opt/m5stack/lib/llm/site-packages", 1); + const std::string port_str = std::to_string(port_); + const std::string model_id = base_model + "tokenizer"; + execl("/usr/bin/python3", "python3", tokenizer_file.c_str(), "--host", "localhost", "--port", - std::to_string(port_).c_str(), "--model_id", (base_model + "tokenizer").c_str(), - "--content", ("'" + prompt_ + "'").c_str(), nullptr); + port_str.c_str(), "--model_id", model_id.c_str(), "--content", prompt_.c_str(), + (char *)nullptr); + perror("execl failed"); - exit(1); + _exit(1); } + tokenizer_server_flage_.store(true); SLOGI("port_=%s model_id=%s content=%s", std::to_string(port_).c_str(), - (base_model + "tokenizer").c_str(), ("'" + prompt_ + "'").c_str()); + (base_model + std::string("tokenizer")).c_str(), prompt_.c_str()); + std::this_thread::sleep_for(std::chrono::seconds(15)); + }; + + auto process_field = [&](std::string &field, const char *name_for_log) -> bool { + if (!has_http(field)) return false; + + field = "http://localhost:" + std::to_string(port_); + const std::string tokenizer_file = find_tokenizer_file(); + start_tokenizer_server(tokenizer_file); + SLOGI("%s: %s", name_for_log, field.c_str()); + return true; + }; + + if (!process_field(mode_config_.filename_tokenizer_model, "filename_tokenizer_model") && + !process_field(mode_config_.url_tokenizer_model, "url_tokenizer_model")) { + mode_config_.filename_tokenizer_model = base_model + mode_config_.filename_tokenizer_model; + SLOGE("filename_tokenizer_model: %s", mode_config_.filename_tokenizer_model.c_str()); } - } else { - mode_config_.filename_tokenizer_model = base_model + mode_config_.filename_tokenizer_model; } - SLOGI("filename_tokenizer_model: %s", mode_config_.filename_tokenizer_model.c_str()); mode_config_.filename_tokens_embed = base_model + mode_config_.filename_tokens_embed; mode_config_.filename_post_axmodel = base_model + mode_config_.filename_post_axmodel; - mode_config_.filename_vpm_resampler_axmodedl = base_model + mode_config_.filename_vpm_resampler_axmodedl; mode_config_.template_filename_axmodel = base_model + mode_config_.template_filename_axmodel; - + mode_config_.filename_vpm_resampler_axmodedl = base_model + mode_config_.filename_vpm_resampler_axmodedl; mode_config_.runing_callback = [this](int *p_token, int n_token, const char *p_str, float token_per_sec, void *reserve) { if (this->out_callback_) { this->out_callback_(std::string(p_str), false); } }; - lLaMa_ = std::make_unique(); - if (!lLaMa_->Init(mode_config_)) { - lLaMa_->Deinit(); - lLaMa_.reset(); - return -2; + + if (mode_config_.precompute_len > 0) { + lLaMa_ctx_ = std::make_unique(); + if (!lLaMa_ctx_->Init(mode_config_)) { + lLaMa_ctx_->Deinit(); + lLaMa_ctx_.reset(); + return -2; + } + } else { + lLaMa_ = std::make_unique(); + if (!lLaMa_->Init(mode_config_)) { + lLaMa_->Deinit(); + lLaMa_.reset(); + return -2; + } } + if (lLaMa_ctx_) { + lLaMa_ctx_->SetSystemPrompt(mode_config_.system_prompt, _token_ids); + std::string kvcache_path = "/tmp/.vlm/"; + if (!kvcache_path.empty() && kvcache_path != "") { + if (lLaMa_ctx_->load_kvcache(kvcache_path, mode_config_.axmodel_num, k_caches, v_caches, + mode_config_.system_prompt, precompute_len)) { + ALOGI("load kvcache from path: %s success,precompute_len: %d", kvcache_path.c_str(), + precompute_len); + } else { + ALOGW("load kvcache from path: %s failed,generate kvcache", kvcache_path.c_str()); + lLaMa_ctx_->GenerateKVCachePrefill(_token_ids, k_caches, v_caches, precompute_len); + if (!lLaMa_ctx_->save_kvcache(kvcache_path, mode_config_.system_prompt, precompute_len, + k_caches, v_caches)) { + ALOGE("save kvcache failed"); + } + ALOGI("generate kvcache to path: %s", kvcache_path.c_str()); + } + } else { + lLaMa_ctx_->GenerateKVCachePrefill(_token_ids, k_caches, v_caches, precompute_len); + } + ALOGI("precompute_len: %d", precompute_len); + ALOGI("system_prompt: %s", mode_config_.system_prompt.c_str()); + } } catch (...) { SLOGE("config false"); return -3; @@ -208,19 +275,6 @@ class llm_task { { std::ostringstream oss_prompt; switch (mode_config_.tokenizer_type) { - case TKT_LLaMa: - oss_prompt << "<|user|>\n" << input << "<|assistant|>\n"; - break; - case TKT_MINICPM: - oss_prompt << "<用户>" << input << ""; - break; - case TKT_Phi3: - oss_prompt << input << " "; - break; - case TKT_Qwen: - oss_prompt << "<|im_start|>system\n" << prompt_ << ".<|im_end|>"; - oss_prompt << "\n<|im_start|>user\n" << input << "<|im_end|>\n<|im_start|>assistant\n"; - break; case TKT_HTTP: default: oss_prompt << input; @@ -255,32 +309,68 @@ class llm_task { void inference(const std::string &msg) { try { - if (encamera_) { - inference_async_par par; - async_list_.get(); // discard buffered frames - par = async_list_.get(); - if (par.inference_src.empty()) return; - if (par.inference_bgr2rgb) { - cv::Mat rgb; - cv::cvtColor(par.inference_src, rgb, cv::COLOR_BGR2RGB); - par.inference_src = rgb; + if (lLaMa_) { + if (encamera_) { + inference_async_par par; + async_list_.get(); // discard buffered frames + par = async_list_.get(); + if (par.inference_src.empty()) return; + if (par.inference_bgr2rgb) { + cv::Mat rgb; + cv::cvtColor(par.inference_src, rgb, cv::COLOR_BGR2RGB); + par.inference_src = rgb; + } + lLaMa_->Encode(par.inference_src, img_embed); + lLaMa_->Encode(img_embed, prompt_data_, prompt_complete(msg)); + std::string out = lLaMa_->Run(prompt_data_); + if (out_callback_) out_callback_(out, true); + } else if (image_data_.empty()) { + lLaMa_->Encode(prompt_data_, prompt_complete(msg)); + std::string out = lLaMa_->Run(prompt_data_); + if (out_callback_) out_callback_(out, true); + } else { + cv::Mat src = cv::imdecode(image_data_, cv::IMREAD_COLOR); + if (src.empty()) return; + image_data_.clear(); + lLaMa_->Encode(src, img_embed); + lLaMa_->Encode(img_embed, prompt_data_, prompt_complete(msg)); + std::string out = lLaMa_->Run(prompt_data_); + if (out_callback_) out_callback_(out, true); + } + } + + if (lLaMa_ctx_) { + if (image_data_.empty()) { + lLaMa_ctx_->Encode(prompt_data_, prompt_complete(prompt_), last_reply, tokens_ids, tokens_diff); + if (auto ret = lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size()); + ret != 0) { + ALOGE("SetKVCache failed: %d,the context may be full,input \"reset\" to reset context", ret); + return; + } + last_reply = lLaMa_ctx_->Run(prompt_data_); + lLaMa_ctx_->GetKVCache(k_caches, v_caches, precompute_len); } - lLaMa_->Encode(par.inference_src, img_embed); - lLaMa_->Encode(img_embed, prompt_data_, prompt_complete(msg)); - std::string out = lLaMa_->Run(prompt_data_); - if (out_callback_) out_callback_(out, true); - } else if (image_data_.empty()) { - lLaMa_->Encode(prompt_data_, prompt_complete(msg)); - std::string out = lLaMa_->Run(prompt_data_); - if (out_callback_) out_callback_(out, true); - } else { cv::Mat src = cv::imdecode(image_data_, cv::IMREAD_COLOR); if (src.empty()) return; image_data_.clear(); - lLaMa_->Encode(src, img_embed); - lLaMa_->Encode(img_embed, prompt_data_, prompt_complete(msg)); - std::string out = lLaMa_->Run(prompt_data_); - if (out_callback_) out_callback_(out, true); + std::vector img_embed; + if (auto ret = lLaMa_ctx_->Encode(src, img_embed); ret != 0) { + ALOGE("lLaMaCtx.Encode failed"); + return; + } + if (auto ret = + lLaMa_ctx_->Encode(img_embed, prompt_data_, prompt_complete(prompt_), tokens_ids, tokens_diff); + ret != 0) { + ALOGE("lLaMaCtx.Encode failed"); + return; + } + if (auto ret = lLaMa_ctx_->SetKVCache(k_caches, v_caches, precompute_len, tokens_diff.size()); + ret != 0) { + ALOGE("SetKVCache failed: %d,the context may be full,input \"reset\" to reset context", ret); + return; + } + last_reply = lLaMa_ctx_->Run(prompt_data_); + lLaMa_ctx_->GetKVCache(k_caches, v_caches, precompute_len); } } catch (...) { SLOGW("lLaMa_->Run have error!"); @@ -333,11 +423,9 @@ class llm_task { if (tokenizer_pid_ != -1) { kill(tokenizer_pid_, SIGTERM); waitpid(tokenizer_pid_, nullptr, WNOHANG); - // tokenizer_pid_ = -1; } if (lLaMa_) { lLaMa_->Deinit(); - // lLaMa_.reset(); } } }; diff --git a/projects/llm_framework/main_vlm/src/runner/LLM.hpp b/projects/llm_framework/main_vlm/src/runner/LLM.hpp index 2cbbf388..f7a8b560 100644 --- a/projects/llm_framework/main_vlm/src/runner/LLM.hpp +++ b/projects/llm_framework/main_vlm/src/runner/LLM.hpp @@ -14,39 +14,59 @@ #include "ax_sys_api.h" #include "LLMPostprocess.hpp" +#define ALIGN_DOWN(x, a) ((x) & ~((a) - 1)) + typedef std::function LLMRuningCallback; struct LLMAttrType { + std::string system_prompt; std::string template_filename_axmodel = "tinyllama-int8/tinyllama_l%d.axmodel"; int axmodel_num = 22; - // std::string template_prefill_filename_axmodel = "minicpmv/prefill_axmodel/minicpm_p96_l%d.axmodel"; - // int prefill_axmodel_num = 40; - int prefill_token_num = 96; - - std::string filename_post_axmodel = "tinyllama-int8/tinyllama_post.axmodel"; - + std::string filename_post_axmodel = "tinyllama-int8/tinyllama_post.axmodel"; + std::string filename_image_encoder_axmodedl = "minicpmv/vpm_resampler_version0_fp16.axmodel"; std::string filename_vpm_encoder_axmodedl = "minicpmv/vpm_resampler_version0_fp16.axmodel"; std::string filename_vpm_resampler_axmodedl = "minicpmv/vpm_resampler_version0_fp16.axmodel"; - int vpm_width = 280; - int vpm_height = 280; - bool b_vpm_two_stage = false; + + int image_encoder_width = 448; + int image_encoder_height = 448; + int vpm_width = 280; + int vpm_height = 280; + bool b_vpm_two_stage = false; + + int prefill_token_num = 96; + int prefill_max_token_num = 512; + std::vector prefill_max_kv_cache_num_grp; + int precompute_len = 0; + int prefill_grpid = -1; TokenizerType tokenizer_type = TKT_LLaMa; std::string filename_tokenizer_model = "tokenizer.model"; + std::string url_tokenizer_model; bool b_bos = true, b_eos = false; std::string filename_tokens_embed = "tinyllama.model.embed_tokens.weight.bfloat16.bin"; int tokens_embed_num = 32000; - int img_token_id = 151667; // InternVL2.5 + int img_token_id = 151667; int tokens_embed_size = 2048; - int max_token_len = 127; // auto calc + int max_token_len = 127; + + int kv_cache_num = 1024; + int kv_cache_size = 256; + + bool enable_temperature = false; + float temperature = 0.7f; + + bool enable_top_p_sampling = false; + float top_p = 0.7f; - int kv_cache_num = 1024; // auto calc - int kv_cache_size = 256; // auto calc + bool enable_top_k_sampling = false; + int top_k = 50; + + bool enable_repetition_penalty = false; + float repetition_penalty = 1.2f; + int penalty_window = 50; - float temperature = 0.7f; - float top_p = 0.9f; bool b_use_mmap_load_embed = false; bool b_dynamic_load_axmodel_layer = false; @@ -58,6 +78,11 @@ struct LLMAttrType { // bool b_live_print = true; LLMRuningCallback runing_callback = nullptr; void *reserve = nullptr; + + int IMAGE_CONTEXT_TOKEN = 151667; + int IMAGE_START_TOKEN = 151665; + int IMAGE_ENCODER_INPUT_NCHW = -1; + int IMAGE_ENCODER_OUTPUT_BF16 = -1; }; class LLM { @@ -82,8 +107,6 @@ class LLM { int prefill_grpid = 1; int decode_grpid = 0; - // std::vector> k_caches, v_caches; - bool b_stop = false; LLMPostprocess postprocess; @@ -111,17 +134,6 @@ class LLM { return false; } update_cqdm(&cqdm, 0, "count", "tokenizer init ok"); - // test code - // { - // std::vector output; - // tokenizer.Encode("Today is National", output); - // // print output - // for (size_t i = 0; i < output.size(); i++) - // { - // printf("%d ", output[i]); - // } - // printf("\n"); - // } if (!embed_selector.Init(attr.filename_tokens_embed, attr.tokens_embed_num, attr.tokens_embed_size, attr.b_use_mmap_load_embed)) { @@ -130,20 +142,8 @@ class LLM { return false; } update_cqdm(&cqdm, 1, "count", "embed_selector init ok"); - // test code - // { - // std::vector embed = embed_selector.getByIndex(123); - // printf("embed size: %d\n", embed.size()); - // for (int i = 0; i < embed.size(); i++) - // { - // bfloat16 bf16 = bfloat16(embed[i]); - // float val = bf16; - // printf("%d %0.22f\n", embed[i], val); - // } - // } llama_layers.resize(attr.axmodel_num); - // prefill_layers.resize(attr.prefill_axmodel_num); char axmodel_path[1024]; for (int i = 0; i < attr.axmodel_num; i++) { @@ -230,8 +230,6 @@ class LLM { int max_token_len = llama_layers[0].layer.get_input("mask").nSize / sizeof(unsigned short) - 1; _attr.max_token_len = max_token_len > _attr.max_token_len ? _attr.max_token_len : max_token_len; ALOGI("max_token_len : %d", _attr.max_token_len); - // auto &input_k_cache = llama_layers[0].layer.get_input("K_cache"); - // auto &output_k_cache_out = llama_layers[0].layer.get_output("K_cache_out"); _attr.kv_cache_size = llama_layers[0].layer.get_output("K_cache_out").nSize / sizeof(unsigned short); _attr.kv_cache_num = llama_layers[0].layer.get_input("K_cache").nSize / _attr.kv_cache_size / sizeof(unsigned short); @@ -308,14 +306,15 @@ class LLM { out_embed[i] = bfloat16(output_data[i]).data; } - // memcpy(out_embed.data(), vpm_resampler.get_output(0).pVirAddr, vpm_resampler.get_output(0).nSize); ALOGI("image encode time : %f ms, size : %d", t.cost(), out_embed.size()); return 0; } int Encode(std::vector &out_embed, std::string prompt = "What is in the image?") { - std::vector input_ids = tokenizer->Encode(prompt, false); + ImageInfo img_info; + img_info.img_prompt = false; + std::vector input_ids = tokenizer->Encode(prompt, img_info); if (input_ids.size() > _attr.prefill_token_num) { ALOGE("input_ids(%d) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num); return -1; @@ -326,16 +325,17 @@ class LLM { embed_selector.getByIndex(input_ids[i], out_embed.data() + i * _attr.tokens_embed_size); } - // memcpy(out_embed.data() + 5 * _attr.tokens_embed_size, vpm_resampler.get_output(0).pVirAddr, - // vpm_resampler.get_output(0).nSize); - return 0; } int Encode(std::vector &img_embed, std::vector &out_embed, std::string prompt = "What is in the image?") { - std::vector input_ids = tokenizer->Encode(prompt, true); + ImageInfo img_info; + img_info.img_prompt = true; + img_info.num_img = 1; + img_info.imgsz = _attr.image_encoder_width; + std::vector input_ids = tokenizer->Encode(prompt, img_info); // constexpr int img_token_id = 49190; // smolvlm // constexpr int img_token_id = 151667; // InternVL2.5 @@ -362,12 +362,6 @@ class LLM { return -1; } - // for (size_t i = 0; i < input_ids.size(); i++) - // { - // printf("%d ", input_ids[i]); - // } - // printf("\n"); - if (input_ids.size() > _attr.prefill_token_num) { ALOGE("input_ids(%d) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num); return -1; @@ -658,3 +652,1068 @@ class LLM { return final_out; } }; + +class LLM_CTX { +private: + std::shared_ptr tokenizer; + LLaMaEmbedSelector embed_selector; + + LLMAttrType _attr; + + struct LLMLayer { + ax_runner_ax650 layer; + std::string filename; + MMap layer_buffer; + std::vector layer_buffer_vec; + }; + + std::vector llama_layers; + ax_runner_ax650 llama_post; + ax_runner_ax650 image_encoder; + + // + int decode_grpid = 0; + + bool b_stop = false; + + LLMPostprocess postprocess; + static int post_process(LLMPostprocess &postprocess, unsigned short *p, int n, std::vector &history, + float *val = 0) + { + std::vector logits(n); + for (int i = 0; i < n; i++) { + unsigned int proc = p[i] << 16; + logits[i] = *reinterpret_cast(&proc); + } + + return postprocess.apply(logits, history); + } + +public: + bool Init(LLMAttrType attr) + { + ALOGI("LLM init start"); + t_cqdm cqdm = create_cqdm(attr.axmodel_num + 3, 32); + this->_attr = attr; + tokenizer = CreateTokenizer(attr.tokenizer_type); + if (!tokenizer->Init(attr.url_tokenizer_model)) { + ALOGE("tokenizer.Init(%s) failed", attr.url_tokenizer_model.c_str()); + return false; + } + std::vector _token_ids; + tokenizer->Reset(attr.system_prompt, _token_ids); + update_cqdm(&cqdm, 0, "count", "tokenizer init ok"); + + if (!embed_selector.Init(attr.filename_tokens_embed, attr.tokens_embed_num, attr.tokens_embed_size, + attr.b_use_mmap_load_embed)) { + ALOGE("embed_selector.Init(%s, %d, %d) failed", attr.filename_tokens_embed.c_str(), attr.tokens_embed_num, + attr.tokens_embed_size); + return false; + } + update_cqdm(&cqdm, 1, "count", "embed_selector init ok"); + + llama_layers.resize(attr.axmodel_num); + + char axmodel_path[1024]; + for (int i = 0; i < attr.axmodel_num; i++) { + sprintf(axmodel_path, attr.template_filename_axmodel.c_str(), i); + llama_layers[i].filename = axmodel_path; + + int ret = llama_layers[i].layer.init(llama_layers[i].filename.c_str(), false); + if (ret != 0) { + ALOGE("init axmodel(%s) failed", llama_layers[i].filename.c_str()); + return false; + } + int remain_cmm = get_remaining_cmm_size(); + sprintf(axmodel_path, "init %d axmodel ok,remain_cmm(%d MB)", i, remain_cmm); + update_cqdm(&cqdm, i + 2, "count", axmodel_path); + } + + int ret = llama_post.init(attr.filename_post_axmodel.c_str(), false); + if (ret != 0) { + ALOGE("init post axmodel(%s) failed", attr.filename_post_axmodel.c_str()); + return false; + } + int remain_cmm = get_remaining_cmm_size(); + sprintf(axmodel_path, "init post axmodel ok,remain_cmm(%d MB)", remain_cmm); + update_cqdm(&cqdm, attr.axmodel_num + 2, "count", axmodel_path); + + ret = image_encoder.init(attr.filename_image_encoder_axmodedl.c_str()); + if (ret != 0) { + ALOGE("init vpm axmodel(%s) failed", attr.filename_image_encoder_axmodedl.c_str()); + return false; + } + + _attr.IMAGE_CONTEXT_TOKEN = tokenizer->GetImgContextID(); + _attr.IMAGE_START_TOKEN = tokenizer->GetImgStartID(); + + ALOGI("IMAGE_CONTEXT_TOKEN: %d, IMAGE_START_TOKEN: %d", _attr.IMAGE_CONTEXT_TOKEN, _attr.IMAGE_START_TOKEN); + + _attr.IMAGE_ENCODER_INPUT_NCHW = -1; + for (size_t i = 1; i < image_encoder.get_input(0).vShape.size(); i++) { + if (image_encoder.get_input(0).vShape[i] == 3) { + if (i == 1) { + _attr.IMAGE_ENCODER_INPUT_NCHW = 1; + } else if (i == 3) { + _attr.IMAGE_ENCODER_INPUT_NCHW = 0; + } + } + } + if (_attr.IMAGE_ENCODER_INPUT_NCHW == -1) { + ALOGE("image encoder input nchw or nhwc not found"); + return false; + } + + if (_attr.IMAGE_ENCODER_INPUT_NCHW) { + ALOGI("image encoder input nchw@float32"); + _attr.image_encoder_height = image_encoder.get_input(0).vShape[2]; + _attr.image_encoder_width = image_encoder.get_input(0).vShape[3]; + } else { + ALOGI("image encoder input nhwc@uint8"); + _attr.image_encoder_height = image_encoder.get_input(0).vShape[1]; + _attr.image_encoder_width = image_encoder.get_input(0).vShape[2]; + } + + if (_attr.image_encoder_height != _attr.image_encoder_width) { + ALOGE("image encoder height != width"); + return false; + } + int output_elem_size = 1; + for (int i = 0; i < image_encoder.get_output(0).vShape.size(); i++) { + output_elem_size *= image_encoder.get_output(0).vShape[i]; + } + + if (output_elem_size * 2 == image_encoder.get_output(0).nSize) { + _attr.IMAGE_ENCODER_OUTPUT_BF16 = 1; + ALOGI("image encoder output bf16"); + } else if (output_elem_size * 4 == image_encoder.get_output(0).nSize) { + _attr.IMAGE_ENCODER_OUTPUT_BF16 = 0; + ALOGI("image encoder output float32"); + } else { + ALOGE("image encoder output not support"); + return false; + } + + printf("\n"); + + { + ALOGI("image_encoder_height : %d, image_encoder_width: %d", _attr.image_encoder_height, + _attr.image_encoder_width); + _attr.max_token_len = llama_layers[0].layer.get_input("mask").nSize / sizeof(unsigned short) - 1; + ALOGI("max_token_len : %d", _attr.max_token_len); + _attr.kv_cache_size = llama_layers[0].layer.get_output("K_cache_out").nSize / sizeof(unsigned short); + _attr.kv_cache_num = + llama_layers[0].layer.get_input("K_cache").nSize / _attr.kv_cache_size / sizeof(unsigned short); + ALOGI("kv_cache_size : %d, kv_cache_num: %d", _attr.kv_cache_size, _attr.kv_cache_num); + if (_attr.max_token_len > _attr.kv_cache_num) { + ALOGE("max_token_len(%d) > kv_cache_num(%d)", _attr.max_token_len, _attr.kv_cache_num); + return false; + } + + _attr.prefill_token_num = llama_layers[0].layer.get_input(1, "indices").vShape[1]; + ALOGI("prefill_token_num : %d", _attr.prefill_token_num); + for (size_t i = 0; i < llama_layers[0].layer.get_num_input_groups() - 1; i++) { + int prefill_max_kv_cache_num = llama_layers[0].layer.get_input(i + 1, "K_cache").vShape[1]; + ALOGI("grp: %d, prefill_max_token_num : %d", i + 1, prefill_max_kv_cache_num); + _attr.prefill_max_kv_cache_num_grp.push_back(prefill_max_kv_cache_num); + } + _attr.prefill_max_token_num = + _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1]; + ALOGI("prefill_max_token_num : %d", _attr.prefill_max_token_num); + } + + if (!postprocess.load_config(attr.post_config_path)) { + ALOGW("load postprocess config(%s) failed", attr.post_config_path.c_str()); + } + + ALOGI("LLM init ok"); + return true; + } + + LLMAttrType *getAttr() + { + return &_attr; + } + + LLMPostprocess *getPostprocess() + { + return &postprocess; + } + + void Deinit() + { + for (int i = 0; i < _attr.axmodel_num; i++) { + llama_layers[i].layer.release(); + } + llama_post.release(); + embed_selector.Deinit(); + } + + void Stop() + { + b_stop = true; + } + + int SetSystemPrompt(std::string system_prompt, std::vector &_token_ids) + { + tokenizer->Reset(system_prompt, _token_ids); + _attr.system_prompt = system_prompt; + _attr.prefill_max_token_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1]; + return 0; + } + + int GenerateKVCachePrefill(std::vector &_token_ids, std::vector> &k_caches, + std::vector> &v_caches, int &precompute_len) + { + bfloat16 bf16 = -65536.f; + int input_embed_num = _token_ids.size(); + precompute_len = _token_ids.size(); + + k_caches.resize(_attr.axmodel_num); + v_caches.resize(_attr.axmodel_num); + int prefill_split_num = ceil((double)input_embed_num / _attr.prefill_token_num); + + int prefill_grpid = _attr.prefill_max_kv_cache_num_grp.size(); + + for (size_t i = 0; i < _attr.prefill_max_kv_cache_num_grp.size(); i++) { + if (input_embed_num <= _attr.prefill_max_kv_cache_num_grp[i]) { + prefill_grpid = i + 1; + break; + } + } + ALOGI("input token num : %d, prefill_split_num : %d prefill_grpid : %d", input_embed_num, prefill_split_num, + prefill_grpid); + + for (size_t i = 0; i < _attr.axmodel_num; i++) { + memset((void *)llama_layers[i].layer.get_input(prefill_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(prefill_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(prefill_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(prefill_grpid, "V_cache").nSize); + } + + if (input_embed_num == 0) { + for (size_t i = 0; i < _attr.axmodel_num; i++) { + k_caches[i].resize(precompute_len * _attr.kv_cache_size); + v_caches[i].resize(precompute_len * _attr.kv_cache_size); + } + ALOGI("input token num is 0, skip"); + return 0; + } + + int kv_cache_num = _attr.prefill_max_kv_cache_num_grp[prefill_grpid - 1]; + + std::vector test_embed; + test_embed.resize(_token_ids.size() * _attr.tokens_embed_size); + + for (size_t i = 0; i < _token_ids.size(); i++) { + embed_selector.getByIndex(_token_ids[i], test_embed.data() + i * _attr.tokens_embed_size); + } + + for (size_t p = 0; p < prefill_split_num; p++) { + std::vector mask_tmp; + mask_tmp.resize(1 * _attr.prefill_token_num * (kv_cache_num + _attr.prefill_token_num), bf16.data); + int input_num_token = _attr.prefill_token_num; + if (p == prefill_split_num - 1) { + input_num_token = input_embed_num - p * _attr.prefill_token_num; + } + + ALOGI("input_num_token:%d", input_num_token); + for (size_t i = 0; i < _attr.prefill_token_num; i++) { + if (i < input_num_token) { + int mask_current_start = kv_cache_num; + auto mask_ptr = mask_tmp.data() + i * (kv_cache_num + _attr.prefill_token_num); + + for (int j = 0; j < p * _attr.prefill_token_num; j++) { + mask_ptr[j] = 0; + } + + for (int j = mask_current_start; j < mask_current_start + i + 1; j++) { + mask_ptr[j] = 0; + } + } + } + + std::vector embed_tmp(_attr.prefill_token_num * _attr.tokens_embed_size, 0); + if (p == (prefill_split_num - 1)) { + memcpy( + embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + (input_embed_num - p * _attr.prefill_token_num) * _attr.tokens_embed_size * sizeof(unsigned short)); + } else { + memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + _attr.prefill_token_num * _attr.tokens_embed_size * sizeof(unsigned short)); + } + + for (unsigned int m = 0; m < _attr.axmodel_num; m++) { + auto &layer = llama_layers[m]; + // set indices + auto &input_indices = layer.layer.get_input(prefill_grpid, "indices"); + unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr; + memset(input_indices_ptr, 0, input_indices.nSize); + int idx = 0; + for (unsigned int i = p * _attr.prefill_token_num; i < (p + 1) * _attr.prefill_token_num; i++) { + input_indices_ptr[idx] = i; + idx++; + } + + // set mask + auto &input_mask = layer.layer.get_input(prefill_grpid, "mask"); + memcpy((void *)input_mask.pVirAddr, (void *)mask_tmp.data(), mask_tmp.size() * sizeof(unsigned short)); + + auto &input_input = layer.layer.get_input(prefill_grpid, "input"); + memcpy((void *)input_input.pVirAddr, embed_tmp.data(), embed_tmp.size() * sizeof(unsigned short)); + + layer.layer.inference(prefill_grpid); + + auto &input_decoder_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + auto &input_decoder_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + + auto &input_prefill_k_cache = layer.layer.get_input(prefill_grpid, "K_cache"); + auto &input_prefill_v_cache = layer.layer.get_input(prefill_grpid, "V_cache"); + + auto &output_k_cache = layer.layer.get_output(prefill_grpid, "K_cache_out"); + auto &output_v_cache = layer.layer.get_output(prefill_grpid, "V_cache_out"); + + int kv_offset = (p * _attr.prefill_token_num) * _attr.kv_cache_size; + + memcpy((unsigned short *)input_decoder_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + memcpy((unsigned short *)input_decoder_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.prefill_token_num * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(prefill_grpid, "output"); + memcpy(embed_tmp.data(), (void *)output.pVirAddr, embed_tmp.size() * sizeof(unsigned short)); + } + } + + for (size_t i = 0; i < _attr.axmodel_num; i++) { + auto &layer = llama_layers[i]; + k_caches[i].resize(precompute_len * _attr.kv_cache_size); + v_caches[i].resize(precompute_len * _attr.kv_cache_size); + auto &input_k_cache = layer.layer.get_input(prefill_grpid, "K_cache"); + auto &input_v_cache = layer.layer.get_input(prefill_grpid, "V_cache"); + memcpy((void *)k_caches[i].data(), (void *)input_k_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy((void *)v_caches[i].data(), (void *)input_v_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + + return 0; + } + + int GenerateKVCache(std::vector &_token_ids) + { + for (size_t i = 0; i < _attr.axmodel_num; i++) { + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "V_cache").nSize); + } + + bfloat16 bf16 = -65536.f; + std::vector mask(_attr.kv_cache_num + 1, bf16.data); + mask[_attr.kv_cache_num] = 0; + std::vector embed; + + int next_token = _token_ids[0]; + + t_cqdm cqdm = create_cqdm(_token_ids.size(), 32); + + for (unsigned int indices = 0; indices < _token_ids.size(); indices++) { + embed_selector.getByIndex(next_token, embed); + + for (int m = 0; m < _attr.axmodel_num; m++) { + if (b_stop) { + break; + } + + auto &layer = llama_layers[m]; + + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + auto &input_indices = layer.layer.get_input(decode_grpid, "indices"); + memcpy(input_indices.pVirAddr, &indices, sizeof(indices)); + + auto &input_mask = layer.layer.get_input(decode_grpid, "mask"); + memcpy(input_mask.pVirAddr, mask.data(), mask.size() * sizeof(unsigned short)); + + auto &input_input = layer.layer.get_input(decode_grpid, "input"); + memcpy(input_input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + + layer.layer.inference(decode_grpid); + + auto &output_k_cache = layer.layer.get_output(decode_grpid, "K_cache_out"); + memcpy(input_k_cache_ptr + indices * _attr.kv_cache_size, output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output_v_cache = layer.layer.get_output(decode_grpid, "V_cache_out"); + memcpy(input_v_cache_ptr + indices * _attr.kv_cache_size, output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(decode_grpid, "output"); + memcpy(embed.data(), output.pVirAddr, embed.size() * sizeof(unsigned short)); + } + mask[indices] = 0; + next_token = _token_ids[indices + 1]; + update_cqdm(&cqdm, indices, "token", ""); + // ALOGI(""); + } + return 0; + } + + int GetKVCache(std::vector> &k_caches, + std::vector> &v_caches, int &precompute_len) + { + bfloat16 bf16 = -65536.f; + std::vector mask(_attr.kv_cache_num + 1, bf16.data); + auto &input_mask = llama_layers[0].layer.get_input(decode_grpid, "mask"); + memcpy(mask.data(), (void *)input_mask.pVirAddr, input_mask.nSize); + for (size_t i = 0; i < mask.size(); i++) { + if (mask[i] == bf16.data) { + precompute_len = i + 1; + break; + } + } + ALOGI("precompute_len:%d, remaining:%d", precompute_len, + _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1] - precompute_len); + k_caches.resize(_attr.axmodel_num); + v_caches.resize(_attr.axmodel_num); + for (size_t i = 0; i < _attr.axmodel_num; i++) { + auto &layer = llama_layers[i]; + k_caches[i].resize(precompute_len * _attr.kv_cache_size); + v_caches[i].resize(precompute_len * _attr.kv_cache_size); + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + memcpy((void *)k_caches[i].data(), (void *)input_k_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy((void *)v_caches[i].data(), (void *)input_v_cache.pVirAddr, + precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + + _attr.prefill_max_token_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_max_kv_cache_num_grp.size() - 1]; + + return 0; + } + + int SetKVCache(std::vector> &k_caches, + std::vector> &v_caches, int precompute_len, int input_num_token) + { + _attr.precompute_len = precompute_len; + for (size_t i = 0; i < _attr.prefill_max_kv_cache_num_grp.size(); i++) { + if (_attr.precompute_len + input_num_token <= _attr.prefill_max_kv_cache_num_grp[i]) { + _attr.prefill_grpid = i + 1; + break; + } + } + int kv_cache_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_grpid - 1]; + ALOGI("prefill_grpid:%d kv_cache_num:%d precompute_len:%d input_num_token:%d", _attr.prefill_grpid, + kv_cache_num, precompute_len, input_num_token); + + _attr.prefill_max_token_num = + ALIGN_DOWN(_attr.prefill_max_token_num - _attr.precompute_len, _attr.prefill_token_num); + ALOGI("current prefill_max_token_num:%d", _attr.prefill_max_token_num); + + if (precompute_len == 0) { + ALOGI("first run"); + return 0; + } + + if (precompute_len + input_num_token > kv_cache_num) { + ALOGE("precompute_len(%d) + input_num_token(%d) > _attr.prefill_max_kv_cache_num_grp[%d]", precompute_len, + input_num_token, _attr.prefill_grpid - 1); + return -1; + } + + if (input_num_token > _attr.prefill_max_token_num) { + ALOGE("input_num_token(%d) > _attr.prefill_max_token_num(%d)", input_num_token, + _attr.prefill_max_token_num); + return -1; + } + + if (k_caches.size() != v_caches.size()) { + ALOGE("k_caches.size(%d) != v_caches.size(%d)", k_caches.size(), v_caches.size()); + return -1; + } + + if (k_caches.size() != _attr.axmodel_num) { + ALOGE("k_caches.size(%d) != _attr.axmodel_num(%d)", k_caches.size(), _attr.axmodel_num); + return -1; + } + + // clear kv cache + for (size_t i = 0; i < _attr.axmodel_num; i++) { + memset((void *)llama_layers[i].layer.get_input(_attr.prefill_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(_attr.prefill_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(_attr.prefill_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(_attr.prefill_grpid, "V_cache").nSize); + + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "K_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "K_cache").nSize); + memset((void *)llama_layers[i].layer.get_input(decode_grpid, "V_cache").pVirAddr, 0, + llama_layers[i].layer.get_input(decode_grpid, "V_cache").nSize); + } + + for (unsigned int m = 0; m < _attr.axmodel_num; m++) { + auto &layer = llama_layers[m]; + + auto &k_cache = k_caches[m]; + auto &v_cache = v_caches[m]; + + if (k_cache.size() != _attr.precompute_len * _attr.kv_cache_size) { + ALOGE("k_cache.size(%d) != precompute_len(%d) * _attr.kv_cache_size(%d)", k_cache.size(), + _attr.precompute_len, _attr.kv_cache_size); + return -1; + } + if (v_cache.size() < _attr.precompute_len * _attr.kv_cache_size) { + ALOGE("v_cache.size(%d) < precompute_len(%d) * _attr.kv_cache_size(%d)", v_cache.size(), + _attr.precompute_len, _attr.kv_cache_size); + return -1; + } + + // set kv cache inputs + { + auto &input_k_cache = layer.layer.get_input(_attr.prefill_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(_attr.prefill_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + memcpy(input_k_cache_ptr, k_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy(input_v_cache_ptr, v_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + + { + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + memcpy(input_k_cache_ptr, k_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + memcpy(input_v_cache_ptr, v_cache.data(), + _attr.precompute_len * _attr.kv_cache_size * sizeof(unsigned short)); + } + } + + return 0; + } + + bool save_kvcache(std::string target_path, std::string system_prompt, int precompute_len, + std::vector> &k_caches, + std::vector> &v_caches) + { + for (size_t i = 0; i < k_caches.size(); i++) { + std::string k_cache_path = target_path + "/k_cache_" + std::to_string(i) + ".bin"; + std::string v_cache_path = target_path + "/v_cache_" + std::to_string(i) + ".bin"; + std::ofstream k_cache_file(k_cache_path); + std::ofstream v_cache_file(v_cache_path); + if (!k_cache_file.is_open() || !v_cache_file.is_open()) { + ALOGE("save kvcache failed"); + return false; + } + k_cache_file.write((char *)k_caches[i].data(), k_caches[i].size() * sizeof(unsigned short)); + v_cache_file.write((char *)v_caches[i].data(), v_caches[i].size() * sizeof(unsigned short)); + k_cache_file.close(); + v_cache_file.close(); + } + nlohmann::json j; + j["system_prompt"] = system_prompt; + j["precompute_len"] = precompute_len; + std::string config_path = target_path + "/config.json"; + std::ofstream config_file(config_path); + config_file << j.dump(); + config_file.close(); + return true; + } + + bool load_kvcache(std::string target_path, int axmodel_num, std::vector> &k_caches, + std::vector> &v_caches, std::string &system_prompt, + int &precompute_len) + { + k_caches.resize(axmodel_num); + v_caches.resize(axmodel_num); + for (size_t i = 0; i < k_caches.size(); i++) { + std::string k_cache_path = target_path + "/k_cache_" + std::to_string(i) + ".bin"; + std::string v_cache_path = target_path + "/v_cache_" + std::to_string(i) + ".bin"; + if (file_exist(k_cache_path) && file_exist(v_cache_path)) { + std::vector k_cache; + std::vector v_cache; + std::ifstream k_cache_file(k_cache_path); + std::ifstream v_cache_file(v_cache_path); + + k_cache_file.seekg(0, std::ios::end); + k_cache.resize(k_cache_file.tellg() / sizeof(unsigned short)); + k_cache_file.seekg(0, std::ios::beg); + + v_cache_file.seekg(0, std::ios::end); + v_cache.resize(v_cache_file.tellg() / sizeof(unsigned short)); + v_cache_file.seekg(0, std::ios::beg); + + k_cache_file.read((char *)k_cache.data(), k_cache.size() * sizeof(unsigned short)); + v_cache_file.read((char *)v_cache.data(), v_cache.size() * sizeof(unsigned short)); + + k_cache_file.close(); + v_cache_file.close(); + k_caches[i] = k_cache; + v_caches[i] = v_cache; + } else { + ALOGE("k_cache %s or v_cache %s not exist", k_cache_path.c_str(), v_cache_path.c_str()); + return false; + } + } + + std::string config_path = target_path + "/config.json"; + if (file_exist(config_path)) { + std::ifstream config_file(config_path); + nlohmann::json j; + config_file >> j; + system_prompt = j["system_prompt"].get(); + precompute_len = j["precompute_len"].get(); + config_file.close(); + } else { + ALOGE("config %s not exist", config_path.c_str()); + return false; + } + return true; + } + + int Encode(cv::Mat src, std::vector &out_embed) + { + timer t; + t.start(); + if (_attr.IMAGE_ENCODER_INPUT_NCHW) { + std::vector mean = {0.485, 0.456, 0.406}; + std::vector scale = {0.229, 0.224, 0.225}; + + cv::Mat dst; + cv::resize(src, dst, cv::Size(_attr.image_encoder_width, _attr.image_encoder_height)); + cv::cvtColor(dst, dst, cv::COLOR_BGR2RGB); + + float *input_data = (float *)image_encoder.get_input(0).pVirAddr; + + unsigned char *img_data = dst.data; + int letterbox_rows = dst.rows; + int letterbox_cols = dst.cols; + + for (int h = 0; h < letterbox_rows; h++) { + for (int w = 0; w < letterbox_cols; w++) { + for (int c = 0; c < 3; c++) { + int in_index = h * letterbox_cols * 3 + w * 3 + c; + int out_index = c * letterbox_rows * letterbox_cols + h * letterbox_cols + w; + input_data[out_index] = (float(img_data[in_index]) / 255.0 - mean[c]) / scale[c]; + } + } + } + image_encoder.inference(); + } else { + cv::Mat dst; + cv::resize(src, dst, cv::Size(_attr.image_encoder_width, _attr.image_encoder_height)); + cv::cvtColor(dst, dst, cv::COLOR_BGR2RGB); + void *data = image_encoder.get_input(0).pVirAddr; + memcpy(data, dst.data, dst.rows * dst.cols * 3); + image_encoder.inference(); + } + + int size = 1; + for (size_t i = 0; i < image_encoder.get_output(0).vShape.size(); i++) { + size *= image_encoder.get_output(0).vShape[i]; + } + + out_embed.resize(size); + + if (_attr.IMAGE_ENCODER_OUTPUT_BF16) + memcpy(out_embed.data(), image_encoder.get_output(0).pVirAddr, image_encoder.get_output(0).nSize); + else { + float *out_data = (float *)image_encoder.get_output(0).pVirAddr; + for (size_t i = 0; i < size; i++) { + out_embed[i] = bfloat16(out_data[i]).data; + } + } + + ALOGI("image encode time : %0.2f ms, size : %ld", t.cost(), out_embed.size()); + return 0; + } + + int Encode(std::vector srcs, std::vector> &out_embeds) + { + out_embeds.resize(srcs.size()); + for (size_t i = 0; i < srcs.size(); i++) { + auto ret = Encode(srcs[i], out_embeds[i]); + if (ret != 0) { + ALOGE("Encode image failed"); + return -1; + } + } + + return 0; + } + + int Encode(std::vector &out_embed, std::string prompt = "What is in the image?") + { + ImageInfo img_info; + img_info.img_prompt = false; + std::vector input_ids = tokenizer->Encode(prompt, img_info); + if (input_ids.size() > _attr.prefill_token_num) { + ALOGE("input_ids(%ld) > prefill_token_num(%d)", input_ids.size(), _attr.prefill_token_num); + return -1; + } + out_embed.resize(input_ids.size() * _attr.tokens_embed_size); + + for (size_t i = 0; i < input_ids.size(); i++) { + embed_selector.getByIndex(input_ids[i], out_embed.data() + i * _attr.tokens_embed_size); + } + + return 0; + } + + int Encode(std::vector> &imgs_embed, std::vector &out_embed, + std::string prompt, std::vector &tokens_ids, std::vector &tokens_diff) + { + ImageInfo img_info; + img_info.img_prompt = true; + img_info.num_img = imgs_embed.size(); + img_info.imgsz = _attr.image_encoder_width; + std::vector input_ids = tokenizer->Encode_ctx(prompt, img_info, tokens_ids, tokens_diff); + + std::vector img_start_index; + for (size_t i = 0; i < input_ids.size(); i++) { + if (input_ids[i] == _attr.IMAGE_START_TOKEN) { + img_start_index.push_back(i); + } + } + + if (img_start_index.size() != imgs_embed.size()) { + ALOGE("img_start_index.size() != imgs_embed.size(), img_start_index.size() : %ld, imgs_embed.size() : %ld", + img_start_index.size(), imgs_embed.size()); + + printf("input_ids : "); + for (size_t i = 0; i < input_ids.size(); i++) { + printf("%d ", input_ids[i]); + } + printf("\n"); + + return -1; + } + + if (input_ids.size() > _attr.prefill_max_token_num) { + ALOGE("input_ids(%ld) > prefill_max_token_num(%d)", input_ids.size(), _attr.prefill_max_token_num); + return -1; + } + out_embed.resize(input_ids.size() * _attr.tokens_embed_size); + + for (size_t i = 0; i < input_ids.size(); i++) { + embed_selector.getByIndex(input_ids[i], out_embed.data() + i * _attr.tokens_embed_size); + } + for (size_t i = 0; i < imgs_embed.size(); i++) { + int offset = img_start_index[i] + 1; + auto &img_embed = imgs_embed[i]; + + int img_context_count = 0; + for (size_t j = offset; j < input_ids.size(); j++) { + if (input_ids[j] == _attr.IMAGE_CONTEXT_TOKEN) { + img_context_count++; + } else { + break; + } + } + + if (img_context_count != img_embed.size() / _attr.tokens_embed_size) { + ALOGE("img_context_count(%d) != img_embed.size() / tokens_embed_size(%ld)", img_context_count, + img_embed.size() / _attr.tokens_embed_size); + return -1; + } + + memcpy(out_embed.data() + offset * _attr.tokens_embed_size, img_embed.data(), + img_embed.size() * sizeof(unsigned short)); + ALOGI("idx:%ld offset : %d out_embed.size() : %ld", i, offset, out_embed.size()); + } + + return 0; + } + + int Encode(std::vector &img_embed, std::vector &out_embed, std::string prompt, + std::vector &tokens_ids, std::vector &tokens_diff) + { + std::vector> imgs_embed = {img_embed}; + return Encode(imgs_embed, out_embed, prompt, tokens_ids, tokens_diff); + } + + int Encode(std::vector &out_embed, std::string prompt, std::string last_reply, + std::vector &tokens_ids, std::vector &tokens_diff) + { + ImageInfo img_info; + img_info.img_prompt = false; + if (!tokenizer->Encode(prompt, last_reply, tokens_ids, tokens_diff, img_info)) { + ALOGE("encode failed"); + return -1; + } + + out_embed.resize(tokens_diff.size() * _attr.tokens_embed_size); + + for (size_t i = 0; i < tokens_diff.size(); i++) { + embed_selector.getByIndex(tokens_diff[i], out_embed.data() + i * _attr.tokens_embed_size); + } + + return 0; + } + + std::string Run(std::vector test_embed) + { + b_stop = false; + std::string final_out; + + bfloat16 bf16 = -65536.f; + std::vector mask(_attr.kv_cache_num + 1, bf16.data); + std::vector embed(_attr.tokens_embed_size, 0); + int kv_cache_num = _attr.prefill_max_kv_cache_num_grp[_attr.prefill_grpid - 1]; + + std::vector cached_token; + std::vector token_ids; + + int input_embed_num = test_embed.size() / _attr.tokens_embed_size; + int prefill_split_num = ceil((double)input_embed_num / _attr.prefill_token_num); + ALOGI("input token num : %d, prefill_split_num : %d", input_embed_num, prefill_split_num); + + mask[_attr.kv_cache_num] = 0; + for (size_t i = 0; i < _attr.precompute_len + input_embed_num; i++) { + mask[i] = 0; + } + timer t_cost; + timer ttft_timer; + ttft_timer.start(); + + for (size_t p = 0; p < prefill_split_num; p++) { + if (b_stop) { + break; + } + + std::vector mask_tmp; + mask_tmp.resize(1 * _attr.prefill_token_num * (kv_cache_num + _attr.prefill_token_num), bf16.data); + int input_num_token = _attr.prefill_token_num; + if (p == prefill_split_num - 1) { + input_num_token = input_embed_num - p * _attr.prefill_token_num; + } + + ALOGI("input_num_token:%d", input_num_token); + for (size_t i = 0; i < _attr.prefill_token_num; i++) { + if (i < input_num_token) { + int mask_current_start = kv_cache_num; + auto mask_ptr = mask_tmp.data() + i * (kv_cache_num + _attr.prefill_token_num); + + for (int j = 0; j < _attr.precompute_len + p * _attr.prefill_token_num; j++) { + mask_ptr[j] = 0; + } + + for (int j = mask_current_start; j < mask_current_start + i + 1; j++) { + mask_ptr[j] = 0; + } + } + } + + std::vector embed_tmp(_attr.prefill_token_num * _attr.tokens_embed_size, 0); + if (p == (prefill_split_num - 1)) { + memcpy( + embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + (input_embed_num - p * _attr.prefill_token_num) * _attr.tokens_embed_size * sizeof(unsigned short)); + } else { + memcpy(embed_tmp.data(), test_embed.data() + p * _attr.prefill_token_num * _attr.tokens_embed_size, + _attr.prefill_token_num * _attr.tokens_embed_size * sizeof(unsigned short)); + } + + for (unsigned int m = 0; m < _attr.axmodel_num; m++) { + if (b_stop) { + break; + } + + auto &layer = llama_layers[m]; + + // set indices + auto &input_indices = layer.layer.get_input(_attr.prefill_grpid, "indices"); + unsigned int *input_indices_ptr = (unsigned int *)input_indices.pVirAddr; + memset(input_indices_ptr, 0, input_indices.nSize); + int idx = 0; + for (unsigned int i = _attr.precompute_len + p * _attr.prefill_token_num; + i < _attr.precompute_len + (p + 1) * _attr.prefill_token_num; i++) { + input_indices_ptr[idx] = i; + idx++; + } + + // set mask + auto &input_mask = layer.layer.get_input(_attr.prefill_grpid, "mask"); + memcpy((void *)input_mask.pVirAddr, (void *)mask_tmp.data(), mask_tmp.size() * sizeof(unsigned short)); + + // set input + auto &input_input = layer.layer.get_input(_attr.prefill_grpid, "input"); + memcpy((void *)input_input.pVirAddr, embed_tmp.data(), embed_tmp.size() * sizeof(unsigned short)); + + layer.layer.inference(_attr.prefill_grpid); + + auto &input_decoder_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + auto &input_decoder_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + + auto &input_prefill_k_cache = layer.layer.get_input(_attr.prefill_grpid, "K_cache"); + auto &input_prefill_v_cache = layer.layer.get_input(_attr.prefill_grpid, "V_cache"); + + auto &output_k_cache = layer.layer.get_output(_attr.prefill_grpid, "K_cache_out"); + auto &output_v_cache = layer.layer.get_output(_attr.prefill_grpid, "V_cache_out"); + + int kv_offset = (_attr.precompute_len + p * _attr.prefill_token_num) * _attr.kv_cache_size; + + memcpy((unsigned short *)input_decoder_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + memcpy((unsigned short *)input_decoder_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_k_cache.pVirAddr + kv_offset, (void *)output_k_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + memcpy((unsigned short *)input_prefill_v_cache.pVirAddr + kv_offset, (void *)output_v_cache.pVirAddr, + sizeof(unsigned short) * input_num_token * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(_attr.prefill_grpid, "output"); + memcpy(embed_tmp.data(), (void *)output.pVirAddr, embed_tmp.size() * sizeof(unsigned short)); + } + if (p == (prefill_split_num - 1)) { + memcpy(embed.data(), + embed_tmp.data() + (input_embed_num - p * _attr.prefill_token_num - 1) * _attr.tokens_embed_size, + _attr.tokens_embed_size * sizeof(unsigned short)); + } + } + + int next_token = -1; + t_cqdm cqdm = create_cqdm(_attr.max_token_len, 32); + + { + // post process + auto &input = llama_post.get_input("input"); + memcpy(input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + llama_post.inference(); + int max_index; + + auto &output_post = llama_post.get_output("output"); + // AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); + unsigned short *post_out = (unsigned short *)output_post.pVirAddr; + float max_val = -MAXFLOAT; + // max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val); + max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, nullptr); + + next_token = max_index; + + token_ids.push_back(max_index); + cached_token.push_back(max_index); + ALOGI("ttft: %.2f ms", ttft_timer.cost()); + } + t_cost.start(); + + bool b_hit_eos = false; + for (unsigned int indices = _attr.precompute_len + input_embed_num; indices < _attr.max_token_len; indices++) { + if (b_stop) { + break; + } + + embed_selector.getByIndex(next_token, embed); + + for (int m = 0; m < _attr.axmodel_num; m++) { + if (b_stop) { + break; + } + + auto &layer = llama_layers[m]; + + auto &input_k_cache = layer.layer.get_input(decode_grpid, "K_cache"); + unsigned short *input_k_cache_ptr = (unsigned short *)input_k_cache.pVirAddr; + auto &input_v_cache = layer.layer.get_input(decode_grpid, "V_cache"); + unsigned short *input_v_cache_ptr = (unsigned short *)input_v_cache.pVirAddr; + + auto &input_indices = layer.layer.get_input(decode_grpid, "indices"); + memcpy(input_indices.pVirAddr, &indices, sizeof(indices)); + + auto &input_mask = layer.layer.get_input(decode_grpid, "mask"); + memcpy(input_mask.pVirAddr, mask.data(), mask.size() * sizeof(unsigned short)); + + auto &input_input = layer.layer.get_input(decode_grpid, "input"); + memcpy(input_input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + + layer.layer.inference(decode_grpid); + + auto &output_k_cache = layer.layer.get_output(decode_grpid, "K_cache_out"); + memcpy(input_k_cache_ptr + indices * _attr.kv_cache_size, output_k_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output_v_cache = layer.layer.get_output(decode_grpid, "V_cache_out"); + memcpy(input_v_cache_ptr + indices * _attr.kv_cache_size, output_v_cache.pVirAddr, + sizeof(unsigned short) * _attr.kv_cache_size); + + auto &output = layer.layer.get_output(decode_grpid, "output"); + memcpy(embed.data(), output.pVirAddr, embed.size() * sizeof(unsigned short)); + } + mask[indices] = 0; + { + // post process + auto &input = llama_post.get_input("input"); + memcpy(input.pVirAddr, embed.data(), embed.size() * sizeof(unsigned short)); + llama_post.inference(); + int max_index; + + auto &output_post = llama_post.get_output("output"); + // AX_SYS_MinvalidateCache(output_post.phyAddr, output_post.pVirAddr, output_post.nSize); + unsigned short *post_out = (unsigned short *)output_post.pVirAddr; + float max_val = -MAXFLOAT; + // max_index = FindMax(post_out, _attr.tokens_embed_num, &max_val); + max_index = post_process(postprocess, post_out, _attr.tokens_embed_num, token_ids, nullptr); + + next_token = max_index; + + if (tokenizer->isEnd(max_index)) { + if (cached_token.size() && _attr.runing_callback) { + float t_cost_ms = t_cost.cost(); + float token_per_sec = token_ids.size() / (t_cost_ms / 1000); + auto tmp_out = tokenizer->Decode(cached_token); + _attr.runing_callback(cached_token.data(), cached_token.size(), tmp_out.c_str(), token_per_sec, + _attr.reserve); + cached_token.clear(); + } + b_hit_eos = true; + break; + } + token_ids.push_back(max_index); + + if (_attr.runing_callback) { + cached_token.push_back(max_index); + if (cached_token.size() >= 3) { + float t_cost_ms = t_cost.cost(); + float token_per_sec = token_ids.size() / (t_cost_ms / 1000); + auto tmp_out = tokenizer->Decode(cached_token); + _attr.runing_callback(cached_token.data(), cached_token.size(), tmp_out.c_str(), token_per_sec, + _attr.reserve); + cached_token.clear(); + } + } + } + + if (_attr.runing_callback == nullptr) update_cqdm(&cqdm, indices, "token", ""); + if (b_hit_eos) { + break; + } + } + printf("\n\n"); + fflush(stdout); + float t_cost_ms = t_cost.cost(); + ALOGN("hit eos,avg %.2f token/s\n", token_ids.size() / (t_cost_ms / 1000)); + + final_out = tokenizer->Decode(token_ids); + + return final_out; + } +}; \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.cpp b/projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.cpp deleted file mode 100644 index 3b723ea1..00000000 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include "QwenTokenizer.hpp" - -#include -#include -#include - -#include "sample_log.h" -#include "base64.h" - -static const std::string PAT_STR = R"((?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?:$|[^\S])|\s+)"; - -static std::pair _parse(const std::string &line) -{ - auto pos = line.find(" "); - if (pos == std::string::npos) - { - throw std::runtime_error("invalid encoder line: " + line); - } - - auto token = base64::decode({line.data(), pos}); - int rank = 0; - try - { - rank = std::stoul(line.substr(pos + 1)); - } - catch (const std::exception &) - { - throw std::runtime_error("invalid encoder rank: " + line); - } - - return {std::move(token), rank}; -} - -QwenTokenizer::QwenTokenizer(const std::string &tiktoken_path, const QwenConfig &config) -{ - std::ifstream file(tiktoken_path); - if (!file) - { - throw std::runtime_error("failed to open encoder file: " + tiktoken_path); - } - - ankerl::unordered_dense::map encoder; - std::string line; - while (std::getline(file, line)) - { - auto [token, rank] = _parse(line); - - if (!encoder.emplace(std::move(token), rank).second) - { - throw std::runtime_error("duplicate item: " + line); - } - } - - std::vector special_tokens_s{"<|endoftext|>", "<|im_start|>", "<|im_end|>"}; - char buffer[14]; - for (size_t i = 0; i < 205; i++) - { - snprintf(buffer, 14, "<|extra_%zu|>", i); - special_tokens_s.push_back(buffer); - } - size_t encoder_size = encoder.size(); - ankerl::unordered_dense::map special_tokens; - special_tokens.reserve(special_tokens_s.size()); - for (size_t i = 0; i < special_tokens_s.size(); i++) - { - special_tokens[special_tokens_s[i]] = encoder_size + i; - } - - tokenizer = tiktoken::tiktoken(std::move(encoder), special_tokens, PAT_STR); - eos_token_id = config.eos_token_id; - im_start_id = config.im_start_id; - im_end_id = config.im_end_id; -} - -auto QwenTokenizer::build_prompt(const std::vector &history) const -> std::string -{ - if (history.size() % 2 == 1) - { - ALOGE("invalid history size %d", history.size()); - return ""; - } - - std::ostringstream oss_prompt; - oss_prompt << "<|im_start|>system\nYou are a helpful assistant.<|im_end|>"; - for (size_t i = 0; i < history.size() - 1; i += 2) - { - oss_prompt << "\n<|im_start|>user\n" - << history[i] << "<|im_end|>\n<|im_start|>" << history[i + 1] << "<|im_end|>"; - } - oss_prompt << "\n<|im_start|>user\n" - << history.back() << "<|im_end|>\n<|im_start|>assistant\n"; - - return oss_prompt.str(); -} - -auto QwenTokenizer::encode(const std::string &text, int max_length) const -> std::vector -{ - auto ids = tokenizer.encode(text); - if ((int)ids.size() > max_length) - { - ids.erase(ids.begin(), ids.end() - max_length); - } - return ids; -} - -auto QwenTokenizer::decode(const std::vector &ids) const -> std::string -{ - std::vector normal_ids(ids); - normal_ids.erase(std::remove_if(normal_ids.begin(), normal_ids.end(), [this](int id) - { return is_special_id(id); }), - normal_ids.end()); - auto text = tokenizer.decode(normal_ids); - return text; -} - -auto QwenTokenizer::encode_history( - const std::vector &history, int max_length) const -> std::vector -{ - std::string prompt = build_prompt(history); - std::vector input_ids = encode(prompt, max_length); - return input_ids; -} - -auto QwenTokenizer::is_special_id(int id) const -> bool -{ - return id == eos_token_id || id == im_start_id || id == im_end_id; -} \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.hpp b/projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.hpp deleted file mode 100644 index a217f9c2..00000000 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/QwenTokenizer.hpp +++ /dev/null @@ -1,34 +0,0 @@ -#pragma once -#include -#include -#include "tiktoken.h" - -struct QwenConfig -{ - // int pad_token_id; - // for tokenizer - int eos_token_id = 151643; - int im_start_id = 151644; - int im_end_id = 151645; -}; - -class QwenTokenizer -{ -public: - QwenTokenizer(const std::string &tiktoken_path, const QwenConfig &config); - - auto encode(const std::string &text, int max_length) const -> std::vector; - - auto decode(const std::vector &ids) const -> std::string; - - auto encode_history(const std::vector &history, int max_length) const -> std::vector; - - auto build_prompt(const std::vector &history) const -> std::string; - - auto is_special_id(int id) const -> bool; - - tiktoken::tiktoken tokenizer; - int eos_token_id; - int im_start_id; - int im_end_id; -}; \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.cpp b/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.cpp index 56600968..38974cb6 100644 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.cpp +++ b/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.cpp @@ -1,327 +1,242 @@ #include "Tokenizer.hpp" -#include "sentencepiece_processor.h" -#include "builtin_pb/sentencepiece.pb.h" - -#include "QwenTokenizer.hpp" - -// #include "chatglm.h" - #include "httplib.h" +#include "http_utils.hpp" #include "json.hpp" #include "sample_log.h" #include "string_utility.hpp" #include "memory_utils.hpp" -class TokenizerLLaMa : public BaseTokenizer +class Tokenizer_Http : public BaseTokenizer { -protected: - sentencepiece::SentencePieceProcessor sp; + std::shared_ptr cli; bool _b_bos, _b_eos; + std::string base_url; + + int bos_id, eos_id; + + std::string uid; + int img_start_token, img_context_token; + private: /* data */ public: - bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override + bool Init(std::string model_path) override { - auto ret = sp.Load(model_path); - if (!ret.ok()) - { - ALOGE("%s", ret.error_message()); - return false; - } - - this->_b_bos = b_bos; - this->_b_eos = b_eos; - return ret.ok(); - } - - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override - { - auto ret = sp.Encode(input, &output); - if (!ret.ok()) + base_url = model_path; + if (!test_connect_http(base_url, 10)) { - ALOGE("%s", ret.error_message()); + ALOGE("connect %s failed", base_url.c_str()); return false; } - if (_b_bos) - { - output.insert(output.begin(), sp.bos_id()); - } - if (_b_eos) - { - output.push_back(sp.eos_id()); - } - return true; - } - - std::vector Encode(std::string input, bool b_img_prompt = false) override - { - std::vector output; - Encode(input, output, b_img_prompt); - return output; - } - - std::string Decode(const std::vector input) override - { - sentencepiece::SentencePieceText spt; - sp.Decode(input, &spt); - std::string out = spt.pieces()[0].piece(); - if (*(unsigned short *)out.data() == 38626) - { - return " " + spt.text(); - } else { - return spt.text(); + ALOGI("connect %s ok", base_url.c_str()); } - } - - int GetBosID() override - { - return sp.bos_id(); - } - - int GetEosID() override - { - return sp.eos_id(); - } -}; - -class TokenizerMINICPM : public TokenizerLLaMa -{ -public: - std::string Decode(const std::vector input) override - { - sentencepiece::SentencePieceText spt; - sp.Decode(input, &spt); - return spt.text(); - } -}; -class TokenizerPhi3 : public BaseTokenizer -{ - sentencepiece::SentencePieceProcessor sp; - bool _b_bos, _b_eos; + cli = std::make_shared(base_url); + cli->set_connection_timeout(10); + cli->set_read_timeout(10); + cli->set_write_timeout(10); -private: - /* data */ -public: - bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override - { - auto ret = sp.Load(model_path); - if (!ret.ok()) + int try_count = 10; + int count = try_count; + while (count-- > 0) { - ALOGE("%s", ret.error_message()); - return false; + try + { + auto ret = cli->Get("/get_uid"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get uid failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + uid = j["uid"]; + ALOGI("uid: %s", uid.c_str()); + break; + } + catch (const std::exception &e) + { + std::cerr << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get uid failed, try again %d/%d", count, try_count); } - this->_b_bos = b_bos; - this->_b_eos = b_eos; - return ret.ok(); - } - - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override - { - auto ret = sp.Encode(input, &output); - if (!ret.ok()) + count = 10; + while (count-- > 0) { - ALOGE("%s", ret.error_message()); - return false; - } - output.insert(output.begin(), 32010); //"<|user|>" - output.push_back(32007); //"<|end|>" - output.push_back(32001); //"<|assistant|>" - if (_b_bos) - { - output.insert(output.begin(), sp.bos_id()); + try + { + auto ret = cli->Get("/bos_id?uid=" + uid); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get bos_id failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + bos_id = j["bos_id"]; + break; + } + catch (const std::exception &e) + { + std::cerr << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get bos_id failed, try again %d/%d", count, try_count); } - if (_b_eos) + + count = 10; + while (count-- > 0) { - output.push_back(sp.eos_id()); + try + { + auto ret = cli->Get("/eos_id?uid=" + uid); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get eos_id failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + eos_id = j["eos_id"]; + break; + } + catch (const std::exception &e) + { + std::cerr << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get eos_id failed, try again %d/%d", count, try_count); } - return true; - } - std::vector Encode(std::string input, bool b_img_prompt = false) override - { - std::vector output; - Encode(input, output, b_img_prompt); - return output; - } - - std::string Decode(const std::vector input) override - { - sentencepiece::SentencePieceText spt; - sp.Decode(input, &spt); - std::string out = spt.pieces()[0].piece(); - if (*(unsigned short *)out.data() == 38626) + count = 10; + while (count-- > 0) { - return " " + spt.text(); + try + { + auto ret = cli->Get("/img_start_token?uid=" + uid); + if (!ret) { + ALOGE("get img_start_token failed, no response"); + continue; + } + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get img_start_token failed, status: %d", rep.status); + continue; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + img_start_token = j["img_start_token"]; + ALOGI("img_start_token: %d", img_start_token); + break; + } + catch (const std::exception &e) + { + std::cerr << "Exception: " << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get img_start_token failed, try again %d/%d", count, try_count); } - else + + count = 10; + while (count-- > 0) { - return spt.text(); + try + { + auto ret = cli->Get("/img_context_token?uid=" + uid); + if (!ret) { + ALOGE("get img_context_token failed, no response"); + continue; + } + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get img_context_token failed, status: %d", rep.status); + continue; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + img_context_token = j["img_context_token"]; + ALOGI("img_context_token: %d", img_context_token); + break; + } + catch (const std::exception &e) + { + std::cerr << "Exception: " << e.what() << '\n'; + } + std::this_thread::sleep_for(std::chrono::seconds(1)); + ALOGE("get img_context_token failed, try again %d/%d", count, try_count); } - } - int GetBosID() override - { - return sp.bos_id(); - } + printf("bos_id: %d, eos_id: %d\n", bos_id, eos_id); + printf("img_start_token: %d, img_context_token: %d\n", img_start_token, img_context_token); - int GetEosID() override - { - return 32007; + return true; } - bool isEnd(int id) override + bool Init(std::string model_path = "http://localhost:8080", bool b_bos = true, bool b_eos = false) override { - return id == GetEosID() || id > 31999; - } -}; - -class TokenizerQwen : public BaseTokenizer -{ - std::shared_ptr sp; - bool _b_bos, _b_eos; + base_url = model_path; + try + { + cli = std::make_shared(base_url); + cli->set_connection_timeout(1); + cli->set_read_timeout(1); + cli->set_write_timeout(1); + { + auto ret = cli->Get("/bos_id"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get bos_id failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + bos_id = j["bos_id"]; + } -private: - /* data */ -public: - bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override - { - if (!file_exist(model_path)) + { + auto ret = cli->Get("/eos_id"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get eos_id failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + eos_id = j["eos_id"]; + } + printf("bos_id: %d, eos_id: %d\n", bos_id, eos_id); + } + catch (const std::exception &e) { - ALOGE("tokenizer model file(%s) not exist", model_path.c_str()); + std::cerr << e.what() << '\n'; return false; } - sp.reset(new QwenTokenizer(model_path, QwenConfig())); - this->_b_bos = b_bos; this->_b_eos = b_eos; return true; } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Init_new(std::string model_path, bool b_bos, bool b_eos) override { - if (_b_bos) + base_url = model_path; + if (!test_connect_http(base_url, 10)) { - // input += "<|im_start|>"; + ALOGE("connect %s failed", base_url.c_str()); + return false; } - if (_b_eos) + else { - input += "<|endoftext|>"; + ALOGI("connect %s ok", base_url.c_str()); } - output = sp->encode(input, 1024); - - return true; - } - - std::vector Encode(std::string input, bool b_img_prompt = false) override - { - std::vector output; - Encode(input, output, b_img_prompt); - return output; - } - - std::string Decode(const std::vector input) override - { - return sp->decode(input); - } - - int GetBosID() override - { - return -1; - } - - int GetEosID() override - { - return sp->eos_token_id; - } -}; - -// class TokenizerGLM3 : public BaseTokenizer -// { -// std::shared_ptr sp; -// bool _b_bos, _b_eos; - -// private: -// /* data */ -// public: -// bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) override -// { -// if (!file_exist(model_path)) -// { -// ALOGE("tokenizer model file(%s) not exist", model_path.c_str()); -// return false; -// } -// // std::vector sp_model_data; -// // read_file(model_path, sp_model_data); -// // std::string_view serialized_model_proto(sp_model_data.data(), sp_model_data.size()); - -// sp.reset(new chatglm::ChatGLM3Tokenizer(model_path)); - -// this->_b_bos = b_bos; -// this->_b_eos = b_eos; -// return true; -// } - -// bool Encode(std::string input, std::vector &output) override -// { -// if (_b_bos) -// { -// // input += "<|im_start|>"; -// } -// if (_b_eos) -// { -// // input += "<|endoftext|>"; -// } -// output = sp->encode(input, 1024); - -// return true; -// } - -// std::vector Encode(std::string input) override -// { -// std::vector output; -// Encode(input, output); -// return output; -// } - -// std::string Decode(const std::vector input) override -// { -// return sp->decode(input); -// } - -// int GetBosID() override -// { -// return sp->sp.bos_id(); -// } - -// int GetEosID() override -// { -// return sp->sp.eos_id(); -// } -// }; - -class Tokenizer_Http : public BaseTokenizer -{ - std::shared_ptr cli; - bool _b_bos, _b_eos; - std::string base_url; - - int bos_id, eos_id; - -private: - /* data */ -public: - bool Init(std::string model_path = "http://localhost:8080", bool b_bos = true, bool b_eos = false) override - { - base_url = model_path; try { cli = std::make_shared(base_url); @@ -351,7 +266,33 @@ class Tokenizer_Http : public BaseTokenizer nlohmann::json j = nlohmann::json::parse(rep.body); eos_id = j["eos_id"]; } - printf("bos_id: %d, eos_id: %d\n", bos_id, eos_id); + ALOGI("bos_id: %d, eos_id: %d", bos_id, eos_id); + + { + auto ret = cli->Get("/img_start_token"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get img_start_token failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + img_start_token = j["img_start_token"]; + } + ALOGI("img_start_token: %d", img_start_token); + + { + auto ret = cli->Get("/img_context_token"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("get img_context_token failed, status: %d", rep.status); + return false; + } + nlohmann::json j = nlohmann::json::parse(rep.body); + img_context_token = j["img_context_token"]; + } + ALOGI("img_context_token: %d", img_context_token); } catch (const std::exception &e) { @@ -364,11 +305,75 @@ class Tokenizer_Http : public BaseTokenizer return true; } - bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) override + bool Reset(std::string system_prompt, std::vector &tokens) override + { + nlohmann::json j; + j["uid"] = uid; + if (!system_prompt.empty() and system_prompt != "") + { + j["system_prompt"] = system_prompt; + } + + auto ret = cli->Post("/reset", j.dump(), "application/json"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("reset failed, status: %d", rep.status); + return false; + } + nlohmann::json j_rep = nlohmann::json::parse(rep.body); + std::vector _token_ids = j_rep["token_ids"]; + tokens = _token_ids; + return true; + } + + bool Encode(std::string input, std::string last_reply, std::vector &tokens, std::vector &tokens_diff, ImageInfo img_info) override { nlohmann::json j; + j["uid"] = uid; j["text"] = input; - j["img_prompt"] = b_img_prompt; + j["img_prompt"] = img_info.img_prompt; + j["imgsz"] = img_info.imgsz; + j["num_img"] = img_info.num_img; + if (!last_reply.empty() and last_reply != "") + { + j["last_reply"] = last_reply; + } + auto ret = cli->Post("/encode", j.dump(), "application/json"); + auto rep = ret.value(); + if (rep.status != 200) + { + ALOGE("encode failed, status: %d", rep.status); + return false; + } + nlohmann::json j2; + try + { + j2 = nlohmann::json::parse(rep.body); + } + catch (const std::exception &e) + { + ALOGE("json parse failed: %s", e.what()); + ALOGE("%s", rep.body.c_str()); + return false; + } + + std::vector _token_ids = j2["token_ids"]; + std::vector _tokens_diff = j2["diff"]; + + tokens = _token_ids; + tokens_diff = _tokens_diff; + + return true; + } + + bool Encode(std::string input, std::vector &output, ImageInfo img_info) override + { + nlohmann::json j; + j["text"] = input; + j["img_prompt"] = img_info.img_prompt; + j["imgsz"] = img_info.imgsz; + j["num_img"] = img_info.num_img; auto ret = cli->Post("/encode", j.dump(), "application/json"); auto rep = ret.value(); if (rep.status != 200) @@ -403,10 +408,17 @@ class Tokenizer_Http : public BaseTokenizer return true; } - std::vector Encode(std::string input, bool b_img_prompt = false) override + std::vector Encode(std::string input, ImageInfo img_info) override { std::vector output; - Encode(input, output, b_img_prompt); + Encode(input, output, img_info); + return output; + } + + std::vector Encode_ctx(std::string input, ImageInfo img_info, std::vector &tokens_ids, std::vector &tokens_diff) override + { + std::vector output; + Encode(input, "", output, tokens_diff, img_info); return output; } @@ -418,6 +430,7 @@ class Tokenizer_Http : public BaseTokenizer { nlohmann::json j; j["token_ids"] = input; + j["uid"] = uid; auto ret = cli->Post("/decode", j.dump(), "application/json"); auto rep = ret.value(); if (rep.status != 200) @@ -453,23 +466,26 @@ class Tokenizer_Http : public BaseTokenizer { return eos_id; } + + int GetImgStartID() override + { + return img_start_token; + } + + int GetImgContextID() override + { + return img_context_token; + } }; std::shared_ptr CreateTokenizer(TokenizerType type) { switch (type) { - case TKT_LLaMa: - return std::make_shared(); - case TKT_MINICPM: - return std::make_shared(); case TKT_HTTP: return std::make_shared(); - case TKT_Qwen: - return std::make_shared(); - case TKT_Phi3: - return std::make_shared(); default: + ALOGE("unknown tokenizer type: %d", type); return nullptr; } } \ No newline at end of file diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.hpp b/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.hpp index c6ad3d8c..e82bdfd8 100644 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.hpp +++ b/projects/llm_framework/main_vlm/src/runner/Tokenizer/Tokenizer.hpp @@ -9,19 +9,32 @@ enum TokenizerType TKT_Qwen, TKT_HTTP, TKT_Phi3, - TKT_MINICPM, TKT_END }; +struct ImageInfo +{ + int imgsz = 448; + int num_img = 1; + bool img_prompt = false; +}; + class BaseTokenizer { public: - virtual bool Init(std::string model_path, bool b_bos = true, bool b_eos = false) = 0; - virtual bool Encode(std::string input, std::vector &output, bool b_img_prompt = false) = 0; - virtual std::vector Encode(std::string input, bool b_img_prompt = false) = 0; + virtual bool Init(std::string model_path) = 0; + virtual bool Init(std::string model_path, bool b_bos, bool b_eos) = 0; + virtual bool Init_new(std::string model_path, bool b_bos, bool b_eos) = 0; + virtual bool Reset(std::string system_prompt, std::vector &tokens) = 0; + virtual bool Encode(std::string input, std::string last_reply, std::vector &tokens, std::vector &tokens_diff, ImageInfo img_info) = 0; + virtual bool Encode(std::string input, std::vector &output, ImageInfo img_info) = 0; + virtual std::vector Encode(std::string input, ImageInfo img_info) = 0; + virtual std::vector Encode_ctx(std::string input, ImageInfo img_info, std::vector &tokens_ids, std::vector &tokens_diff) = 0; virtual std::string Decode(const std::vector input) = 0; virtual int GetBosID() = 0; virtual int GetEosID() = 0; + virtual int GetImgStartID() = 0; + virtual int GetImgContextID() = 0; virtual bool isEnd(int id) { return id == GetEosID(); } }; diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/base64.h b/projects/llm_framework/main_vlm/src/runner/Tokenizer/base64.h deleted file mode 100644 index c42b4828..00000000 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/base64.h +++ /dev/null @@ -1,54 +0,0 @@ -#pragma once - -#include -#include -#include - -namespace base64 -{ - - static auto pos_of_char(const unsigned char chr) -> size_t - { - if (chr >= 'A' && chr <= 'Z') - return chr - 'A'; - else if (chr >= 'a' && chr <= 'z') - return chr - 'a' + ('Z' - 'A') + 1; - else if (chr >= '0' && chr <= '9') - return chr - '0' + ('Z' - 'A') + ('z' - 'a') + 2; - else if (chr == '+' || chr == '-') - return 62; - else if (chr == '/' || chr == '_') - return 63; - else - throw std::runtime_error("Input is not valid base64-encoded data."); - } - - inline auto decode(std::string_view s) -> std::string - { - if (s.empty()) - throw std::runtime_error("empty input"); - size_t length = s.length(); - size_t idx = 0; - - std::string out; - out.reserve(length / 4 * 3); - - while (idx < length) - { - size_t pos_of_char_1 = pos_of_char(s.at(idx + 1)); - out.push_back(static_cast(((pos_of_char(s.at(idx + 0))) << 2) + ((pos_of_char_1 & 0x30) >> 4))); - if ((idx + 2 < length) && s.at(idx + 2) != '=' && s.at(idx + 2) != '.') - { - size_t pos_of_char_2 = pos_of_char(s.at(idx + 2)); - out.push_back(static_cast(((pos_of_char_1 & 0x0f) << 4) + ((pos_of_char_2 & 0x3c) >> 2))); - if ((idx + 3 < length) && s.at(idx + 3) != '=' && s.at(idx + 3) != '.') - { - out.push_back(static_cast(((pos_of_char_2 & 0x03) << 6) + pos_of_char(s.at(idx + 3)))); - } - } - idx += 4; - } - return out; - } - -} // namespace base64 diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/tiktoken.h b/projects/llm_framework/main_vlm/src/runner/Tokenizer/tiktoken.h deleted file mode 100644 index 62738502..00000000 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/tiktoken.h +++ /dev/null @@ -1,326 +0,0 @@ -#pragma once - -#include -#include "unordered_dense.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace tiktoken -{ - - static auto _byte_pair_merge( - const std::string &piece, - const ankerl::unordered_dense::map &ranks, - std::function func) -> std::vector - { - std::vector> parts; - parts.reserve(piece.size() + 1); - for (auto idx = 0U; idx < piece.size() + 1; ++idx) - { - parts.emplace_back(idx, std::numeric_limits::max()); - } - - auto get_rank = [&piece, &ranks]( - const std::vector> &parts, - int start_idx, - int skip) -> std::optional - { - if (start_idx + skip + 2 < parts.size()) - { - auto s = parts[start_idx].first; - auto e = parts[start_idx + skip + 2].first; - auto key = piece.substr(s, e - s); - auto iter = ranks.find(key); - if (iter != ranks.end()) - { - return iter->second; - } - } - return std::nullopt; - }; - - for (auto i = 0U; i < parts.size() - 2; ++i) - { - auto rank = get_rank(parts, i, 0); - if (rank) - { - assert(*rank != std::numeric_limits::max()); - parts[i].second = *rank; - } - } - - while (true) - { - if (parts.size() == 1) - break; - - auto min_rank = std::make_pair(std::numeric_limits::max(), 0); - for (auto i = 0U; i < parts.size() - 1; ++i) - { - auto rank = parts[i].second; - if (rank < min_rank.first) - { - min_rank = {rank, i}; - } - } - - if (min_rank.first != std::numeric_limits::max()) - { - auto i = min_rank.second; - auto rank = get_rank(parts, i, 1); - if (rank) - { - parts[i].second = *rank; - } - else - { - parts[i].second = std::numeric_limits::max(); - } - if (i > 0) - { - auto rank = get_rank(parts, i - 1, 1); - if (rank) - { - parts[i - 1].second = *rank; - } - else - { - parts[i - 1].second = std::numeric_limits::max(); - } - } - - parts.erase(parts.begin() + (i + 1)); - } - else - { - break; - } - } - std::vector out; - out.reserve(parts.size() - 1); - for (auto i = 0U; i < parts.size() - 1; ++i) - { - out.push_back(func(parts[i].first, parts[i + 1].first)); - } - return out; - } - - static auto byte_pair_encode( - const std::string &piece, - const ankerl::unordered_dense::map &ranks) -> std::vector - { - if (piece.size() == 1) - { - return {ranks.at(piece)}; - } - - auto func = [&piece, &ranks](int start, int stop) -> int - { - std::string key = piece.substr(start, stop - start); - return ranks.at(key); - }; - - return _byte_pair_merge(piece, ranks, func); - } - - class tiktoken - { - public: - tiktoken() = default; - tiktoken( - ankerl::unordered_dense::map encoder, - ankerl::unordered_dense::map special_encoder, - const std::string &pattern) - { - regex_ = std::make_unique("(" + pattern + ")"); - - std::string special_pattern; - for (const auto &item : special_encoder) - { - if (!special_pattern.empty()) - { - special_pattern += "|"; - } - special_pattern += re2::RE2::QuoteMeta(item.first); - } - if (special_pattern.empty()) - { - special_regex_ = nullptr; - } - else - { - special_regex_ = std::make_unique("(" + special_pattern + ")"); - } - - encoder_ = std::move(encoder); - special_tokens_encoder = std::move(special_encoder); - - for (const auto &[k, v] : encoder_) - { - decoder_.emplace(v, k); - } - assert(encoder_.size() != decoder_.size() && "Encoder and decoder must be of equal length; maybe you had duplicate token indices in your encoder?"); - - for (const auto &[k, v] : special_tokens_encoder) - { - special_tokens_decoder.emplace(v, k); - } - } - - auto encode_ordinary(const std::string &text) const -> std::vector - { - return _encode_ordinary_native(text); - } - - auto encode(const std::string &text) const -> std::vector - { - return _encode_native(text, special_tokens_encoder).first; - } - - auto encode_single_piece(const std::string &text) const -> std::vector - { - auto iter = encoder_.find(text); - if (iter != encoder_.end()) - { - return {iter->second}; - } - return byte_pair_encode(text, encoder_); - } - - auto decode(const std::vector &tokens) const -> std::string - { - return _decode_native(tokens); - } - - private: - auto split_with_allowed_special_token( - re2::StringPiece &input, - const ankerl::unordered_dense::map &allowed_special) const -> std::pair, re2::StringPiece> - { - if (special_regex_ == nullptr) - return {std::nullopt, input}; - - auto start = input.begin(); - std::string special; - while (true) - { - if (!re2::RE2::FindAndConsume(&input, *special_regex_, &special)) - { - break; - } - - if (allowed_special.count(special) == 1) - { - return {std::move(special), re2::StringPiece(start, input.begin() - start - special.size())}; - } - } - - return {std::nullopt, input}; - } - - auto _encode_ordinary_native(const std::string &text) const -> std::vector - { - std::vector ret; - re2::StringPiece input(text); - - std::string piece; - while (re2::RE2::FindAndConsume(&input, *regex_, &piece)) - { - auto iter = encoder_.find(piece); - if (iter != encoder_.end()) - { - ret.push_back(iter->second); - continue; - } - auto tokens = byte_pair_encode(piece, encoder_); - ret.insert(ret.end(), tokens.begin(), tokens.end()); - } - return ret; - } - - auto _encode_native( - const std::string &text, - const ankerl::unordered_dense::map &allowed_special) const -> std::pair, int> - { - std::vector ret; - int last_piece_token_len = 0; - re2::StringPiece input(text); - - while (true) - { - auto [special, sub_input] = split_with_allowed_special_token(input, allowed_special); - std::string piece; - while (re2::RE2::FindAndConsume(&sub_input, *regex_, &piece)) - { - auto iter = encoder_.find(piece); - if (iter != encoder_.end()) - { - last_piece_token_len = 1; - ret.push_back(iter->second); - continue; - } - auto tokens = byte_pair_encode(piece, encoder_); - last_piece_token_len = tokens.size(); - ret.insert(ret.end(), tokens.begin(), tokens.end()); - } - - if (special) - { - int token = special_tokens_encoder.at(*special); - ret.push_back(token); - last_piece_token_len = 0; - } - else - { - break; - } - } - - return {ret, last_piece_token_len}; - } - - auto _decode_native(const std::vector &tokens) const -> std::string - { - std::string ret; - ret.reserve(tokens.size() * 2); - for (auto token : tokens) - { - std::string token_bytes; - auto iter = decoder_.find(token); - if (iter != decoder_.end()) - { - token_bytes = iter->second; - } - else - { - iter = special_tokens_decoder.find(token); - if (iter != special_tokens_decoder.end()) - { - token_bytes = iter->second; - } - else - { - throw std::runtime_error("unknown token: " + std::to_string(token)); - } - } - ret += token_bytes; - } - return ret; - } - - ankerl::unordered_dense::map encoder_; - ankerl::unordered_dense::map special_tokens_encoder; - ankerl::unordered_dense::map decoder_; - ankerl::unordered_dense::map special_tokens_decoder; - std::unique_ptr regex_; - std::unique_ptr special_regex_; - }; - -} // namespace tiktoken diff --git a/projects/llm_framework/main_vlm/src/runner/Tokenizer/unordered_dense.h b/projects/llm_framework/main_vlm/src/runner/Tokenizer/unordered_dense.h deleted file mode 100644 index af23690b..00000000 --- a/projects/llm_framework/main_vlm/src/runner/Tokenizer/unordered_dense.h +++ /dev/null @@ -1,2240 +0,0 @@ -///////////////////////// ankerl::unordered_dense::{map, set} ///////////////////////// - -// A fast & densely stored hashmap and hashset based on robin-hood backward shift deletion. -// Version 4.1.2 -// https://github.com/martinus/unordered_dense -// -// Licensed under the MIT License . -// SPDX-License-Identifier: MIT -// Copyright (c) 2022-2023 Martin Leitner-Ankerl -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all -// copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#ifndef ANKERL_UNORDERED_DENSE_H -#define ANKERL_UNORDERED_DENSE_H - -// see https://semver.org/spec/v2.0.0.html -#define ANKERL_UNORDERED_DENSE_VERSION_MAJOR 4 // NOLINT(cppcoreguidelines-macro-usage) incompatible API changes -#define ANKERL_UNORDERED_DENSE_VERSION_MINOR 1 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible functionality -#define ANKERL_UNORDERED_DENSE_VERSION_PATCH 2 // NOLINT(cppcoreguidelines-macro-usage) backwards compatible bug fixes - -// API versioning with inline namespace, see https://www.foonathan.net/2018/11/inline-namespaces/ - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) v##major##_##minor##_##patch -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_VERSION_CONCAT(major, minor, patch) ANKERL_UNORDERED_DENSE_VERSION_CONCAT1(major, minor, patch) -#define ANKERL_UNORDERED_DENSE_NAMESPACE \ - ANKERL_UNORDERED_DENSE_VERSION_CONCAT( \ - ANKERL_UNORDERED_DENSE_VERSION_MAJOR, ANKERL_UNORDERED_DENSE_VERSION_MINOR, ANKERL_UNORDERED_DENSE_VERSION_PATCH) - -#if defined(_MSVC_LANG) -#define ANKERL_UNORDERED_DENSE_CPP_VERSION _MSVC_LANG -#else -#define ANKERL_UNORDERED_DENSE_CPP_VERSION __cplusplus -#endif - -#if defined(__GNUC__) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_PACK(decl) decl __attribute__((__packed__)) -#elif defined(_MSC_VER) -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_PACK(decl) __pragma(pack(push, 1)) decl __pragma(pack(pop)) -#endif - -// exceptions -#if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || defined(_CPPUNWIND) -#define ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() 1 // NOLINT(cppcoreguidelines-macro-usage) -#else -#define ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() 0 // NOLINT(cppcoreguidelines-macro-usage) -#endif -#ifdef _MSC_VER -#define ANKERL_UNORDERED_DENSE_NOINLINE __declspec(noinline) -#else -#define ANKERL_UNORDERED_DENSE_NOINLINE __attribute__((noinline)) -#endif - -// defined in unordered_dense.cpp -#if !defined(ANKERL_UNORDERED_DENSE_EXPORT) -#define ANKERL_UNORDERED_DENSE_EXPORT -#endif - -#if ANKERL_UNORDERED_DENSE_CPP_VERSION < 201703L -#error ankerl::unordered_dense requires C++17 or higher -#else -#include // for array -#include // for uint64_t, uint32_t, uint8_t, UINT64_C -#include // for size_t, memcpy, memset -#include // for equal_to, hash -#include // for initializer_list -#include // for pair, distance -#include // for numeric_limits -#include // for allocator, allocator_traits, shared_ptr -#include // for out_of_range -#include // for basic_string -#include // for basic_string_view, hash -#include // for forward_as_tuple -#include // for enable_if_t, declval, conditional_t, ena... -#include // for forward, exchange, pair, as_const, piece... -#include // for vector -#if ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() == 0 -#include // for abort -#endif - -#if defined(__has_include) -#if __has_include() -#define ANKERL_UNORDERED_DENSE_PMR std::pmr // NOLINT(cppcoreguidelines-macro-usage) -#include // for polymorphic_allocator -#elif __has_include() -#define ANKERL_UNORDERED_DENSE_PMR std::experimental::pmr // NOLINT(cppcoreguidelines-macro-usage) -#include // for polymorphic_allocator -#endif -#endif - -#if defined(_MSC_VER) && defined(_M_X64) -#include -#pragma intrinsic(_umul128) -#endif - -#if defined(__GNUC__) || defined(__INTEL_COMPILER) || defined(__clang__) -#define ANKERL_UNORDERED_DENSE_LIKELY(x) __builtin_expect(x, 1) // NOLINT(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_UNLIKELY(x) __builtin_expect(x, 0) // NOLINT(cppcoreguidelines-macro-usage) -#else -#define ANKERL_UNORDERED_DENSE_LIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_UNLIKELY(x) (x) // NOLINT(cppcoreguidelines-macro-usage) -#endif - -namespace ankerl::unordered_dense -{ - inline namespace ANKERL_UNORDERED_DENSE_NAMESPACE - { - - namespace detail - { - -#if ANKERL_UNORDERED_DENSE_HAS_EXCEPTIONS() - - // make sure this is not inlined as it is slow and dramatically enlarges code, thus making other - // inlinings more difficult. Throws are also generally the slow path. - [[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_key_not_found() - { - throw std::out_of_range("ankerl::unordered_dense::map::at(): key not found"); - } - [[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_bucket_overflow() - { - throw std::overflow_error("ankerl::unordered_dense: reached max bucket size, cannot increase size"); - } - [[noreturn]] inline ANKERL_UNORDERED_DENSE_NOINLINE void on_error_too_many_elements() - { - throw std::out_of_range("ankerl::unordered_dense::map::replace(): too many elements"); - } - -#else - - [[noreturn]] inline void on_error_key_not_found() - { - abort(); - } - [[noreturn]] inline void on_error_bucket_overflow() - { - abort(); - } - [[noreturn]] inline void on_error_too_many_elements() - { - abort(); - } - -#endif - - } // namespace detail - - // hash /////////////////////////////////////////////////////////////////////// - - // This is a stripped-down implementation of wyhash: https://github.com/wangyi-fudan/wyhash - // No big-endian support (because different values on different machines don't matter), - // hardcodes seed and the secret, reformats the code, and clang-tidy fixes. - namespace detail::wyhash - { - - inline void mum(uint64_t *a, uint64_t *b) - { -#if defined(__SIZEOF_INT128__) - __uint128_t r = *a; - r *= *b; - *a = static_cast(r); - *b = static_cast(r >> 64U); -#elif defined(_MSC_VER) && defined(_M_X64) - *a = _umul128(*a, *b, b); -#else - uint64_t ha = *a >> 32U; - uint64_t hb = *b >> 32U; - uint64_t la = static_cast(*a); - uint64_t lb = static_cast(*b); - uint64_t hi{}; - uint64_t lo{}; - uint64_t rh = ha * hb; - uint64_t rm0 = ha * lb; - uint64_t rm1 = hb * la; - uint64_t rl = la * lb; - uint64_t t = rl + (rm0 << 32U); - auto c = static_cast(t < rl); - lo = t + (rm1 << 32U); - c += static_cast(lo < t); - hi = rh + (rm0 >> 32U) + (rm1 >> 32U) + c; - *a = lo; - *b = hi; -#endif - } - - // multiply and xor mix function, aka MUM - [[nodiscard]] inline auto mix(uint64_t a, uint64_t b) -> uint64_t - { - mum(&a, &b); - return a ^ b; - } - - // read functions. WARNING: we don't care about endianness, so results are different on big endian! - [[nodiscard]] inline auto r8(const uint8_t *p) -> uint64_t - { - uint64_t v{}; - std::memcpy(&v, p, 8U); - return v; - } - - [[nodiscard]] inline auto r4(const uint8_t *p) -> uint64_t - { - uint32_t v{}; - std::memcpy(&v, p, 4); - return v; - } - - // reads 1, 2, or 3 bytes - [[nodiscard]] inline auto r3(const uint8_t *p, size_t k) -> uint64_t - { - return (static_cast(p[0]) << 16U) | (static_cast(p[k >> 1U]) << 8U) | p[k - 1]; - } - - [[maybe_unused]] [[nodiscard]] inline auto hash(void const *key, size_t len) -> uint64_t - { - static constexpr auto secret = std::array{UINT64_C(0xa0761d6478bd642f), - UINT64_C(0xe7037ed1a0b428db), - UINT64_C(0x8ebc6af09c88c6e3), - UINT64_C(0x589965cc75374cc3)}; - - auto const *p = static_cast(key); - uint64_t seed = secret[0]; - uint64_t a{}; - uint64_t b{}; - if (ANKERL_UNORDERED_DENSE_LIKELY(len <= 16)) - { - if (ANKERL_UNORDERED_DENSE_LIKELY(len >= 4)) - { - a = (r4(p) << 32U) | r4(p + ((len >> 3U) << 2U)); - b = (r4(p + len - 4) << 32U) | r4(p + len - 4 - ((len >> 3U) << 2U)); - } - else if (ANKERL_UNORDERED_DENSE_LIKELY(len > 0)) - { - a = r3(p, len); - b = 0; - } - else - { - a = 0; - b = 0; - } - } - else - { - size_t i = len; - if (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 48)) - { - uint64_t see1 = seed; - uint64_t see2 = seed; - do - { - seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); - see1 = mix(r8(p + 16) ^ secret[2], r8(p + 24) ^ see1); - see2 = mix(r8(p + 32) ^ secret[3], r8(p + 40) ^ see2); - p += 48; - i -= 48; - } while (ANKERL_UNORDERED_DENSE_LIKELY(i > 48)); - seed ^= see1 ^ see2; - } - while (ANKERL_UNORDERED_DENSE_UNLIKELY(i > 16)) - { - seed = mix(r8(p) ^ secret[1], r8(p + 8) ^ seed); - i -= 16; - p += 16; - } - a = r8(p + i - 16); - b = r8(p + i - 8); - } - - return mix(secret[1] ^ len, mix(a ^ secret[1], b ^ seed)); - } - - [[nodiscard]] inline auto hash(uint64_t x) -> uint64_t - { - return detail::wyhash::mix(x, UINT64_C(0x9E3779B97F4A7C15)); - } - - } // namespace detail::wyhash - - ANKERL_UNORDERED_DENSE_EXPORT template - struct hash - { - auto operator()(T const &obj) const noexcept(noexcept(std::declval>().operator()(std::declval()))) - -> uint64_t - { - return std::hash{}(obj); - } - }; - - template - struct hash> - { - using is_avalanching = void; - auto operator()(std::basic_string const &str) const noexcept -> uint64_t - { - return detail::wyhash::hash(str.data(), sizeof(CharT) * str.size()); - } - }; - - template - struct hash> - { - using is_avalanching = void; - auto operator()(std::basic_string_view const &sv) const noexcept -> uint64_t - { - return detail::wyhash::hash(sv.data(), sizeof(CharT) * sv.size()); - } - }; - - template - struct hash - { - using is_avalanching = void; - auto operator()(T *ptr) const noexcept -> uint64_t - { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - return detail::wyhash::hash(reinterpret_cast(ptr)); - } - }; - - template - struct hash> - { - using is_avalanching = void; - auto operator()(std::unique_ptr const &ptr) const noexcept -> uint64_t - { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - return detail::wyhash::hash(reinterpret_cast(ptr.get())); - } - }; - - template - struct hash> - { - using is_avalanching = void; - auto operator()(std::shared_ptr const &ptr) const noexcept -> uint64_t - { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - return detail::wyhash::hash(reinterpret_cast(ptr.get())); - } - }; - - template - struct hash::value>::type> - { - using is_avalanching = void; - auto operator()(Enum e) const noexcept -> uint64_t - { - using underlying = typename std::underlying_type_t; - return detail::wyhash::hash(static_cast(e)); - } - }; - -// NOLINTNEXTLINE(cppcoreguidelines-macro-usage) -#define ANKERL_UNORDERED_DENSE_HASH_STATICCAST(T) \ - template <> \ - struct hash \ - { \ - using is_avalanching = void; \ - auto operator()(T const &obj) const noexcept -> uint64_t \ - { \ - return detail::wyhash::hash(static_cast(obj)); \ - } \ - } - -#if defined(__GNUC__) && !defined(__clang__) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wuseless-cast" -#endif - // see https://en.cppreference.com/w/cpp/utility/hash - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(bool); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(signed char); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned char); -#if ANKERL_UNORDERED_DENSE_CPP_VERSION >= 202002L && defined(__cpp_char8_t) - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char8_t); -#endif - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char16_t); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(char32_t); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(wchar_t); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(short); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned short); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(int); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned int); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(long long); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long); - ANKERL_UNORDERED_DENSE_HASH_STATICCAST(unsigned long long); - -#if defined(__GNUC__) && !defined(__clang__) -#pragma GCC diagnostic pop -#endif - - // bucket_type ////////////////////////////////////////////////////////// - - namespace bucket_type - { - - struct standard - { - static constexpr uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint - static constexpr uint32_t fingerprint_mask = dist_inc - 1; // mask for 1 byte of fingerprint - - uint32_t m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash - uint32_t m_value_idx; // index into the m_values vector. - }; - - ANKERL_UNORDERED_DENSE_PACK(struct big { - static constexpr uint32_t dist_inc = 1U << 8U; // skip 1 byte fingerprint - static constexpr uint32_t fingerprint_mask = dist_inc - 1; // mask for 1 byte of fingerprint - - uint32_t m_dist_and_fingerprint; // upper 3 byte: distance to original bucket. lower byte: fingerprint from hash - size_t m_value_idx; // index into the m_values vector. - }); - - } // namespace bucket_type - - namespace detail - { - - struct nonesuch - { - }; - - template class Op, class... Args> - struct detector - { - using value_t = std::false_type; - using type = Default; - }; - - template class Op, class... Args> - struct detector>, Op, Args...> - { - using value_t = std::true_type; - using type = Op; - }; - - template