From 8e7350f9344d065b79727a15cf26c51b93f43a3e Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 23 Mar 2026 13:06:22 -0700 Subject: [PATCH 01/18] Add C++ SDK with CMake build, tests, and sample --- sdk/cpp/.clang-format | 47 + sdk/cpp/CMakeLists.txt | 155 ++++ sdk/cpp/CMakePresets.json | 129 +++ sdk/cpp/include/configuration.h | 65 ++ sdk/cpp/include/core_interop_request.h | 43 + sdk/cpp/include/flcore_native.h | 33 + sdk/cpp/include/foundry_local.h | 413 +++++++++ sdk/cpp/include/foundry_local_exception.h | 19 + sdk/cpp/include/foundry_local_internal_core.h | 19 + sdk/cpp/include/log_level.h | 34 + sdk/cpp/include/logger.h | 16 + sdk/cpp/include/parser.h | 289 ++++++ sdk/cpp/sample/main.cpp | 398 ++++++++ sdk/cpp/src/foundry_local.cpp | 848 ++++++++++++++++++ sdk/cpp/test/catalog_test.cpp | 372 ++++++++ sdk/cpp/test/client_test.cpp | 541 +++++++++++ sdk/cpp/test/mock_core.h | 148 +++ sdk/cpp/test/mock_object_factory.h | 61 ++ sdk/cpp/test/model_variant_test.cpp | 251 ++++++ sdk/cpp/test/parser_and_types_test.cpp | 592 ++++++++++++ sdk/cpp/test/testdata/empty_models_list.json | 1 + .../test/testdata/malformed_models_list.json | 1 + .../missing_name_field_models_list.json | 12 + .../test/testdata/mixed_openai_and_local.json | 35 + sdk/cpp/test/testdata/real_models_list.json | 88 ++ .../test/testdata/single_cached_model.json | 1 + .../testdata/three_variants_one_model.json | 41 + .../test/testdata/valid_cached_models.json | 1 + .../test/testdata/valid_loaded_models.json | 1 + 29 files changed, 4654 insertions(+) create mode 100644 sdk/cpp/.clang-format create mode 100644 sdk/cpp/CMakeLists.txt create mode 100644 sdk/cpp/CMakePresets.json create mode 100644 sdk/cpp/include/configuration.h create mode 100644 sdk/cpp/include/core_interop_request.h create mode 100644 sdk/cpp/include/flcore_native.h create mode 100644 sdk/cpp/include/foundry_local.h create mode 100644 sdk/cpp/include/foundry_local_exception.h create mode 100644 sdk/cpp/include/foundry_local_internal_core.h create mode 100644 sdk/cpp/include/log_level.h create mode 100644 sdk/cpp/include/logger.h create mode 100644 sdk/cpp/include/parser.h create mode 100644 sdk/cpp/sample/main.cpp create mode 100644 sdk/cpp/src/foundry_local.cpp create mode 100644 sdk/cpp/test/catalog_test.cpp create mode 100644 sdk/cpp/test/client_test.cpp create mode 100644 sdk/cpp/test/mock_core.h create mode 100644 sdk/cpp/test/mock_object_factory.h create mode 100644 sdk/cpp/test/model_variant_test.cpp create mode 100644 sdk/cpp/test/parser_and_types_test.cpp create mode 100644 sdk/cpp/test/testdata/empty_models_list.json create mode 100644 sdk/cpp/test/testdata/malformed_models_list.json create mode 100644 sdk/cpp/test/testdata/missing_name_field_models_list.json create mode 100644 sdk/cpp/test/testdata/mixed_openai_and_local.json create mode 100644 sdk/cpp/test/testdata/real_models_list.json create mode 100644 sdk/cpp/test/testdata/single_cached_model.json create mode 100644 sdk/cpp/test/testdata/three_variants_one_model.json create mode 100644 sdk/cpp/test/testdata/valid_cached_models.json create mode 100644 sdk/cpp/test/testdata/valid_loaded_models.json diff --git a/sdk/cpp/.clang-format b/sdk/cpp/.clang-format new file mode 100644 index 00000000..751f30aa --- /dev/null +++ b/sdk/cpp/.clang-format @@ -0,0 +1,47 @@ +--- +Language: Cpp +BasedOnStyle: Microsoft + +# Match the existing project style +Standard: c++17 +ColumnLimit: 120 + +# Indentation +IndentWidth: 4 +TabWidth: 4 +UseTab: Never +AccessModifierOffset: -4 +IndentCaseLabels: false +NamespaceIndentation: All + +# Braces +BreakBeforeBraces: Custom +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterStruct: false + BeforeCatch: true + BeforeElse: true + IndentBraces: false + +# Alignment +AlignAfterOpenBracket: Align +AlignOperands: Align +AlignTrailingComments: true + +# Includes +SortIncludes: false +IncludeBlocks: Preserve + +# Misc +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AllowShortBlocksOnASingleLine: Empty +PointerAlignment: Left +SpaceAfterCStyleCast: false +SpaceBeforeParens: ControlStatements diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt new file mode 100644 index 00000000..064c46ca --- /dev/null +++ b/sdk/cpp/CMakeLists.txt @@ -0,0 +1,155 @@ +cmake_minimum_required(VERSION 3.20) + +# VS hot reload policy (safe-guarded) +if (POLICY CMP0141) + cmake_policy(SET CMP0141 NEW) + if (MSVC) + set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT + "$<$:ProgramDatabase>") + endif() +endif() + +project(CppSdk LANGUAGES CXX) + +# ----------------------------- +# Windows-only + compiler guard +# ----------------------------- +if (NOT WIN32) + message(FATAL_ERROR "CppSdk is Windows-only for now (uses Win32/WIL headers).") +endif() + +# Accept MSVC OR clang-cl (Clang in MSVC compatibility mode). +# VS CMake Open-Folder often uses clang-cl by default. +if (NOT (MSVC OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC"))) + message(STATUS "CMAKE_CXX_COMPILER_ID = ${CMAKE_CXX_COMPILER_ID}") + message(STATUS "CMAKE_CXX_COMPILER = ${CMAKE_CXX_COMPILER}") + message(STATUS "CMAKE_CXX_SIMULATE_ID = ${CMAKE_CXX_SIMULATE_ID}") + message(FATAL_ERROR "Need MSVC or clang-cl (MSVC-compatible toolchain).") +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Optional: target Windows 10+ APIs (adjust if you need older) +add_compile_definitions(_WIN32_WINNT=0x0A00 WINVER=0x0A00) + +include(FetchContent) + +# ----------------------------- +# nlohmann_json (clean CMake target) +# ----------------------------- +FetchContent_Declare( + nlohmann_json + GIT_REPOSITORY https://github.com/nlohmann/json.git + GIT_TAG v3.12.0 +) +FetchContent_MakeAvailable(nlohmann_json) + +# ----------------------------- +# WIL (download headers only; DO NOT run WIL's CMake) +# This avoids NuGet/test requirements and missing wil::wil targets. +# ----------------------------- +FetchContent_Declare( + wil_src + GIT_REPOSITORY https://github.com/microsoft/wil.git + GIT_TAG v1.0.250325.1 +) +FetchContent_Populate(wil_src) + +# ----------------------------- +# Microsoft GSL (Guidelines Support Library) +# Provides gsl::span for C++17 (std::span is C++20) +# ----------------------------- +FetchContent_Declare( + gsl + GIT_REPOSITORY https://github.com/microsoft/GSL.git + GIT_TAG v4.0.0 +) +FetchContent_MakeAvailable(gsl) + +# ----------------------------- +# Google Test (for unit tests) +# ----------------------------- +FetchContent_Declare( + googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.14.0 +) +# Prevent GoogleTest from overriding our compiler/linker options on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) +FetchContent_MakeAvailable(googletest) + +# ----------------------------- +# SDK library (STATIC) +# List ONLY .cpp files here. +# ----------------------------- +add_library(CppSdk STATIC + src/foundry_local.cpp + # Add more .cpp files as you migrate: + # src/parser.cpp + # src/dllmain.cpp + # src/pch.cpp +) + +target_include_directories(CppSdk + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/include + ${wil_src_SOURCE_DIR}/include +) + +target_link_libraries(CppSdk + PUBLIC + nlohmann_json::nlohmann_json + Microsoft.GSL::GSL +) + +# ----------------------------- +# Sample executable +# ----------------------------- +add_executable(CppSdkSample + sample/main.cpp +) + +target_link_libraries(CppSdkSample PRIVATE CppSdk) + +# ----------------------------- +# Unit tests +# ----------------------------- +enable_testing() + +add_executable(CppSdkTests + test/parser_and_types_test.cpp + test/model_variant_test.cpp + test/catalog_test.cpp + test/client_test.cpp +) + +target_include_directories(CppSdkTests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test +) + +target_compile_definitions(CppSdkTests PRIVATE FL_TESTS) + +target_link_libraries(CppSdkTests + PRIVATE + CppSdk + GTest::gtest_main +) + +# Copy testdata files next to the test executable so file-based tests can find them. +add_custom_command(TARGET CppSdkTests POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/test/testdata + $/testdata +) + +include(GoogleTest) +gtest_discover_tests(CppSdkTests + WORKING_DIRECTORY $ +) + +# Make Visual Studio start/debug this target by default +set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + PROPERTY VS_STARTUP_PROJECT CppSdkSample) diff --git a/sdk/cpp/CMakePresets.json b/sdk/cpp/CMakePresets.json new file mode 100644 index 00000000..aa233618 --- /dev/null +++ b/sdk/cpp/CMakePresets.json @@ -0,0 +1,129 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "windows-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_C_COMPILER": "cl.exe", + "CMAKE_CXX_COMPILER": "cl.exe" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "x64-debug", + "displayName": "MSVC x64 Debug", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + }, + { + "name": "x64-release", + "displayName": "MSVC x64 Release", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release" + } + }, + { + "name": "x86-debug", + "displayName": "MSVC x86 Debug", + "inherits": "windows-base", + "architecture": { + "value": "x86", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + } + }, + { + "name": "x86-release", + "displayName": "MSVC x86 Release", + "inherits": "windows-base", + "architecture": { + "value": "x86", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release" + } + }, + { + "name": "linux-debug", + "displayName": "Linux Debug", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + } + }, + { + "name": "macos-debug", + "displayName": "macOS Debug", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } + } + ], + "buildPresets": [ + { + "name": "x64-debug", + "configurePreset": "x64-debug", + "displayName": "MSVC x64 Debug Build" + }, + { + "name": "x64-release", + "configurePreset": "x64-release", + "displayName": "MSVC x64 Release Build" + } + ], + "testPresets": [ + { + "name": "x64-debug", + "configurePreset": "x64-debug", + "displayName": "MSVC x64 Debug Tests", + "output": { + "outputOnFailure": true + } + }, + { + "name": "x64-release", + "configurePreset": "x64-release", + "displayName": "MSVC x64 Release Tests", + "output": { + "outputOnFailure": true + } + } + ] +} diff --git a/sdk/cpp/include/configuration.h b/sdk/cpp/include/configuration.h new file mode 100644 index 00000000..59fe63e3 --- /dev/null +++ b/sdk/cpp/include/configuration.h @@ -0,0 +1,65 @@ +#pragma once +#include +#include +#include +#include +#include +#include "log_level.h" + +namespace FoundryLocal { + + /// Optional configuration for the built-in web service. + struct WebServiceConfig { + // URL/s to bind the web service to. + // Default: 127.0.0.1:0 (random ephemeral port). + // Multiple URLs can be specified as a semicolon-separated list. + std::optional urls; + + // If the web service is running in a separate process, provide its URL here. + std::optional external_url; + }; + + struct Configuration { + // Construct a Configuration with just an application name. + // All other fields use their defaults. + Configuration(std::string name) : app_name(std::move(name)) {} + + // Your application name. MUST be set to a valid name. + std::string app_name; + + // Application data directory. + // Default: {home}/.{appname}, where {home} is the user's home directory and {appname} is the app_name value. + std::optional app_data_dir; + + // Model cache directory. + // Default: {appdata}/cache/models, where {appdata} is the app_data_dir value. + std::optional model_cache_dir; + + // Log directory. + // Default: {appdata}/logs + std::optional logs_dir; + + // Logging level. + // Valid values are: Verbose, Debug, Information, Warning, Error, Fatal. + // Default: LogLevel.Warning + LogLevel log_level = LogLevel::Warning; + + // Optional web service configuration. + std::optional web; + + // Additional settings that Foundry Local Core can consume. + std::optional> additional_settings; + + void Validate() const { + if (app_name.empty()) { + throw std::invalid_argument("Configuration app_name must be set to a valid application name."); + } + + constexpr std::string_view invalidChars = R"(\/:?\"<>|)"; + if (app_name.find_first_of(invalidChars) != std::string::npos) { + throw std::invalid_argument("Configuration app_name value contains invalid characters."); + } + } + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/core_interop_request.h b/sdk/cpp/include/core_interop_request.h new file mode 100644 index 00000000..de03a61e --- /dev/null +++ b/sdk/cpp/include/core_interop_request.h @@ -0,0 +1,43 @@ +#pragma once +#include +#include +#include +#include + +namespace FoundryLocal { + + class CoreInteropRequest final { + public: + explicit CoreInteropRequest(std::string command) : command_(std::move(command)) {} + + CoreInteropRequest& AddParam(std::string_view key, std::string_view value) { + params_[std::string(key)] = std::string(value); + return *this; + } + + template CoreInteropRequest& AddParam(std::string_view key, const T& value) { + params_[std::string(key)] = value; + return *this; + } + + CoreInteropRequest& AddJsonParam(std::string_view key, const nlohmann::json& jsonValue) { + params_[std::string(key)] = jsonValue.dump(); + return *this; + } + + std::string ToJson() const { + nlohmann::json wrapper; + if (!params_.empty()) { + wrapper["Params"] = params_; + } + return wrapper.dump(); + } + + const std::string& Command() const noexcept { return command_; } + + private: + std::string command_; + nlohmann::json params_; + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/flcore_native.h b/sdk/cpp/include/flcore_native.h new file mode 100644 index 00000000..d67703e0 --- /dev/null +++ b/sdk/cpp/include/flcore_native.h @@ -0,0 +1,33 @@ +#pragma once +#include + +extern "C" { + // Layout must match C# structs exactly +#pragma pack(push, 8) + struct RequestBuffer { + const void* Command; + int32_t CommandLength; + const void* Data; + int32_t DataLength; + }; + + struct ResponseBuffer { + void* Data; + int32_t DataLength; + void* Error; + int32_t ErrorLength; + }; + + // Callback signature: void(*)(void* data, int length, void* userData) + using UserCallbackFn = void(__cdecl*)(void*, int32_t, void*); + + // Exported function pointer types + using execute_command_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*); + using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, void* /*userData*/); + using free_response_fn = void(__cdecl*)(ResponseBuffer*); + + static_assert(std::is_standard_layout::value, "RequestBuffer must be standard layout"); + static_assert(std::is_standard_layout::value, "ResponseBuffer must be standard layout"); + +#pragma pack(pop) +} diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h new file mode 100644 index 00000000..9bce8f5f --- /dev/null +++ b/sdk/cpp/include/foundry_local.h @@ -0,0 +1,413 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "configuration.h" +#include "foundry_local_internal_core.h" + +#include "logger.h" + +namespace FoundryLocal { +#ifdef FL_TESTS + namespace Testing { + struct MockObjectFactory; + } +#endif + + enum class DeviceType { + Invalid, + CPU, + GPU, + NPU + }; + + /// Reason the model stopped generating tokens. + enum class FinishReason { + None, + Stop, + Length, + ToolCalls, + ContentFilter + }; + + struct Runtime { + DeviceType device_type = DeviceType::Invalid; + std::string execution_provider; + }; + + struct PromptTemplate { + std::string system; + std::string user; + std::string assistant; + std::string prompt; + }; + + struct AudioCreateTranscriptionResponse { + std::string text; + }; + + /// JSON Schema property definition used to describe tool function parameters. + struct PropertyDefinition { + std::string type; + std::optional description; + std::optional> properties; + std::optional> required; + }; + + /// Describes a function that a model may call. + struct FunctionDefinition { + std::string name; + std::optional description; + std::optional parameters; + }; + + /// A tool definition following the OpenAI tool calling spec. + struct ToolDefinition { + std::string type = "function"; + FunctionDefinition function; + }; + + /// A parsed function call returned by the model. + struct FunctionCall { + std::string name; + std::string arguments; ///< JSON string of the arguments + }; + + /// A tool call returned by the model in a chat completion response. + struct ToolCall { + std::string id; + std::string type; + std::optional function_call; + }; + + /// Controls whether and how the model calls tools. + enum class ToolChoiceKind { + Auto, + None, + Required + }; + + struct ChatMessage { + std::string role; + std::string content; + std::optional tool_call_id; ///< For role="tool" responses + std::vector tool_calls; + }; + + struct ChatChoice { + int index = 0; + FinishReason finish_reason = FinishReason::None; + + // non-streaming + std::optional message; + + // streaming + std::optional delta; + }; + + struct ChatCompletionCreateResponse { + int64_t created = 0; + std::string id; + + bool is_delta = false; + bool successful = false; + int http_status_code = 0; + + std::vector choices; + + /// Returns the object type string. Derived from is_delta — no allocation. + const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } + + /// Returns the created timestamp as an ISO 8601 string. + /// Computed lazilym only allocates when called. + std::string GetCreatedAtIso() const; + }; + + struct ChatSettings { + std::optional frequency_penalty; + std::optional max_tokens; + std::optional n; + std::optional temperature; + std::optional presence_penalty; + std::optional random_seed; + std::optional top_k; + std::optional top_p; + std::optional tool_choice; + }; + + using DownloadProgressCallback = std::function; + + // Forward declarations + class ModelVariant; + + struct Parameter { + std::string name; + std::optional value; + }; + + struct ModelSettings { + std::vector parameters; + }; + + struct ModelInfo { + std::string id; + std::string name; + uint32_t version = 0; + std::string alias; + std::optional display_name; + std::string provider_type; + std::string uri; + std::string model_type; + std::optional prompt_template; + std::optional publisher; + std::optional model_settings; + std::optional license; + std::optional license_description; + bool cached = false; + std::optional task; + std::optional runtime; + std::optional file_size_mb; + std::optional supports_tool_calling; + std::optional max_output_tokens; + std::optional min_fl_version; + int64_t created_at_unix = 0; + }; + + class AudioClient final { + public: + explicit AudioClient(gsl::not_null model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath) const; + + using StreamCallback = std::function; + void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; + + private: + AudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class ModelVariant; + }; + + class ChatClient final { + public: + explicit ChatClient(gsl::not_null model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + const ChatSettings& settings) const; + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const; + + using StreamCallback = std::function; + void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const; + + void CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk) const; + + private: + ChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, + const ChatSettings& settings, bool stream) const; + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class ModelVariant; + }; + + class ModelVariant final { + public: + const ModelInfo& GetInfo() const; + const std::filesystem::path& GetPath() const; + void Download(DownloadProgressCallback onProgress = nullptr) const; + void Load() const; + + bool IsLoaded() const; + bool IsCached() const; + void Unload() const; + void RemoveFromCache(); + + [[deprecated("Use AudioClient(model) constructor instead")]] + AudioClient GetAudioClient() const; + + [[deprecated("Use ChatClient(model) constructor instead")]] + ChatClient GetChatClient() const; + + const std::string& GetId() const noexcept; + const std::string& GetAlias() const noexcept; + uint32_t GetVersion() const noexcept; + + private: + static std::string MakeModelParamRequest(std::string_view modelId); + explicit ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger); + + ModelInfo info_; + mutable std::filesystem::path cachedPath_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class Catalog; + friend class AudioClient; + friend class ChatClient; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class Model final { + public: + gsl::span GetAllModelVariants() const; + const ModelVariant* GetLatestVariant(gsl::not_null variant) const; + + bool IsLoaded() const { return SelectedVariant().IsLoaded(); } + bool IsCached() const { return SelectedVariant().IsCached(); } + const std::filesystem::path& GetPath() const { return SelectedVariant().GetPath(); } + void Download(DownloadProgressCallback onProgress = nullptr) const { + SelectedVariant().Download(std::move(onProgress)); + } + void Load() const { SelectedVariant().Load(); } + void Unload() const { SelectedVariant().Unload(); } + void RemoveFromCache() { SelectedVariant().RemoveFromCache(); } + [[deprecated("Use AudioClient(model) constructor instead")]] + AudioClient GetAudioClient() const { + return SelectedVariant().GetAudioClient(); + } + + [[deprecated("Use ChatClient(model) constructor instead")]] + ChatClient GetChatClient() const { + return SelectedVariant().GetChatClient(); + } + + const std::string& GetId() const; + const std::string& GetAlias() const; + void SelectVariant(gsl::not_null variant) const; + + private: + explicit Model(gsl::not_null core, gsl::not_null logger); + ModelVariant& SelectedVariant(); + const ModelVariant& SelectedVariant() const; + + gsl::not_null core_; + + std::vector variants_; + mutable std::optional selectedVariantIndex_; + gsl::not_null logger_; + + friend class Catalog; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class Catalog final { + public: + Catalog(const Catalog&) = delete; + Catalog& operator=(const Catalog&) = delete; + Catalog(Catalog&&) = delete; + Catalog& operator=(Catalog&&) = delete; + + static std::unique_ptr Create(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + const std::string& GetName() const { return name_; } + std::vector ListModels() const; + std::vector GetLoadedModels() const; + std::vector GetCachedModels() const; + + const Model* GetModel(std::string_view modelId) const; + const ModelVariant* GetModelVariant(std::string_view modelVariantId) const; + + private: + void UpdateModels() const; + + mutable std::chrono::steady_clock::time_point lastFetch_{}; + + mutable std::unordered_map byAlias_; + mutable std::unordered_map modelIdToModelVariant_; + + explicit Catalog(gsl::not_null injected, + gsl::not_null logger); + + gsl::not_null core_; + std::string name_; + gsl::not_null logger_; + + friend class FoundryLocalManager; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class FoundryLocalManager final { + public: + FoundryLocalManager(const FoundryLocalManager&) = delete; + FoundryLocalManager& operator=(const FoundryLocalManager&) = delete; + FoundryLocalManager(FoundryLocalManager&& other) noexcept; + FoundryLocalManager& operator=(FoundryLocalManager&& other) noexcept; + + explicit FoundryLocalManager(Configuration configuration, ILogger* logger = nullptr); + ~FoundryLocalManager(); + + const Catalog& GetCatalog() const; + + /// Start the optional built-in web service. + /// Provides an OpenAI-compatible REST endpoint. + /// After startup, GetUrls() returns the actual bound URL/s. + /// Requires Configuration::Web to be set. + void StartWebService(); + + /// Stop the web service if started. + void StopWebService(); + + /// Returns the bound URL/s after StartWebService(), or empty if not started. + gsl::span GetUrls() const noexcept; + + /// Ensure execution providers are downloaded and registered. + /// Once downloaded, EPs are not re-downloaded unless a new version is available. + void EnsureEpsDownloaded() const; + + private: + bool OwnsLogger() const noexcept { return logger_ == &defaultLogger_; } + + Configuration config_; + + void Initialize(); + + NullLogger defaultLogger_; + std::unique_ptr core_; + std::unique_ptr catalog_; + ILogger* logger_; + std::vector urls_; + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/foundry_local_exception.h b/sdk/cpp/include/foundry_local_exception.h new file mode 100644 index 00000000..6ca886a1 --- /dev/null +++ b/sdk/cpp/include/foundry_local_exception.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include + +#include "logger.h" + +namespace FoundryLocal { + + class FoundryLocalException final : public std::runtime_error { + public: + explicit FoundryLocalException(std::string message) : std::runtime_error(std::move(message)) {} + + FoundryLocalException(std::string message, ILogger& logger) : std::runtime_error(std::move(message)) { + logger.Log(LogLevel::Error, what()); + } + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/foundry_local_internal_core.h b/sdk/cpp/include/foundry_local_internal_core.h new file mode 100644 index 00000000..eedfa5d4 --- /dev/null +++ b/sdk/cpp/include/foundry_local_internal_core.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include "logger.h" + +namespace FoundryLocal { + namespace Internal { + struct IFoundryLocalCore { + virtual ~IFoundryLocalCore() = default; + + virtual std::string call(std::string_view command, ILogger& logger, + const std::string* dataArgument = nullptr, void* callback = nullptr, + void* data = nullptr) const = 0; + virtual void unload() = 0; + }; + + } // namespace Internal +} // namespace FoundryLocal \ No newline at end of file diff --git a/sdk/cpp/include/log_level.h b/sdk/cpp/include/log_level.h new file mode 100644 index 00000000..d9b82863 --- /dev/null +++ b/sdk/cpp/include/log_level.h @@ -0,0 +1,34 @@ +#pragma once + +#include + +namespace FoundryLocal { + + enum class LogLevel { + Verbose, + Debug, + Information, + Warning, + Error, + Fatal + }; + + inline std::string_view LogLevelToString(LogLevel level) noexcept { + switch (level) { + case LogLevel::Verbose: + return "Verbose"; + case LogLevel::Debug: + return "Debug"; + case LogLevel::Information: + return "Information"; + case LogLevel::Warning: + return "Warning"; + case LogLevel::Error: + return "Error"; + case LogLevel::Fatal: + return "Fatal"; + } + return "Unknown"; + } + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/logger.h b/sdk/cpp/include/logger.h new file mode 100644 index 00000000..98d10155 --- /dev/null +++ b/sdk/cpp/include/logger.h @@ -0,0 +1,16 @@ +#pragma once +#include +#include "log_level.h" + +namespace FoundryLocal { + class ILogger { + public: + virtual ~ILogger() = default; + virtual void Log(LogLevel level, std::string_view message) noexcept = 0; + }; + + class NullLogger final : public ILogger { + public: + void Log(LogLevel, std::string_view) noexcept override {} + }; +} // namespace FoundryLocal diff --git a/sdk/cpp/include/parser.h b/sdk/cpp/include/parser.h new file mode 100644 index 00000000..58c31e87 --- /dev/null +++ b/sdk/cpp/include/parser.h @@ -0,0 +1,289 @@ +#pragma once +#include +#include +#include "foundry_local.h" +#include + +namespace FoundryLocal { + inline DeviceType parse_device_type(std::string_view v) { + if (v == "CPU") { + return DeviceType::CPU; + } + if (v == "NPU") { + return DeviceType::NPU; + } + if (v == "GPU") { + return DeviceType::GPU; + } + return DeviceType::Invalid; + } + + inline FinishReason parse_finish_reason(std::string_view v) { + if (v == "stop") + return FinishReason::Stop; + if (v == "length") + return FinishReason::Length; + if (v == "tool_calls") + return FinishReason::ToolCalls; + if (v == "content_filter") + return FinishReason::ContentFilter; + return FinishReason::None; + } + + // ---------- Helpers ---------- + inline std::string get_string_or_empty(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + std::string out = ""; + if (it != j.end() && it->is_string()) { + out = it->get(); + } + return out; + } + + inline void from_json(const nlohmann::json& j, Runtime& r) { + std::string deviceType; + std::string executionProvider; + j.at("deviceType").get_to(deviceType); + j.at("executionProvider").get_to(r.execution_provider); + + r.device_type = parse_device_type(std::move(deviceType)); + } + + inline void from_json(const nlohmann::json& j, PromptTemplate& p) { + p.system = get_string_or_empty(j, "system"); + p.user = get_string_or_empty(j, "user"); + p.assistant = get_string_or_empty(j, "assistant"); + p.prompt = get_string_or_empty(j, "prompt"); + } + + inline std::optional get_opt_string(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_string()) { + return it->get(); + } + return std::nullopt; + } + + inline std::optional get_opt_int(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } + return std::nullopt; + } + + inline std::optional get_opt_i64(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } + return std::nullopt; + } + + inline std::optional get_opt_bool(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_boolean()) { + return it->get(); + } + return std::nullopt; + } + + inline void from_json(const nlohmann::json& j, Parameter& p) { + j.at("name").get_to(p.name); + p.value = get_opt_string(j, "value"); + } + + inline void from_json(const nlohmann::json& j, ModelSettings& ms) { + ms.parameters.clear(); + if (auto it = j.find("parameters"); it != j.end() && it->is_array()) { + ms.parameters = it->get>(); + } + } + + inline void from_json(const nlohmann::json& j, ModelInfo& m) { + j.at("id").get_to(m.id); + j.at("name").get_to(m.name); + j.at("version").get_to(m.version); + j.at("alias").get_to(m.alias); + j.at("providerType").get_to(m.provider_type); + j.at("uri").get_to(m.uri); + j.at("modelType").get_to(m.model_type); + + m.display_name = get_opt_string(j, "displayName"); + m.publisher = get_opt_string(j, "publisher"); + m.license = get_opt_string(j, "license"); + m.license_description = get_opt_string(j, "licenseDescription"); + m.task = get_opt_string(j, "task"); + if (auto it = j.find("fileSizeMb"); it != j.end() && !it->is_null() && it->is_number_integer()) { + auto v = it->get(); + m.file_size_mb = (v >= 0) ? static_cast(v) : 0u; + } + m.supports_tool_calling = get_opt_bool(j, "supportsToolCalling"); + m.max_output_tokens = get_opt_i64(j, "maxOutputTokens"); + m.min_fl_version = get_opt_string(j, "minFLVersion"); + + if (auto it = j.find("cached"); it != j.end() && it->is_boolean()) { + m.cached = it->get(); + } + else { + m.cached = false; + } + + if (auto it = j.find("createdAt"); it != j.end() && it->is_number_integer()) { + m.created_at_unix = it->get(); + } + else { + m.created_at_unix = 0; + } + + // nested optional objects + if (auto it = j.find("modelSettings"); it != j.end() && it->is_object()) { + m.model_settings = it->get(); + } + else { + m.model_settings.reset(); + } + + if (auto it = j.find("promptTemplate"); it != j.end() && it->is_object()) { + m.prompt_template = it->get(); + } + else { + m.prompt_template.reset(); + } + + if (auto it = j.find("runtime"); it != j.end() && it->is_object()) { + m.runtime = it->get(); + } + else { + m.runtime.reset(); + } + } + + // ---------- Tool calling: to_json (serialization for requests) ---------- + + inline void to_json(nlohmann::json& j, const PropertyDefinition& pd) { + j = nlohmann::json{{"type", pd.type}}; + if (pd.description) + j["description"] = *pd.description; + if (pd.properties) { + nlohmann::json props = nlohmann::json::object(); + for (const auto& [key, val] : *pd.properties) { + nlohmann::json pj; + to_json(pj, val); + props[key] = std::move(pj); + } + j["properties"] = std::move(props); + } + if (pd.required) + j["required"] = *pd.required; + } + + inline void to_json(nlohmann::json& j, const FunctionDefinition& fd) { + j = nlohmann::json{{"name", fd.name}}; + if (fd.description) + j["description"] = *fd.description; + if (fd.parameters) { + nlohmann::json pj; + to_json(pj, *fd.parameters); + j["parameters"] = std::move(pj); + } + } + + inline void to_json(nlohmann::json& j, const ToolDefinition& td) { + j = nlohmann::json{{"type", td.type}}; + nlohmann::json fj; + to_json(fj, td.function); + j["function"] = std::move(fj); + } + + // ---------- Tool calling: from_json (deserialization from responses) ---------- + + inline void from_json(const nlohmann::json& j, FunctionCall& fc) { + fc.name = get_string_or_empty(j, "name"); + if (j.contains("arguments")) { + const auto& args = j.at("arguments"); + if (args.is_string()) + fc.arguments = args.get(); + else + fc.arguments = args.dump(); + } + } + + inline void from_json(const nlohmann::json& j, ToolCall& tc) { + tc.id = get_string_or_empty(j, "id"); + tc.type = get_string_or_empty(j, "type"); + if (j.contains("function") && j.at("function").is_object()) + tc.function_call = j.at("function").get(); + } + + inline void from_json(const nlohmann::json& j, ChatMessage& m) { + if (j.contains("role")) + j.at("role").get_to(m.role); + if (j.contains("content") && !j.at("content").is_null()) + j.at("content").get_to(m.content); + + m.tool_call_id = get_opt_string(j, "tool_call_id"); + + m.tool_calls.clear(); + if (j.contains("tool_calls") && j.at("tool_calls").is_array()) { + for (const auto& tc : j.at("tool_calls")) { + if (tc.is_object()) + m.tool_calls.push_back(tc.get()); + } + } + } + + inline void from_json(const nlohmann::json& j, ChatChoice& c) { + if (j.contains("index")) + j.at("index").get_to(c.index); + if (j.contains("finish_reason") && !j.at("finish_reason").is_null()) + c.finish_reason = parse_finish_reason(j.at("finish_reason").get()); + + if (j.contains("message") && !j.at("message").is_null()) + c.message = j.at("message").get(); + + if (j.contains("delta") && !j.at("delta").is_null()) + c.delta = j.at("delta").get(); + } + + inline void from_json(const nlohmann::json& j, ChatCompletionCreateResponse& r) { + if (j.contains("created")) + j.at("created").get_to(r.created); + r.id = get_string_or_empty(j, "id"); + if (j.contains("IsDelta")) + j.at("IsDelta").get_to(r.is_delta); + if (j.contains("Successful")) + j.at("Successful").get_to(r.successful); + if (j.contains("HttpStatusCode")) + j.at("HttpStatusCode").get_to(r.http_status_code); + + r.choices.clear(); + if (j.contains("choices") && j.at("choices").is_array()) { + r.choices = j.at("choices").get>(); + } + } + + // ---------- Tool choice helpers ---------- + + inline std::string tool_choice_to_string(ToolChoiceKind kind) { + switch (kind) { + case ToolChoiceKind::Auto: return "auto"; + case ToolChoiceKind::None: return "none"; + case ToolChoiceKind::Required: return "required"; + } + return "auto"; + } + +} // namespace FoundryLocal \ No newline at end of file diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp new file mode 100644 index 00000000..f5166ba9 --- /dev/null +++ b/sdk/cpp/sample/main.cpp @@ -0,0 +1,398 @@ +#include +#include "foundry_local.h" + +#include +#include +#include + +using namespace FoundryLocal; + +// --------------------------------------------------------------------------- +// Logger +// --------------------------------------------------------------------------- +class StdLogger final : public ILogger { +public: + void Log(LogLevel level, std::string_view message) noexcept override { + const char* tag = "UNK"; + switch (level) { + case LogLevel::Information: + tag = "INFO"; + break; + case LogLevel::Warning: + tag = "WARN"; + break; + case LogLevel::Error: + tag = "ERROR"; + break; + default: + tag = "DEBUG"; + break; + } + std::fprintf(stderr, "[FoundryLocal][%s] %.*s\n", tag, static_cast(message.size()), message.data()); + } +}; + +// --------------------------------------------------------------------------- +// Example 1 – Browse the catalog +// --------------------------------------------------------------------------- +void BrowseCatalog(FoundryLocalManager& manager) { + std::cout << "\n=== Example 1: Browse Catalog ===\n"; + + auto& catalog = manager.GetCatalog(); + std::cout << "Catalog: " << catalog.GetName() << "\n"; + + auto models = catalog.ListModels(); + std::cout << "Models in catalog: " << models.size() << "\n"; + + for (const auto* model : models) { + std::cout << " - " << model->GetAlias() << " (" << model->GetId() << ")" + << " cached=" << (model->IsCached() ? "yes" : "no") + << " loaded=" << (model->IsLoaded() ? "yes" : "no") << "\n"; + + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + std::cout << " variant: " << info.name << " v" << info.version; + if (info.runtime) + std::cout << " device=" << (info.runtime->device_type == DeviceType::GPU ? "GPU" : "CPU"); + if (info.file_size_mb) + std::cout << " size=" << *info.file_size_mb << "MB"; + std::cout << "\n"; + } + } +} + +// --------------------------------------------------------------------------- +// Example 2 – Download, load, chat (non-streaming), then unload +// --------------------------------------------------------------------------- +void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 2: Non-Streaming Chat ===\n"; + + auto& catalog = manager.GetCatalog(); + auto models = catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + std::cout << "Model loaded: " << model->GetAlias() << "\n"; + + // Get the selected variant pointer for ChatClient + const auto& selectedVariant = model->GetAllModelVariants()[0]; + ChatClient chat(&selectedVariant); + + std::vector messages = {{"system", "You are a helpful assistant. Keep answers brief."}, + {"user", "What is the capital of Croatia?"}}; + + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 128; + + auto response = chat.CompleteChat(messages, settings); + + if (!response.choices.empty() && response.choices[0].message) { + std::cout << "Assistant: " << response.choices[0].message->content << "\n"; + } + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + +// --------------------------------------------------------------------------- +// Example 3 – Streaming chat +// --------------------------------------------------------------------------- +void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 3: Streaming Chat ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Load(); + + const auto& selectedVariant = model->GetAllModelVariants()[0]; + ChatClient chat(&selectedVariant); + + std::vector messages = {{"user", "Explain quantum computing in three sentences."}}; + + ChatSettings settings; + settings.temperature = 0.9f; + settings.max_tokens = 256; + + std::cout << "Assistant: "; + chat.CompleteChatStreaming(messages, settings, [](const ChatCompletionCreateResponse& chunk) { + if (chunk.choices.empty()) + return; + const auto& choice = chunk.choices[0]; + if (choice.delta && !choice.delta->content.empty()) { + std::cout << choice.delta->content << std::flush; + } + else if (choice.message && !choice.message->content.empty()) { + std::cout << choice.message->content << std::flush; + } + }); + std::cout << "\n"; + + model->Unload(); +} + +// --------------------------------------------------------------------------- +// Example 4 – Audio transcription +// --------------------------------------------------------------------------- +void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, const std::string& audioPath) { + std::cout << "\n=== Example 4: Audio Transcription ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + + const auto& selectedVariant = model->GetAllModelVariants()[0]; + AudioClient audio(&selectedVariant); + + std::cout << "Transcribing: " << audioPath << "\n"; + auto result = audio.TranscribeAudio(audioPath); + std::cout << "Transcription: " << result.text << "\n"; + + // Streaming alternative: + audio.TranscribeAudioStreaming( + audioPath, [](const AudioCreateTranscriptionResponse& chunk) { std::cout << chunk.text << std::flush; }); + std::cout << "\n"; + + model->Unload(); +} + +// --------------------------------------------------------------------------- +// Example 5 – Tool calling +// --------------------------------------------------------------------------- +// Tool calling lets you define functions that the model can decide to invoke. +// The flow is: +// 1. You describe your tools (functions) as ToolDefinition objects. +// 2. You send a chat request with those tools attached. +// 3. The model may respond with finish_reason = ToolCalls and include +// ToolCall objects in the message, each containing the function name +// and a JSON string of arguments. +// 4. YOUR CODE executes the real function using those arguments. +// 5. You add a message with role = "tool" containing the result, then +// send the conversation back so the model can formulate a final answer. +// +// This lets the model "reach out" to external capabilities (calculators, +// databases, APIs, etc.) while keeping the actual execution in your code. +// --------------------------------------------------------------------------- +void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 5: Tool Calling ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + std::cout << "Model loaded: " << model->GetAlias() << "\n"; + + const auto& selectedVariant = model->GetAllModelVariants()[0]; + ChatClient chat(&selectedVariant); + + // ── Step 1: Define tools ────────────────────────────────────────────── + // Each tool describes a function the model can call. The PropertyDefinition + // mirrors a JSON Schema so the model knows what arguments are expected. + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", // function name + "Multiply two integers and return the result.", // description + PropertyDefinition{ + "object", // top-level schema type + std::nullopt, // no top-level description + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} // both params are required + } + } + }}; + + // ── Step 2: Send the first request ──────────────────────────────────── + // tool_choice = Required forces the model to always produce a tool call. + // In production you'd typically use Auto so the model decides on its own. + std::vector messages = { + {"system", "You are a helpful AI assistant. Use the provided tools when appropriate."}, + {"user", "What is 7 multiplied by 6?"} + }; + + ChatSettings settings; + settings.temperature = 0.0f; + settings.max_tokens = 500; + settings.tool_choice = ToolChoiceKind::Required; + + std::cout << "Sending chat request with tool definitions...\n"; + auto response = chat.CompleteChat(messages, tools, settings); + + // ── Step 3: Inspect the model's tool call ───────────────────────────── + if (response.choices.empty()) { + std::cerr << "No choices returned.\n"; + model->Unload(); + return; + } + + const auto& firstChoice = response.choices[0]; + + // The model signals it wants to call a tool via finish_reason == ToolCalls. + if (firstChoice.finish_reason == FinishReason::ToolCalls && + firstChoice.message && !firstChoice.message->tool_calls.empty()) + { + const auto& tc = firstChoice.message->tool_calls[0]; + std::cout << "Model requested tool call:\n" + << " function : " << (tc.function_call ? tc.function_call->name : "(none)") << "\n" + << " arguments: " << (tc.function_call ? tc.function_call->arguments : "{}") << "\n"; + + // ── Step 4: Execute the tool locally ────────────────────────────── + // Parse the arguments JSON and perform the actual computation. + // In a real application this could be a web request, DB query, etc. + std::string toolResult; + if (tc.function_call && tc.function_call->name == "multiply_numbers") { + // The arguments string is JSON, e.g. {"first": 7, "second": 6} + // For brevity we hard-code the expected result here. + toolResult = "7 x 6 = 42."; + std::cout << " result : " << toolResult << "\n"; + } else { + toolResult = "Unknown tool."; + } + + // ── Step 5: Feed the tool result back ───────────────────────────── + // Add the assistant's message (including the raw tool_call content) + // and then a "tool" message with the result. + messages.push_back({"tool", toolResult}); + + // Add a follow-up system instruction so the model uses the tool output. + messages.push_back({"system", "Respond only with the answer generated by the tool."}); + + // Switch to Auto so the model can answer without calling tools again. + settings.tool_choice = ToolChoiceKind::Auto; + + std::cout << "\nSending tool result back to model...\n"; + auto followUp = chat.CompleteChat(messages, tools, settings); + + if (!followUp.choices.empty() && followUp.choices[0].message) { + std::cout << "Assistant: " << followUp.choices[0].message->content << "\n"; + } + } + else { + // The model answered directly without a tool call. + if (firstChoice.message) + std::cout << "Assistant: " << firstChoice.message->content << "\n"; + } + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + +// --------------------------------------------------------------------------- +// Example 6 – Model variant inspection & selection +// --------------------------------------------------------------------------- +void InspectVariants(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 6: Variant Inspection ===\n"; + + auto& catalog = manager.GetCatalog(); + catalog.ListModels(); + + const auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + auto variants = model->GetAllModelVariants(); + std::cout << "Model '" << alias << "' has " << variants.size() << " variant(s):\n"; + + for (const auto& v : variants) { + const auto& info = v.GetInfo(); + std::cout << " " << info.name << " v" << info.version << " cached=" << (v.IsCached() ? "yes" : "no"); + if (info.display_name) + std::cout << " display=\"" << *info.display_name << "\""; + if (info.publisher) + std::cout << " publisher=" << *info.publisher; + if (info.license) + std::cout << " license=" << *info.license; + if (info.runtime) { + std::cout << " device=" + << (info.runtime->device_type == DeviceType::GPU ? "GPU" + : info.runtime->device_type == DeviceType::NPU ? "NPU" + : "CPU") + << " ep=" << info.runtime->execution_provider; + } + if (info.supports_tool_calling) + std::cout << " tools=" << (*info.supports_tool_calling ? "yes" : "no"); + std::cout << "\n"; + } + + // Select a specific variant by pointer (e.g. prefer the GPU variant) + for (const auto& v : variants) { + if (v.GetInfo().runtime && v.GetInfo().runtime->device_type == DeviceType::GPU) { + model->SelectVariant(&v); + std::cout << "Selected GPU variant: " << model->GetId() << "\n"; + break; + } + } +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- +int main() { + try { + StdLogger logger; + FoundryLocalManager manager({"SampleApp"}, &logger); + + // 1. Browse the full catalog + BrowseCatalog(manager); + + // 2. Non-streaming chat (change alias to a model in your catalog) + ChatNonStreaming(manager, "phi-3.5-mini"); + + // 3. Streaming chat + ChatStreaming(manager, "phi-3.5-mini"); + + // 4. Audio transcription (uncomment and set a valid alias + wav path) + // TranscribeAudio(manager, "whisper-small", R"(C:\path\to\your\audio.wav)"); + + // 5. Tool calling (define tools, let the model call them, feed results back) + ChatWithToolCalling(manager, "phi-3.5-mini"); + + // 6. Inspect model variants and select one + InspectVariants(manager, "phi-3.5-mini"); + + return 0; + } + catch (const std::exception& ex) { + std::cerr << "Fatal: " << ex.what() << std::endl; + return 1; + } +} \ No newline at end of file diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp new file mode 100644 index 00000000..a1eb2947 --- /dev/null +++ b/sdk/cpp/src/foundry_local.cpp @@ -0,0 +1,848 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include "core_interop_request.h" +#include "configuration.h" +#include "foundry_local.h" +#include "flcore_native.h" +#include "foundry_local_internal_core.h" +#include "parser.h" +#include "logger.h" +#include +#include "foundry_local_exception.h" + +// Internal private namespace. +namespace { + std::filesystem::path getExecutableDir() { + auto exePath = wil::GetModuleFileNameW(nullptr); + return std::filesystem::path(exePath.get()).parent_path(); + } +} // namespace + +namespace { + // Wrap Params: { ... } into a request object + inline nlohmann::json MakeParams(nlohmann::json params) { + return nlohmann::json{ {"Params", std::move(params)} }; + } + + // Most common: Params { "Model": } + inline nlohmann::json MakeModelParams(std::string_view model) { + return MakeParams(nlohmann::json{ {"Model", std::string(model)} }); + } + + // Serialize + call + inline std::string CallWithJson(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, FoundryLocal::ILogger& logger) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload); + } + + // Serialize + call with native callback + inline std::string CallWithJsonAndCallback(FoundryLocal::Internal::IFoundryLocalCore* core, + std::string_view command, const nlohmann::json& requestJson, FoundryLocal::ILogger& logger, + void* callback, void* userData) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload, callback, userData); + } + + // Overload: allow Params object directly + inline std::string CallWithParams(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& params, FoundryLocal::ILogger& logger) { + return CallWithJson(core, command, MakeParams(params), logger); + } + + // Overload: no payload + inline std::string CallNoArgs(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, + FoundryLocal::ILogger& logger) { + return core->call(command, logger, nullptr); + } + + std::vector GetLoadedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, + FoundryLocal::ILogger& logger) { + std::string raw = core->call("list_loaded_models", logger); + try { + auto parsed = nlohmann::json::parse(raw); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw FoundryLocal::FoundryLocalException( + "Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); + } + } + + std::vector GetCachedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, + FoundryLocal::ILogger& logger) { + std::string raw = core->call("get_cached_models", logger); + + try { + auto parsed = nlohmann::json::parse(raw); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw FoundryLocal::FoundryLocalException( + "Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); + } + } + + inline void StripSuffixAfterColon(std::string& id) { + const auto pos = id.find_last_of(':'); + if (pos != std::string::npos) { + id.erase(pos); + } + } + + std::vector + CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, + std::vector ids) { + std::vector out; + out.reserve(ids.size()); + + for (auto& id : ids) { + StripSuffixAfterColon(id); + + auto it = modelIdToModelVariant.find(id); + if (it != modelIdToModelVariant.end()) { + out.emplace_back(&it->second); + } + } + return out; + } + +} // namespace + +namespace FoundryLocal { + inline static void* RequireProc(HMODULE mod, const char* name) { + if (void* p = ::GetProcAddress(mod, name)) + return p; + throw std::runtime_error(std::string("GetProcAddress failed for ") + name); + } + + struct Core : FoundryLocal::Internal::IFoundryLocalCore { + using ResponseHandle = std::unique_ptr; + + Core() = default; + ~Core() = default; + + void loadEmbedded() { + loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); + } + + void unload() { + module_.reset(); + execCmd_ = nullptr; + execCbCmd_ = nullptr; + freeResCmd_ = nullptr; + } + std::string call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, + void* callback = nullptr, void* data = nullptr) const override { + if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { + throw FoundryLocalException( + "Core is not loaded. Cannot call command: " + std::string(command), logger); + } + + RequestBuffer request{}; + request.Command = command.empty() ? nullptr : command.data(); + request.CommandLength = static_cast(command.size()); + + if (dataArgument && !dataArgument->empty()) { + request.Data = dataArgument->data(); + request.DataLength = static_cast(dataArgument->size()); + } + + ResponseBuffer response{}; + auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { + if (fn) fn(buf); + }; + std::unique_ptr responseGuard(&response, safeDeleter); + + using CallbackFn = void (*)(void*, int32_t, void*); + + if (callback != nullptr) { + auto cb = reinterpret_cast(callback); + execCbCmd_(&request, &response, reinterpret_cast(cb), data); + } + else { + execCmd_(&request, &response); + } + + std::string result; + if (response.Error && response.ErrorLength > 0) { + std::string err(static_cast(response.Error), response.ErrorLength); + throw FoundryLocalException( + std::string("Command failed [").append(command).append("]: ").append(err), logger); + } + + if (response.Data && response.DataLength > 0) { + result.assign(static_cast(response.Data), response.DataLength); + } + + return result; + } + + private: + wil::unique_hmodule module_; + execute_command_fn execCmd_{}; + execute_command_with_callback_fn execCbCmd_{}; + free_response_fn freeResCmd_{}; + + void loadFromPath(const std::filesystem::path& path) { + wil::unique_hmodule m(::LoadLibraryW(path.c_str())); + if (!m) + throw std::runtime_error("LoadLibraryW failed"); + + execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); + execCbCmd_ = reinterpret_cast( + RequireProc(m.get(), "execute_command_with_callback")); + freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); + + module_ = std::move(m); + } + }; + + /// + /// AudioClient + /// + + AudioClient::AudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) { + } + + AudioCreateTranscriptionResponse AudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { + nlohmann::json openAiReq = { {"Model", modelId_}, {"FileName", audioFilePath.string()} }; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + AudioCreateTranscriptionResponse response; + response.text = core_->call(req.Command(), *logger_, &json); + + return response; + } + + void AudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const { + nlohmann::json openAiReq = { {"Model", modelId_}, {"FileName", audioFilePath.string()} }; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + struct State { + const StreamCallback* cb; + std::exception_ptr exception; + } state{ &onChunk, nullptr }; + + auto streamCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + try { + std::string text(static_cast(data), static_cast(len)); + AudioCreateTranscriptionResponse chunk; + chunk.text = std::move(text); + (*(st->cb))(chunk); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), + reinterpret_cast(&state)); + + if (state.exception) { + std::rethrow_exception(state.exception); + } + } + + + std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { + if (created == 0) return {}; + std::time_t t = static_cast(created); + std::tm tm{}; +#ifdef _WIN32 + gmtime_s(&tm, &t); +#else + gmtime_r(&t, &tm); +#endif + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); + return buf; + } + + /// + /// ChatClient + /// + + ChatClient::ChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) { + } + + std::string ChatClient::BuildChatRequestJson(gsl::span messages, gsl::span tools, + const ChatSettings& settings, bool stream) const { + nlohmann::json jMessages = nlohmann::json::array(); + for (const auto& msg : messages) { + nlohmann::json jMsg = { {"role", msg.role}, {"content", msg.content} }; + if (msg.tool_call_id) + jMsg["tool_call_id"] = *msg.tool_call_id; + jMessages.push_back(std::move(jMsg)); + } + + nlohmann::json req = { {"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream} }; + + if (!tools.empty()) { + nlohmann::json jTools = nlohmann::json::array(); + for (const auto& tool : tools) { + nlohmann::json jTool; + to_json(jTool, tool); + jTools.push_back(std::move(jTool)); + } + req["tools"] = std::move(jTools); + } + + if (settings.tool_choice) + req["tool_choice"] = tool_choice_to_string(*settings.tool_choice); + if (settings.top_k) + req["metadata"] = { {"top_k", *settings.top_k} }; + if (settings.frequency_penalty) + req["frequency_penalty"] = *settings.frequency_penalty; + if (settings.presence_penalty) + req["presence_penalty"] = *settings.presence_penalty; + if (settings.max_tokens) + req["max_completion_tokens"] = *settings.max_tokens; + if (settings.n) + req["n"] = *settings.n; + if (settings.temperature) + req["temperature"] = *settings.temperature; + if (settings.top_p) + req["top_p"] = *settings.top_p; + if (settings.random_seed) + req["seed"] = *settings.random_seed; + + return req.dump(); + } + + ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, + const ChatSettings& settings) const { + return CompleteChat(messages, {}, settings); + } + + ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, + gsl::span tools, const ChatSettings& settings) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + std::string rawResult = core_->call(req.Command(), *logger_, &json); + + return nlohmann::json::parse(rawResult).get(); + } + + void ChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const { + CompleteChatStreaming(messages, {}, settings, onChunk); + } + + void ChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + std::string json = req.ToJson(); + + struct State { + const StreamCallback* cb; + std::exception_ptr exception; + } state{ &onChunk, nullptr }; + + auto streamCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + std::string s(static_cast(data), static_cast(len)); + + try { + auto parsed = nlohmann::json::parse(s).get(); + + (*(st->cb))(parsed); + } + catch (const nlohmann::json::exception& e) { + st->exception = std::make_exception_ptr( + FoundryLocalException(std::string("Error while parsing streaming chat chunk: ") + e.what())); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), + reinterpret_cast(&state)); + + if (state.exception) { + std::rethrow_exception(state.exception); + } + } + + /// + /// ModelVariant + /// + + ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) + : core_(core), info_(std::move(info)), logger_(logger) { + } + + const ModelInfo& ModelVariant::GetInfo() const { + return info_; + } + + void ModelVariant::RemoveFromCache() { + try { + CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); + cachedPath_.clear(); + } + catch (const std::exception& ex) { + throw FoundryLocalException("Error removing model from cache [" + info_.name + "]: " + ex.what(), *logger_); + } + } + + void ModelVariant::Unload() const { + try { + CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); + } + catch (const std::exception& ex) { + throw FoundryLocalException("Error unloading model [" + info_.name + "]: " + ex.what(), *logger_); + } + } + + bool ModelVariant::IsLoaded() const { + std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); + for (auto& id : loadedModelIds) { + auto pos = id.find_last_of(':'); + if (pos != std::string::npos) { + id.erase(pos); + } + + if (id == info_.name) { + return true; + } + } + + return false; + } + + bool ModelVariant::IsCached() const { + auto cachedModels = GetCachedModelsInternal(core_, *logger_); + for (auto& id : cachedModels) { + StripSuffixAfterColon(id); + if (id == info_.name) { + return true; + } + } + return false; + } + + void ModelVariant::Download(DownloadProgressCallback onProgress) const { + if (IsCached()) { + logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); + return; + } + + if (onProgress) { + struct ProgressState { + DownloadProgressCallback* cb; + ILogger* logger; + } state{ &onProgress, logger_ }; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + auto* st = static_cast(user); + std::string perc(static_cast(data), static_cast((std::min)(4, static_cast(len)))); + try { + float value = std::stof(perc); + (*(st->cb))(value); + } catch (...) { + st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + } + }; + + CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, + reinterpret_cast(+nativeCallback), reinterpret_cast(&state)); + } else { + CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); + } + } + + void ModelVariant::Load() const { + CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); + } + + const std::filesystem::path& ModelVariant::GetPath() const { + if (cachedPath_.empty()) { + cachedPath_ = std::filesystem::path(CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_)); + } + return cachedPath_; + } + + const std::string& ModelVariant::GetId() const noexcept { + return info_.id; + } + + const std::string& ModelVariant::GetAlias() const noexcept { + return info_.alias; + } + + uint32_t ModelVariant::GetVersion() const noexcept { + return info_.version; + } + + AudioClient::AudioClient(gsl::not_null model) + : AudioClient(model->core_, model->info_.name, model->logger_) { + if (!model->IsLoaded()) { + throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); + } + } + + AudioClient ModelVariant::GetAudioClient() const { + return AudioClient(this); + } + + ChatClient::ChatClient(gsl::not_null model) + : ChatClient(model->core_, model->info_.name, model->logger_) { + if (!model->IsLoaded()) { + throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); + } + } + + ChatClient ModelVariant::GetChatClient() const { + return ChatClient(this); + } + + /// + /// Model + /// + Model::Model(gsl::not_null core, gsl::not_null logger) + : core_(core), logger_(logger) { + } + + ModelVariant& Model::SelectedVariant() { + if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + throw FoundryLocalException("Model has no selected variant", *logger_); + } + return variants_[*selectedVariantIndex_]; + } + + const ModelVariant& Model::SelectedVariant() const { + if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + throw FoundryLocalException("Model has no selected variant", *logger_); + } + return variants_[*selectedVariantIndex_]; + } + + gsl::span Model::GetAllModelVariants() const { + return variants_; + } + + const ModelVariant* Model::GetLatestVariant(gsl::not_null variant) const { + const auto& targetName = variant->GetInfo().name; + + for (const auto& v : variants_) { + if (v.GetInfo().name == targetName) { + return &v; + } + } + + throw FoundryLocalException( + "Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", *logger_); + } + + const std::string& Model::GetId() const { + return SelectedVariant().GetId(); + } + + const std::string& Model::GetAlias() const { + return SelectedVariant().GetAlias(); + } + + void Model::SelectVariant(gsl::not_null variant) const { + auto it = std::find_if(variants_.begin(), variants_.end(), + [&](const ModelVariant& v) { return &v == variant.get(); }); + + if (it == variants_.end()) { + throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", + *logger_); + } + + selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); + } + + /// + /// Catalog + /// + + Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) + : core_(injected), logger_(logger) { + try { + name_ = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); + } + catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error getting catalog name: ") + ex.what(), *logger_); + } + } + + std::vector Catalog::GetLoadedModels() const { + return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); + } + + std::vector Catalog::GetCachedModels() const { + return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); + } + + const Model* Catalog::GetModel(std::string_view modelId) const { + auto it = byAlias_.find(std::string(modelId)); + if (it != byAlias_.end()) { + return &it->second; + } + return nullptr; + } + + std::vector Catalog::ListModels() const { + UpdateModels(); + + std::vector out; + out.reserve(byAlias_.size()); + for (auto& kv : byAlias_) + out.emplace_back(&kv.second); + + return out; + } + + void Catalog::UpdateModels() const { + using clock = std::chrono::steady_clock; + + // TODO: make this configurable + constexpr auto kRefreshInterval = std::chrono::hours(6); + + const auto now = clock::now(); + if (lastFetch_.time_since_epoch() != clock::duration::zero() && (now - lastFetch_) < kRefreshInterval) { + return; + } + + const std::string raw = core_->call("get_model_list", *logger_); + const auto arr = nlohmann::json::parse(raw); + + byAlias_.clear(); + modelIdToModelVariant_.clear(); + + for (const auto& j : arr) { + const std::string alias = j.at("alias").get(); + if (alias.rfind("openai-", 0) == 0) + continue; + + auto it = byAlias_.find(alias); + if (it == byAlias_.end()) { + Model m(core_, logger_); + it = byAlias_.emplace(alias, std::move(m)).first; + } + + ModelInfo modelVariantInfo; + from_json(j, modelVariantInfo); + std::string variantId = modelVariantInfo.name; + ModelVariant modelVariant(core_, modelVariantInfo, logger_); + modelIdToModelVariant_.emplace(variantId, modelVariant); + + it->second.variants_.emplace_back(std::move(modelVariant)); + } + + // Auto-select the first variant for each model. + for (auto& [alias, model] : byAlias_) { + if (!model.variants_.empty()) { + model.selectedVariantIndex_ = 0; + } + } + + lastFetch_ = now; + } + + const ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + auto it = modelIdToModelVariant_.find(std::string(id)); + if (it != modelIdToModelVariant_.end()) { + return &it->second; + } + return nullptr; + } + + /// + /// FoundryLocalManager + /// + + FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) + : config_(std::move(configuration)), core_(std::make_unique()), logger_(logger ? logger : &defaultLogger_) { + static_cast(core_.get())->loadEmbedded(); + Initialize(); + catalog_ = Catalog::Create(core_.get(), logger_); + } + + FoundryLocalManager::FoundryLocalManager(FoundryLocalManager&& other) noexcept + : config_(std::move(other.config_)), + core_(std::move(other.core_)), + catalog_(std::move(other.catalog_)), + logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), + urls_(std::move(other.urls_)) { + other.logger_ = &other.defaultLogger_; + } + + FoundryLocalManager& FoundryLocalManager::operator=(FoundryLocalManager&& other) noexcept { + if (this != &other) { + config_ = std::move(other.config_); + core_ = std::move(other.core_); + catalog_ = std::move(other.catalog_); + logger_ = other.OwnsLogger() ? &defaultLogger_ : other.logger_; + urls_ = std::move(other.urls_); + other.logger_ = &other.defaultLogger_; + } + return *this; + } + + FoundryLocalManager::~FoundryLocalManager() { + // Unload all loaded models before tearing down. + if (catalog_) { + try { + auto loadedModels = catalog_->GetLoadedModels(); + for (const auto* variant : loadedModels) { + try { + variant->Unload(); + } catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error unloading model during destruction: ") + ex.what()); + } + } + } catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error retrieving loaded models during destruction: ") + ex.what()); + } + } + + if (!urls_.empty()) { + try { + StopWebService(); + } catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, std::string("Error stopping web service during destruction: ") + ex.what()); + } + } + } + + const Catalog& FoundryLocalManager::GetCatalog() const { + return *catalog_; + } + + void FoundryLocalManager::StartWebService() { + if (!config_.web) { + throw FoundryLocalException("Web service configuration was not provided.", *logger_); + } + + try { + std::string raw = core_->call("start_service", *logger_); + auto arr = nlohmann::json::parse(raw); + urls_ = arr.get>(); + } catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error starting web service: ") + ex.what(), *logger_); + } + } + + void FoundryLocalManager::StopWebService() { + if (!config_.web) { + throw FoundryLocalException("Web service configuration was not provided.", *logger_); + } + + try { + core_->call("stop_service", *logger_); + urls_.clear(); + } catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error stopping web service: ") + ex.what(), *logger_); + } + } + + gsl::span FoundryLocalManager::GetUrls() const noexcept { + return urls_; + } + + void FoundryLocalManager::EnsureEpsDownloaded() const { + try { + core_->call("ensure_eps_downloaded", *logger_); + } catch (const std::exception& ex) { + throw FoundryLocalException( + std::string("Error ensuring execution providers downloaded: ") + ex.what(), *logger_); + } + } + + void FoundryLocalManager::Initialize() { + config_.Validate(); + + try { + CoreInteropRequest initReq("initialize"); + initReq.AddParam("AppName", config_.app_name); + initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); + + if (config_.app_data_dir) { + initReq.AddParam("AppDataDir", config_.app_data_dir->string()); + } + if (config_.logs_dir) { + initReq.AddParam("LogsDir", config_.logs_dir->string()); + } + if (config_.web && config_.web->urls) { + initReq.AddParam("WebServiceUrls", *config_.web->urls); + } + if (config_.additional_settings) { + for (const auto& [key, value] : *config_.additional_settings) { + if (!key.empty()) { + initReq.AddParam(key, value); + } + } + } + + std::string initJson = initReq.ToJson(); + core_->call(initReq.Command(), *logger_, &initJson); + + if (config_.model_cache_dir) { + std::string current = core_->call("get_cache_directory", *logger_); + + if (current != config_.model_cache_dir->string()) { + CoreInteropRequest setReq("set_cache_directory"); + setReq.AddParam("Directory", config_.model_cache_dir->string()); + std::string setJson = setReq.ToJson(); + core_->call(setReq.Command(), *logger_, &setJson); + + logger_->Log(LogLevel::Information, + std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); + } + else { + logger_->Log(LogLevel::Information, std::string("Model cache directory already set to: ") + current); + } + } + } + catch (const std::exception& ex) { + throw FoundryLocalException(std::string("FoundryLocalManager::Initialize failed: ") + ex.what(), *logger_); + } + } + +} // namespace FoundryLocal diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp new file mode 100644 index 00000000..e40d7c11 --- /dev/null +++ b/sdk/cpp/test/catalog_test.cpp @@ -0,0 +1,372 @@ +#include + +#include +#include +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +using Factory = MockObjectFactory; + +class CatalogTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeModelListJson(const std::vector>& models) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& [name, alias] : models) { + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson(name, alias))); + } + return arr.dump(); + } + + std::unique_ptr MakeCatalog() { + core_.OnCall("get_catalog_name", "test-catalog"); + return Factory::CreateCatalog(&core_, &logger_); + } +}; + +TEST_F(CatalogTest, GetName) { + auto catalog = MakeCatalog(); + EXPECT_EQ("test-catalog", catalog->GetName()); +} + +TEST_F(CatalogTest, Create_ThrowsOnCoreError) { + core_.OnCallThrow("get_catalog_name", "catalog error"); + EXPECT_THROW(MockObjectFactory::CreateCatalog(&core_, &logger_), FoundryLocalException); +} + +TEST_F(CatalogTest, ListModels_Empty) { + core_.OnCall("get_model_list", "[]"); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + EXPECT_TRUE(models.empty()); +} + +TEST_F(CatalogTest, ListModels_SingleModel) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("my-model", models[0]->GetAlias()); +} + +TEST_F(CatalogTest, ListModels_MultipleVariantsSameAlias) { + // Two variants of the same model (same alias, different names) + nlohmann::json arr = nlohmann::json::array(); + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson("model-v1", "my-model", 1))); + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson("model-v2", "my-model", 2))); + core_.OnCall("get_model_list", arr.dump()); + + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + + // Should be grouped into one Model + ASSERT_EQ(1u, models.size()); + EXPECT_EQ(2u, models[0]->GetAllModelVariants().size()); +} + +TEST_F(CatalogTest, ListModels_DifferentAliases) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "alias-a"}, {"model-b", "alias-b"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + EXPECT_EQ(2u, models.size()); +} + +TEST_F(CatalogTest, ListModels_FiltersOpenAIPrefix) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "my-model"}, {"openai-model", "openai-stuff"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("my-model", models[0]->GetAlias()); +} + +TEST_F(CatalogTest, GetModel_Found) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto* model = catalog->GetModel("my-model"); + ASSERT_NE(nullptr, model); + EXPECT_EQ("my-model", model->GetAlias()); +} + +TEST_F(CatalogTest, GetModel_NotFound) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); +} + +TEST_F(CatalogTest, GetModelVariant_Found) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto* variant = catalog->GetModelVariant("model-1"); + ASSERT_NE(nullptr, variant); + EXPECT_EQ("model-1", variant->GetId()); +} + +TEST_F(CatalogTest, GetModelVariant_NotFound) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + catalog->ListModels(); + + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent")); +} + +TEST_F(CatalogTest, GetLoadedModels) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); + core_.OnCall("list_loaded_models", R"(["model-1:v1"])"); + + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto loaded = catalog->GetLoadedModels(); + ASSERT_EQ(1u, loaded.size()); + EXPECT_EQ("model-1", loaded[0]->GetId()); +} + +TEST_F(CatalogTest, GetCachedModels) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); + core_.OnCall("get_cached_models", R"(["model-1:1", "model-2:1"])"); + + auto catalog = MakeCatalog(); + catalog->ListModels(); // populate + + auto cached = catalog->GetCachedModels(); + EXPECT_EQ(2u, cached.size()); +} + +TEST_F(CatalogTest, ListModels_CachesResults) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + catalog->ListModels(); + catalog->ListModels(); + + // Should only call get_model_list once due to caching + EXPECT_EQ(1, core_.GetCallCount("get_model_list")); +} + +class FileBasedCatalogTest : public ::testing::Test { +protected: + NullLogger logger_; + + static std::string TestDataPath(const std::string& filename) { return "testdata/" + filename; } +}; + +TEST_F(FileBasedCatalogTest, RealModelsList) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(2u, models.size()); + + int phi_models = 0, mistral_models = 0; + size_t phi_variants = 0, mistral_variants = 0; + + for (const auto* model : models) { + if (model->GetAlias() == "phi-4") { + phi_models++; + phi_variants = model->GetAllModelVariants().size(); + } + else if (model->GetAlias() == "mistral-7b-v0.2") { + mistral_models++; + mistral_variants = model->GetAllModelVariants().size(); + } + } + + EXPECT_EQ(1, phi_models); + EXPECT_EQ(1, mistral_models); + EXPECT_EQ(2u, phi_variants); + EXPECT_EQ(2u, mistral_variants); +} + +TEST_F(FileBasedCatalogTest, RealModelsList_VariantDetails) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + const auto* gpuVariant = catalog->GetModelVariant("Phi-4-generic-gpu"); + ASSERT_NE(nullptr, gpuVariant); + + const auto& info = gpuVariant->GetInfo(); + EXPECT_EQ("Phi-4-generic-gpu", info.id); + EXPECT_EQ("Phi-4-generic-gpu", info.name); + EXPECT_EQ("phi-4", info.alias); + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("Phi-4 (GPU)", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("Microsoft", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("MIT", *info.license); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::GPU, info.runtime->device_type); + EXPECT_EQ("DML", info.runtime->execution_provider); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(8192u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(4096, *info.max_output_tokens); + ASSERT_TRUE(info.prompt_template.has_value()); + EXPECT_EQ("<|system|>", info.prompt_template->system); + EXPECT_EQ("<|user|>", info.prompt_template->user); + EXPECT_EQ("<|assistant|>", info.prompt_template->assistant); + EXPECT_EQ("<|prompt|>", info.prompt_template->prompt); +} + +TEST_F(FileBasedCatalogTest, RealModelsList_CpuVariantDetails) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + const auto* cpuVariant = catalog->GetModelVariant("Phi-4-generic-cpu"); + ASSERT_NE(nullptr, cpuVariant); + + const auto& info = cpuVariant->GetInfo(); + EXPECT_EQ("Phi-4-generic-cpu", info.name); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::CPU, info.runtime->device_type); + EXPECT_EQ("ORT", info.runtime->execution_provider); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(4096u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(*info.supports_tool_calling); + EXPECT_FALSE(info.prompt_template.has_value()); +} + +TEST_F(FileBasedCatalogTest, EmptyModelsList) { + auto core = FileBackedCore::FromModelList(TestDataPath("empty_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + EXPECT_TRUE(models.empty()); +} + +TEST_F(FileBasedCatalogTest, MalformedJson) { + auto core = FileBackedCore::FromModelList(TestDataPath("malformed_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_ANY_THROW(catalog->ListModels()); +} + +TEST_F(FileBasedCatalogTest, MissingNameField) { + auto core = FileBackedCore::FromModelList(TestDataPath("missing_name_field_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + try { + catalog->ListModels(); + FAIL() << "Expected exception for missing 'name' field"; + } + catch (const std::exception& e) { + std::string msg = e.what(); + EXPECT_NE(std::string::npos, msg.find("name")) << "Actual: " << msg; + } +} + +TEST_F(FileBasedCatalogTest, CachedModels) { + auto core = + FileBackedCore::FromBoth(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate internal maps + + auto cached = catalog->GetCachedModels(); + ASSERT_EQ(2u, cached.size()); + + std::vector names; + names.reserve(cached.size()); + for (const auto* mv : cached) + names.push_back(mv->GetInfo().name); + + EXPECT_NE(std::find(names.begin(), names.end(), "Phi-4-generic-gpu"), names.end()); + EXPECT_NE(std::find(names.begin(), names.end(), "Phi-4-generic-cpu"), names.end()); +} + +TEST_F(FileBasedCatalogTest, CoreErrorOnModelList) { + auto core = FileBackedCore::FromModelList("testdata/nonexistent_file.json"); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_ANY_THROW(catalog->ListModels()); +} + +TEST_F(FileBasedCatalogTest, MixedOpenAIAndLocal_FiltersOpenAIPrefix) { + auto core = FileBackedCore::FromModelList(TestDataPath("mixed_openai_and_local.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("phi-4", models[0]->GetAlias()); +} + +TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel) { + auto core = FileBackedCore::FromModelList(TestDataPath("three_variants_one_model.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ(3u, models[0]->GetAllModelVariants().size()); +} + +TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel_CachedSubset) { + auto core = FileBackedCore::FromBoth(TestDataPath("three_variants_one_model.json"), + TestDataPath("single_cached_model.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + auto cached = catalog->GetCachedModels(); + ASSERT_EQ(1u, cached.size()); + EXPECT_EQ("multi-v1-cpu", cached[0]->GetInfo().name); +} + +TEST_F(FileBasedCatalogTest, GetModelByAlias) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + const auto* model = catalog->GetModel("phi-4"); + ASSERT_NE(nullptr, model); + EXPECT_EQ("phi-4", model->GetAlias()); + EXPECT_EQ(2u, model->GetAllModelVariants().size()); + + const auto* missing = catalog->GetModel("nonexistent-alias"); + EXPECT_EQ(nullptr, missing); +} + +TEST_F(FileBasedCatalogTest, GetModelVariant_NotInCatalog) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent-variant-id")); +} + +TEST_F(FileBasedCatalogTest, LoadedModels) { + auto core = FileBackedCore::FromAll(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json"), + TestDataPath("valid_loaded_models.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + catalog->ListModels(); // populate + + auto loaded = catalog->GetLoadedModels(); + ASSERT_EQ(1u, loaded.size()); + EXPECT_EQ("Phi-4-generic-gpu", loaded[0]->GetInfo().name); +} diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp new file mode 100644 index 00000000..0857bc92 --- /dev/null +++ b/sdk/cpp/test/client_test.cpp @@ -0,0 +1,541 @@ +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +using Factory = MockObjectFactory; + +class ChatClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeChatResponseJson(const std::string& content = "Hello!") { + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-test"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", content}}}}}}}; + return resp.dump(); + } + + ModelVariant MakeLoadedVariant(const std::string& name = "chat-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":v1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(ChatClientTest, CompleteChat_BasicResponse) { + core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "Say hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + EXPECT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("Hello world!", response.choices[0].message->content); +} + +TEST_F(ChatClientTest, CompleteChat_WithSettings) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 100; + settings.top_p = 0.9f; + settings.frequency_penalty = 0.5f; + settings.presence_penalty = 0.3f; + settings.n = 2; + settings.random_seed = 42; + settings.top_k = 10; + + auto response = client.CompleteChat(messages, settings); + + // Verify the request JSON contains the settings + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_NEAR(0.7f, openAiReq["temperature"].get(), 0.001f); + EXPECT_EQ(100, openAiReq["max_completion_tokens"].get()); + EXPECT_NEAR(0.9f, openAiReq["top_p"].get(), 0.001f); + EXPECT_NEAR(0.5f, openAiReq["frequency_penalty"].get(), 0.001f); + EXPECT_NEAR(0.3f, openAiReq["presence_penalty"].get(), 0.001f); + EXPECT_EQ(2, openAiReq["n"].get()); + EXPECT_EQ(42, openAiReq["seed"].get()); + EXPECT_EQ(10, openAiReq["metadata"]["top_k"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_RequestFormat) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_EQ("chat-model", openAiReq["model"].get()); + EXPECT_FALSE(openAiReq["stream"].get()); + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_EQ("system", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(ChatClientTest, CompleteChatStreaming) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hello"}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", {{{"index", 0}, {"finish_reason", "stop"}, {"delta", {{"content", " world"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + cb(s1.data(), static_cast(s1.size()), userData); + cb(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + std::vector chunks; + client.CompleteChatStreaming(messages, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_TRUE(chunks[0].is_delta); + ASSERT_TRUE(chunks[0].choices[0].delta.has_value()); + EXPECT_EQ("Hello", chunks[0].choices[0].delta->content); + EXPECT_EQ(" world", chunks[1].choices[0].delta->content); +} + +TEST_F(ChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { + nlohmann::json chunk = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string s = chunk.dump(); + cb(s.data(), static_cast(s.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChatStreaming( + messages, settings, + [](const ChatCompletionCreateResponse&) { throw std::runtime_error("callback error"); }), + std::runtime_error); +} + +TEST_F(ChatClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(ChatClient client(&variant), FoundryLocalException); +} + +TEST_F(ChatClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + EXPECT_EQ("chat-model", client.GetModelId()); +} + +// ---------- Tool calling tests ---------- + +TEST_F(ChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", + "A tool for multiplying two numbers.", + PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} + } + } + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + // Verify the request JSON contains tools and tool_choice + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_TRUE(openAiReq.contains("tools")); + ASSERT_TRUE(openAiReq["tools"].is_array()); + EXPECT_EQ(1u, openAiReq["tools"].size()); + EXPECT_EQ("function", openAiReq["tools"][0]["type"].get()); + EXPECT_EQ("multiply_numbers", openAiReq["tools"][0]["function"]["name"].get()); + EXPECT_EQ("A tool for multiplying two numbers.", openAiReq["tools"][0]["function"]["description"].get()); + EXPECT_EQ("object", openAiReq["tools"][0]["function"]["parameters"]["type"].get()); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"].contains("properties")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("first")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("second")); + + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_FALSE(openAiReq.contains("tools")); + EXPECT_FALSE(openAiReq.contains("tool_choice")); +} + +TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { + // Simulate a response with tool calls from the model + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", resp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_TRUE(response.choices[0].message.has_value()); + + const auto& msg = *response.choices[0].message; + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_1", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("multiply_numbers", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"first\": 7, \"second\": 6}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(ChatClientTest, CompleteChat_ToolChoiceAuto) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Auto; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_ToolChoiceNone) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::None; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("none", openAiReq["tool_choice"].get()); +} + +TEST_F(ChatClientTest, CompleteChat_ToolMessageWithToolCallId) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "42"; + toolMsg.tool_call_id = "call_1"; + + std::vector messages = { + {"user", "What is 7 * 6?", {}}, + std::move(toolMsg) + }; + ChatSettings settings; + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_FALSE(openAiReq["messages"][0].contains("tool_call_id")); + EXPECT_EQ("call_1", openAiReq["messages"][1]["tool_call_id"].get()); + EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(ChatClientTest, CompleteChatStreaming_WithTools) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", nullptr}, + {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"delta", + {{"content", ""}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + cb(s1.data(), static_cast(s1.size()), userData); + cb(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + + auto variant = MakeLoadedVariant(); + ChatClient client(&variant); + + std::vector messages = {{"user", "test", {}}}; + + std::vector tools = {{ + "function", + FunctionDefinition{"multiply", "Multiply numbers."} + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + std::vector chunks; + client.CompleteChatStreaming(messages, tools, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ(FinishReason::ToolCalls, chunks[1].choices[0].finish_reason); + ASSERT_TRUE(chunks[1].choices[0].delta.has_value()); + ASSERT_EQ(1u, chunks[1].choices[0].delta->tool_calls.size()); + EXPECT_EQ("multiply", chunks[1].choices[0].delta->tool_calls[0].function_call->name); + + // Verify tools were included in the request + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_TRUE(openAiReq.contains("tools")); + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +class AudioClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeLoadedVariant(const std::string& name = "audio-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":v1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(AudioClientTest, TranscribeAudio) { + core_.OnCall("audio_transcribe", "Hello world transcribed text"); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + auto response = client.TranscribeAudio("test.wav"); + + EXPECT_EQ("Hello world transcribed text", response.text); +} + +TEST_F(AudioClientTest, TranscribeAudio_RequestFormat) { + core_.OnCall("audio_transcribe", "text"); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + client.TranscribeAudio("audio.wav"); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("audio-model", openAiReq["Model"].get()); + EXPECT_EQ("audio.wav", openAiReq["FileName"].get()); +} + +TEST_F(AudioClientTest, TranscribeAudioStreaming) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string text1 = "Hello "; + std::string text2 = "world!"; + cb(text1.data(), static_cast(text1.size()), userData); + cb(text2.data(), static_cast(text2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + + std::vector chunks; + client.TranscribeAudioStreaming( + "test.wav", [&](const AudioCreateTranscriptionResponse& chunk) { chunks.push_back(chunk.text); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ("Hello ", chunks[0]); + EXPECT_EQ("world!", chunks[1]); +} + +TEST_F(AudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string text = "test"; + cb(text.data(), static_cast(text.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + + EXPECT_THROW( + client.TranscribeAudioStreaming( + "test.wav", [](const AudioCreateTranscriptionResponse&) { throw std::runtime_error("streaming error"); }), + std::runtime_error); +} + +TEST_F(AudioClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(AudioClient client(&variant), FoundryLocalException); +} + +TEST_F(AudioClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + auto variant = MakeLoadedVariant(); + AudioClient client(&variant); + EXPECT_EQ("audio-model", client.GetModelId()); +} diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h new file mode 100644 index 00000000..b7aa349d --- /dev/null +++ b/sdk/cpp/test/mock_core.h @@ -0,0 +1,148 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "foundry_local_internal_core.h" +#include "logger.h" + +namespace FoundryLocal::Testing { + + /// A mock implementation of IFoundryLocalCore for unit testing. + /// Register expected command -> response mappings before use. + class MockCore final : public Internal::IFoundryLocalCore { + public: + using CallbackFn = void (*)(void*, int32_t, void*); + + /// Handler signature: (command, dataArgument, callback, userData) -> response string. + using Handler = std::function; + + /// Register a fixed response for a command. + void OnCall(std::string command, std::string response) { + handlers_[std::move(command)] = [r = std::move(response)](std::string_view, const std::string*, void*, + void*) { return r; }; + } + + /// Register a custom handler for a command. + void OnCall(std::string command, Handler handler) { handlers_[std::move(command)] = std::move(handler); } + + /// Register a handler that throws for a command. + void OnCallThrow(std::string command, std::string errorMessage) { + handlers_[std::move(command)] = [msg = std::move(errorMessage)](std::string_view, const std::string*, void*, + void*) -> std::string { + throw std::runtime_error(msg); + }; + } + + /// Returns the number of times a command was called. + int GetCallCount(const std::string& command) const { + auto it = callCounts_.find(command); + return it != callCounts_.end() ? it->second : 0; + } + + /// Returns the last data argument passed for a command. + const std::string& GetLastDataArg(const std::string& command) const { + auto it = lastDataArgs_.find(command); + if (it == lastDataArgs_.end()) { + static const std::string empty; + return empty; + } + return it->second; + } + + // IFoundryLocalCore implementation + std::string call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, + void* callback = nullptr, void* data = nullptr) const override { + + std::string cmd(command); + const_cast(this)->callCounts_[cmd]++; + if (dataArgument) { + const_cast(this)->lastDataArgs_[cmd] = *dataArgument; + } + + auto it = handlers_.find(cmd); + if (it == handlers_.end()) { + throw std::runtime_error("MockCore: no handler registered for command '" + cmd + "'"); + } + + return it->second(command, dataArgument, callback, data); + } + + void unload() override {} + + private: + std::unordered_map handlers_; + std::unordered_map callCounts_; + std::unordered_map lastDataArgs_; + }; + + /// Read a file into a string. Throws on failure. + inline std::string ReadFile(const std::string& path) { + std::ifstream in(path, std::ios::in | std::ios::binary); + if (!in) + throw std::runtime_error("Failed to open test data file: " + path); + std::ostringstream contents; + contents << in.rdbuf(); + return contents.str(); + } + + /// A mock core that reads model list, cached models and loaded models from JSON files on disk. + class FileBackedCore final : public Internal::IFoundryLocalCore { + public: + FileBackedCore(std::string modelListPath, std::string cachedModelsPath, std::string loadedModelsPath = "") + : modelListPath_(std::move(modelListPath)), cachedModelsPath_(std::move(cachedModelsPath)), + loadedModelsPath_(std::move(loadedModelsPath)) {} + + static FileBackedCore FromModelList(const std::string& path) { return FileBackedCore(path, ""); } + + static FileBackedCore FromBoth(const std::string& modelListPath, const std::string& cachedModelsPath) { + return FileBackedCore(modelListPath, cachedModelsPath); + } + + static FileBackedCore FromAll(const std::string& modelListPath, const std::string& cachedModelsPath, + const std::string& loadedModelsPath) { + return FileBackedCore(modelListPath, cachedModelsPath, loadedModelsPath); + } + + std::string call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, + void* /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { + + if (command == "get_catalog_name") + return "TestCatalog"; + + if (command == "get_model_list") { + if (modelListPath_.empty()) + return "[]"; + return ReadFile(modelListPath_); + } + + if (command == "get_cached_models") { + if (cachedModelsPath_.empty()) + return "[]"; + return ReadFile(cachedModelsPath_); + } + + if (command == "list_loaded_models") { + if (loadedModelsPath_.empty()) + return "[]"; + return ReadFile(loadedModelsPath_); + } + + return "{}"; + } + + void unload() override {} + + private: + std::string modelListPath_; + std::string cachedModelsPath_; + std::string loadedModelsPath_; + }; + +} // namespace FoundryLocal::Testing diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h new file mode 100644 index 00000000..9d029aec --- /dev/null +++ b/sdk/cpp/test/mock_object_factory.h @@ -0,0 +1,61 @@ +#pragma once + +#ifndef FL_TESTS +#define FL_TESTS +#endif + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "logger.h" + +namespace FoundryLocal::Testing { + + /// Factory to construct private-constructor types for testing. + /// Declared as a friend (Testing::MockObjectFactory) in ModelVariant, Model, and Catalog when FL_TESTS is defined. + struct MockObjectFactory { + static ModelVariant CreateModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) { + return ModelVariant(core, std::move(info), logger); + } + + static std::unique_ptr CreateCatalog(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + static Model CreateModel(gsl::not_null core, gsl::not_null logger) { + return Model(core, logger); + } + + /// Push a variant into a Model's internal variant list. + static void AddVariantToModel(Model& model, ModelVariant variant) { + model.variants_.push_back(std::move(variant)); + } + + /// Set the selected variant index on a Model. + static void SetSelectedVariantIndex(Model& model, size_t index) { model.selectedVariantIndex_ = index; } + + /// Helper to build a minimal ModelInfo with defaults. + static ModelInfo MakeModelInfo(std::string name, std::string alias = "", uint32_t version = 1) { + ModelInfo info; + info.id = name; + info.name = std::move(name); + info.alias = alias.empty() ? info.name : std::move(alias); + info.version = version; + info.provider_type = "test"; + info.uri = "test://uri"; + info.model_type = "text"; + return info; + } + + /// Helper to build a JSON string representing a model list entry. + static std::string MakeModelInfoJson(const std::string& name, const std::string& alias = "", + uint32_t version = 1, bool cached = false) { + std::string a = alias.empty() ? name : alias; + return R"({"id":")" + name + R"(","name":")" + name + R"(","version":)" + std::to_string(version) + + R"(,"alias":")" + a + R"(","providerType":"test","uri":"test://uri","modelType":"text","cached":)" + + (cached ? "true" : "false") + R"(,"createdAt":0})"; + } + }; + +} // namespace FoundryLocal::Testing diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp new file mode 100644 index 00000000..5ecbb696 --- /dev/null +++ b/sdk/cpp/test/model_variant_test.cpp @@ -0,0 +1,251 @@ +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +using Factory = MockObjectFactory; + +class ModelVariantTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeVariant(std::string name = "test-model", std::string alias = "test-alias", uint32_t version = 1) { + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); + } +}; + +TEST_F(ModelVariantTest, GetInfo) { + auto variant = MakeVariant("my-model", "my-alias", 3); + const auto& info = variant.GetInfo(); + EXPECT_EQ("my-model", info.name); + EXPECT_EQ("my-alias", info.alias); + EXPECT_EQ(3u, info.version); +} + +TEST_F(ModelVariantTest, GetId) { + auto variant = MakeVariant("my-model"); + EXPECT_EQ("my-model", variant.GetId()); +} + +TEST_F(ModelVariantTest, GetAlias) { + auto variant = MakeVariant("name", "alias"); + EXPECT_EQ("alias", variant.GetAlias()); +} + +TEST_F(ModelVariantTest, GetVersion) { + auto variant = MakeVariant("name", "alias", 5); + EXPECT_EQ(5u, variant.GetVersion()); +} + +TEST_F(ModelVariantTest, IsLoaded_True) { + core_.OnCall("list_loaded_models", R"(["test-model:v1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_TRUE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsLoaded_False) { + core_.OnCall("list_loaded_models", R"(["other-model:v1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsLoaded_EmptyList) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsCached_True) { + core_.OnCall("get_cached_models", R"(["test-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_TRUE(variant.IsCached()); +} + +TEST_F(ModelVariantTest, IsCached_False) { + core_.OnCall("get_cached_models", R"(["other-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsCached()); +} + +TEST_F(ModelVariantTest, Load_CallsCore) { + core_.OnCall("load_model", ""); + auto variant = MakeVariant("test-model"); + variant.Load(); + EXPECT_EQ(1, core_.GetCallCount("load_model")); + + // Verify the data argument contains the model name + auto parsed = nlohmann::json::parse(core_.GetLastDataArg("load_model")); + EXPECT_EQ("test-model", parsed["Params"]["Model"].get()); +} + +TEST_F(ModelVariantTest, Unload_CallsCore) { + core_.OnCall("unload_model", ""); + auto variant = MakeVariant("test-model"); + variant.Unload(); + EXPECT_EQ(1, core_.GetCallCount("unload_model")); +} + +TEST_F(ModelVariantTest, Unload_ThrowsOnError) { + core_.OnCallThrow("unload_model", "unload failed"); + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Unload(), FoundryLocalException); +} + +TEST_F(ModelVariantTest, Download_NoCallback) { +core_.OnCall("get_cached_models", R"([])"); +core_.OnCall("download_model", ""); +auto variant = MakeVariant("test-model"); +variant.Download(); + EXPECT_EQ(1, core_.GetCallCount("download_model")); +} + +TEST_F(ModelVariantTest, Download_WithCallback) { +core_.OnCall("get_cached_models", R"([])"); +core_.OnCall("download_model", + [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + // Simulate calling the progress callback + if (callback && userData) { + auto cb = reinterpret_cast(callback); + std::string progress = "50"; + cb(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + float lastProgress = -1.0f; + variant.Download([&](float pct) { lastProgress = pct; }); + EXPECT_NEAR(50.0f, lastProgress, 0.01f); +} + +TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { + core_.OnCall("remove_cached_model", ""); + auto variant = MakeVariant("test-model"); + variant.RemoveFromCache(); + EXPECT_EQ(1, core_.GetCallCount("remove_cached_model")); +} + +TEST_F(ModelVariantTest, RemoveFromCache_ThrowsOnError) { + core_.OnCallThrow("remove_cached_model", "remove failed"); + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.RemoveFromCache(), FoundryLocalException); +} + +TEST_F(ModelVariantTest, GetPath_CallsCore) { + core_.OnCall("get_model_path", R"(C:\models\test)"); + auto variant = MakeVariant("test-model"); + const auto& path = variant.GetPath(); + EXPECT_EQ(std::filesystem::path(R"(C:\models\test)"), path); +} + +TEST_F(ModelVariantTest, GetPath_CachesResult) { + core_.OnCall("get_model_path", R"(C:\models\test)"); + auto variant = MakeVariant("test-model"); + variant.GetPath(); + variant.GetPath(); + // Should only call once due to caching + EXPECT_EQ(1, core_.GetCallCount("get_model_path")); +} + +class ModelTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + Model MakeModel() { return Factory::CreateModel(&core_, &logger_); } + + ModelVariant MakeVariant(std::string name = "test-model", std::string alias = "test-alias", uint32_t version = 1) { + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); + } + + /// Helper: create a Model with one variant and selectedVariantIndex_=0. + Model MakeModelWithVariant(const std::string& name = "test-model", const std::string& alias = "test-alias") { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant(name, alias, 1)); + Factory::SetSelectedVariantIndex(model, 0); + return model; + } +}; + +TEST_F(ModelTest, SelectedVariant_ThrowsWhenEmpty) { + auto model = MakeModel(); + EXPECT_THROW(model.GetId(), FoundryLocalException); +} + +TEST_F(ModelTest, AddVariant_AndSelect) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SetSelectedVariantIndex(model, 0); + + EXPECT_EQ("v1", model.GetId()); + EXPECT_EQ("alias", model.GetAlias()); +} + +TEST_F(ModelTest, GetAllModelVariants) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SetSelectedVariantIndex(model, 0); + + auto variants = model.GetAllModelVariants(); + EXPECT_EQ(2u, variants.size()); +} + +TEST_F(ModelTest, SelectVariant) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SetSelectedVariantIndex(model, 0); + + const auto* v2 = &model.GetAllModelVariants()[1]; + model.SelectVariant(v2); + EXPECT_EQ("v2", model.GetId()); +} + +TEST_F(ModelTest, SelectVariant_NotFound_Throws) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SetSelectedVariantIndex(model, 0); + + auto external = MakeVariant("external", "alias", 1); + EXPECT_THROW(model.SelectVariant(&external), FoundryLocalException); +} + +TEST_F(ModelTest, GetLatestVariant) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); + Factory::SetSelectedVariantIndex(model, 0); + + const auto* first = &model.GetAllModelVariants()[0]; + const auto* latest = model.GetLatestVariant(first); + // Should return the first one with matching name (which is variants_[0]) + EXPECT_EQ(first, latest); +} + +TEST_F(ModelTest, DelegationMethods) { + // Test that Model delegates to SelectedVariant + core_.OnCall("list_loaded_models", R"(["test-model:v1"])"); + core_.OnCall("get_cached_models", R"(["test-model:1"])"); + core_.OnCall("load_model", ""); + core_.OnCall("unload_model", ""); + core_.OnCall("download_model", ""); + core_.OnCall("get_model_path", R"(C:\test)"); + + auto model = MakeModelWithVariant("test-model", "alias"); + + EXPECT_TRUE(model.IsLoaded()); + EXPECT_TRUE(model.IsCached()); + model.Load(); + model.Unload(); + model.Download(); + EXPECT_EQ(std::filesystem::path(R"(C:\test)"), model.GetPath()); +} diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp new file mode 100644 index 00000000..4515c3a0 --- /dev/null +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -0,0 +1,592 @@ +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" + +#include + +using namespace FoundryLocal; +using namespace FoundryLocal::Testing; + +class ParserTest : public ::testing::Test { +protected: + static nlohmann::json MinimalModelJson() { + return nlohmann::json{{"id", "model-1"}, {"name", "model-1"}, {"version", 1}, + {"alias", "my-model"}, {"providerType", "onnx"}, {"uri", "https://example.com/model"}, + {"modelType", "text"}, {"cached", false}, {"createdAt", 1700000000}}; + } +}; + +TEST_F(ParserTest, ParseDeviceType_CPU) { + EXPECT_EQ(DeviceType::CPU, parse_device_type("CPU")); +} + +TEST_F(ParserTest, ParseDeviceType_GPU) { + EXPECT_EQ(DeviceType::GPU, parse_device_type("GPU")); +} + +TEST_F(ParserTest, ParseDeviceType_NPU) { + EXPECT_EQ(DeviceType::NPU, parse_device_type("NPU")); +} + +TEST_F(ParserTest, ParseDeviceType_Unknown) { + EXPECT_EQ(DeviceType::Invalid, parse_device_type("FPGA")); +} + +TEST_F(ParserTest, ParseFinishReason_Stop) { + EXPECT_EQ(FinishReason::Stop, parse_finish_reason("stop")); +} + +TEST_F(ParserTest, ParseFinishReason_Length) { + EXPECT_EQ(FinishReason::Length, parse_finish_reason("length")); +} + +TEST_F(ParserTest, ParseFinishReason_ToolCalls) { + EXPECT_EQ(FinishReason::ToolCalls, parse_finish_reason("tool_calls")); +} + +TEST_F(ParserTest, ParseFinishReason_ContentFilter) { + EXPECT_EQ(FinishReason::ContentFilter, parse_finish_reason("content_filter")); +} + +TEST_F(ParserTest, ParseFinishReason_None) { + EXPECT_EQ(FinishReason::None, parse_finish_reason("unknown_value")); +} + +TEST_F(ParserTest, GetStringOrEmpty_Present) { + nlohmann::json j = {{"key", "value"}}; + EXPECT_EQ("value", get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetStringOrEmpty_Missing) { + nlohmann::json j = {{"other", "value"}}; + EXPECT_EQ("", get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetStringOrEmpty_NonString) { + nlohmann::json j = {{"key", 42}}; + EXPECT_EQ("", get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetOptString_Present) { + nlohmann::json j = {{"key", "hello"}}; + auto result = get_opt_string(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ("hello", *result); +} + +TEST_F(ParserTest, GetOptString_Null) { + nlohmann::json j = {{"key", nullptr}}; + EXPECT_FALSE(get_opt_string(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptString_Missing) { + nlohmann::json j = {{"other", "v"}}; + EXPECT_FALSE(get_opt_string(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptInt_Present) { + nlohmann::json j = {{"key", 42}}; + auto result = get_opt_int(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(42, *result); +} + +TEST_F(ParserTest, GetOptInt_Missing) { + nlohmann::json j = {}; + EXPECT_FALSE(get_opt_int(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptBool_Present) { + nlohmann::json j = {{"key", true}}; + auto result = get_opt_bool(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(*result); +} + +TEST_F(ParserTest, GetOptBool_Missing) { + nlohmann::json j = {}; + EXPECT_FALSE(get_opt_bool(j, "key").has_value()); +} + +TEST_F(ParserTest, ParseRuntime) { + nlohmann::json j = {{"deviceType", "GPU"}, {"executionProvider", "DML"}}; + Runtime r = j.get(); + EXPECT_EQ(DeviceType::GPU, r.device_type); + EXPECT_EQ("DML", r.execution_provider); +} + +TEST_F(ParserTest, ParsePromptTemplate) { + nlohmann::json j = {{"system", "sys"}, {"user", "usr"}, {"assistant", "asst"}, {"prompt", "p"}}; + PromptTemplate pt = j.get(); + EXPECT_EQ("sys", pt.system); + EXPECT_EQ("usr", pt.user); + EXPECT_EQ("asst", pt.assistant); + EXPECT_EQ("p", pt.prompt); +} + +TEST_F(ParserTest, ParsePromptTemplate_MissingFields) { + nlohmann::json j = {{"system", "sys"}}; + PromptTemplate pt = j.get(); + EXPECT_EQ("sys", pt.system); + EXPECT_EQ("", pt.user); + EXPECT_EQ("", pt.assistant); + EXPECT_EQ("", pt.prompt); +} + +TEST_F(ParserTest, ParseModelInfo_Minimal) { + auto j = MinimalModelJson(); + ModelInfo info = j.get(); + EXPECT_EQ("model-1", info.id); + EXPECT_EQ("model-1", info.name); + EXPECT_EQ(1u, info.version); + EXPECT_EQ("my-model", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("https://example.com/model", info.uri); + EXPECT_EQ("text", info.model_type); + EXPECT_FALSE(info.cached); + EXPECT_EQ(1700000000, info.created_at_unix); + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(ParserTest, ParseModelInfo_WithOptionals) { + auto j = MinimalModelJson(); + j["displayName"] = "My Model"; + j["publisher"] = "TestPublisher"; + j["license"] = "MIT"; + j["fileSizeMb"] = 512; + j["supportsToolCalling"] = true; + j["maxOutputTokens"] = 4096; + j["runtime"] = {{"deviceType", "CPU"}, {"executionProvider", "ORT"}}; + + ModelInfo info = j.get(); + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("My Model", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("TestPublisher", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("MIT", *info.license); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(512u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(4096, *info.max_output_tokens); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::CPU, info.runtime->device_type); + EXPECT_EQ("ORT", info.runtime->execution_provider); +} + +TEST_F(ParserTest, ParseModelSettings) { + nlohmann::json j = {{"parameters", {{{"name", "p1"}, {"value", "v1"}}, {{"name", "p2"}}}}}; + ModelSettings ms = j.get(); + ASSERT_EQ(2u, ms.parameters.size()); + EXPECT_EQ("p1", ms.parameters[0].name); + ASSERT_TRUE(ms.parameters[0].value.has_value()); + EXPECT_EQ("v1", *ms.parameters[0].value); + EXPECT_EQ("p2", ms.parameters[1].name); + EXPECT_FALSE(ms.parameters[1].value.has_value()); +} + +TEST_F(ParserTest, ParseChatMessage) { + nlohmann::json j = {{"role", "user"}, {"content", "hello"}}; + ChatMessage msg = j.get(); + EXPECT_EQ("user", msg.role); + EXPECT_EQ("hello", msg.content); + EXPECT_TRUE(msg.tool_calls.empty()); + EXPECT_FALSE(msg.tool_call_id.has_value()); +} + +TEST_F(ParserTest, ParseChatMessage_WithToolCalls) { + nlohmann::json j = { + {"role", "assistant"}, + {"content", "I'll call a tool."}, + {"tool_calls", + {{{"id", "call_abc123"}, + {"type", "function"}, + {"function", {{"name", "get_weather"}, {"arguments", "{\"city\": \"Seattle\"}"}}}}}}}; + ChatMessage msg = j.get(); + EXPECT_EQ("assistant", msg.role); + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_abc123", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("get_weather", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"city\": \"Seattle\"}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(ParserTest, ParseChatMessage_WithToolCallId) { + nlohmann::json j = { + {"role", "tool"}, + {"content", "72 degrees and sunny"}, + {"tool_call_id", "call_abc123"}}; + ChatMessage msg = j.get(); + EXPECT_EQ("tool", msg.role); + EXPECT_EQ("72 degrees and sunny", msg.content); + ASSERT_TRUE(msg.tool_call_id.has_value()); + EXPECT_EQ("call_abc123", *msg.tool_call_id); +} + +TEST_F(ParserTest, ParseFunctionCall) { + nlohmann::json j = {{"name", "multiply"}, {"arguments", "{\"a\": 1, \"b\": 2}"}}; + FunctionCall fc = j.get(); + EXPECT_EQ("multiply", fc.name); + EXPECT_EQ("{\"a\": 1, \"b\": 2}", fc.arguments); +} + +TEST_F(ParserTest, ParseFunctionCall_ObjectArguments) { + nlohmann::json j = {{"name", "add"}, {"arguments", {{"x", 10}}}}; + FunctionCall fc = j.get(); + EXPECT_EQ("add", fc.name); + EXPECT_EQ("{\"x\":10}", fc.arguments); +} + +TEST_F(ParserTest, ParseToolCall) { + nlohmann::json j = { + {"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "search"}, {"arguments", "{\"query\": \"test\"}"}}}}; + ToolCall tc = j.get(); + EXPECT_EQ("call_1", tc.id); + EXPECT_EQ("function", tc.type); + ASSERT_TRUE(tc.function_call.has_value()); + EXPECT_EQ("search", tc.function_call->name); +} + +TEST_F(ParserTest, SerializeToolDefinition) { + ToolDefinition tool; + tool.type = "function"; + tool.function.name = "get_weather"; + tool.function.description = "Get the current weather"; + tool.function.parameters = PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"location", PropertyDefinition{"string", "The city name"}} + }, + std::vector{"location"} + }; + + nlohmann::json j; + to_json(j, tool); + + EXPECT_EQ("function", j["type"].get()); + EXPECT_EQ("get_weather", j["function"]["name"].get()); + EXPECT_EQ("Get the current weather", j["function"]["description"].get()); + EXPECT_EQ("object", j["function"]["parameters"]["type"].get()); + ASSERT_TRUE(j["function"]["parameters"]["properties"].contains("location")); + EXPECT_EQ("string", j["function"]["parameters"]["properties"]["location"]["type"].get()); + ASSERT_EQ(1u, j["function"]["parameters"]["required"].size()); + EXPECT_EQ("location", j["function"]["parameters"]["required"][0].get()); +} + +TEST_F(ParserTest, SerializeToolDefinition_MinimalFunction) { + ToolDefinition tool; + tool.function.name = "noop"; + + nlohmann::json j; + to_json(j, tool); + + EXPECT_EQ("function", j["type"].get()); + EXPECT_EQ("noop", j["function"]["name"].get()); + EXPECT_FALSE(j["function"].contains("description")); + EXPECT_FALSE(j["function"].contains("parameters")); +} + +TEST_F(ParserTest, ToolChoiceToString) { + EXPECT_EQ("auto", tool_choice_to_string(ToolChoiceKind::Auto)); + EXPECT_EQ("none", tool_choice_to_string(ToolChoiceKind::None)); + EXPECT_EQ("required", tool_choice_to_string(ToolChoiceKind::Required)); +} + +TEST_F(ParserTest, ParseChatChoice_NonStreaming) { + nlohmann::json j = { + {"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", "Hi there!"}}}}; + ChatChoice c = j.get(); + EXPECT_EQ(0, c.index); + EXPECT_EQ(FinishReason::Stop, c.finish_reason); + ASSERT_TRUE(c.message.has_value()); + EXPECT_EQ("assistant", c.message->role); + EXPECT_EQ("Hi there!", c.message->content); + EXPECT_FALSE(c.delta.has_value()); +} + +TEST_F(ParserTest, ParseChatChoice_Streaming) { + nlohmann::json j = { + {"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}; + ChatChoice c = j.get(); + EXPECT_EQ(FinishReason::None, c.finish_reason); + EXPECT_FALSE(c.message.has_value()); + ASSERT_TRUE(c.delta.has_value()); + EXPECT_EQ("Hi", c.delta->content); +} + +TEST_F(ParserTest, ParseChatCompletionCreateResponse) { + nlohmann::json j = { + {"created", 1700000000}, + {"id", "chatcmpl-123"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", "Hello!"}}}}}}}; + ChatCompletionCreateResponse r = j.get(); + EXPECT_EQ(1700000000, r.created); + EXPECT_EQ("chatcmpl-123", r.id); + EXPECT_FALSE(r.is_delta); + EXPECT_TRUE(r.successful); + EXPECT_EQ(200, r.http_status_code); + ASSERT_EQ(1u, r.choices.size()); + EXPECT_EQ("Hello!", r.choices[0].message->content); +} + +TEST(ChatCompletionCreateResponseTest, GetObject_NonDelta) { + ChatCompletionCreateResponse r; + r.is_delta = false; + EXPECT_STREQ("chat.completion", r.GetObject()); +} + +TEST(ChatCompletionCreateResponseTest, GetObject_Delta) { + ChatCompletionCreateResponse r; + r.is_delta = true; + EXPECT_STREQ("chat.completion.chunk", r.GetObject()); +} + +TEST(ChatCompletionCreateResponseTest, GetCreatedAtIso_Zero) { + ChatCompletionCreateResponse r; + r.created = 0; + EXPECT_EQ("", r.GetCreatedAtIso()); +} + +TEST(ChatCompletionCreateResponseTest, GetCreatedAtIso_ValidTimestamp) { + ChatCompletionCreateResponse r; + r.created = 1700000000; // 2023-11-14T22:13:20Z + std::string iso = r.GetCreatedAtIso(); + EXPECT_FALSE(iso.empty()); + EXPECT_EQ('Z', iso.back()); + EXPECT_NE(std::string::npos, iso.find("2023")); +} + +// ============================================================================= +// CoreInteropRequest tests +// ============================================================================= + +TEST(CoreInteropRequestTest, Command) { + CoreInteropRequest req("test_command"); + EXPECT_EQ("test_command", req.Command()); +} + +TEST(CoreInteropRequestTest, ToJson_NoParams) { + CoreInteropRequest req("cmd"); + std::string json = req.ToJson(); + auto parsed = nlohmann::json::parse(json); + EXPECT_FALSE(parsed.contains("Params")); +} + +TEST(CoreInteropRequestTest, ToJson_WithParams) { + CoreInteropRequest req("cmd"); + req.AddParam("key1", "value1"); + req.AddParam("key2", "value2"); + std::string json = req.ToJson(); + auto parsed = nlohmann::json::parse(json); + ASSERT_TRUE(parsed.contains("Params")); + EXPECT_EQ("value1", parsed["Params"]["key1"].get()); + EXPECT_EQ("value2", parsed["Params"]["key2"].get()); +} + +TEST(CoreInteropRequestTest, AddParam_Chaining) { + CoreInteropRequest req("cmd"); + auto& ref = req.AddParam("a", "1").AddParam("b", "2"); + EXPECT_EQ(&req, &ref); +} + +// ============================================================================= +// FoundryLocalException tests +// ============================================================================= + +TEST(FoundryLocalExceptionTest, MessageOnly) { + FoundryLocalException ex("test error"); + EXPECT_STREQ("test error", ex.what()); +} + +TEST(FoundryLocalExceptionTest, MessageAndLogger) { + NullLogger logger; + FoundryLocalException ex("logged error", logger); + EXPECT_STREQ("logged error", ex.what()); +} + +// ============================================================================= +// File-based parser tests (read JSON from testdata/) +// ============================================================================= + +class FileBasedParserTest : public ::testing::Test { +protected: + static std::string TestDataPath(const std::string& filename) { return "testdata/" + filename; } + + static nlohmann::json LoadJsonArray(const std::string& filename) { + std::string raw = Testing::ReadFile(TestDataPath(filename)); + return nlohmann::json::parse(raw); + } +}; + +TEST_F(FileBasedParserTest, AllFields_RequiredFields) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + EXPECT_EQ("model-all-fields", info.id); + EXPECT_EQ("model-all-fields", info.name); + EXPECT_EQ(3u, info.version); + EXPECT_EQ("full-model", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("https://example.com/full-model", info.uri); + EXPECT_EQ("text", info.model_type); + EXPECT_TRUE(info.cached); + EXPECT_EQ(1710000000, info.created_at_unix); +} + +TEST_F(FileBasedParserTest, AllFields_OptionalStrings) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("Full Model Display Name", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("TestPublisher", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("Apache-2.0", *info.license); + ASSERT_TRUE(info.license_description.has_value()); + EXPECT_EQ("Permissive open source license", *info.license_description); + ASSERT_TRUE(info.task.has_value()); + EXPECT_EQ("text-generation", *info.task); + ASSERT_TRUE(info.min_fl_version.has_value()); + EXPECT_EQ("1.0.0", *info.min_fl_version); +} + +TEST_F(FileBasedParserTest, AllFields_NumericOptionals) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(16384u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(8192, *info.max_output_tokens); +} + +TEST_F(FileBasedParserTest, AllFields_Runtime) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::NPU, info.runtime->device_type); + EXPECT_EQ("QNN", info.runtime->execution_provider); +} + +TEST_F(FileBasedParserTest, AllFields_PromptTemplate) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.prompt_template.has_value()); + EXPECT_EQ("<|system|>\n", info.prompt_template->system); + EXPECT_EQ("<|user|>\n", info.prompt_template->user); + EXPECT_EQ("<|assistant|>\n", info.prompt_template->assistant); + EXPECT_EQ("<|endoftext|>", info.prompt_template->prompt); +} + +TEST_F(FileBasedParserTest, AllFields_ModelSettings) { + auto arr = LoadJsonArray("model_all_fields.json"); + ModelInfo info = arr.at(0).get(); + + ASSERT_TRUE(info.model_settings.has_value()); + ASSERT_EQ(3u, info.model_settings->parameters.size()); + EXPECT_EQ("temperature", info.model_settings->parameters[0].name); + ASSERT_TRUE(info.model_settings->parameters[0].value.has_value()); + EXPECT_EQ("0.7", *info.model_settings->parameters[0].value); + EXPECT_EQ("top_p", info.model_settings->parameters[1].name); + ASSERT_TRUE(info.model_settings->parameters[1].value.has_value()); + EXPECT_EQ("0.9", *info.model_settings->parameters[1].value); + EXPECT_EQ("max_tokens", info.model_settings->parameters[2].name); + EXPECT_FALSE(info.model_settings->parameters[2].value.has_value()); +} + +TEST_F(FileBasedParserTest, MinimalFields_RequiredOnly) { + auto arr = LoadJsonArray("model_minimal_fields.json"); + ModelInfo info = arr.at(0).get(); + + EXPECT_EQ("minimal-model", info.id); + EXPECT_EQ("minimal-model", info.name); + EXPECT_EQ(1u, info.version); + EXPECT_EQ("minimal", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("text", info.model_type); + EXPECT_FALSE(info.cached); + EXPECT_EQ(0, info.created_at_unix); +} + +TEST_F(FileBasedParserTest, MinimalFields_AllOptionalsAbsent) { + auto arr = LoadJsonArray("model_minimal_fields.json"); + ModelInfo info = arr.at(0).get(); + + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.license.has_value()); + EXPECT_FALSE(info.license_description.has_value()); + EXPECT_FALSE(info.task.has_value()); + EXPECT_FALSE(info.file_size_mb.has_value()); + EXPECT_FALSE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(info.max_output_tokens.has_value()); + EXPECT_FALSE(info.min_fl_version.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(FileBasedParserTest, NullOptionals_AllOptionalsAbsent) { + auto arr = LoadJsonArray("model_null_optionals.json"); + ModelInfo info = arr.at(0).get(); + + EXPECT_EQ("model-null-optionals", info.id); + EXPECT_EQ("null-opts", info.alias); + + // All explicitly-null fields should parse as absent + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.license.has_value()); + EXPECT_FALSE(info.license_description.has_value()); + EXPECT_FALSE(info.task.has_value()); + EXPECT_FALSE(info.file_size_mb.has_value()); + EXPECT_FALSE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(info.max_output_tokens.has_value()); + EXPECT_FALSE(info.min_fl_version.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(FileBasedParserTest, RealModelsList_ParseAllEntries) { + auto arr = LoadJsonArray("real_models_list.json"); + ASSERT_EQ(4u, arr.size()); + + for (const auto& j : arr) { + EXPECT_NO_THROW({ + auto info = j.get(); + EXPECT_FALSE(info.id.empty()); + EXPECT_FALSE(info.name.empty()); + EXPECT_FALSE(info.alias.empty()); + }); + } +} + +TEST_F(FileBasedParserTest, MalformedJson_Throws) { + EXPECT_ANY_THROW({ + std::string raw = Testing::ReadFile(TestDataPath("malformed_models_list.json")); + nlohmann::json::parse(raw); + }); +} diff --git a/sdk/cpp/test/testdata/empty_models_list.json b/sdk/cpp/test/testdata/empty_models_list.json new file mode 100644 index 00000000..fe51488c --- /dev/null +++ b/sdk/cpp/test/testdata/empty_models_list.json @@ -0,0 +1 @@ +[] diff --git a/sdk/cpp/test/testdata/malformed_models_list.json b/sdk/cpp/test/testdata/malformed_models_list.json new file mode 100644 index 00000000..a04360f5 --- /dev/null +++ b/sdk/cpp/test/testdata/malformed_models_list.json @@ -0,0 +1 @@ +{this is not valid json[} diff --git a/sdk/cpp/test/testdata/missing_name_field_models_list.json b/sdk/cpp/test/testdata/missing_name_field_models_list.json new file mode 100644 index 00000000..ff4742f3 --- /dev/null +++ b/sdk/cpp/test/testdata/missing_name_field_models_list.json @@ -0,0 +1,12 @@ +[ + { + "id": "model-missing-name", + "version": 1, + "alias": "test", + "providerType": "onnx", + "uri": "https://example.com/model", + "modelType": "text", + "cached": false, + "createdAt": 0 + } +] diff --git a/sdk/cpp/test/testdata/mixed_openai_and_local.json b/sdk/cpp/test/testdata/mixed_openai_and_local.json new file mode 100644 index 00000000..091e473f --- /dev/null +++ b/sdk/cpp/test/testdata/mixed_openai_and_local.json @@ -0,0 +1,35 @@ +[ + { + "id": "openai-gpt4", + "name": "openai-gpt4", + "version": 1, + "alias": "openai-gpt4", + "providerType": "openai", + "uri": "https://example.com/openai-gpt4", + "modelType": "text", + "cached": false, + "createdAt": 0 + }, + { + "id": "openai-whisper", + "name": "openai-whisper", + "version": 1, + "alias": "openai-whisper", + "providerType": "openai", + "uri": "https://example.com/openai-whisper", + "modelType": "audio", + "cached": false, + "createdAt": 0 + }, + { + "id": "local-phi-4", + "name": "local-phi-4", + "version": 1, + "alias": "phi-4", + "providerType": "onnx", + "uri": "https://example.com/phi-4", + "modelType": "text", + "cached": false, + "createdAt": 1700000000 + } +] diff --git a/sdk/cpp/test/testdata/real_models_list.json b/sdk/cpp/test/testdata/real_models_list.json new file mode 100644 index 00000000..45f456af --- /dev/null +++ b/sdk/cpp/test/testdata/real_models_list.json @@ -0,0 +1,88 @@ +[ + { + "id": "Phi-4-generic-gpu", + "name": "Phi-4-generic-gpu", + "version": 1, + "alias": "phi-4", + "displayName": "Phi-4 (GPU)", + "providerType": "onnx", + "uri": "https://example.com/phi-4-gpu", + "modelType": "text", + "publisher": "Microsoft", + "license": "MIT", + "fileSizeMb": 8192, + "supportsToolCalling": true, + "maxOutputTokens": 4096, + "cached": false, + "createdAt": 1700000000, + "runtime": { + "deviceType": "GPU", + "executionProvider": "DML" + }, + "promptTemplate": { + "system": "<|system|>", + "user": "<|user|>", + "assistant": "<|assistant|>", + "prompt": "<|prompt|>" + } + }, + { + "id": "Phi-4-generic-cpu", + "name": "Phi-4-generic-cpu", + "version": 1, + "alias": "phi-4", + "displayName": "Phi-4 (CPU)", + "providerType": "onnx", + "uri": "https://example.com/phi-4-cpu", + "modelType": "text", + "publisher": "Microsoft", + "license": "MIT", + "fileSizeMb": 4096, + "supportsToolCalling": false, + "maxOutputTokens": 2048, + "cached": false, + "createdAt": 1700000000, + "runtime": { + "deviceType": "CPU", + "executionProvider": "ORT" + } + }, + { + "id": "Mistral-7b-v0.2-generic-gpu", + "name": "Mistral-7b-v0.2-generic-gpu", + "version": 1, + "alias": "mistral-7b-v0.2", + "displayName": "Mistral 7B v0.2 (GPU)", + "providerType": "onnx", + "uri": "https://example.com/mistral-gpu", + "modelType": "text", + "publisher": "Mistral AI", + "license": "Apache-2.0", + "fileSizeMb": 14000, + "cached": false, + "createdAt": 1700100000, + "runtime": { + "deviceType": "GPU", + "executionProvider": "DML" + } + }, + { + "id": "Mistral-7b-v0.2-generic-cpu", + "name": "Mistral-7b-v0.2-generic-cpu", + "version": 1, + "alias": "mistral-7b-v0.2", + "displayName": "Mistral 7B v0.2 (CPU)", + "providerType": "onnx", + "uri": "https://example.com/mistral-cpu", + "modelType": "text", + "publisher": "Mistral AI", + "license": "Apache-2.0", + "fileSizeMb": 7000, + "cached": false, + "createdAt": 1700100000, + "runtime": { + "deviceType": "CPU", + "executionProvider": "ORT" + } + } +] diff --git a/sdk/cpp/test/testdata/single_cached_model.json b/sdk/cpp/test/testdata/single_cached_model.json new file mode 100644 index 00000000..76efa8e7 --- /dev/null +++ b/sdk/cpp/test/testdata/single_cached_model.json @@ -0,0 +1 @@ +["multi-v1-cpu:1"] diff --git a/sdk/cpp/test/testdata/three_variants_one_model.json b/sdk/cpp/test/testdata/three_variants_one_model.json new file mode 100644 index 00000000..e60581ee --- /dev/null +++ b/sdk/cpp/test/testdata/three_variants_one_model.json @@ -0,0 +1,41 @@ +[ + { + "id": "multi-v1-gpu", + "name": "multi-v1-gpu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 GPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-gpu", + "modelType": "text", + "cached": false, + "createdAt": 1700000000, + "runtime": { "deviceType": "GPU", "executionProvider": "DML" } + }, + { + "id": "multi-v1-cpu", + "name": "multi-v1-cpu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 CPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-cpu", + "modelType": "text", + "cached": true, + "createdAt": 1700000000, + "runtime": { "deviceType": "CPU", "executionProvider": "ORT" } + }, + { + "id": "multi-v1-npu", + "name": "multi-v1-npu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 NPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-npu", + "modelType": "text", + "cached": false, + "createdAt": 1700000000, + "runtime": { "deviceType": "NPU", "executionProvider": "QNN" } + } +] diff --git a/sdk/cpp/test/testdata/valid_cached_models.json b/sdk/cpp/test/testdata/valid_cached_models.json new file mode 100644 index 00000000..2b144174 --- /dev/null +++ b/sdk/cpp/test/testdata/valid_cached_models.json @@ -0,0 +1 @@ +["Phi-4-generic-gpu:1", "Phi-4-generic-cpu:1"] diff --git a/sdk/cpp/test/testdata/valid_loaded_models.json b/sdk/cpp/test/testdata/valid_loaded_models.json new file mode 100644 index 00000000..4d2ef328 --- /dev/null +++ b/sdk/cpp/test/testdata/valid_loaded_models.json @@ -0,0 +1 @@ +["Phi-4-generic-gpu:1"] From 35476b3a0ed4c56e289e197bc7d264c3cd79321f Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 23 Mar 2026 14:33:13 -0700 Subject: [PATCH 02/18] Copilot comments fix 1 --- sdk/cpp/CMakePresets.json | 30 ------------------------------ sdk/cpp/include/foundry_local.h | 2 +- sdk/cpp/include/parser.h | 1 - sdk/cpp/src/foundry_local.cpp | 17 ++++++----------- 4 files changed, 7 insertions(+), 43 deletions(-) diff --git a/sdk/cpp/CMakePresets.json b/sdk/cpp/CMakePresets.json index aa233618..3defcc5c 100644 --- a/sdk/cpp/CMakePresets.json +++ b/sdk/cpp/CMakePresets.json @@ -64,36 +64,6 @@ "cacheVariables": { "CMAKE_BUILD_TYPE": "Release" } - }, - { - "name": "linux-debug", - "displayName": "Linux Debug", - "generator": "Ninja", - "binaryDir": "${sourceDir}/out/build/${presetName}", - "installDir": "${sourceDir}/out/install/${presetName}", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug" - }, - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Linux" - } - }, - { - "name": "macos-debug", - "displayName": "macOS Debug", - "generator": "Ninja", - "binaryDir": "${sourceDir}/out/build/${presetName}", - "installDir": "${sourceDir}/out/install/${presetName}", - "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug" - }, - "condition": { - "type": "equals", - "lhs": "${hostSystemName}", - "rhs": "Darwin" - } } ], "buildPresets": [ diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index 9bce8f5f..fa061ba4 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -353,7 +353,7 @@ namespace FoundryLocal { mutable std::chrono::steady_clock::time_point lastFetch_{}; mutable std::unordered_map byAlias_; - mutable std::unordered_map modelIdToModelVariant_; + mutable std::unordered_map modelIdToModelVariant_; explicit Catalog(gsl::not_null injected, gsl::not_null logger); diff --git a/sdk/cpp/include/parser.h b/sdk/cpp/include/parser.h index 58c31e87..5396596d 100644 --- a/sdk/cpp/include/parser.h +++ b/sdk/cpp/include/parser.h @@ -42,7 +42,6 @@ namespace FoundryLocal { inline void from_json(const nlohmann::json& j, Runtime& r) { std::string deviceType; - std::string executionProvider; j.at("deviceType").get_to(deviceType); j.at("executionProvider").get_to(r.execution_provider); diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp index a1eb2947..37da932e 100644 --- a/sdk/cpp/src/foundry_local.cpp +++ b/sdk/cpp/src/foundry_local.cpp @@ -97,7 +97,7 @@ namespace { } std::vector - CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, + CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, std::vector ids) { std::vector out; out.reserve(ids.size()); @@ -107,7 +107,7 @@ namespace { auto it = modelIdToModelVariant.find(id); if (it != modelIdToModelVariant.end()) { - out.emplace_back(&it->second); + out.emplace_back(it->second); } } return out; @@ -666,18 +666,13 @@ namespace FoundryLocal { ModelInfo modelVariantInfo; from_json(j, modelVariantInfo); - std::string variantId = modelVariantInfo.name; ModelVariant modelVariant(core_, modelVariantInfo, logger_); - modelIdToModelVariant_.emplace(variantId, modelVariant); - it->second.variants_.emplace_back(std::move(modelVariant)); - } - // Auto-select the first variant for each model. - for (auto& [alias, model] : byAlias_) { - if (!model.variants_.empty()) { - model.selectedVariantIndex_ = 0; + for (const auto& v : it->second.variants_) { + modelIdToModelVariant_[v.GetInfo().name] = &v; } + it->second.selectedVariantIndex_ = 0; } lastFetch_ = now; @@ -686,7 +681,7 @@ namespace FoundryLocal { const ModelVariant* Catalog::GetModelVariant(std::string_view id) const { auto it = modelIdToModelVariant_.find(std::string(id)); if (it != modelIdToModelVariant_.end()) { - return &it->second; + return it->second; } return nullptr; } From 44f86d1c1274a91acfdc944962edde293cb85030 Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 23 Mar 2026 14:50:56 -0700 Subject: [PATCH 03/18] Copilot comments fix 2 --- sdk/cpp/include/flcore_native.h | 1 + sdk/cpp/include/foundry_local.h | 2 +- sdk/cpp/sample/main.cpp | 6 ++++-- sdk/cpp/test/model_variant_test.cpp | 12 ++++++------ 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/sdk/cpp/include/flcore_native.h b/sdk/cpp/include/flcore_native.h index d67703e0..8cde9ec8 100644 --- a/sdk/cpp/include/flcore_native.h +++ b/sdk/cpp/include/flcore_native.h @@ -1,5 +1,6 @@ #pragma once #include +#include extern "C" { // Layout must match C# structs exactly diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index fa061ba4..7db8987b 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -130,7 +130,7 @@ namespace FoundryLocal { const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } /// Returns the created timestamp as an ISO 8601 string. - /// Computed lazilym only allocates when called. + /// Computed lazily; only allocates when called. std::string GetCreatedAtIso() const; }; diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index f5166ba9..7e152d4b 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -1,10 +1,12 @@ #include -#include "foundry_local.h" -#include +#include #include #include +#include "foundry_local.h" + + using namespace FoundryLocal; // --------------------------------------------------------------------------- diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index 5ecbb696..7660207c 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -100,16 +100,16 @@ TEST_F(ModelVariantTest, Unload_ThrowsOnError) { } TEST_F(ModelVariantTest, Download_NoCallback) { -core_.OnCall("get_cached_models", R"([])"); -core_.OnCall("download_model", ""); -auto variant = MakeVariant("test-model"); -variant.Download(); + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", ""); + auto variant = MakeVariant("test-model"); + variant.Download(); EXPECT_EQ(1, core_.GetCallCount("download_model")); } TEST_F(ModelVariantTest, Download_WithCallback) { -core_.OnCall("get_cached_models", R"([])"); -core_.OnCall("download_model", + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { // Simulate calling the progress callback if (callback && userData) { From fb5e329b0f59fadd2a048d65ba413c9086cce1db Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 10:06:17 -0700 Subject: [PATCH 04/18] remove stripping version --- sdk/cpp/sample/main.cpp | 14 ++-- sdk/cpp/src/foundry_local.cpp | 42 +++++------- sdk/cpp/test/catalog_test.cpp | 16 ++--- sdk/cpp/test/client_test.cpp | 66 +++++++++---------- sdk/cpp/test/mock_object_factory.h | 5 +- sdk/cpp/test/model_variant_test.cpp | 24 +++---- sdk/cpp/test/parser_and_types_test.cpp | 4 +- .../missing_name_field_models_list.json | 2 +- .../test/testdata/mixed_openai_and_local.json | 6 +- sdk/cpp/test/testdata/real_models_list.json | 8 +-- .../testdata/three_variants_one_model.json | 6 +- 11 files changed, 94 insertions(+), 99 deletions(-) diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 7e152d4b..cddede9c 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -1,12 +1,10 @@ #include +#include "foundry_local.h" -#include +#include #include #include -#include "foundry_local.h" - - using namespace FoundryLocal; // --------------------------------------------------------------------------- @@ -82,7 +80,13 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { std::cout << "\n"; model->Load(); - std::cout << "Model loaded: " << model->GetAlias() << "\n"; + + if (model->IsLoaded()) { + std::cout << "Model is loaded and ready for inference.\n"; + } else { + std::cerr << "Failed to load model.\n"; + return; + } // Get the selected variant pointer for ChatClient const auto& selectedVariant = model->GetAllModelVariants()[0]; diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp index 37da932e..28df4d2c 100644 --- a/sdk/cpp/src/foundry_local.cpp +++ b/sdk/cpp/src/foundry_local.cpp @@ -89,25 +89,16 @@ namespace { } } - inline void StripSuffixAfterColon(std::string& id) { - const auto pos = id.find_last_of(':'); - if (pos != std::string::npos) { - id.erase(pos); - } - } - std::vector - CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, + CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, std::vector ids) { std::vector out; out.reserve(ids.size()); - for (auto& id : ids) { - StripSuffixAfterColon(id); - + for (const auto& id : ids) { auto it = modelIdToModelVariant.find(id); if (it != modelIdToModelVariant.end()) { - out.emplace_back(it->second); + out.emplace_back(&it->second); } } return out; @@ -435,13 +426,8 @@ namespace FoundryLocal { bool ModelVariant::IsLoaded() const { std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); - for (auto& id : loadedModelIds) { - auto pos = id.find_last_of(':'); - if (pos != std::string::npos) { - id.erase(pos); - } - - if (id == info_.name) { + for (const auto& id : loadedModelIds) { + if (id == info_.id) { return true; } } @@ -451,9 +437,8 @@ namespace FoundryLocal { bool ModelVariant::IsCached() const { auto cachedModels = GetCachedModelsInternal(core_, *logger_); - for (auto& id : cachedModels) { - StripSuffixAfterColon(id); - if (id == info_.name) { + for (const auto& id : cachedModels) { + if (id == info_.id) { return true; } } @@ -666,13 +651,18 @@ namespace FoundryLocal { ModelInfo modelVariantInfo; from_json(j, modelVariantInfo); + std::string variantId = modelVariantInfo.id; ModelVariant modelVariant(core_, modelVariantInfo, logger_); + modelIdToModelVariant_.emplace(variantId, modelVariant); + it->second.variants_.emplace_back(std::move(modelVariant)); + } - for (const auto& v : it->second.variants_) { - modelIdToModelVariant_[v.GetInfo().name] = &v; + // Auto-select the first variant for each model. + for (auto& [alias, model] : byAlias_) { + if (!model.variants_.empty()) { + model.selectedVariantIndex_ = 0; } - it->second.selectedVariantIndex_ = 0; } lastFetch_ = now; @@ -681,7 +671,7 @@ namespace FoundryLocal { const ModelVariant* Catalog::GetModelVariant(std::string_view id) const { auto it = modelIdToModelVariant_.find(std::string(id)); if (it != modelIdToModelVariant_.end()) { - return it->second; + return &it->second; } return nullptr; } diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp index e40d7c11..78824574 100644 --- a/sdk/cpp/test/catalog_test.cpp +++ b/sdk/cpp/test/catalog_test.cpp @@ -113,9 +113,9 @@ TEST_F(CatalogTest, GetModelVariant_Found) { auto catalog = MakeCatalog(); catalog->ListModels(); // populate - auto* variant = catalog->GetModelVariant("model-1"); + auto* variant = catalog->GetModelVariant("model-1:1"); ASSERT_NE(nullptr, variant); - EXPECT_EQ("model-1", variant->GetId()); + EXPECT_EQ("model-1:1", variant->GetId()); } TEST_F(CatalogTest, GetModelVariant_NotFound) { @@ -123,19 +123,19 @@ TEST_F(CatalogTest, GetModelVariant_NotFound) { auto catalog = MakeCatalog(); catalog->ListModels(); - EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent")); + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent:1")); } TEST_F(CatalogTest, GetLoadedModels) { core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); - core_.OnCall("list_loaded_models", R"(["model-1:v1"])"); + core_.OnCall("list_loaded_models", R"(["model-1:1"])"); auto catalog = MakeCatalog(); catalog->ListModels(); // populate auto loaded = catalog->GetLoadedModels(); ASSERT_EQ(1u, loaded.size()); - EXPECT_EQ("model-1", loaded[0]->GetId()); + EXPECT_EQ("model-1:1", loaded[0]->GetId()); } TEST_F(CatalogTest, GetCachedModels) { @@ -200,11 +200,11 @@ TEST_F(FileBasedCatalogTest, RealModelsList_VariantDetails) { catalog->ListModels(); // populate - const auto* gpuVariant = catalog->GetModelVariant("Phi-4-generic-gpu"); + const auto* gpuVariant = catalog->GetModelVariant("Phi-4-generic-gpu:1"); ASSERT_NE(nullptr, gpuVariant); const auto& info = gpuVariant->GetInfo(); - EXPECT_EQ("Phi-4-generic-gpu", info.id); + EXPECT_EQ("Phi-4-generic-gpu:1", info.id); EXPECT_EQ("Phi-4-generic-gpu", info.name); EXPECT_EQ("phi-4", info.alias); ASSERT_TRUE(info.display_name.has_value()); @@ -235,7 +235,7 @@ TEST_F(FileBasedCatalogTest, RealModelsList_CpuVariantDetails) { catalog->ListModels(); // populate - const auto* cpuVariant = catalog->GetModelVariant("Phi-4-generic-cpu"); + const auto* cpuVariant = catalog->GetModelVariant("Phi-4-generic-cpu:1"); ASSERT_NE(nullptr, cpuVariant); const auto& info = cpuVariant->GetInfo(); diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 0857bc92..ceae995c 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -30,14 +30,14 @@ class ChatClientTest : public ::testing::Test { } ModelVariant MakeLoadedVariant(const std::string& name = "chat-model") { - core_.OnCall("list_loaded_models", "[\"" + name + ":v1\"]"); + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); } }; TEST_F(ChatClientTest, CompleteChat_BasicResponse) { core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -52,16 +52,16 @@ TEST_F(ChatClientTest, CompleteChat_BasicResponse) { } TEST_F(ChatClientTest, CompleteChat_WithSettings) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - auto variant = MakeLoadedVariant(); - ChatClient client(&variant); +auto variant = MakeLoadedVariant(); +ChatClient client(&variant); - std::vector messages = {{"user", "test", {}}}; - ChatSettings settings; - settings.temperature = 0.7f; - settings.max_tokens = 100; +std::vector messages = {{"user", "test", {}}}; +ChatSettings settings; +settings.temperature = 0.7f; +settings.max_tokens = 100; settings.top_p = 0.9f; settings.frequency_penalty = 0.5f; settings.presence_penalty = 0.3f; @@ -86,15 +86,15 @@ TEST_F(ChatClientTest, CompleteChat_WithSettings) { } TEST_F(ChatClientTest, CompleteChat_RequestFormat) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - auto variant = MakeLoadedVariant(); - ChatClient client(&variant); +auto variant = MakeLoadedVariant(); +ChatClient client(&variant); - std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; - ChatSettings settings; - auto response = client.CompleteChat(messages, settings); +std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; +ChatSettings settings; +auto response = client.CompleteChat(messages, settings); auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); @@ -134,7 +134,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming) { } return ""; }); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -172,7 +172,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { } return ""; }); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -193,7 +193,7 @@ TEST_F(ChatClientTest, Constructor_ThrowsIfNotLoaded) { } TEST_F(ChatClientTest, GetModelId) { - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); EXPECT_EQ("chat-model", client.GetModelId()); @@ -203,7 +203,7 @@ TEST_F(ChatClientTest, GetModelId) { TEST_F(ChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -252,7 +252,7 @@ TEST_F(ChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { TEST_F(ChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -288,7 +288,7 @@ TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; core_.OnCall("chat_completions", resp.dump()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -312,7 +312,7 @@ TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { TEST_F(ChatClientTest, CompleteChat_ToolChoiceAuto) { core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -330,7 +330,7 @@ TEST_F(ChatClientTest, CompleteChat_ToolChoiceAuto) { TEST_F(ChatClientTest, CompleteChat_ToolChoiceNone) { core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -348,7 +348,7 @@ TEST_F(ChatClientTest, CompleteChat_ToolChoiceNone) { TEST_F(ChatClientTest, CompleteChat_ToolMessageWithToolCallId) { core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -412,7 +412,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming_WithTools) { } return ""; }); - core_.OnCall("list_loaded_models", R"(["chat-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); ChatClient client(&variant); @@ -450,14 +450,14 @@ class AudioClientTest : public ::testing::Test { NullLogger logger_; ModelVariant MakeLoadedVariant(const std::string& name = "audio-model") { - core_.OnCall("list_loaded_models", "[\"" + name + ":v1\"]"); + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); } }; TEST_F(AudioClientTest, TranscribeAudio) { core_.OnCall("audio_transcribe", "Hello world transcribed text"); - core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); AudioClient client(&variant); @@ -468,7 +468,7 @@ TEST_F(AudioClientTest, TranscribeAudio) { TEST_F(AudioClientTest, TranscribeAudio_RequestFormat) { core_.OnCall("audio_transcribe", "text"); - core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); AudioClient client(&variant); @@ -492,7 +492,7 @@ TEST_F(AudioClientTest, TranscribeAudioStreaming) { } return ""; }); - core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); AudioClient client(&variant); @@ -516,7 +516,7 @@ TEST_F(AudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { } return ""; }); - core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); AudioClient client(&variant); @@ -534,7 +534,7 @@ TEST_F(AudioClientTest, Constructor_ThrowsIfNotLoaded) { } TEST_F(AudioClientTest, GetModelId) { - core_.OnCall("list_loaded_models", R"(["audio-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); AudioClient client(&variant); EXPECT_EQ("audio-model", client.GetModelId()); diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h index 9d029aec..ac055da8 100644 --- a/sdk/cpp/test/mock_object_factory.h +++ b/sdk/cpp/test/mock_object_factory.h @@ -38,7 +38,7 @@ namespace FoundryLocal::Testing { /// Helper to build a minimal ModelInfo with defaults. static ModelInfo MakeModelInfo(std::string name, std::string alias = "", uint32_t version = 1) { ModelInfo info; - info.id = name; + info.id = name + ":" + std::to_string(version); info.name = std::move(name); info.alias = alias.empty() ? info.name : std::move(alias); info.version = version; @@ -52,7 +52,8 @@ namespace FoundryLocal::Testing { static std::string MakeModelInfoJson(const std::string& name, const std::string& alias = "", uint32_t version = 1, bool cached = false) { std::string a = alias.empty() ? name : alias; - return R"({"id":")" + name + R"(","name":")" + name + R"(","version":)" + std::to_string(version) + + std::string id = name + ":" + std::to_string(version); + return R"({"id":")" + id + R"(","name":")" + name + R"(","version":)" + std::to_string(version) + R"(,"alias":")" + a + R"(","providerType":"test","uri":"test://uri","modelType":"text","cached":)" + (cached ? "true" : "false") + R"(,"createdAt":0})"; } diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index 7660207c..f791dc4f 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -32,7 +32,7 @@ TEST_F(ModelVariantTest, GetInfo) { TEST_F(ModelVariantTest, GetId) { auto variant = MakeVariant("my-model"); - EXPECT_EQ("my-model", variant.GetId()); + EXPECT_EQ("my-model:1", variant.GetId()); } TEST_F(ModelVariantTest, GetAlias) { @@ -46,13 +46,13 @@ TEST_F(ModelVariantTest, GetVersion) { } TEST_F(ModelVariantTest, IsLoaded_True) { - core_.OnCall("list_loaded_models", R"(["test-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["test-model:1"])"); auto variant = MakeVariant("test-model"); EXPECT_TRUE(variant.IsLoaded()); } TEST_F(ModelVariantTest, IsLoaded_False) { - core_.OnCall("list_loaded_models", R"(["other-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["other-model:1"])"); auto variant = MakeVariant("test-model"); EXPECT_FALSE(variant.IsLoaded()); } @@ -100,16 +100,16 @@ TEST_F(ModelVariantTest, Unload_ThrowsOnError) { } TEST_F(ModelVariantTest, Download_NoCallback) { - core_.OnCall("get_cached_models", R"([])"); - core_.OnCall("download_model", ""); - auto variant = MakeVariant("test-model"); - variant.Download(); +core_.OnCall("get_cached_models", R"([])"); +core_.OnCall("download_model", ""); +auto variant = MakeVariant("test-model"); +variant.Download(); EXPECT_EQ(1, core_.GetCallCount("download_model")); } TEST_F(ModelVariantTest, Download_WithCallback) { - core_.OnCall("get_cached_models", R"([])"); - core_.OnCall("download_model", +core_.OnCall("get_cached_models", R"([])"); +core_.OnCall("download_model", [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { // Simulate calling the progress callback if (callback && userData) { @@ -185,7 +185,7 @@ TEST_F(ModelTest, AddVariant_AndSelect) { Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); Factory::SetSelectedVariantIndex(model, 0); - EXPECT_EQ("v1", model.GetId()); + EXPECT_EQ("v1:1", model.GetId()); EXPECT_EQ("alias", model.GetAlias()); } @@ -207,7 +207,7 @@ TEST_F(ModelTest, SelectVariant) { const auto* v2 = &model.GetAllModelVariants()[1]; model.SelectVariant(v2); - EXPECT_EQ("v2", model.GetId()); + EXPECT_EQ("v2:2", model.GetId()); } TEST_F(ModelTest, SelectVariant_NotFound_Throws) { @@ -233,7 +233,7 @@ TEST_F(ModelTest, GetLatestVariant) { TEST_F(ModelTest, DelegationMethods) { // Test that Model delegates to SelectedVariant - core_.OnCall("list_loaded_models", R"(["test-model:v1"])"); + core_.OnCall("list_loaded_models", R"(["test-model:1"])"); core_.OnCall("get_cached_models", R"(["test-model:1"])"); core_.OnCall("load_model", ""); core_.OnCall("unload_model", ""); diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp index 4515c3a0..c83acbc9 100644 --- a/sdk/cpp/test/parser_and_types_test.cpp +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -14,7 +14,7 @@ using namespace FoundryLocal::Testing; class ParserTest : public ::testing::Test { protected: static nlohmann::json MinimalModelJson() { - return nlohmann::json{{"id", "model-1"}, {"name", "model-1"}, {"version", 1}, + return nlohmann::json{{"id", "model-1:1"}, {"name", "model-1"}, {"version", 1}, {"alias", "my-model"}, {"providerType", "onnx"}, {"uri", "https://example.com/model"}, {"modelType", "text"}, {"cached", false}, {"createdAt", 1700000000}}; } @@ -140,7 +140,7 @@ TEST_F(ParserTest, ParsePromptTemplate_MissingFields) { TEST_F(ParserTest, ParseModelInfo_Minimal) { auto j = MinimalModelJson(); ModelInfo info = j.get(); - EXPECT_EQ("model-1", info.id); + EXPECT_EQ("model-1:1", info.id); EXPECT_EQ("model-1", info.name); EXPECT_EQ(1u, info.version); EXPECT_EQ("my-model", info.alias); diff --git a/sdk/cpp/test/testdata/missing_name_field_models_list.json b/sdk/cpp/test/testdata/missing_name_field_models_list.json index ff4742f3..da1e9465 100644 --- a/sdk/cpp/test/testdata/missing_name_field_models_list.json +++ b/sdk/cpp/test/testdata/missing_name_field_models_list.json @@ -1,6 +1,6 @@ [ { - "id": "model-missing-name", + "id": "model-missing-name:1", "version": 1, "alias": "test", "providerType": "onnx", diff --git a/sdk/cpp/test/testdata/mixed_openai_and_local.json b/sdk/cpp/test/testdata/mixed_openai_and_local.json index 091e473f..9d8de80b 100644 --- a/sdk/cpp/test/testdata/mixed_openai_and_local.json +++ b/sdk/cpp/test/testdata/mixed_openai_and_local.json @@ -1,6 +1,6 @@ [ { - "id": "openai-gpt4", + "id": "openai-gpt4:1", "name": "openai-gpt4", "version": 1, "alias": "openai-gpt4", @@ -11,7 +11,7 @@ "createdAt": 0 }, { - "id": "openai-whisper", + "id": "openai-whisper:1", "name": "openai-whisper", "version": 1, "alias": "openai-whisper", @@ -22,7 +22,7 @@ "createdAt": 0 }, { - "id": "local-phi-4", + "id": "local-phi-4:1", "name": "local-phi-4", "version": 1, "alias": "phi-4", diff --git a/sdk/cpp/test/testdata/real_models_list.json b/sdk/cpp/test/testdata/real_models_list.json index 45f456af..284d3a1a 100644 --- a/sdk/cpp/test/testdata/real_models_list.json +++ b/sdk/cpp/test/testdata/real_models_list.json @@ -1,6 +1,6 @@ [ { - "id": "Phi-4-generic-gpu", + "id": "Phi-4-generic-gpu:1", "name": "Phi-4-generic-gpu", "version": 1, "alias": "phi-4", @@ -27,7 +27,7 @@ } }, { - "id": "Phi-4-generic-cpu", + "id": "Phi-4-generic-cpu:1", "name": "Phi-4-generic-cpu", "version": 1, "alias": "phi-4", @@ -48,7 +48,7 @@ } }, { - "id": "Mistral-7b-v0.2-generic-gpu", + "id": "Mistral-7b-v0.2-generic-gpu:1", "name": "Mistral-7b-v0.2-generic-gpu", "version": 1, "alias": "mistral-7b-v0.2", @@ -67,7 +67,7 @@ } }, { - "id": "Mistral-7b-v0.2-generic-cpu", + "id": "Mistral-7b-v0.2-generic-cpu:1", "name": "Mistral-7b-v0.2-generic-cpu", "version": 1, "alias": "mistral-7b-v0.2", diff --git a/sdk/cpp/test/testdata/three_variants_one_model.json b/sdk/cpp/test/testdata/three_variants_one_model.json index e60581ee..fad0555d 100644 --- a/sdk/cpp/test/testdata/three_variants_one_model.json +++ b/sdk/cpp/test/testdata/three_variants_one_model.json @@ -1,6 +1,6 @@ [ { - "id": "multi-v1-gpu", + "id": "multi-v1-gpu:1", "name": "multi-v1-gpu", "version": 1, "alias": "multi-model", @@ -13,7 +13,7 @@ "runtime": { "deviceType": "GPU", "executionProvider": "DML" } }, { - "id": "multi-v1-cpu", + "id": "multi-v1-cpu:1", "name": "multi-v1-cpu", "version": 1, "alias": "multi-model", @@ -26,7 +26,7 @@ "runtime": { "deviceType": "CPU", "executionProvider": "ORT" } }, { - "id": "multi-v1-npu", + "id": "multi-v1-npu:1", "name": "multi-v1-npu", "version": 1, "alias": "multi-model", From 24eb623c86041030745ca3f0ccd171cdb04998ed Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 10:33:20 -0700 Subject: [PATCH 05/18] more changes --- sdk/cpp/CMakeLists.txt | 3 + sdk/cpp/CMakePresets.json | 36 ++-- sdk/cpp/include/foundry_local.h | 9 +- .../{include => src}/core_interop_request.h | 0 sdk/cpp/{include => src}/flcore_native.h | 1 - sdk/cpp/src/foundry_local.cpp | 170 +++++++++--------- .../foundry_local_internal_core.h | 0 sdk/cpp/{include => src}/parser.h | 1 + 8 files changed, 120 insertions(+), 100 deletions(-) rename sdk/cpp/{include => src}/core_interop_request.h (100%) rename sdk/cpp/{include => src}/flcore_native.h (97%) rename sdk/cpp/{include => src}/foundry_local_internal_core.h (100%) rename sdk/cpp/{include => src}/parser.h (99%) diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index 064c46ca..1ef1194a 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -96,6 +96,8 @@ target_include_directories(CppSdk PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include ${wil_src_SOURCE_DIR}/include + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src ) target_link_libraries(CppSdk @@ -128,6 +130,7 @@ add_executable(CppSdkTests target_include_directories(CppSdkTests PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/test + ${CMAKE_CURRENT_SOURCE_DIR}/src ) target_compile_definitions(CppSdkTests PRIVATE FL_TESTS) diff --git a/sdk/cpp/CMakePresets.json b/sdk/cpp/CMakePresets.json index 3defcc5c..f9ab249d 100644 --- a/sdk/cpp/CMakePresets.json +++ b/sdk/cpp/CMakePresets.json @@ -42,27 +42,33 @@ } }, { - "name": "x86-debug", - "displayName": "MSVC x86 Debug", - "inherits": "windows-base", - "architecture": { - "value": "x86", - "strategy": "external" - }, + "name": "linux-debug", + "displayName": "Linux Debug", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", "cacheVariables": { "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" } }, { - "name": "x86-release", - "displayName": "MSVC x86 Release", - "inherits": "windows-base", - "architecture": { - "value": "x86", - "strategy": "external" - }, + "name": "macos-debug", + "displayName": "macOS Debug", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release" + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" } } ], diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index 7db8987b..ba5aee98 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -14,10 +14,13 @@ #include #include "configuration.h" -#include "foundry_local_internal_core.h" #include "logger.h" +namespace FoundryLocal::Internal { + struct IFoundryLocalCore; +} + namespace FoundryLocal { #ifdef FL_TESTS namespace Testing { @@ -130,7 +133,7 @@ namespace FoundryLocal { const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } /// Returns the created timestamp as an ISO 8601 string. - /// Computed lazily; only allocates when called. + /// Computed lazilym only allocates when called. std::string GetCreatedAtIso() const; }; @@ -353,7 +356,7 @@ namespace FoundryLocal { mutable std::chrono::steady_clock::time_point lastFetch_{}; mutable std::unordered_map byAlias_; - mutable std::unordered_map modelIdToModelVariant_; + mutable std::unordered_map modelIdToModelVariant_; explicit Catalog(gsl::not_null injected, gsl::not_null logger); diff --git a/sdk/cpp/include/core_interop_request.h b/sdk/cpp/src/core_interop_request.h similarity index 100% rename from sdk/cpp/include/core_interop_request.h rename to sdk/cpp/src/core_interop_request.h diff --git a/sdk/cpp/include/flcore_native.h b/sdk/cpp/src/flcore_native.h similarity index 97% rename from sdk/cpp/include/flcore_native.h rename to sdk/cpp/src/flcore_native.h index 8cde9ec8..d67703e0 100644 --- a/sdk/cpp/include/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -1,6 +1,5 @@ #pragma once #include -#include extern "C" { // Layout must match C# structs exactly diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp index 28df4d2c..ae199555 100644 --- a/sdk/cpp/src/foundry_local.cpp +++ b/sdk/cpp/src/foundry_local.cpp @@ -27,43 +27,43 @@ namespace { namespace { // Wrap Params: { ... } into a request object inline nlohmann::json MakeParams(nlohmann::json params) { - return nlohmann::json{ {"Params", std::move(params)} }; + return nlohmann::json{{"Params", std::move(params)}}; } // Most common: Params { "Model": } inline nlohmann::json MakeModelParams(std::string_view model) { - return MakeParams(nlohmann::json{ {"Model", std::string(model)} }); + return MakeParams(nlohmann::json{{"Model", std::string(model)}}); } // Serialize + call inline std::string CallWithJson(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& requestJson, FoundryLocal::ILogger& logger) { + const nlohmann::json& requestJson, FoundryLocal::ILogger& logger) { std::string payload = requestJson.dump(); return core->call(command, logger, &payload); } // Serialize + call with native callback inline std::string CallWithJsonAndCallback(FoundryLocal::Internal::IFoundryLocalCore* core, - std::string_view command, const nlohmann::json& requestJson, FoundryLocal::ILogger& logger, - void* callback, void* userData) { + std::string_view command, const nlohmann::json& requestJson, + FoundryLocal::ILogger& logger, void* callback, void* userData) { std::string payload = requestJson.dump(); return core->call(command, logger, &payload, callback, userData); } // Overload: allow Params object directly inline std::string CallWithParams(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& params, FoundryLocal::ILogger& logger) { + const nlohmann::json& params, FoundryLocal::ILogger& logger) { return CallWithJson(core, command, MakeParams(params), logger); } // Overload: no payload inline std::string CallNoArgs(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, - FoundryLocal::ILogger& logger) { + FoundryLocal::ILogger& logger) { return core->call(command, logger, nullptr); } std::vector GetLoadedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, - FoundryLocal::ILogger& logger) { + FoundryLocal::ILogger& logger) { std::string raw = core->call("list_loaded_models", logger); try { auto parsed = nlohmann::json::parse(raw); @@ -76,7 +76,7 @@ namespace { } std::vector GetCachedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, - FoundryLocal::ILogger& logger) { + FoundryLocal::ILogger& logger) { std::string raw = core->call("get_cached_models", logger); try { @@ -89,9 +89,9 @@ namespace { } } - std::vector - CollectVariantsByIds(const std::unordered_map& modelIdToModelVariant, - std::vector ids) { + std::vector CollectVariantsByIds( + const std::unordered_map& modelIdToModelVariant, + std::vector ids) { std::vector out; out.reserve(ids.size()); @@ -119,9 +119,7 @@ namespace FoundryLocal { Core() = default; ~Core() = default; - void loadEmbedded() { - loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); - } + void loadEmbedded() { loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); } void unload() { module_.reset(); @@ -130,10 +128,9 @@ namespace FoundryLocal { freeResCmd_ = nullptr; } std::string call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, - void* callback = nullptr, void* data = nullptr) const override { + void* callback = nullptr, void* data = nullptr) const override { if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { - throw FoundryLocalException( - "Core is not loaded. Cannot call command: " + std::string(command), logger); + throw FoundryLocalException("Core is not loaded. Cannot call command: " + std::string(command), logger); } RequestBuffer request{}; @@ -147,7 +144,8 @@ namespace FoundryLocal { ResponseBuffer response{}; auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { - if (fn) fn(buf); + if (fn) + fn(buf); }; std::unique_ptr responseGuard(&response, safeDeleter); @@ -164,8 +162,8 @@ namespace FoundryLocal { std::string result; if (response.Error && response.ErrorLength > 0) { std::string err(static_cast(response.Error), response.ErrorLength); - throw FoundryLocalException( - std::string("Command failed [").append(command).append("]: ").append(err), logger); + throw FoundryLocalException(std::string("Command failed [").append(command).append("]: ").append(err), + logger); } if (response.Data && response.DataLength > 0) { @@ -200,12 +198,11 @@ namespace FoundryLocal { /// AudioClient::AudioClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger) - : core_(core), modelId_(modelId), logger_(logger) { - } + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} AudioCreateTranscriptionResponse AudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { - nlohmann::json openAiReq = { {"Model", modelId_}, {"FileName", audioFilePath.string()} }; + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; CoreInteropRequest req("audio_transcribe"); req.AddParam("OpenAICreateRequest", openAiReq.dump()); @@ -217,8 +214,9 @@ namespace FoundryLocal { return response; } - void AudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const { - nlohmann::json openAiReq = { {"Model", modelId_}, {"FileName", audioFilePath.string()} }; + void AudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, + const StreamCallback& onChunk) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; CoreInteropRequest req("audio_transcribe"); req.AddParam("OpenAICreateRequest", openAiReq.dump()); @@ -227,7 +225,7 @@ namespace FoundryLocal { struct State { const StreamCallback* cb; std::exception_ptr exception; - } state{ &onChunk, nullptr }; + } state{&onChunk, nullptr}; auto streamCallback = [](void* data, int32_t len, void* user) { if (!data || len <= 0) @@ -249,16 +247,16 @@ namespace FoundryLocal { }; core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), - reinterpret_cast(&state)); + reinterpret_cast(&state)); if (state.exception) { std::rethrow_exception(state.exception); } } - std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { - if (created == 0) return {}; + if (created == 0) + return {}; std::time_t t = static_cast(created); std::tm tm{}; #ifdef _WIN32 @@ -276,21 +274,21 @@ namespace FoundryLocal { /// ChatClient::ChatClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger) - : core_(core), modelId_(modelId), logger_(logger) { - } + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} - std::string ChatClient::BuildChatRequestJson(gsl::span messages, gsl::span tools, - const ChatSettings& settings, bool stream) const { + std::string ChatClient::BuildChatRequestJson(gsl::span messages, + gsl::span tools, const ChatSettings& settings, + bool stream) const { nlohmann::json jMessages = nlohmann::json::array(); for (const auto& msg : messages) { - nlohmann::json jMsg = { {"role", msg.role}, {"content", msg.content} }; + nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; if (msg.tool_call_id) jMsg["tool_call_id"] = *msg.tool_call_id; jMessages.push_back(std::move(jMsg)); } - nlohmann::json req = { {"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream} }; + nlohmann::json req = {{"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream}}; if (!tools.empty()) { nlohmann::json jTools = nlohmann::json::array(); @@ -305,7 +303,7 @@ namespace FoundryLocal { if (settings.tool_choice) req["tool_choice"] = tool_choice_to_string(*settings.tool_choice); if (settings.top_k) - req["metadata"] = { {"top_k", *settings.top_k} }; + req["metadata"] = {{"top_k", *settings.top_k}}; if (settings.frequency_penalty) req["frequency_penalty"] = *settings.frequency_penalty; if (settings.presence_penalty) @@ -325,12 +323,13 @@ namespace FoundryLocal { } ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, - const ChatSettings& settings) const { + const ChatSettings& settings) const { return CompleteChat(messages, {}, settings); } ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, - gsl::span tools, const ChatSettings& settings) const { + gsl::span tools, + const ChatSettings& settings) const { std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); CoreInteropRequest req("chat_completions"); @@ -343,12 +342,12 @@ namespace FoundryLocal { } void ChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, - const StreamCallback& onChunk) const { + const StreamCallback& onChunk) const { CompleteChatStreaming(messages, {}, settings, onChunk); } void ChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, - const ChatSettings& settings, const StreamCallback& onChunk) const { + const ChatSettings& settings, const StreamCallback& onChunk) const { std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); CoreInteropRequest req("chat_completions"); @@ -358,7 +357,7 @@ namespace FoundryLocal { struct State { const StreamCallback* cb; std::exception_ptr exception; - } state{ &onChunk, nullptr }; + } state{&onChunk, nullptr}; auto streamCallback = [](void* data, int32_t len, void* user) { if (!data || len <= 0) @@ -382,10 +381,10 @@ namespace FoundryLocal { catch (...) { st->exception = std::current_exception(); } - }; + }; core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), - reinterpret_cast(&state)); + reinterpret_cast(&state)); if (state.exception) { std::rethrow_exception(state.exception); @@ -397,9 +396,8 @@ namespace FoundryLocal { /// ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, - gsl::not_null logger) - : core_(core), info_(std::move(info)), logger_(logger) { - } + gsl::not_null logger) + : core_(core), info_(std::move(info)), logger_(logger) {} const ModelInfo& ModelVariant::GetInfo() const { return info_; @@ -455,7 +453,7 @@ namespace FoundryLocal { struct ProgressState { DownloadProgressCallback* cb; ILogger* logger; - } state{ &onProgress, logger_ }; + } state{&onProgress, logger_}; auto nativeCallback = [](void* data, int32_t len, void* user) { if (!data || len <= 0) @@ -465,14 +463,16 @@ namespace FoundryLocal { try { float value = std::stof(perc); (*(st->cb))(value); - } catch (...) { + } + catch (...) { st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); } }; CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, - reinterpret_cast(+nativeCallback), reinterpret_cast(&state)); - } else { + reinterpret_cast(+nativeCallback), reinterpret_cast(&state)); + } + else { CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); } } @@ -483,7 +483,8 @@ namespace FoundryLocal { const std::filesystem::path& ModelVariant::GetPath() const { if (cachedPath_.empty()) { - cachedPath_ = std::filesystem::path(CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_)); + cachedPath_ = + std::filesystem::path(CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_)); } return cachedPath_; } @@ -503,7 +504,8 @@ namespace FoundryLocal { AudioClient::AudioClient(gsl::not_null model) : AudioClient(model->core_, model->info_.name, model->logger_) { if (!model->IsLoaded()) { - throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); + throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", + *model->logger_); } } @@ -514,7 +516,8 @@ namespace FoundryLocal { ChatClient::ChatClient(gsl::not_null model) : ChatClient(model->core_, model->info_.name, model->logger_) { if (!model->IsLoaded()) { - throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); + throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", + *model->logger_); } } @@ -526,8 +529,7 @@ namespace FoundryLocal { /// Model /// Model::Model(gsl::not_null core, gsl::not_null logger) - : core_(core), logger_(logger) { - } + : core_(core), logger_(logger) {} ModelVariant& Model::SelectedVariant() { if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { @@ -556,8 +558,8 @@ namespace FoundryLocal { } } - throw FoundryLocalException( - "Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", *logger_); + throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", + *logger_); } const std::string& Model::GetId() const { @@ -570,11 +572,11 @@ namespace FoundryLocal { void Model::SelectVariant(gsl::not_null variant) const { auto it = std::find_if(variants_.begin(), variants_.end(), - [&](const ModelVariant& v) { return &v == variant.get(); }); + [&](const ModelVariant& v) { return &v == variant.get(); }); if (it == variants_.end()) { throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", - *logger_); + *logger_); } selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); @@ -681,18 +683,16 @@ namespace FoundryLocal { /// FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) - : config_(std::move(configuration)), core_(std::make_unique()), logger_(logger ? logger : &defaultLogger_) { + : config_(std::move(configuration)), core_(std::make_unique()), + logger_(logger ? logger : &defaultLogger_) { static_cast(core_.get())->loadEmbedded(); Initialize(); catalog_ = Catalog::Create(core_.get(), logger_); } FoundryLocalManager::FoundryLocalManager(FoundryLocalManager&& other) noexcept - : config_(std::move(other.config_)), - core_(std::move(other.core_)), - catalog_(std::move(other.catalog_)), - logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), - urls_(std::move(other.urls_)) { + : config_(std::move(other.config_)), core_(std::move(other.core_)), catalog_(std::move(other.catalog_)), + logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), urls_(std::move(other.urls_)) { other.logger_ = &other.defaultLogger_; } @@ -716,22 +716,26 @@ namespace FoundryLocal { for (const auto* variant : loadedModels) { try { variant->Unload(); - } catch (const std::exception& ex) { + } + catch (const std::exception& ex) { logger_->Log(LogLevel::Warning, - std::string("Error unloading model during destruction: ") + ex.what()); + std::string("Error unloading model during destruction: ") + ex.what()); } } - } catch (const std::exception& ex) { + } + catch (const std::exception& ex) { logger_->Log(LogLevel::Warning, - std::string("Error retrieving loaded models during destruction: ") + ex.what()); + std::string("Error retrieving loaded models during destruction: ") + ex.what()); } } if (!urls_.empty()) { try { StopWebService(); - } catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, std::string("Error stopping web service during destruction: ") + ex.what()); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error stopping web service during destruction: ") + ex.what()); } } } @@ -749,7 +753,8 @@ namespace FoundryLocal { std::string raw = core_->call("start_service", *logger_); auto arr = nlohmann::json::parse(raw); urls_ = arr.get>(); - } catch (const std::exception& ex) { + } + catch (const std::exception& ex) { throw FoundryLocalException(std::string("Error starting web service: ") + ex.what(), *logger_); } } @@ -762,7 +767,8 @@ namespace FoundryLocal { try { core_->call("stop_service", *logger_); urls_.clear(); - } catch (const std::exception& ex) { + } + catch (const std::exception& ex) { throw FoundryLocalException(std::string("Error stopping web service: ") + ex.what(), *logger_); } } @@ -774,9 +780,10 @@ namespace FoundryLocal { void FoundryLocalManager::EnsureEpsDownloaded() const { try { core_->call("ensure_eps_downloaded", *logger_); - } catch (const std::exception& ex) { - throw FoundryLocalException( - std::string("Error ensuring execution providers downloaded: ") + ex.what(), *logger_); + } + catch (const std::exception& ex) { + throw FoundryLocalException(std::string("Error ensuring execution providers downloaded: ") + ex.what(), + *logger_); } } @@ -818,10 +825,11 @@ namespace FoundryLocal { core_->call(setReq.Command(), *logger_, &setJson); logger_->Log(LogLevel::Information, - std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); + std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); } else { - logger_->Log(LogLevel::Information, std::string("Model cache directory already set to: ") + current); + logger_->Log(LogLevel::Information, + std::string("Model cache directory already set to: ") + current); } } } diff --git a/sdk/cpp/include/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h similarity index 100% rename from sdk/cpp/include/foundry_local_internal_core.h rename to sdk/cpp/src/foundry_local_internal_core.h diff --git a/sdk/cpp/include/parser.h b/sdk/cpp/src/parser.h similarity index 99% rename from sdk/cpp/include/parser.h rename to sdk/cpp/src/parser.h index 5396596d..58c31e87 100644 --- a/sdk/cpp/include/parser.h +++ b/sdk/cpp/src/parser.h @@ -42,6 +42,7 @@ namespace FoundryLocal { inline void from_json(const nlohmann::json& j, Runtime& r) { std::string deviceType; + std::string executionProvider; j.at("deviceType").get_to(deviceType); j.at("executionProvider").get_to(r.execution_provider); From 0abca3c6ea59b0d5f37c4873845d361bcfcb9558 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 11:01:09 -0700 Subject: [PATCH 06/18] openai chat/audio api --- sdk/cpp/include/configuration.h | 3 + sdk/cpp/include/foundry_local.h | 191 ++----------------- sdk/cpp/include/foundry_local_exception.h | 3 + sdk/cpp/include/log_level.h | 3 + sdk/cpp/include/logger.h | 3 + sdk/cpp/include/openai/openai_audio_client.h | 51 +++++ sdk/cpp/include/openai/openai_chat_client.h | 161 ++++++++++++++++ sdk/cpp/sample/main.cpp | 12 +- sdk/cpp/src/core_interop_request.h | 3 + sdk/cpp/src/flcore_native.h | 3 + sdk/cpp/src/foundry_local.cpp | 41 ++-- sdk/cpp/src/foundry_local_internal_core.h | 3 + sdk/cpp/src/parser.h | 3 + sdk/cpp/test/catalog_test.cpp | 3 + sdk/cpp/test/client_test.cpp | 113 +++++------ sdk/cpp/test/mock_core.h | 3 + sdk/cpp/test/mock_object_factory.h | 3 + sdk/cpp/test/model_variant_test.cpp | 3 + sdk/cpp/test/parser_and_types_test.cpp | 3 + 19 files changed, 357 insertions(+), 251 deletions(-) create mode 100644 sdk/cpp/include/openai/openai_audio_client.h create mode 100644 sdk/cpp/include/openai/openai_chat_client.h diff --git a/sdk/cpp/include/configuration.h b/sdk/cpp/include/configuration.h index 59fe63e3..99c6a52c 100644 --- a/sdk/cpp/include/configuration.h +++ b/sdk/cpp/include/configuration.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include #include diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index ba5aee98..07970c69 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -1,4 +1,7 @@ -#pragma once +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once #include #include #include @@ -14,9 +17,13 @@ #include #include "configuration.h" - #include "logger.h" +// OpenAI-based API types and clients are in a separate directory to keep the +// OpenAI surface well-separated from the core SDK (mirrors the C# layout). +#include "openai/openai_chat_client.h" +#include "openai/openai_audio_client.h" + namespace FoundryLocal::Internal { struct IFoundryLocalCore; } @@ -35,15 +42,6 @@ namespace FoundryLocal { NPU }; - /// Reason the model stopped generating tokens. - enum class FinishReason { - None, - Stop, - Length, - ToolCalls, - ContentFilter - }; - struct Runtime { DeviceType device_type = DeviceType::Invalid; std::string execution_provider; @@ -56,99 +54,6 @@ namespace FoundryLocal { std::string prompt; }; - struct AudioCreateTranscriptionResponse { - std::string text; - }; - - /// JSON Schema property definition used to describe tool function parameters. - struct PropertyDefinition { - std::string type; - std::optional description; - std::optional> properties; - std::optional> required; - }; - - /// Describes a function that a model may call. - struct FunctionDefinition { - std::string name; - std::optional description; - std::optional parameters; - }; - - /// A tool definition following the OpenAI tool calling spec. - struct ToolDefinition { - std::string type = "function"; - FunctionDefinition function; - }; - - /// A parsed function call returned by the model. - struct FunctionCall { - std::string name; - std::string arguments; ///< JSON string of the arguments - }; - - /// A tool call returned by the model in a chat completion response. - struct ToolCall { - std::string id; - std::string type; - std::optional function_call; - }; - - /// Controls whether and how the model calls tools. - enum class ToolChoiceKind { - Auto, - None, - Required - }; - - struct ChatMessage { - std::string role; - std::string content; - std::optional tool_call_id; ///< For role="tool" responses - std::vector tool_calls; - }; - - struct ChatChoice { - int index = 0; - FinishReason finish_reason = FinishReason::None; - - // non-streaming - std::optional message; - - // streaming - std::optional delta; - }; - - struct ChatCompletionCreateResponse { - int64_t created = 0; - std::string id; - - bool is_delta = false; - bool successful = false; - int http_status_code = 0; - - std::vector choices; - - /// Returns the object type string. Derived from is_delta — no allocation. - const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } - - /// Returns the created timestamp as an ISO 8601 string. - /// Computed lazilym only allocates when called. - std::string GetCreatedAtIso() const; - }; - - struct ChatSettings { - std::optional frequency_penalty; - std::optional max_tokens; - std::optional n; - std::optional temperature; - std::optional presence_penalty; - std::optional random_seed; - std::optional top_k; - std::optional top_p; - std::optional tool_choice; - }; - using DownloadProgressCallback = std::function; // Forward declarations @@ -187,64 +92,6 @@ namespace FoundryLocal { int64_t created_at_unix = 0; }; - class AudioClient final { - public: - explicit AudioClient(gsl::not_null model); - - /// Returns the model ID this client was created for. - const std::string& GetModelId() const noexcept { return modelId_; } - - AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath) const; - - using StreamCallback = std::function; - void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; - - private: - AudioClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger); - - std::string modelId_; - gsl::not_null core_; - gsl::not_null logger_; - - friend class ModelVariant; - }; - - class ChatClient final { - public: - explicit ChatClient(gsl::not_null model); - - /// Returns the model ID this client was created for. - const std::string& GetModelId() const noexcept { return modelId_; } - - ChatCompletionCreateResponse CompleteChat(gsl::span messages, - const ChatSettings& settings) const; - - ChatCompletionCreateResponse CompleteChat(gsl::span messages, - gsl::span tools, - const ChatSettings& settings) const; - - using StreamCallback = std::function; - void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, - const StreamCallback& onChunk) const; - - void CompleteChatStreaming(gsl::span messages, gsl::span tools, - const ChatSettings& settings, const StreamCallback& onChunk) const; - - private: - ChatClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger); - - std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, - const ChatSettings& settings, bool stream) const; - - std::string modelId_; - gsl::not_null core_; - gsl::not_null logger_; - - friend class ModelVariant; - }; - class ModelVariant final { public: const ModelInfo& GetInfo() const; @@ -257,11 +104,11 @@ namespace FoundryLocal { void Unload() const; void RemoveFromCache(); - [[deprecated("Use AudioClient(model) constructor instead")]] - AudioClient GetAudioClient() const; + [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] + OpenAIAudioClient GetAudioClient() const; - [[deprecated("Use ChatClient(model) constructor instead")]] - ChatClient GetChatClient() const; + [[deprecated("Use OpenAIChatClient(model) constructor instead")]] + OpenAIChatClient GetChatClient() const; const std::string& GetId() const noexcept; const std::string& GetAlias() const noexcept; @@ -278,8 +125,8 @@ namespace FoundryLocal { gsl::not_null logger_; friend class Catalog; - friend class AudioClient; - friend class ChatClient; + friend class OpenAIAudioClient; + friend class OpenAIChatClient; #ifdef FL_TESTS friend struct Testing::MockObjectFactory; #endif @@ -299,13 +146,13 @@ namespace FoundryLocal { void Load() const { SelectedVariant().Load(); } void Unload() const { SelectedVariant().Unload(); } void RemoveFromCache() { SelectedVariant().RemoveFromCache(); } - [[deprecated("Use AudioClient(model) constructor instead")]] - AudioClient GetAudioClient() const { + [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] + OpenAIAudioClient GetAudioClient() const { return SelectedVariant().GetAudioClient(); } - [[deprecated("Use ChatClient(model) constructor instead")]] - ChatClient GetChatClient() const { + [[deprecated("Use OpenAIChatClient(model) constructor instead")]] + OpenAIChatClient GetChatClient() const { return SelectedVariant().GetChatClient(); } diff --git a/sdk/cpp/include/foundry_local_exception.h b/sdk/cpp/include/foundry_local_exception.h index 6ca886a1..79648238 100644 --- a/sdk/cpp/include/foundry_local_exception.h +++ b/sdk/cpp/include/foundry_local_exception.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include diff --git a/sdk/cpp/include/log_level.h b/sdk/cpp/include/log_level.h index d9b82863..304abca1 100644 --- a/sdk/cpp/include/log_level.h +++ b/sdk/cpp/include/log_level.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include diff --git a/sdk/cpp/include/logger.h b/sdk/cpp/include/logger.h index 98d10155..53922a91 100644 --- a/sdk/cpp/include/logger.h +++ b/sdk/cpp/include/logger.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include #include "log_level.h" diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h new file mode 100644 index 00000000..da024fc0 --- /dev/null +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -0,0 +1,51 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include + +namespace FoundryLocal::Internal { + struct IFoundryLocalCore; +} + +namespace FoundryLocal { + class ILogger; + class ModelVariant; + + struct AudioCreateTranscriptionResponse { + std::string text; + }; + + class OpenAIAudioClient final { + public: + explicit OpenAIAudioClient(gsl::not_null model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath) const; + + using StreamCallback = std::function; + void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; + + private: + OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class ModelVariant; + }; + + /// Backward-compatible alias. + using AudioClient = OpenAIAudioClient; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h new file mode 100644 index 00000000..9acc66cb --- /dev/null +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace FoundryLocal::Internal { + struct IFoundryLocalCore; +} + +namespace FoundryLocal { + class ILogger; + class ModelVariant; + + /// Reason the model stopped generating tokens. + enum class FinishReason { + None, + Stop, + Length, + ToolCalls, + ContentFilter + }; + + /// JSON Schema property definition used to describe tool function parameters. + struct PropertyDefinition { + std::string type; + std::optional description; + std::optional> properties; + std::optional> required; + }; + + /// Describes a function that a model may call. + struct FunctionDefinition { + std::string name; + std::optional description; + std::optional parameters; + }; + + /// A tool definition following the OpenAI tool calling spec. + struct ToolDefinition { + std::string type = "function"; + FunctionDefinition function; + }; + + /// A parsed function call returned by the model. + struct FunctionCall { + std::string name; + std::string arguments; ///< JSON string of the arguments + }; + + /// A tool call returned by the model in a chat completion response. + struct ToolCall { + std::string id; + std::string type; + std::optional function_call; + }; + + /// Controls whether and how the model calls tools. + enum class ToolChoiceKind { + Auto, + None, + Required + }; + + struct ChatMessage { + std::string role; + std::string content; + std::optional tool_call_id; ///< For role="tool" responses + std::vector tool_calls; + }; + + struct ChatChoice { + int index = 0; + FinishReason finish_reason = FinishReason::None; + + // non-streaming + std::optional message; + + // streaming + std::optional delta; + }; + + struct ChatCompletionCreateResponse { + int64_t created = 0; + std::string id; + + bool is_delta = false; + bool successful = false; + int http_status_code = 0; + + std::vector choices; + + /// Returns the object type string. Derived from is_delta — no allocation. + const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } + + /// Returns the created timestamp as an ISO 8601 string. + /// Computed lazily, only allocates when called. + std::string GetCreatedAtIso() const; + }; + + struct ChatSettings { + std::optional frequency_penalty; + std::optional max_tokens; + std::optional n; + std::optional temperature; + std::optional presence_penalty; + std::optional random_seed; + std::optional top_k; + std::optional top_p; + std::optional tool_choice; + }; + + class OpenAIChatClient final { + public: + explicit OpenAIChatClient(gsl::not_null model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + const ChatSettings& settings) const; + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const; + + using StreamCallback = std::function; + void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const; + + void CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk) const; + + private: + OpenAIChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, + const ChatSettings& settings, bool stream) const; + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class ModelVariant; + }; + + /// Backward-compatible alias. + using ChatClient = OpenAIChatClient; + +} // namespace FoundryLocal diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index cddede9c..9385b9ec 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include "foundry_local.h" @@ -88,9 +91,8 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { return; } - // Get the selected variant pointer for ChatClient const auto& selectedVariant = model->GetAllModelVariants()[0]; - ChatClient chat(&selectedVariant); + OpenAIChatClient chat(&selectedVariant); std::vector messages = {{"system", "You are a helpful assistant. Keep answers brief."}, {"user", "What is the capital of Croatia?"}}; @@ -127,7 +129,7 @@ void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { model->Load(); const auto& selectedVariant = model->GetAllModelVariants()[0]; - ChatClient chat(&selectedVariant); + OpenAIChatClient chat(&selectedVariant); std::vector messages = {{"user", "Explain quantum computing in three sentences."}}; @@ -173,7 +175,7 @@ void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, con model->Load(); const auto& selectedVariant = model->GetAllModelVariants()[0]; - AudioClient audio(&selectedVariant); + OpenAIAudioClient audio(&selectedVariant); std::cout << "Transcribing: " << audioPath << "\n"; auto result = audio.TranscribeAudio(audioPath); @@ -223,7 +225,7 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) std::cout << "Model loaded: " << model->GetAlias() << "\n"; const auto& selectedVariant = model->GetAllModelVariants()[0]; - ChatClient chat(&selectedVariant); + OpenAIChatClient chat(&selectedVariant); // ── Step 1: Define tools ────────────────────────────────────────────── // Each tool describes a function the model can call. The PropertyDefinition diff --git a/sdk/cpp/src/core_interop_request.h b/sdk/cpp/src/core_interop_request.h index de03a61e..bb35d324 100644 --- a/sdk/cpp/src/core_interop_request.h +++ b/sdk/cpp/src/core_interop_request.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include #include diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index d67703e0..dffcdad0 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp index ae199555..31085e0e 100644 --- a/sdk/cpp/src/foundry_local.cpp +++ b/sdk/cpp/src/foundry_local.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include @@ -194,14 +197,14 @@ namespace FoundryLocal { }; /// - /// AudioClient + /// OpenAIAudioClient /// - AudioClient::AudioClient(gsl::not_null core, std::string_view modelId, + OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger) : core_(core), modelId_(modelId), logger_(logger) {} - AudioCreateTranscriptionResponse AudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { + AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; CoreInteropRequest req("audio_transcribe"); req.AddParam("OpenAICreateRequest", openAiReq.dump()); @@ -214,7 +217,7 @@ namespace FoundryLocal { return response; } - void AudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, + void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const { nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; CoreInteropRequest req("audio_transcribe"); @@ -270,14 +273,14 @@ namespace FoundryLocal { } /// - /// ChatClient + /// OpenAIChatClient /// - ChatClient::ChatClient(gsl::not_null core, std::string_view modelId, + OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger) : core_(core), modelId_(modelId), logger_(logger) {} - std::string ChatClient::BuildChatRequestJson(gsl::span messages, + std::string OpenAIChatClient::BuildChatRequestJson(gsl::span messages, gsl::span tools, const ChatSettings& settings, bool stream) const { nlohmann::json jMessages = nlohmann::json::array(); @@ -322,12 +325,12 @@ namespace FoundryLocal { return req.dump(); } - ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, const ChatSettings& settings) const { return CompleteChat(messages, {}, settings); } - ChatCompletionCreateResponse ChatClient::CompleteChat(gsl::span messages, + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, gsl::span tools, const ChatSettings& settings) const { std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); @@ -341,12 +344,12 @@ namespace FoundryLocal { return nlohmann::json::parse(rawResult).get(); } - void ChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, const StreamCallback& onChunk) const { CompleteChatStreaming(messages, {}, settings, onChunk); } - void ChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, const ChatSettings& settings, const StreamCallback& onChunk) const { std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); @@ -501,28 +504,28 @@ namespace FoundryLocal { return info_.version; } - AudioClient::AudioClient(gsl::not_null model) - : AudioClient(model->core_, model->info_.name, model->logger_) { + OpenAIAudioClient::OpenAIAudioClient(gsl::not_null model) + : OpenAIAudioClient(model->core_, model->info_.name, model->logger_) { if (!model->IsLoaded()) { throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); } } - AudioClient ModelVariant::GetAudioClient() const { - return AudioClient(this); + OpenAIAudioClient ModelVariant::GetAudioClient() const { + return OpenAIAudioClient(this); } - ChatClient::ChatClient(gsl::not_null model) - : ChatClient(model->core_, model->info_.name, model->logger_) { + OpenAIChatClient::OpenAIChatClient(gsl::not_null model) + : OpenAIChatClient(model->core_, model->info_.name, model->logger_) { if (!model->IsLoaded()) { throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", *model->logger_); } } - ChatClient ModelVariant::GetChatClient() const { - return ChatClient(this); + OpenAIChatClient ModelVariant::GetChatClient() const { + return OpenAIChatClient(this); } /// diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index eedfa5d4..aa702e3c 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 58c31e87..555d5078 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include #include diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp index 78824574..b25f9457 100644 --- a/sdk/cpp/test/catalog_test.cpp +++ b/sdk/cpp/test/catalog_test.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index ceae995c..602df3e6 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include "mock_core.h" @@ -12,7 +15,7 @@ using namespace FoundryLocal::Testing; using Factory = MockObjectFactory; -class ChatClientTest : public ::testing::Test { +class OpenAIChatClientTest : public ::testing::Test { protected: MockCore core_; NullLogger logger_; @@ -35,12 +38,12 @@ class ChatClientTest : public ::testing::Test { } }; -TEST_F(ChatClientTest, CompleteChat_BasicResponse) { +TEST_F(OpenAIChatClientTest, CompleteChat_BasicResponse) { core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "Say hello", {}}}; ChatSettings settings; @@ -51,12 +54,12 @@ TEST_F(ChatClientTest, CompleteChat_BasicResponse) { EXPECT_EQ("Hello world!", response.choices[0].message->content); } -TEST_F(ChatClientTest, CompleteChat_WithSettings) { +TEST_F(OpenAIChatClientTest, CompleteChat_WithSettings) { core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); -ChatClient client(&variant); +OpenAIChatClient client(&variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -85,12 +88,12 @@ settings.max_tokens = 100; EXPECT_EQ(10, openAiReq["metadata"]["top_k"].get()); } -TEST_F(ChatClientTest, CompleteChat_RequestFormat) { +TEST_F(OpenAIChatClientTest, CompleteChat_RequestFormat) { core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); -ChatClient client(&variant); +OpenAIChatClient client(&variant); std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; ChatSettings settings; @@ -106,8 +109,8 @@ auto response = client.CompleteChat(messages, settings); EXPECT_EQ("user", openAiReq["messages"][1]["role"].get()); } -TEST_F(ChatClientTest, CompleteChatStreaming) { - nlohmann::json chunk1 = { +TEST_F(OpenAIChatClientTest, CompleteChatStreaming) { +nlohmann::json chunk1 = { {"created", 1700000000}, {"id", "chatcmpl-1"}, {"IsDelta", true}, @@ -137,7 +140,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -153,7 +156,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming) { EXPECT_EQ(" world", chunks[1].choices[0].delta->content); } -TEST_F(ChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { nlohmann::json chunk = { {"created", 1700000000}, {"id", "chatcmpl-1"}, @@ -175,7 +178,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -186,27 +189,27 @@ TEST_F(ChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { std::runtime_error); } -TEST_F(ChatClientTest, Constructor_ThrowsIfNotLoaded) { +TEST_F(OpenAIChatClientTest, Constructor_ThrowsIfNotLoaded) { core_.OnCall("list_loaded_models", R"([])"); auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(ChatClient client(&variant), FoundryLocalException); + EXPECT_THROW(OpenAIChatClient client(&variant), FoundryLocalException); } -TEST_F(ChatClientTest, GetModelId) { +TEST_F(OpenAIChatClientTest, GetModelId) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); EXPECT_EQ("chat-model", client.GetModelId()); } // ---------- Tool calling tests ---------- -TEST_F(ChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); +TEST_F(OpenAIChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - auto variant = MakeLoadedVariant(); - ChatClient client(&variant); +auto variant = MakeLoadedVariant(); +OpenAIChatClient client(&variant); std::vector messages = {{"user", "What is 7 * 6?", {}}}; @@ -250,12 +253,12 @@ TEST_F(ChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { EXPECT_EQ("required", openAiReq["tool_choice"].get()); } -TEST_F(ChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); +TEST_F(OpenAIChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - auto variant = MakeLoadedVariant(); - ChatClient client(&variant); +auto variant = MakeLoadedVariant(); +OpenAIChatClient client(&variant); std::vector messages = {{"user", "Hello", {}}}; ChatSettings settings; @@ -268,7 +271,7 @@ TEST_F(ChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { EXPECT_FALSE(openAiReq.contains("tool_choice")); } -TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { +TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallResponse_Parsed) { // Simulate a response with tool calls from the model nlohmann::json resp = { {"created", 1700000000}, @@ -291,7 +294,7 @@ TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "What is 7 * 6?", {}}}; ChatSettings settings; @@ -310,12 +313,12 @@ TEST_F(ChatClientTest, CompleteChat_ToolCallResponse_Parsed) { EXPECT_EQ("{\"first\": 7, \"second\": 6}", msg.tool_calls[0].function_call->arguments); } -TEST_F(ChatClientTest, CompleteChat_ToolChoiceAuto) { +TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceAuto) { core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -328,12 +331,12 @@ TEST_F(ChatClientTest, CompleteChat_ToolChoiceAuto) { EXPECT_EQ("auto", openAiReq["tool_choice"].get()); } -TEST_F(ChatClientTest, CompleteChat_ToolChoiceNone) { +TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceNone) { core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -346,12 +349,12 @@ TEST_F(ChatClientTest, CompleteChat_ToolChoiceNone) { EXPECT_EQ("none", openAiReq["tool_choice"].get()); } -TEST_F(ChatClientTest, CompleteChat_ToolMessageWithToolCallId) { +TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); ChatMessage toolMsg; toolMsg.role = "tool"; @@ -374,7 +377,7 @@ TEST_F(ChatClientTest, CompleteChat_ToolMessageWithToolCallId) { EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); } -TEST_F(ChatClientTest, CompleteChatStreaming_WithTools) { +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { nlohmann::json chunk1 = { {"created", 1700000000}, {"id", "chatcmpl-1"}, @@ -415,7 +418,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming_WithTools) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - ChatClient client(&variant); + OpenAIChatClient client(&variant); std::vector messages = {{"user", "test", {}}}; @@ -444,7 +447,7 @@ TEST_F(ChatClientTest, CompleteChatStreaming_WithTools) { EXPECT_EQ("required", openAiReq["tool_choice"].get()); } -class AudioClientTest : public ::testing::Test { +class OpenAIAudioClientTest : public ::testing::Test { protected: MockCore core_; NullLogger logger_; @@ -455,23 +458,23 @@ class AudioClientTest : public ::testing::Test { } }; -TEST_F(AudioClientTest, TranscribeAudio) { - core_.OnCall("audio_transcribe", "Hello world transcribed text"); - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); +TEST_F(OpenAIAudioClientTest, TranscribeAudio) { +core_.OnCall("audio_transcribe", "Hello world transcribed text"); +core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - auto variant = MakeLoadedVariant(); - AudioClient client(&variant); +auto variant = MakeLoadedVariant(); +OpenAIAudioClient client(&variant); auto response = client.TranscribeAudio("test.wav"); EXPECT_EQ("Hello world transcribed text", response.text); } -TEST_F(AudioClientTest, TranscribeAudio_RequestFormat) { - core_.OnCall("audio_transcribe", "text"); - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); +TEST_F(OpenAIAudioClientTest, TranscribeAudio_RequestFormat) { +core_.OnCall("audio_transcribe", "text"); +core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - auto variant = MakeLoadedVariant(); - AudioClient client(&variant); +auto variant = MakeLoadedVariant(); +OpenAIAudioClient client(&variant); client.TranscribeAudio("audio.wav"); auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); @@ -480,7 +483,7 @@ TEST_F(AudioClientTest, TranscribeAudio_RequestFormat) { EXPECT_EQ("audio.wav", openAiReq["FileName"].get()); } -TEST_F(AudioClientTest, TranscribeAudioStreaming) { +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { core_.OnCall("audio_transcribe", [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { if (callback && userData) { @@ -495,7 +498,7 @@ TEST_F(AudioClientTest, TranscribeAudioStreaming) { core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); - AudioClient client(&variant); + OpenAIAudioClient client(&variant); std::vector chunks; client.TranscribeAudioStreaming( @@ -506,7 +509,7 @@ TEST_F(AudioClientTest, TranscribeAudioStreaming) { EXPECT_EQ("world!", chunks[1]); } -TEST_F(AudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { core_.OnCall("audio_transcribe", [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { if (callback && userData) { @@ -519,7 +522,7 @@ TEST_F(AudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); - AudioClient client(&variant); + OpenAIAudioClient client(&variant); EXPECT_THROW( client.TranscribeAudioStreaming( @@ -527,15 +530,15 @@ TEST_F(AudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { std::runtime_error); } -TEST_F(AudioClientTest, Constructor_ThrowsIfNotLoaded) { +TEST_F(OpenAIAudioClientTest, Constructor_ThrowsIfNotLoaded) { core_.OnCall("list_loaded_models", R"([])"); auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(AudioClient client(&variant), FoundryLocalException); + EXPECT_THROW(OpenAIAudioClient client(&variant), FoundryLocalException); } -TEST_F(AudioClientTest, GetModelId) { +TEST_F(OpenAIAudioClientTest, GetModelId) { core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); - AudioClient client(&variant); + OpenAIAudioClient client(&variant); EXPECT_EQ("audio-model", client.GetModelId()); } diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h index b7aa349d..16dc5991 100644 --- a/sdk/cpp/test/mock_core.h +++ b/sdk/cpp/test/mock_core.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h index ac055da8..2addbd67 100644 --- a/sdk/cpp/test/mock_object_factory.h +++ b/sdk/cpp/test/mock_object_factory.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #ifndef FL_TESTS diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index f791dc4f..6c1835a6 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include "mock_core.h" diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp index c83acbc9..9703c8b6 100644 --- a/sdk/cpp/test/parser_and_types_test.cpp +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include "mock_core.h" From c63a14da2bd52550b73309dfc88fe5c7a6397296 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 11:49:23 -0700 Subject: [PATCH 07/18] solving nit to split files --- sdk/cpp/include/catalog.h | 69 +++++ sdk/cpp/include/foundry_local.h | 270 ++------------------ sdk/cpp/include/foundry_local_manager.h | 64 +++++ sdk/cpp/include/model.h | 174 +++++++++++++ sdk/cpp/include/openai/openai_chat_client.h | 45 +--- sdk/cpp/include/openai/openai_tool_types.h | 54 ++++ 6 files changed, 380 insertions(+), 296 deletions(-) create mode 100644 sdk/cpp/include/catalog.h create mode 100644 sdk/cpp/include/foundry_local_manager.h create mode 100644 sdk/cpp/include/model.h create mode 100644 sdk/cpp/include/openai/openai_tool_types.h diff --git a/sdk/cpp/include/catalog.h b/sdk/cpp/include/catalog.h new file mode 100644 index 00000000..a9cd9185 --- /dev/null +++ b/sdk/cpp/include/catalog.h @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "model.h" + +namespace FoundryLocal::Internal { + struct IFoundryLocalCore; +} + +namespace FoundryLocal { +#ifdef FL_TESTS + namespace Testing { + struct MockObjectFactory; + } +#endif + + class Catalog final { + public: + Catalog(const Catalog&) = delete; + Catalog& operator=(const Catalog&) = delete; + Catalog(Catalog&&) = delete; + Catalog& operator=(Catalog&&) = delete; + + static std::unique_ptr Create(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + const std::string& GetName() const { return name_; } + std::vector ListModels() const; + std::vector GetLoadedModels() const; + std::vector GetCachedModels() const; + + const Model* GetModel(std::string_view modelId) const; + const ModelVariant* GetModelVariant(std::string_view modelVariantId) const; + + private: + void UpdateModels() const; + + mutable std::chrono::steady_clock::time_point lastFetch_{}; + + mutable std::unordered_map byAlias_; + mutable std::unordered_map modelIdToModelVariant_; + + explicit Catalog(gsl::not_null injected, + gsl::not_null logger); + + gsl::not_null core_; + std::string name_; + gsl::not_null logger_; + + friend class FoundryLocalManager; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index 07970c69..f3c22526 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -1,263 +1,25 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// +// Umbrella header – includes every public header for convenience. +// Consumers may also include individual headers directly: +// #include "model.h" +// #include "catalog.h" +// #include "foundry_local_manager.h" +// #include "openai/openai_tool_types.h" +// #include "openai/openai_chat_client.h" +// #include "openai/openai_audio_client.h" #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include #include "configuration.h" +#include "foundry_local_exception.h" +#include "log_level.h" #include "logger.h" - -// OpenAI-based API types and clients are in a separate directory to keep the -// OpenAI surface well-separated from the core SDK (mirrors the C# layout). +#include "model.h" +#include "catalog.h" +#include "foundry_local_manager.h" +#include "openai/openai_tool_types.h" #include "openai/openai_chat_client.h" #include "openai/openai_audio_client.h" - -namespace FoundryLocal::Internal { - struct IFoundryLocalCore; -} - -namespace FoundryLocal { -#ifdef FL_TESTS - namespace Testing { - struct MockObjectFactory; - } -#endif - - enum class DeviceType { - Invalid, - CPU, - GPU, - NPU - }; - - struct Runtime { - DeviceType device_type = DeviceType::Invalid; - std::string execution_provider; - }; - - struct PromptTemplate { - std::string system; - std::string user; - std::string assistant; - std::string prompt; - }; - - using DownloadProgressCallback = std::function; - - // Forward declarations - class ModelVariant; - - struct Parameter { - std::string name; - std::optional value; - }; - - struct ModelSettings { - std::vector parameters; - }; - - struct ModelInfo { - std::string id; - std::string name; - uint32_t version = 0; - std::string alias; - std::optional display_name; - std::string provider_type; - std::string uri; - std::string model_type; - std::optional prompt_template; - std::optional publisher; - std::optional model_settings; - std::optional license; - std::optional license_description; - bool cached = false; - std::optional task; - std::optional runtime; - std::optional file_size_mb; - std::optional supports_tool_calling; - std::optional max_output_tokens; - std::optional min_fl_version; - int64_t created_at_unix = 0; - }; - - class ModelVariant final { - public: - const ModelInfo& GetInfo() const; - const std::filesystem::path& GetPath() const; - void Download(DownloadProgressCallback onProgress = nullptr) const; - void Load() const; - - bool IsLoaded() const; - bool IsCached() const; - void Unload() const; - void RemoveFromCache(); - - [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] - OpenAIAudioClient GetAudioClient() const; - - [[deprecated("Use OpenAIChatClient(model) constructor instead")]] - OpenAIChatClient GetChatClient() const; - - const std::string& GetId() const noexcept; - const std::string& GetAlias() const noexcept; - uint32_t GetVersion() const noexcept; - - private: - static std::string MakeModelParamRequest(std::string_view modelId); - explicit ModelVariant(gsl::not_null core, ModelInfo info, - gsl::not_null logger); - - ModelInfo info_; - mutable std::filesystem::path cachedPath_; - gsl::not_null core_; - gsl::not_null logger_; - - friend class Catalog; - friend class OpenAIAudioClient; - friend class OpenAIChatClient; -#ifdef FL_TESTS - friend struct Testing::MockObjectFactory; -#endif - }; - - class Model final { - public: - gsl::span GetAllModelVariants() const; - const ModelVariant* GetLatestVariant(gsl::not_null variant) const; - - bool IsLoaded() const { return SelectedVariant().IsLoaded(); } - bool IsCached() const { return SelectedVariant().IsCached(); } - const std::filesystem::path& GetPath() const { return SelectedVariant().GetPath(); } - void Download(DownloadProgressCallback onProgress = nullptr) const { - SelectedVariant().Download(std::move(onProgress)); - } - void Load() const { SelectedVariant().Load(); } - void Unload() const { SelectedVariant().Unload(); } - void RemoveFromCache() { SelectedVariant().RemoveFromCache(); } - [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] - OpenAIAudioClient GetAudioClient() const { - return SelectedVariant().GetAudioClient(); - } - - [[deprecated("Use OpenAIChatClient(model) constructor instead")]] - OpenAIChatClient GetChatClient() const { - return SelectedVariant().GetChatClient(); - } - - const std::string& GetId() const; - const std::string& GetAlias() const; - void SelectVariant(gsl::not_null variant) const; - - private: - explicit Model(gsl::not_null core, gsl::not_null logger); - ModelVariant& SelectedVariant(); - const ModelVariant& SelectedVariant() const; - - gsl::not_null core_; - - std::vector variants_; - mutable std::optional selectedVariantIndex_; - gsl::not_null logger_; - - friend class Catalog; -#ifdef FL_TESTS - friend struct Testing::MockObjectFactory; -#endif - }; - - class Catalog final { - public: - Catalog(const Catalog&) = delete; - Catalog& operator=(const Catalog&) = delete; - Catalog(Catalog&&) = delete; - Catalog& operator=(Catalog&&) = delete; - - static std::unique_ptr Create(gsl::not_null core, - gsl::not_null logger) { - return std::unique_ptr(new Catalog(core, logger)); - } - - const std::string& GetName() const { return name_; } - std::vector ListModels() const; - std::vector GetLoadedModels() const; - std::vector GetCachedModels() const; - - const Model* GetModel(std::string_view modelId) const; - const ModelVariant* GetModelVariant(std::string_view modelVariantId) const; - - private: - void UpdateModels() const; - - mutable std::chrono::steady_clock::time_point lastFetch_{}; - - mutable std::unordered_map byAlias_; - mutable std::unordered_map modelIdToModelVariant_; - - explicit Catalog(gsl::not_null injected, - gsl::not_null logger); - - gsl::not_null core_; - std::string name_; - gsl::not_null logger_; - - friend class FoundryLocalManager; -#ifdef FL_TESTS - friend struct Testing::MockObjectFactory; -#endif - }; - - class FoundryLocalManager final { - public: - FoundryLocalManager(const FoundryLocalManager&) = delete; - FoundryLocalManager& operator=(const FoundryLocalManager&) = delete; - FoundryLocalManager(FoundryLocalManager&& other) noexcept; - FoundryLocalManager& operator=(FoundryLocalManager&& other) noexcept; - - explicit FoundryLocalManager(Configuration configuration, ILogger* logger = nullptr); - ~FoundryLocalManager(); - - const Catalog& GetCatalog() const; - - /// Start the optional built-in web service. - /// Provides an OpenAI-compatible REST endpoint. - /// After startup, GetUrls() returns the actual bound URL/s. - /// Requires Configuration::Web to be set. - void StartWebService(); - - /// Stop the web service if started. - void StopWebService(); - - /// Returns the bound URL/s after StartWebService(), or empty if not started. - gsl::span GetUrls() const noexcept; - - /// Ensure execution providers are downloaded and registered. - /// Once downloaded, EPs are not re-downloaded unless a new version is available. - void EnsureEpsDownloaded() const; - - private: - bool OwnsLogger() const noexcept { return logger_ == &defaultLogger_; } - - Configuration config_; - - void Initialize(); - - NullLogger defaultLogger_; - std::unique_ptr core_; - std::unique_ptr catalog_; - ILogger* logger_; - std::vector urls_; - }; - -} // namespace FoundryLocal diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h new file mode 100644 index 00000000..447894eb --- /dev/null +++ b/sdk/cpp/include/foundry_local_manager.h @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include + +#include +#include + +#include "configuration.h" +#include "logger.h" +#include "catalog.h" + +namespace FoundryLocal::Internal { + struct IFoundryLocalCore; +} + +namespace FoundryLocal { + + class FoundryLocalManager final { + public: + FoundryLocalManager(const FoundryLocalManager&) = delete; + FoundryLocalManager& operator=(const FoundryLocalManager&) = delete; + FoundryLocalManager(FoundryLocalManager&& other) noexcept; + FoundryLocalManager& operator=(FoundryLocalManager&& other) noexcept; + + explicit FoundryLocalManager(Configuration configuration, ILogger* logger = nullptr); + ~FoundryLocalManager(); + + const Catalog& GetCatalog() const; + + /// Start the optional built-in web service. + /// Provides an OpenAI-compatible REST endpoint. + /// After startup, GetUrls() returns the actual bound URL/s. + /// Requires Configuration::Web to be set. + void StartWebService(); + + /// Stop the web service if started. + void StopWebService(); + + /// Returns the bound URL/s after StartWebService(), or empty if not started. + gsl::span GetUrls() const noexcept; + + /// Ensure execution providers are downloaded and registered. + /// Once downloaded, EPs are not re-downloaded unless a new version is available. + void EnsureEpsDownloaded() const; + + private: + bool OwnsLogger() const noexcept { return logger_ == &defaultLogger_; } + + Configuration config_; + + void Initialize(); + + NullLogger defaultLogger_; + std::unique_ptr core_; + std::unique_ptr catalog_; + ILogger* logger_; + std::vector urls_; + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h new file mode 100644 index 00000000..55f00028 --- /dev/null +++ b/sdk/cpp/include/model.h @@ -0,0 +1,174 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "logger.h" +#include "openai/openai_chat_client.h" +#include "openai/openai_audio_client.h" + +namespace FoundryLocal::Internal { + struct IFoundryLocalCore; +} + +namespace FoundryLocal { +#ifdef FL_TESTS + namespace Testing { + struct MockObjectFactory; + } +#endif + + enum class DeviceType { + Invalid, + CPU, + GPU, + NPU + }; + + struct Runtime { + DeviceType device_type = DeviceType::Invalid; + std::string execution_provider; + }; + + struct PromptTemplate { + std::string system; + std::string user; + std::string assistant; + std::string prompt; + }; + + using DownloadProgressCallback = std::function; + + // Forward declarations + class ModelVariant; + + struct Parameter { + std::string name; + std::optional value; + }; + + struct ModelSettings { + std::vector parameters; + }; + + struct ModelInfo { + std::string id; + std::string name; + uint32_t version = 0; + std::string alias; + std::optional display_name; + std::string provider_type; + std::string uri; + std::string model_type; + std::optional prompt_template; + std::optional publisher; + std::optional model_settings; + std::optional license; + std::optional license_description; + bool cached = false; + std::optional task; + std::optional runtime; + std::optional file_size_mb; + std::optional supports_tool_calling; + std::optional max_output_tokens; + std::optional min_fl_version; + int64_t created_at_unix = 0; + }; + + class ModelVariant final { + public: + const ModelInfo& GetInfo() const; + const std::filesystem::path& GetPath() const; + void Download(DownloadProgressCallback onProgress = nullptr) const; + void Load() const; + + bool IsLoaded() const; + bool IsCached() const; + void Unload() const; + void RemoveFromCache(); + + [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] + OpenAIAudioClient GetAudioClient() const; + + [[deprecated("Use OpenAIChatClient(model) constructor instead")]] + OpenAIChatClient GetChatClient() const; + + const std::string& GetId() const noexcept; + const std::string& GetAlias() const noexcept; + uint32_t GetVersion() const noexcept; + + private: + static std::string MakeModelParamRequest(std::string_view modelId); + explicit ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger); + + ModelInfo info_; + mutable std::filesystem::path cachedPath_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class Catalog; + friend class OpenAIAudioClient; + friend class OpenAIChatClient; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class Model final { + public: + gsl::span GetAllModelVariants() const; + const ModelVariant* GetLatestVariant(gsl::not_null variant) const; + + bool IsLoaded() const { return SelectedVariant().IsLoaded(); } + bool IsCached() const { return SelectedVariant().IsCached(); } + const std::filesystem::path& GetPath() const { return SelectedVariant().GetPath(); } + void Download(DownloadProgressCallback onProgress = nullptr) const { + SelectedVariant().Download(std::move(onProgress)); + } + void Load() const { SelectedVariant().Load(); } + void Unload() const { SelectedVariant().Unload(); } + void RemoveFromCache() { SelectedVariant().RemoveFromCache(); } + [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] + OpenAIAudioClient GetAudioClient() const { + return SelectedVariant().GetAudioClient(); + } + + [[deprecated("Use OpenAIChatClient(model) constructor instead")]] + OpenAIChatClient GetChatClient() const { + return SelectedVariant().GetChatClient(); + } + + const std::string& GetId() const; + const std::string& GetAlias() const; + void SelectVariant(gsl::not_null variant) const; + + private: + explicit Model(gsl::not_null core, gsl::not_null logger); + ModelVariant& SelectedVariant(); + const ModelVariant& SelectedVariant() const; + + gsl::not_null core_; + + std::vector variants_; + mutable std::optional selectedVariantIndex_; + gsl::not_null logger_; + + friend class Catalog; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + +} // namespace FoundryLocal diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h index 9acc66cb..8d24242d 100644 --- a/sdk/cpp/include/openai/openai_chat_client.h +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once @@ -8,12 +9,13 @@ #include #include #include -#include #include #include #include +#include "openai_tool_types.h" + namespace FoundryLocal::Internal { struct IFoundryLocalCore; } @@ -31,47 +33,6 @@ namespace FoundryLocal { ContentFilter }; - /// JSON Schema property definition used to describe tool function parameters. - struct PropertyDefinition { - std::string type; - std::optional description; - std::optional> properties; - std::optional> required; - }; - - /// Describes a function that a model may call. - struct FunctionDefinition { - std::string name; - std::optional description; - std::optional parameters; - }; - - /// A tool definition following the OpenAI tool calling spec. - struct ToolDefinition { - std::string type = "function"; - FunctionDefinition function; - }; - - /// A parsed function call returned by the model. - struct FunctionCall { - std::string name; - std::string arguments; ///< JSON string of the arguments - }; - - /// A tool call returned by the model in a chat completion response. - struct ToolCall { - std::string id; - std::string type; - std::optional function_call; - }; - - /// Controls whether and how the model calls tools. - enum class ToolChoiceKind { - Auto, - None, - Required - }; - struct ChatMessage { std::string role; std::string content; diff --git a/sdk/cpp/include/openai/openai_tool_types.h b/sdk/cpp/include/openai/openai_tool_types.h new file mode 100644 index 00000000..8ae4a068 --- /dev/null +++ b/sdk/cpp/include/openai/openai_tool_types.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +namespace FoundryLocal { + + /// JSON Schema property definition used to describe tool function parameters. + struct PropertyDefinition { + std::string type; + std::optional description; + std::optional> properties; + std::optional> required; + }; + + /// Describes a function that a model may call. + struct FunctionDefinition { + std::string name; + std::optional description; + std::optional parameters; + }; + + /// A tool definition following the OpenAI tool calling spec. + struct ToolDefinition { + std::string type = "function"; + FunctionDefinition function; + }; + + /// A parsed function call returned by the model. + struct FunctionCall { + std::string name; + std::string arguments; ///< JSON string of the arguments + }; + + /// A tool call returned by the model in a chat completion response. + struct ToolCall { + std::string id; + std::string type; + std::optional function_call; + }; + + /// Controls whether and how the model calls tools. + enum class ToolChoiceKind { + Auto, + None, + Required + }; + +} // namespace FoundryLocal From 7887a0455d4e5df584d7c578abf15b55e32f9b45 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 12:25:18 -0700 Subject: [PATCH 08/18] resolving nits --- sdk/cpp/include/model.h | 84 ++++++++++++++------ sdk/cpp/include/openai/openai_audio_client.h | 8 +- sdk/cpp/include/openai/openai_chat_client.h | 12 ++- sdk/cpp/sample/main.cpp | 15 ++-- sdk/cpp/src/foundry_local.cpp | 44 +++++----- sdk/cpp/test/client_test.cpp | 40 +++++----- sdk/cpp/test/model_variant_test.cpp | 8 +- 7 files changed, 122 insertions(+), 89 deletions(-) diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index 55f00028..a591ba20 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -29,6 +29,35 @@ namespace FoundryLocal { } #endif + using DownloadProgressCallback = std::function; + + class IModel { + public: + virtual ~IModel() = default; + + virtual const std::string& GetId() const = 0; + virtual const std::string& GetAlias() const = 0; + virtual bool IsLoaded() const = 0; + virtual bool IsCached() const = 0; + virtual const std::filesystem::path& GetPath() const = 0; + virtual void Download(DownloadProgressCallback onProgress = nullptr) const = 0; + virtual void Load() const = 0; + virtual void Unload() const = 0; + virtual void RemoveFromCache() = 0; + + protected: + struct CoreAccess { + gsl::not_null core; + std::string modelName; + gsl::not_null logger; + }; + + virtual CoreAccess GetCoreAccess() const = 0; + + friend class OpenAIChatClient; + friend class OpenAIAudioClient; + }; + enum class DeviceType { Invalid, CPU, @@ -48,8 +77,6 @@ namespace FoundryLocal { std::string prompt; }; - using DownloadProgressCallback = std::function; - // Forward declarations class ModelVariant; @@ -86,17 +113,17 @@ namespace FoundryLocal { int64_t created_at_unix = 0; }; - class ModelVariant final { + class ModelVariant final : public IModel { public: const ModelInfo& GetInfo() const; - const std::filesystem::path& GetPath() const; - void Download(DownloadProgressCallback onProgress = nullptr) const; - void Load() const; + const std::filesystem::path& GetPath() const override; + void Download(DownloadProgressCallback onProgress = nullptr) const override; + void Load() const override; - bool IsLoaded() const; - bool IsCached() const; - void Unload() const; - void RemoveFromCache(); + bool IsLoaded() const override; + bool IsCached() const override; + void Unload() const override; + void RemoveFromCache() override; [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] OpenAIAudioClient GetAudioClient() const; @@ -104,10 +131,13 @@ namespace FoundryLocal { [[deprecated("Use OpenAIChatClient(model) constructor instead")]] OpenAIChatClient GetChatClient() const; - const std::string& GetId() const noexcept; - const std::string& GetAlias() const noexcept; + const std::string& GetId() const noexcept override; + const std::string& GetAlias() const noexcept override; uint32_t GetVersion() const noexcept; + protected: + CoreAccess GetCoreAccess() const override; + private: static std::string MakeModelParamRequest(std::string_view modelId); explicit ModelVariant(gsl::not_null core, ModelInfo info, @@ -119,27 +149,26 @@ namespace FoundryLocal { gsl::not_null logger_; friend class Catalog; - friend class OpenAIAudioClient; - friend class OpenAIChatClient; + friend class Model; #ifdef FL_TESTS friend struct Testing::MockObjectFactory; #endif }; - class Model final { + class Model final : public IModel { public: gsl::span GetAllModelVariants() const; - const ModelVariant* GetLatestVariant(gsl::not_null variant) const; + const ModelVariant* GetLatestVariant(const ModelVariant& variant) const; - bool IsLoaded() const { return SelectedVariant().IsLoaded(); } - bool IsCached() const { return SelectedVariant().IsCached(); } - const std::filesystem::path& GetPath() const { return SelectedVariant().GetPath(); } - void Download(DownloadProgressCallback onProgress = nullptr) const { + bool IsLoaded() const override { return SelectedVariant().IsLoaded(); } + bool IsCached() const override { return SelectedVariant().IsCached(); } + const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); } + void Download(DownloadProgressCallback onProgress = nullptr) const override { SelectedVariant().Download(std::move(onProgress)); } - void Load() const { SelectedVariant().Load(); } - void Unload() const { SelectedVariant().Unload(); } - void RemoveFromCache() { SelectedVariant().RemoveFromCache(); } + void Load() const override { SelectedVariant().Load(); } + void Unload() const override { SelectedVariant().Unload(); } + void RemoveFromCache() override { SelectedVariant().RemoveFromCache(); } [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] OpenAIAudioClient GetAudioClient() const { return SelectedVariant().GetAudioClient(); @@ -150,9 +179,12 @@ namespace FoundryLocal { return SelectedVariant().GetChatClient(); } - const std::string& GetId() const; - const std::string& GetAlias() const; - void SelectVariant(gsl::not_null variant) const; + const std::string& GetId() const override; + const std::string& GetAlias() const override; + void SelectVariant(const ModelVariant& variant) const; + + protected: + CoreAccess GetCoreAccess() const override; private: explicit Model(gsl::not_null core, gsl::not_null logger); diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h index da024fc0..cd5f52e4 100644 --- a/sdk/cpp/include/openai/openai_audio_client.h +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -15,8 +15,8 @@ namespace FoundryLocal::Internal { } namespace FoundryLocal { - class ILogger; - class ModelVariant; +class ILogger; +class IModel; struct AudioCreateTranscriptionResponse { std::string text; @@ -24,7 +24,7 @@ namespace FoundryLocal { class OpenAIAudioClient final { public: - explicit OpenAIAudioClient(gsl::not_null model); + explicit OpenAIAudioClient(const IModel& model); /// Returns the model ID this client was created for. const std::string& GetModelId() const noexcept { return modelId_; } @@ -41,8 +41,6 @@ namespace FoundryLocal { std::string modelId_; gsl::not_null core_; gsl::not_null logger_; - - friend class ModelVariant; }; /// Backward-compatible alias. diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h index 8d24242d..20d97fae 100644 --- a/sdk/cpp/include/openai/openai_chat_client.h +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -21,11 +21,11 @@ namespace FoundryLocal::Internal { } namespace FoundryLocal { - class ILogger; - class ModelVariant; +class ILogger; +class IModel; - /// Reason the model stopped generating tokens. - enum class FinishReason { +/// Reason the model stopped generating tokens. +enum class FinishReason { None, Stop, Length, @@ -83,7 +83,7 @@ namespace FoundryLocal { class OpenAIChatClient final { public: - explicit OpenAIChatClient(gsl::not_null model); + explicit OpenAIChatClient(const IModel& model); /// Returns the model ID this client was created for. const std::string& GetModelId() const noexcept { return modelId_; } @@ -112,8 +112,6 @@ namespace FoundryLocal { std::string modelId_; gsl::not_null core_; gsl::not_null logger_; - - friend class ModelVariant; }; /// Backward-compatible alias. diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 9385b9ec..dc1327d9 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -31,7 +31,7 @@ class StdLogger final : public ILogger { tag = "DEBUG"; break; } - std::fprintf(stderr, "[FoundryLocal][%s] %.*s\n", tag, static_cast(message.size()), message.data()); + std::cout << "[FoundryLocal][" << tag << "] " << message << "\n"; } }; @@ -91,8 +91,7 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { return; } - const auto& selectedVariant = model->GetAllModelVariants()[0]; - OpenAIChatClient chat(&selectedVariant); + OpenAIChatClient chat(*model); std::vector messages = {{"system", "You are a helpful assistant. Keep answers brief."}, {"user", "What is the capital of Croatia?"}}; @@ -128,8 +127,7 @@ void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { model->Load(); - const auto& selectedVariant = model->GetAllModelVariants()[0]; - OpenAIChatClient chat(&selectedVariant); + OpenAIChatClient chat(*model); std::vector messages = {{"user", "Explain quantum computing in three sentences."}}; @@ -175,7 +173,7 @@ void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, con model->Load(); const auto& selectedVariant = model->GetAllModelVariants()[0]; - OpenAIAudioClient audio(&selectedVariant); + OpenAIAudioClient audio(*model); std::cout << "Transcribing: " << audioPath << "\n"; auto result = audio.TranscribeAudio(audioPath); @@ -224,8 +222,7 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) model->Load(); std::cout << "Model loaded: " << model->GetAlias() << "\n"; - const auto& selectedVariant = model->GetAllModelVariants()[0]; - OpenAIChatClient chat(&selectedVariant); + OpenAIChatClient chat(*model); // ── Step 1: Define tools ────────────────────────────────────────────── // Each tool describes a function the model can call. The PropertyDefinition @@ -364,7 +361,7 @@ void InspectVariants(FoundryLocalManager& manager, const std::string& alias) { // Select a specific variant by pointer (e.g. prefer the GPU variant) for (const auto& v : variants) { if (v.GetInfo().runtime && v.GetInfo().runtime->device_type == DeviceType::GPU) { - model->SelectVariant(&v); + model->SelectVariant(v); std::cout << "Selected GPU variant: " << model->GetId() << "\n"; break; } diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp index 31085e0e..b81d70a2 100644 --- a/sdk/cpp/src/foundry_local.cpp +++ b/sdk/cpp/src/foundry_local.cpp @@ -504,28 +504,32 @@ namespace FoundryLocal { return info_.version; } - OpenAIAudioClient::OpenAIAudioClient(gsl::not_null model) - : OpenAIAudioClient(model->core_, model->info_.name, model->logger_) { - if (!model->IsLoaded()) { - throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", - *model->logger_); + IModel::CoreAccess ModelVariant::GetCoreAccess() const { + return {core_, info_.name, logger_}; + } + + OpenAIAudioClient::OpenAIAudioClient(const IModel& model) + : OpenAIAudioClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw FoundryLocalException("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); } } OpenAIAudioClient ModelVariant::GetAudioClient() const { - return OpenAIAudioClient(this); + return OpenAIAudioClient(*this); } - OpenAIChatClient::OpenAIChatClient(gsl::not_null model) - : OpenAIChatClient(model->core_, model->info_.name, model->logger_) { - if (!model->IsLoaded()) { - throw FoundryLocalException("Model " + model->info_.name + " is not loaded. Call Load() first.", - *model->logger_); + OpenAIChatClient::OpenAIChatClient(const IModel& model) + : OpenAIChatClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw FoundryLocalException("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); } } OpenAIChatClient ModelVariant::GetChatClient() const { - return OpenAIChatClient(this); + return OpenAIChatClient(*this); } /// @@ -552,8 +556,8 @@ namespace FoundryLocal { return variants_; } - const ModelVariant* Model::GetLatestVariant(gsl::not_null variant) const { - const auto& targetName = variant->GetInfo().name; + const ModelVariant* Model::GetLatestVariant(const ModelVariant& variant) const { + const auto& targetName = variant.GetInfo().name; for (const auto& v : variants_) { if (v.GetInfo().name == targetName) { @@ -561,7 +565,7 @@ namespace FoundryLocal { } } - throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", + throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); } @@ -573,18 +577,22 @@ namespace FoundryLocal { return SelectedVariant().GetAlias(); } - void Model::SelectVariant(gsl::not_null variant) const { + void Model::SelectVariant(const ModelVariant& variant) const { auto it = std::find_if(variants_.begin(), variants_.end(), - [&](const ModelVariant& v) { return &v == variant.get(); }); + [&](const ModelVariant& v) { return &v == &variant; }); if (it == variants_.end()) { - throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant->GetId() + " variant.", + throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); } selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); } + IModel::CoreAccess Model::GetCoreAccess() const { + return SelectedVariant().GetCoreAccess(); + } + /// /// Catalog /// diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 602df3e6..71445b1b 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -43,7 +43,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_BasicResponse) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "Say hello", {}}}; ChatSettings settings; @@ -59,7 +59,7 @@ core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); -OpenAIChatClient client(&variant); +OpenAIChatClient client(variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -93,7 +93,7 @@ core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); -OpenAIChatClient client(&variant); +OpenAIChatClient client(variant); std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; ChatSettings settings; @@ -140,7 +140,7 @@ nlohmann::json chunk1 = { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -178,7 +178,7 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -192,13 +192,13 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) TEST_F(OpenAIChatClientTest, Constructor_ThrowsIfNotLoaded) { core_.OnCall("list_loaded_models", R"([])"); auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(OpenAIChatClient client(&variant), FoundryLocalException); + EXPECT_THROW(OpenAIChatClient client(variant), FoundryLocalException); } TEST_F(OpenAIChatClientTest, GetModelId) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); EXPECT_EQ("chat-model", client.GetModelId()); } @@ -209,7 +209,7 @@ core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); -OpenAIChatClient client(&variant); +OpenAIChatClient client(variant); std::vector messages = {{"user", "What is 7 * 6?", {}}}; @@ -258,7 +258,7 @@ core_.OnCall("chat_completions", MakeChatResponseJson()); core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); -OpenAIChatClient client(&variant); +OpenAIChatClient client(variant); std::vector messages = {{"user", "Hello", {}}}; ChatSettings settings; @@ -294,7 +294,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallResponse_Parsed) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "What is 7 * 6?", {}}}; ChatSettings settings; @@ -318,7 +318,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceAuto) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -336,7 +336,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceNone) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "test", {}}}; ChatSettings settings; @@ -354,7 +354,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); ChatMessage toolMsg; toolMsg.role = "tool"; @@ -418,7 +418,7 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIChatClient client(&variant); + OpenAIChatClient client(variant); std::vector messages = {{"user", "test", {}}}; @@ -463,7 +463,7 @@ core_.OnCall("audio_transcribe", "Hello world transcribed text"); core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); -OpenAIAudioClient client(&variant); +OpenAIAudioClient client(variant); auto response = client.TranscribeAudio("test.wav"); EXPECT_EQ("Hello world transcribed text", response.text); @@ -474,7 +474,7 @@ core_.OnCall("audio_transcribe", "text"); core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); -OpenAIAudioClient client(&variant); +OpenAIAudioClient client(variant); client.TranscribeAudio("audio.wav"); auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); @@ -498,7 +498,7 @@ TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(&variant); + OpenAIAudioClient client(variant); std::vector chunks; client.TranscribeAudioStreaming( @@ -522,7 +522,7 @@ TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackExcepti core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(&variant); + OpenAIAudioClient client(variant); EXPECT_THROW( client.TranscribeAudioStreaming( @@ -533,12 +533,12 @@ TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackExcepti TEST_F(OpenAIAudioClientTest, Constructor_ThrowsIfNotLoaded) { core_.OnCall("list_loaded_models", R"([])"); auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(OpenAIAudioClient client(&variant), FoundryLocalException); + EXPECT_THROW(OpenAIAudioClient client(variant), FoundryLocalException); } TEST_F(OpenAIAudioClientTest, GetModelId) { core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(&variant); + OpenAIAudioClient client(variant); EXPECT_EQ("audio-model", client.GetModelId()); } diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index 6c1835a6..929e7c87 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -208,7 +208,7 @@ TEST_F(ModelTest, SelectVariant) { Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); Factory::SetSelectedVariantIndex(model, 0); - const auto* v2 = &model.GetAllModelVariants()[1]; + const auto& v2 = model.GetAllModelVariants()[1]; model.SelectVariant(v2); EXPECT_EQ("v2:2", model.GetId()); } @@ -219,7 +219,7 @@ TEST_F(ModelTest, SelectVariant_NotFound_Throws) { Factory::SetSelectedVariantIndex(model, 0); auto external = MakeVariant("external", "alias", 1); - EXPECT_THROW(model.SelectVariant(&external), FoundryLocalException); + EXPECT_THROW(model.SelectVariant(external), FoundryLocalException); } TEST_F(ModelTest, GetLatestVariant) { @@ -228,10 +228,10 @@ TEST_F(ModelTest, GetLatestVariant) { Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); Factory::SetSelectedVariantIndex(model, 0); - const auto* first = &model.GetAllModelVariants()[0]; + const auto& first = model.GetAllModelVariants()[0]; const auto* latest = model.GetLatestVariant(first); // Should return the first one with matching name (which is variants_[0]) - EXPECT_EQ(first, latest); + EXPECT_EQ(&first, latest); } TEST_F(ModelTest, DelegationMethods) { From fd4c472068e8201eb639492f413f418c4fd281ff Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 13:48:26 -0700 Subject: [PATCH 09/18] rename namespace --- sdk/cpp/include/catalog.h | 22 ++--- sdk/cpp/include/configuration.h | 4 +- sdk/cpp/include/foundry_local.h | 1 - sdk/cpp/include/foundry_local_exception.h | 4 +- sdk/cpp/include/foundry_local_manager.h | 7 +- sdk/cpp/include/log_level.h | 6 +- sdk/cpp/include/logger.h | 4 +- sdk/cpp/include/model.h | 34 ++++---- sdk/cpp/include/openai/openai_audio_client.h | 10 +-- sdk/cpp/include/openai/openai_chat_client.h | 10 +-- sdk/cpp/include/openai/openai_tool_types.h | 4 +- sdk/cpp/sample/main.cpp | 14 ++-- sdk/cpp/src/core_interop_request.h | 4 +- sdk/cpp/src/foundry_local.cpp | 86 ++++++++++---------- sdk/cpp/src/foundry_local_internal_core.h | 6 +- sdk/cpp/src/parser.h | 4 +- sdk/cpp/test/catalog_test.cpp | 4 +- sdk/cpp/test/client_test.cpp | 4 +- sdk/cpp/test/mock_core.h | 4 +- sdk/cpp/test/mock_object_factory.h | 4 +- sdk/cpp/test/model_variant_test.cpp | 8 +- sdk/cpp/test/parser_and_types_test.cpp | 4 +- 22 files changed, 126 insertions(+), 122 deletions(-) diff --git a/sdk/cpp/include/catalog.h b/sdk/cpp/include/catalog.h index a9cd9185..2cb05a1c 100644 --- a/sdk/cpp/include/catalog.h +++ b/sdk/cpp/include/catalog.h @@ -14,11 +14,11 @@ #include "model.h" -namespace FoundryLocal::Internal { +namespace foundry_local::Internal { struct IFoundryLocalCore; } -namespace FoundryLocal { +namespace foundry_local { #ifdef FL_TESTS namespace Testing { struct MockObjectFactory; @@ -32,18 +32,18 @@ namespace FoundryLocal { Catalog(Catalog&&) = delete; Catalog& operator=(Catalog&&) = delete; - static std::unique_ptr Create(gsl::not_null core, + static std::unique_ptr Create(gsl::not_null core, gsl::not_null logger) { return std::unique_ptr(new Catalog(core, logger)); } const std::string& GetName() const { return name_; } - std::vector ListModels() const; - std::vector GetLoadedModels() const; - std::vector GetCachedModels() const; + std::vector ListModels() const; + std::vector GetLoadedModels() const; + std::vector GetCachedModels() const; - const Model* GetModel(std::string_view modelId) const; - const ModelVariant* GetModelVariant(std::string_view modelVariantId) const; + Model* GetModel(std::string_view modelId) const; + ModelVariant* GetModelVariant(std::string_view modelVariantId) const; private: void UpdateModels() const; @@ -53,10 +53,10 @@ namespace FoundryLocal { mutable std::unordered_map byAlias_; mutable std::unordered_map modelIdToModelVariant_; - explicit Catalog(gsl::not_null injected, + explicit Catalog(gsl::not_null injected, gsl::not_null logger); - gsl::not_null core_; + gsl::not_null core_; std::string name_; gsl::not_null logger_; @@ -66,4 +66,4 @@ namespace FoundryLocal { #endif }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/configuration.h b/sdk/cpp/include/configuration.h index 99c6a52c..21c40473 100644 --- a/sdk/cpp/include/configuration.h +++ b/sdk/cpp/include/configuration.h @@ -9,7 +9,7 @@ #include #include "log_level.h" -namespace FoundryLocal { +namespace foundry_local { /// Optional configuration for the built-in web service. struct WebServiceConfig { @@ -65,4 +65,4 @@ namespace FoundryLocal { } }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index f3c22526..5ee0dd6f 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // // Umbrella header – includes every public header for convenience. diff --git a/sdk/cpp/include/foundry_local_exception.h b/sdk/cpp/include/foundry_local_exception.h index 79648238..a352a3fb 100644 --- a/sdk/cpp/include/foundry_local_exception.h +++ b/sdk/cpp/include/foundry_local_exception.h @@ -8,7 +8,7 @@ #include "logger.h" -namespace FoundryLocal { +namespace foundry_local { class FoundryLocalException final : public std::runtime_error { public: @@ -19,4 +19,4 @@ namespace FoundryLocal { } }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h index 447894eb..00621b24 100644 --- a/sdk/cpp/include/foundry_local_manager.h +++ b/sdk/cpp/include/foundry_local_manager.h @@ -13,11 +13,11 @@ #include "logger.h" #include "catalog.h" -namespace FoundryLocal::Internal { +namespace foundry_local::Internal { struct IFoundryLocalCore; } -namespace FoundryLocal { +namespace foundry_local { class FoundryLocalManager final { public: @@ -30,6 +30,7 @@ namespace FoundryLocal { ~FoundryLocalManager(); const Catalog& GetCatalog() const; + Catalog& GetCatalog(); /// Start the optional built-in web service. /// Provides an OpenAI-compatible REST endpoint. @@ -61,4 +62,4 @@ namespace FoundryLocal { std::vector urls_; }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/log_level.h b/sdk/cpp/include/log_level.h index 304abca1..887189ec 100644 --- a/sdk/cpp/include/log_level.h +++ b/sdk/cpp/include/log_level.h @@ -5,9 +5,9 @@ #include -namespace FoundryLocal { +namespace foundry_local { - enum class LogLevel { +enum class LogLevel { Verbose, Debug, Information, @@ -34,4 +34,4 @@ namespace FoundryLocal { return "Unknown"; } -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/logger.h b/sdk/cpp/include/logger.h index 53922a91..d0b05b4e 100644 --- a/sdk/cpp/include/logger.h +++ b/sdk/cpp/include/logger.h @@ -5,7 +5,7 @@ #include #include "log_level.h" -namespace FoundryLocal { +namespace foundry_local { class ILogger { public: virtual ~ILogger() = default; @@ -16,4 +16,4 @@ namespace FoundryLocal { public: void Log(LogLevel, std::string_view) noexcept override {} }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index a591ba20..a5008ff8 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -18,11 +18,11 @@ #include "openai/openai_chat_client.h" #include "openai/openai_audio_client.h" -namespace FoundryLocal::Internal { +namespace foundry_local::Internal { struct IFoundryLocalCore; } -namespace FoundryLocal { +namespace foundry_local { #ifdef FL_TESTS namespace Testing { struct MockObjectFactory; @@ -40,9 +40,9 @@ namespace FoundryLocal { virtual bool IsLoaded() const = 0; virtual bool IsCached() const = 0; virtual const std::filesystem::path& GetPath() const = 0; - virtual void Download(DownloadProgressCallback onProgress = nullptr) const = 0; - virtual void Load() const = 0; - virtual void Unload() const = 0; + virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0; + virtual void Load() = 0; + virtual void Unload() = 0; virtual void RemoveFromCache() = 0; protected: @@ -117,12 +117,12 @@ namespace FoundryLocal { public: const ModelInfo& GetInfo() const; const std::filesystem::path& GetPath() const override; - void Download(DownloadProgressCallback onProgress = nullptr) const override; - void Load() const override; + void Download(DownloadProgressCallback onProgress = nullptr) override; + void Load() override; bool IsLoaded() const override; bool IsCached() const override; - void Unload() const override; + void Unload() override; void RemoveFromCache() override; [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] @@ -140,12 +140,12 @@ namespace FoundryLocal { private: static std::string MakeModelParamRequest(std::string_view modelId); - explicit ModelVariant(gsl::not_null core, ModelInfo info, + explicit ModelVariant(gsl::not_null core, ModelInfo info, gsl::not_null logger); ModelInfo info_; mutable std::filesystem::path cachedPath_; - gsl::not_null core_; + gsl::not_null core_; gsl::not_null logger_; friend class Catalog; @@ -158,16 +158,16 @@ namespace FoundryLocal { class Model final : public IModel { public: gsl::span GetAllModelVariants() const; - const ModelVariant* GetLatestVariant(const ModelVariant& variant) const; + const ModelVariant& GetLatestVersion(const ModelVariant& variant) const; bool IsLoaded() const override { return SelectedVariant().IsLoaded(); } bool IsCached() const override { return SelectedVariant().IsCached(); } const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); } - void Download(DownloadProgressCallback onProgress = nullptr) const override { + void Download(DownloadProgressCallback onProgress = nullptr) override { SelectedVariant().Download(std::move(onProgress)); } - void Load() const override { SelectedVariant().Load(); } - void Unload() const override { SelectedVariant().Unload(); } + void Load() override { SelectedVariant().Load(); } + void Unload() override { SelectedVariant().Unload(); } void RemoveFromCache() override { SelectedVariant().RemoveFromCache(); } [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] OpenAIAudioClient GetAudioClient() const { @@ -187,11 +187,11 @@ namespace FoundryLocal { CoreAccess GetCoreAccess() const override; private: - explicit Model(gsl::not_null core, gsl::not_null logger); + explicit Model(gsl::not_null core, gsl::not_null logger); ModelVariant& SelectedVariant(); const ModelVariant& SelectedVariant() const; - gsl::not_null core_; + gsl::not_null core_; std::vector variants_; mutable std::optional selectedVariantIndex_; @@ -203,4 +203,4 @@ namespace FoundryLocal { #endif }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h index cd5f52e4..876f93bd 100644 --- a/sdk/cpp/include/openai/openai_audio_client.h +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -10,11 +10,11 @@ #include -namespace FoundryLocal::Internal { +namespace foundry_local::Internal { struct IFoundryLocalCore; } -namespace FoundryLocal { +namespace foundry_local { class ILogger; class IModel; @@ -35,15 +35,15 @@ class IModel; void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; private: - OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + OpenAIAudioClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger); std::string modelId_; - gsl::not_null core_; + gsl::not_null core_; gsl::not_null logger_; }; /// Backward-compatible alias. using AudioClient = OpenAIAudioClient; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h index 20d97fae..49ce0fe8 100644 --- a/sdk/cpp/include/openai/openai_chat_client.h +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -16,11 +16,11 @@ #include "openai_tool_types.h" -namespace FoundryLocal::Internal { +namespace foundry_local::Internal { struct IFoundryLocalCore; } -namespace FoundryLocal { +namespace foundry_local { class ILogger; class IModel; @@ -103,18 +103,18 @@ enum class FinishReason { const ChatSettings& settings, const StreamCallback& onChunk) const; private: - OpenAIChatClient(gsl::not_null core, std::string_view modelId, + OpenAIChatClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger); std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, const ChatSettings& settings, bool stream) const; std::string modelId_; - gsl::not_null core_; + gsl::not_null core_; gsl::not_null logger_; }; /// Backward-compatible alias. using ChatClient = OpenAIChatClient; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_tool_types.h b/sdk/cpp/include/openai/openai_tool_types.h index 8ae4a068..4130d2b7 100644 --- a/sdk/cpp/include/openai/openai_tool_types.h +++ b/sdk/cpp/include/openai/openai_tool_types.h @@ -8,7 +8,7 @@ #include #include -namespace FoundryLocal { +namespace foundry_local { /// JSON Schema property definition used to describe tool function parameters. struct PropertyDefinition { @@ -51,4 +51,4 @@ namespace FoundryLocal { Required }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index dc1327d9..8efe78fc 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -8,7 +8,7 @@ #include #include -using namespace FoundryLocal; +using namespace foundry_local; // --------------------------------------------------------------------------- // Logger @@ -73,7 +73,7 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { auto& catalog = manager.GetCatalog(); auto models = catalog.ListModels(); - const auto* model = catalog.GetModel(alias); + auto* model = catalog.GetModel(alias); if (!model) { std::cerr << "Model '" << alias << "' not found in catalog.\n"; return; @@ -119,7 +119,7 @@ void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { auto& catalog = manager.GetCatalog(); catalog.ListModels(); - const auto* model = catalog.GetModel(alias); + auto* model = catalog.GetModel(alias); if (!model) { std::cerr << "Model '" << alias << "' not found in catalog.\n"; return; @@ -161,7 +161,7 @@ void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, con auto& catalog = manager.GetCatalog(); catalog.ListModels(); - const auto* model = catalog.GetModel(alias); + auto* model = catalog.GetModel(alias); if (!model) { std::cerr << "Model '" << alias << "' not found in catalog.\n"; return; @@ -210,7 +210,7 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) auto& catalog = manager.GetCatalog(); catalog.ListModels(); - const auto* model = catalog.GetModel(alias); + auto* model = catalog.GetModel(alias); if (!model) { std::cerr << "Model '" << alias << "' not found in catalog.\n"; return; @@ -328,7 +328,7 @@ void InspectVariants(FoundryLocalManager& manager, const std::string& alias) { auto& catalog = manager.GetCatalog(); catalog.ListModels(); - const auto* model = catalog.GetModel(alias); + auto* model = catalog.GetModel(alias); if (!model) { std::cerr << "Model '" << alias << "' not found in catalog.\n"; return; @@ -400,4 +400,4 @@ int main() { std::cerr << "Fatal: " << ex.what() << std::endl; return 1; } -} \ No newline at end of file +} diff --git a/sdk/cpp/src/core_interop_request.h b/sdk/cpp/src/core_interop_request.h index bb35d324..67ef1590 100644 --- a/sdk/cpp/src/core_interop_request.h +++ b/sdk/cpp/src/core_interop_request.h @@ -7,7 +7,7 @@ #include #include -namespace FoundryLocal { +namespace foundry_local { class CoreInteropRequest final { public: @@ -43,4 +43,4 @@ namespace FoundryLocal { nlohmann::json params_; }; -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp index b81d70a2..c776d640 100644 --- a/sdk/cpp/src/foundry_local.cpp +++ b/sdk/cpp/src/foundry_local.cpp @@ -39,47 +39,47 @@ namespace { } // Serialize + call - inline std::string CallWithJson(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& requestJson, FoundryLocal::ILogger& logger) { + inline std::string CallWithJson(foundry_local::Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, foundry_local::ILogger& logger) { std::string payload = requestJson.dump(); return core->call(command, logger, &payload); } // Serialize + call with native callback - inline std::string CallWithJsonAndCallback(FoundryLocal::Internal::IFoundryLocalCore* core, - std::string_view command, const nlohmann::json& requestJson, - FoundryLocal::ILogger& logger, void* callback, void* userData) { + inline std::string CallWithJsonAndCallback(foundry_local::Internal::IFoundryLocalCore* core, + std::string_view command, const nlohmann::json& requestJson, + foundry_local::ILogger& logger, void* callback, void* userData) { std::string payload = requestJson.dump(); return core->call(command, logger, &payload, callback, userData); } // Overload: allow Params object directly - inline std::string CallWithParams(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& params, FoundryLocal::ILogger& logger) { + inline std::string CallWithParams(foundry_local::Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& params, foundry_local::ILogger& logger) { return CallWithJson(core, command, MakeParams(params), logger); } // Overload: no payload - inline std::string CallNoArgs(FoundryLocal::Internal::IFoundryLocalCore* core, std::string_view command, - FoundryLocal::ILogger& logger) { + inline std::string CallNoArgs(foundry_local::Internal::IFoundryLocalCore* core, std::string_view command, + foundry_local::ILogger& logger) { return core->call(command, logger, nullptr); } - std::vector GetLoadedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, - FoundryLocal::ILogger& logger) { + std::vector GetLoadedModelsInternal(foundry_local::Internal::IFoundryLocalCore* core, + foundry_local::ILogger& logger) { std::string raw = core->call("list_loaded_models", logger); try { auto parsed = nlohmann::json::parse(raw); return parsed.get>(); } catch (const nlohmann::json::exception& e) { - throw FoundryLocal::FoundryLocalException( + throw foundry_local::FoundryLocalException( "Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); } } - std::vector GetCachedModelsInternal(FoundryLocal::Internal::IFoundryLocalCore* core, - FoundryLocal::ILogger& logger) { + std::vector GetCachedModelsInternal(foundry_local::Internal::IFoundryLocalCore* core, + foundry_local::ILogger& logger) { std::string raw = core->call("get_cached_models", logger); try { @@ -87,15 +87,15 @@ namespace { return parsed.get>(); } catch (const nlohmann::json::exception& e) { - throw FoundryLocal::FoundryLocalException( + throw foundry_local::FoundryLocalException( "Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); } } - std::vector CollectVariantsByIds( - const std::unordered_map& modelIdToModelVariant, + std::vector CollectVariantsByIds( + std::unordered_map& modelIdToModelVariant, std::vector ids) { - std::vector out; + std::vector out; out.reserve(ids.size()); for (const auto& id : ids) { @@ -109,14 +109,14 @@ namespace { } // namespace -namespace FoundryLocal { - inline static void* RequireProc(HMODULE mod, const char* name) { +namespace foundry_local { +inline static void* RequireProc(HMODULE mod, const char* name) { if (void* p = ::GetProcAddress(mod, name)) return p; throw std::runtime_error(std::string("GetProcAddress failed for ") + name); } - struct Core : FoundryLocal::Internal::IFoundryLocalCore { + struct Core : foundry_local::Internal::IFoundryLocalCore { using ResponseHandle = std::unique_ptr; Core() = default; @@ -200,7 +200,7 @@ namespace FoundryLocal { /// OpenAIAudioClient /// - OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger) : core_(core), modelId_(modelId), logger_(logger) {} @@ -276,7 +276,7 @@ namespace FoundryLocal { /// OpenAIChatClient /// - OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, + OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, gsl::not_null logger) : core_(core), modelId_(modelId), logger_(logger) {} @@ -341,7 +341,7 @@ namespace FoundryLocal { std::string json = req.ToJson(); std::string rawResult = core_->call(req.Command(), *logger_, &json); - return nlohmann::json::parse(rawResult).get(); + return nlohmann::json::parse(rawResult).get(); } void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, @@ -373,7 +373,7 @@ namespace FoundryLocal { std::string s(static_cast(data), static_cast(len)); try { - auto parsed = nlohmann::json::parse(s).get(); + auto parsed = nlohmann::json::parse(s).get(); (*(st->cb))(parsed); } @@ -398,7 +398,7 @@ namespace FoundryLocal { /// ModelVariant /// - ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, + ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, gsl::not_null logger) : core_(core), info_(std::move(info)), logger_(logger) {} @@ -416,7 +416,7 @@ namespace FoundryLocal { } } - void ModelVariant::Unload() const { + void ModelVariant::Unload() { try { CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); } @@ -446,7 +446,7 @@ namespace FoundryLocal { return false; } - void ModelVariant::Download(DownloadProgressCallback onProgress) const { + void ModelVariant::Download(DownloadProgressCallback onProgress) { if (IsCached()) { logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); return; @@ -480,7 +480,7 @@ namespace FoundryLocal { } } - void ModelVariant::Load() const { + void ModelVariant::Load() { CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); } @@ -535,7 +535,7 @@ namespace FoundryLocal { /// /// Model /// - Model::Model(gsl::not_null core, gsl::not_null logger) + Model::Model(gsl::not_null core, gsl::not_null logger) : core_(core), logger_(logger) {} ModelVariant& Model::SelectedVariant() { @@ -556,12 +556,12 @@ namespace FoundryLocal { return variants_; } - const ModelVariant* Model::GetLatestVariant(const ModelVariant& variant) const { + const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { const auto& targetName = variant.GetInfo().name; for (const auto& v : variants_) { if (v.GetInfo().name == targetName) { - return &v; + return v; } } @@ -597,7 +597,7 @@ namespace FoundryLocal { /// Catalog /// - Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) + Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) : core_(injected), logger_(logger) { try { name_ = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); @@ -607,15 +607,15 @@ namespace FoundryLocal { } } - std::vector Catalog::GetLoadedModels() const { + std::vector Catalog::GetLoadedModels() const { return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); } - std::vector Catalog::GetCachedModels() const { + std::vector Catalog::GetCachedModels() const { return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); } - const Model* Catalog::GetModel(std::string_view modelId) const { + Model* Catalog::GetModel(std::string_view modelId) const { auto it = byAlias_.find(std::string(modelId)); if (it != byAlias_.end()) { return &it->second; @@ -623,10 +623,10 @@ namespace FoundryLocal { return nullptr; } - std::vector Catalog::ListModels() const { + std::vector Catalog::ListModels() const { UpdateModels(); - std::vector out; + std::vector out; out.reserve(byAlias_.size()); for (auto& kv : byAlias_) out.emplace_back(&kv.second); @@ -681,7 +681,7 @@ namespace FoundryLocal { lastFetch_ = now; } - const ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + ModelVariant* Catalog::GetModelVariant(std::string_view id) const { auto it = modelIdToModelVariant_.find(std::string(id)); if (it != modelIdToModelVariant_.end()) { return &it->second; @@ -724,7 +724,7 @@ namespace FoundryLocal { if (catalog_) { try { auto loadedModels = catalog_->GetLoadedModels(); - for (const auto* variant : loadedModels) { + for (auto* variant : loadedModels) { try { variant->Unload(); } @@ -755,6 +755,10 @@ namespace FoundryLocal { return *catalog_; } + Catalog& FoundryLocalManager::GetCatalog() { + return *catalog_; + } + void FoundryLocalManager::StartWebService() { if (!config_.web) { throw FoundryLocalException("Web service configuration was not provided.", *logger_); @@ -849,4 +853,4 @@ namespace FoundryLocal { } } -} // namespace FoundryLocal +} // namespace foundry_local diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index aa702e3c..c5ee81f2 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -7,8 +7,8 @@ #include #include "logger.h" -namespace FoundryLocal { - namespace Internal { +namespace foundry_local { +namespace Internal { struct IFoundryLocalCore { virtual ~IFoundryLocalCore() = default; @@ -19,4 +19,4 @@ namespace FoundryLocal { }; } // namespace Internal -} // namespace FoundryLocal \ No newline at end of file +} // namespace foundry_local \ No newline at end of file diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 555d5078..7930b7cb 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -7,7 +7,7 @@ #include "foundry_local.h" #include -namespace FoundryLocal { +namespace foundry_local { inline DeviceType parse_device_type(std::string_view v) { if (v == "CPU") { return DeviceType::CPU; @@ -289,4 +289,4 @@ namespace FoundryLocal { return "auto"; } -} // namespace FoundryLocal \ No newline at end of file +} // namespace foundry_local \ No newline at end of file diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp index b25f9457..5ed41e5d 100644 --- a/sdk/cpp/test/catalog_test.cpp +++ b/sdk/cpp/test/catalog_test.cpp @@ -14,8 +14,8 @@ #include -using namespace FoundryLocal; -using namespace FoundryLocal::Testing; +using namespace foundry_local; +using namespace foundry_local::Testing; using Factory = MockObjectFactory; diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 71445b1b..bfe6bffb 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -10,8 +10,8 @@ #include -using namespace FoundryLocal; -using namespace FoundryLocal::Testing; +using namespace foundry_local; +using namespace foundry_local::Testing; using Factory = MockObjectFactory; diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h index 16dc5991..c2b4f4e6 100644 --- a/sdk/cpp/test/mock_core.h +++ b/sdk/cpp/test/mock_core.h @@ -14,7 +14,7 @@ #include "foundry_local_internal_core.h" #include "logger.h" -namespace FoundryLocal::Testing { +namespace foundry_local::Testing { /// A mock implementation of IFoundryLocalCore for unit testing. /// Register expected command -> response mappings before use. @@ -148,4 +148,4 @@ namespace FoundryLocal::Testing { std::string loadedModelsPath_; }; -} // namespace FoundryLocal::Testing +} // namespace foundry_local::Testing diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h index 2addbd67..6a3d3d14 100644 --- a/sdk/cpp/test/mock_object_factory.h +++ b/sdk/cpp/test/mock_object_factory.h @@ -11,7 +11,7 @@ #include "foundry_local_internal_core.h" #include "logger.h" -namespace FoundryLocal::Testing { +namespace foundry_local::Testing { /// Factory to construct private-constructor types for testing. /// Declared as a friend (Testing::MockObjectFactory) in ModelVariant, Model, and Catalog when FL_TESTS is defined. @@ -62,4 +62,4 @@ namespace FoundryLocal::Testing { } }; -} // namespace FoundryLocal::Testing +} // namespace foundry_local::Testing diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index 929e7c87..1da0d2a1 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -10,8 +10,8 @@ #include -using namespace FoundryLocal; -using namespace FoundryLocal::Testing; +using namespace foundry_local; +using namespace foundry_local::Testing; using Factory = MockObjectFactory; @@ -229,9 +229,9 @@ TEST_F(ModelTest, GetLatestVariant) { Factory::SetSelectedVariantIndex(model, 0); const auto& first = model.GetAllModelVariants()[0]; - const auto* latest = model.GetLatestVariant(first); + const auto& latest = model.GetLatestVersion(first); // Should return the first one with matching name (which is variants_[0]) - EXPECT_EQ(&first, latest); + EXPECT_EQ(&first, &latest); } TEST_F(ModelTest, DelegationMethods) { diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp index 9703c8b6..38d5c992 100644 --- a/sdk/cpp/test/parser_and_types_test.cpp +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -11,8 +11,8 @@ #include -using namespace FoundryLocal; -using namespace FoundryLocal::Testing; +using namespace foundry_local; +using namespace foundry_local::Testing; class ParserTest : public ::testing::Test { protected: From 70fa791fdf5d2ba537b0ed0f549f573f1a9c2453 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 17:11:42 -0700 Subject: [PATCH 10/18] Resolve more comments --- sdk/cpp/CMakeLists.txt | 14 +- sdk/cpp/include/foundry_local.h | 8 +- sdk/cpp/include/foundry_local_exception.h | 9 +- sdk/cpp/sample/main.cpp | 82 +-- sdk/cpp/src/catalog.cpp | 115 +++ sdk/cpp/src/core.h | 114 +++ sdk/cpp/src/core_helpers.h | 146 ++++ sdk/cpp/src/foundry_local.cpp | 856 ---------------------- sdk/cpp/src/foundry_local_internal_core.h | 21 +- sdk/cpp/src/foundry_local_manager.cpp | 178 +++++ sdk/cpp/src/model.cpp | 212 ++++++ sdk/cpp/src/openai_audio_client.cpp | 69 ++ sdk/cpp/src/openai_chat_client.cpp | 139 ++++ sdk/cpp/src/parser.h | 207 +++--- sdk/cpp/test/catalog_test.cpp | 12 +- sdk/cpp/test/client_test.cpp | 33 +- sdk/cpp/test/mock_core.h | 61 +- sdk/cpp/test/model_variant_test.cpp | 25 +- sdk/cpp/test/parser_and_types_test.cpp | 54 +- 19 files changed, 1223 insertions(+), 1132 deletions(-) create mode 100644 sdk/cpp/src/catalog.cpp create mode 100644 sdk/cpp/src/core.h create mode 100644 sdk/cpp/src/core_helpers.h delete mode 100644 sdk/cpp/src/foundry_local.cpp create mode 100644 sdk/cpp/src/foundry_local_manager.cpp create mode 100644 sdk/cpp/src/model.cpp create mode 100644 sdk/cpp/src/openai_audio_client.cpp create mode 100644 sdk/cpp/src/openai_chat_client.cpp diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index 1ef1194a..f35f2ee9 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -1,4 +1,6 @@ -cmake_minimum_required(VERSION 3.20) +cmake_minimum_required(VERSION 3.20) + +cmake_minimum_required(VERSION 3.20) # VS hot reload policy (safe-guarded) if (POLICY CMP0141) @@ -85,11 +87,11 @@ FetchContent_MakeAvailable(googletest) # List ONLY .cpp files here. # ----------------------------- add_library(CppSdk STATIC - src/foundry_local.cpp - # Add more .cpp files as you migrate: - # src/parser.cpp - # src/dllmain.cpp - # src/pch.cpp + src/model.cpp + src/catalog.cpp + src/openai_chat_client.cpp + src/openai_audio_client.cpp + src/foundry_local_manager.cpp ) target_include_directories(CppSdk diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h index 5ee0dd6f..c16337e1 100644 --- a/sdk/cpp/include/foundry_local.h +++ b/sdk/cpp/include/foundry_local.h @@ -2,13 +2,7 @@ // Licensed under the MIT License. // // Umbrella header – includes every public header for convenience. -// Consumers may also include individual headers directly: -// #include "model.h" -// #include "catalog.h" -// #include "foundry_local_manager.h" -// #include "openai/openai_tool_types.h" -// #include "openai/openai_chat_client.h" -// #include "openai/openai_audio_client.h" +// Consumers may also include individual headers directly. #pragma once diff --git a/sdk/cpp/include/foundry_local_exception.h b/sdk/cpp/include/foundry_local_exception.h index a352a3fb..1dba9119 100644 --- a/sdk/cpp/include/foundry_local_exception.h +++ b/sdk/cpp/include/foundry_local_exception.h @@ -10,13 +10,16 @@ namespace foundry_local { - class FoundryLocalException final : public std::runtime_error { + class Exception final : public std::runtime_error { public: - explicit FoundryLocalException(std::string message) : std::runtime_error(std::move(message)) {} + explicit Exception(std::string message) : std::runtime_error(std::move(message)) {} - FoundryLocalException(std::string message, ILogger& logger) : std::runtime_error(std::move(message)) { + Exception(std::string message, ILogger& logger) : std::runtime_error(std::move(message)) { logger.Log(LogLevel::Error, what()); } }; + // Backward compatibility alias. + using FoundryLocalException = Exception; + } // namespace foundry_local diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 8efe78fc..e1ed84be 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -54,11 +54,25 @@ void BrowseCatalog(FoundryLocalManager& manager) { for (const auto& variant : model->GetAllModelVariants()) { const auto& info = variant.GetInfo(); - std::cout << " variant: " << info.name << " v" << info.version; - if (info.runtime) - std::cout << " device=" << (info.runtime->device_type == DeviceType::GPU ? "GPU" : "CPU"); + std::cout << " variant: " << info.name << " v" << info.version + << " cached=" << (variant.IsCached() ? "yes" : "no"); + if (info.display_name) + std::cout << " display=\"" << *info.display_name << "\""; + if (info.publisher) + std::cout << " publisher=" << *info.publisher; + if (info.license) + std::cout << " license=" << *info.license; + if (info.runtime) { + std::cout << " device=" + << (info.runtime->device_type == DeviceType::GPU ? "GPU" + : info.runtime->device_type == DeviceType::NPU ? "NPU" + : "CPU") + << " ep=" << info.runtime->execution_provider; + } if (info.file_size_mb) std::cout << " size=" << *info.file_size_mb << "MB"; + if (info.supports_tool_calling) + std::cout << " tools=" << (*info.supports_tool_calling ? "yes" : "no"); std::cout << "\n"; } } @@ -93,8 +107,7 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { OpenAIChatClient chat(*model); - std::vector messages = {{"system", "You are a helpful assistant. Keep answers brief."}, - {"user", "What is the capital of Croatia?"}}; + std::vector messages = {{"user", "What is the capital of Croatia?"}}; ChatSettings settings; settings.temperature = 0.7f; @@ -117,7 +130,7 @@ void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { std::cout << "\n=== Example 3: Streaming Chat ===\n"; auto& catalog = manager.GetCatalog(); - catalog.ListModels(); + auto* model = catalog.GetModel(alias); if (!model) { @@ -143,9 +156,6 @@ void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { if (choice.delta && !choice.delta->content.empty()) { std::cout << choice.delta->content << std::flush; } - else if (choice.message && !choice.message->content.empty()) { - std::cout << choice.message->content << std::flush; - } }); std::cout << "\n"; @@ -159,7 +169,6 @@ void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, con std::cout << "\n=== Example 4: Audio Transcription ===\n"; auto& catalog = manager.GetCatalog(); - catalog.ListModels(); auto* model = catalog.GetModel(alias); if (!model) { @@ -208,7 +217,6 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) std::cout << "\n=== Example 5: Tool Calling ===\n"; auto& catalog = manager.GetCatalog(); - catalog.ListModels(); auto* model = catalog.GetModel(alias); if (!model) { @@ -319,55 +327,6 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) std::cout << "Model unloaded.\n"; } -// --------------------------------------------------------------------------- -// Example 6 – Model variant inspection & selection -// --------------------------------------------------------------------------- -void InspectVariants(FoundryLocalManager& manager, const std::string& alias) { - std::cout << "\n=== Example 6: Variant Inspection ===\n"; - - auto& catalog = manager.GetCatalog(); - catalog.ListModels(); - - auto* model = catalog.GetModel(alias); - if (!model) { - std::cerr << "Model '" << alias << "' not found in catalog.\n"; - return; - } - - auto variants = model->GetAllModelVariants(); - std::cout << "Model '" << alias << "' has " << variants.size() << " variant(s):\n"; - - for (const auto& v : variants) { - const auto& info = v.GetInfo(); - std::cout << " " << info.name << " v" << info.version << " cached=" << (v.IsCached() ? "yes" : "no"); - if (info.display_name) - std::cout << " display=\"" << *info.display_name << "\""; - if (info.publisher) - std::cout << " publisher=" << *info.publisher; - if (info.license) - std::cout << " license=" << *info.license; - if (info.runtime) { - std::cout << " device=" - << (info.runtime->device_type == DeviceType::GPU ? "GPU" - : info.runtime->device_type == DeviceType::NPU ? "NPU" - : "CPU") - << " ep=" << info.runtime->execution_provider; - } - if (info.supports_tool_calling) - std::cout << " tools=" << (*info.supports_tool_calling ? "yes" : "no"); - std::cout << "\n"; - } - - // Select a specific variant by pointer (e.g. prefer the GPU variant) - for (const auto& v : variants) { - if (v.GetInfo().runtime && v.GetInfo().runtime->device_type == DeviceType::GPU) { - model->SelectVariant(v); - std::cout << "Selected GPU variant: " << model->GetId() << "\n"; - break; - } - } -} - // --------------------------------------------------------------------------- // main // --------------------------------------------------------------------------- @@ -391,9 +350,6 @@ int main() { // 5. Tool calling (define tools, let the model call them, feed results back) ChatWithToolCalling(manager, "phi-3.5-mini"); - // 6. Inspect model variants and select one - InspectVariants(manager, "phi-3.5-mini"); - return 0; } catch (const std::exception& ex) { diff --git a/sdk/cpp/src/catalog.cpp b/sdk/cpp/src/catalog.cpp new file mode 100644 index 00000000..0de167f9 --- /dev/null +++ b/sdk/cpp/src/catalog.cpp @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_helpers.h" +#include "parser.h" +#include "logger.h" + +namespace foundry_local { + +using namespace detail; + +Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) + : core_(injected), logger_(logger) { + auto response = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); + if (response.HasError()) { + throw Exception(std::string("Error getting catalog name: ") + response.error, *logger_); + } + name_ = std::move(response.data); +} + +std::vector Catalog::GetLoadedModels() const { + return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); +} + +std::vector Catalog::GetCachedModels() const { + return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); +} + +Model* Catalog::GetModel(std::string_view modelId) const { + auto it = byAlias_.find(std::string(modelId)); + if (it != byAlias_.end()) { + return &it->second; + } + return nullptr; +} + +std::vector Catalog::ListModels() const { + UpdateModels(); + + std::vector out; + out.reserve(byAlias_.size()); + for (auto& kv : byAlias_) + out.emplace_back(&kv.second); + + return out; +} + +void Catalog::UpdateModels() const { + using clock = std::chrono::steady_clock; + + // TODO: make this configurable + constexpr auto kRefreshInterval = std::chrono::hours(6); + + const auto now = clock::now(); + if (lastFetch_.time_since_epoch() != clock::duration::zero() && (now - lastFetch_) < kRefreshInterval) { + return; + } + + const auto response = core_->call("get_model_list", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error getting model list: ") + response.error, *logger_); + } + const auto arr = nlohmann::json::parse(response.data); + + byAlias_.clear(); + modelIdToModelVariant_.clear(); + + for (const auto& j : arr) { + const std::string alias = j.at("alias").get(); + + auto it = byAlias_.find(alias); + if (it == byAlias_.end()) { + Model m(core_, logger_); + it = byAlias_.emplace(alias, std::move(m)).first; + } + + ModelInfo modelVariantInfo; + from_json(j, modelVariantInfo); + std::string variantId = modelVariantInfo.id; + ModelVariant modelVariant(core_, modelVariantInfo, logger_); + modelIdToModelVariant_.emplace(variantId, modelVariant); + + it->second.variants_.emplace_back(std::move(modelVariant)); + } + + // Auto-select the first variant for each model. + for (auto& [alias, model] : byAlias_) { + if (!model.variants_.empty()) { + model.selectedVariantIndex_ = 0; + } + } + + lastFetch_ = now; +} + +ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + auto it = modelIdToModelVariant_.find(std::string(id)); + if (it != modelIdToModelVariant_.end()) { + return &it->second; + } + return nullptr; +} + +} // namespace foundry_local diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h new file mode 100644 index 00000000..d0f3b682 --- /dev/null +++ b/sdk/cpp/src/core.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Core DLL interop – loads Microsoft.AI.Foundry.Local.Core.dll at runtime. +// Internal header, not part of the public API. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "flcore_native.h" +#include "logger.h" + +namespace foundry_local { + +namespace { +inline std::filesystem::path getExecutableDir() { + auto exePath = wil::GetModuleFileNameW(nullptr); + return std::filesystem::path(exePath.get()).parent_path(); +} + +inline void* RequireProc(HMODULE mod, const char* name) { + if (void* p = ::GetProcAddress(mod, name)) + return p; + throw std::runtime_error(std::string("GetProcAddress failed for ") + name); +} +} // namespace + +struct Core : Internal::IFoundryLocalCore { + using ResponseHandle = std::unique_ptr; + + Core() = default; + ~Core() = default; + + void loadEmbedded() { loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); } + + void unload() override { + module_.reset(); + execCmd_ = nullptr; + execCbCmd_ = nullptr; + freeResCmd_ = nullptr; + } + + CoreResponse call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, + NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { + throw Exception("Core is not loaded. Cannot call command: " + std::string(command), logger); + } + + RequestBuffer request{}; + request.Command = command.empty() ? nullptr : command.data(); + request.CommandLength = static_cast(command.size()); + + if (dataArgument && !dataArgument->empty()) { + request.Data = dataArgument->data(); + request.DataLength = static_cast(dataArgument->size()); + } + + ResponseBuffer response{}; + auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { + if (fn) + fn(buf); + }; + std::unique_ptr responseGuard(&response, safeDeleter); + + if (callback != nullptr) { + execCbCmd_(&request, &response, reinterpret_cast(callback), data); + } + else { + execCmd_(&request, &response); + } + + CoreResponse result; + if (response.Error && response.ErrorLength > 0) { + result.error.assign(static_cast(response.Error), response.ErrorLength); + return result; + } + + if (response.Data && response.DataLength > 0) { + result.data.assign(static_cast(response.Data), response.DataLength); + } + + return result; + } + +private: + wil::unique_hmodule module_; + execute_command_fn execCmd_{}; + execute_command_with_callback_fn execCbCmd_{}; + free_response_fn freeResCmd_{}; + + void loadFromPath(const std::filesystem::path& path) { + wil::unique_hmodule m(::LoadLibraryW(path.c_str())); + if (!m) + throw std::runtime_error("LoadLibraryW failed"); + + execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); + execCbCmd_ = reinterpret_cast( + RequireProc(m.get(), "execute_command_with_callback")); + freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); + + module_ = std::move(m); + } +}; + +} // namespace foundry_local diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h new file mode 100644 index 00000000..d35d87b4 --- /dev/null +++ b/sdk/cpp/src/core_helpers.h @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Internal helpers shared across implementation files. +// Not part of the public API. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "logger.h" +#include "model.h" + +namespace foundry_local::detail { + +// Wrap Params: { ... } into a request object +inline nlohmann::json MakeParams(nlohmann::json params) { + return nlohmann::json{{"Params", std::move(params)}}; +} + +// Most common: Params { "Model": } +inline nlohmann::json MakeModelParams(std::string_view model) { + return MakeParams(nlohmann::json{{"Model", std::string(model)}}); +} + +// Serialize + call +inline CoreResponse CallWithJson(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, ILogger& logger) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload); +} + +// Serialize + call with native callback +inline CoreResponse CallWithJsonAndCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, ILogger& logger, + NativeCallbackFn callback, void* userData) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload, callback, userData); +} + +// Serialize + call with a streaming chunk handler. +// Wraps the caller-supplied onChunk with the native callback boilerplate +// (null/length checks, exception capture, rethrow after the call). +// The errorContext string is used to prefix any core-layer error message. +inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext) { + struct State { + const std::function* cb; + std::exception_ptr exception; + } state{&onChunk, nullptr}; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + try { + std::string chunk(static_cast(data), static_cast(len)); + (*(st->cb))(chunk); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + auto response = core->call(command, logger, &payload, +nativeCallback, &state); + if (response.HasError()) { + throw Exception(std::string(errorContext) + response.error, logger); + } + + if (state.exception) { + std::rethrow_exception(state.exception); + } + + return response; +} + +// Overload: allow Params object directly +inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& params, ILogger& logger) { + return CallWithJson(core, command, MakeParams(params), logger); +} + +// Overload: no payload +inline CoreResponse CallNoArgs(Internal::IFoundryLocalCore* core, std::string_view command, ILogger& logger) { + return core->call(command, logger, nullptr); +} + +inline std::vector GetLoadedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { + auto response = core->call("list_loaded_models", logger); + if (response.HasError()) { + throw Exception("Failed to get loaded models: " + response.error, logger); + } + try { + auto parsed = nlohmann::json::parse(response.data); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw Exception("Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); + } +} + +inline std::vector GetCachedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { + auto response = core->call("get_cached_models", logger); + if (response.HasError()) { + throw Exception("Failed to get cached models: " + response.error, logger); + } + + try { + auto parsed = nlohmann::json::parse(response.data); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw Exception("Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); + } +} + +inline std::vector CollectVariantsByIds( + std::unordered_map& modelIdToModelVariant, std::vector ids) { + std::vector out; + out.reserve(ids.size()); + + for (const auto& id : ids) { + auto it = modelIdToModelVariant.find(id); + if (it != modelIdToModelVariant.end()) { + out.emplace_back(&it->second); + } + } + return out; +} + +} // namespace foundry_local::detail diff --git a/sdk/cpp/src/foundry_local.cpp b/sdk/cpp/src/foundry_local.cpp deleted file mode 100644 index c776d640..00000000 --- a/sdk/cpp/src/foundry_local.cpp +++ /dev/null @@ -1,856 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include "core_interop_request.h" -#include "configuration.h" -#include "foundry_local.h" -#include "flcore_native.h" -#include "foundry_local_internal_core.h" -#include "parser.h" -#include "logger.h" -#include -#include "foundry_local_exception.h" - -// Internal private namespace. -namespace { - std::filesystem::path getExecutableDir() { - auto exePath = wil::GetModuleFileNameW(nullptr); - return std::filesystem::path(exePath.get()).parent_path(); - } -} // namespace - -namespace { - // Wrap Params: { ... } into a request object - inline nlohmann::json MakeParams(nlohmann::json params) { - return nlohmann::json{{"Params", std::move(params)}}; - } - - // Most common: Params { "Model": } - inline nlohmann::json MakeModelParams(std::string_view model) { - return MakeParams(nlohmann::json{{"Model", std::string(model)}}); - } - - // Serialize + call - inline std::string CallWithJson(foundry_local::Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& requestJson, foundry_local::ILogger& logger) { - std::string payload = requestJson.dump(); - return core->call(command, logger, &payload); - } - - // Serialize + call with native callback - inline std::string CallWithJsonAndCallback(foundry_local::Internal::IFoundryLocalCore* core, - std::string_view command, const nlohmann::json& requestJson, - foundry_local::ILogger& logger, void* callback, void* userData) { - std::string payload = requestJson.dump(); - return core->call(command, logger, &payload, callback, userData); - } - - // Overload: allow Params object directly - inline std::string CallWithParams(foundry_local::Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& params, foundry_local::ILogger& logger) { - return CallWithJson(core, command, MakeParams(params), logger); - } - - // Overload: no payload - inline std::string CallNoArgs(foundry_local::Internal::IFoundryLocalCore* core, std::string_view command, - foundry_local::ILogger& logger) { - return core->call(command, logger, nullptr); - } - - std::vector GetLoadedModelsInternal(foundry_local::Internal::IFoundryLocalCore* core, - foundry_local::ILogger& logger) { - std::string raw = core->call("list_loaded_models", logger); - try { - auto parsed = nlohmann::json::parse(raw); - return parsed.get>(); - } - catch (const nlohmann::json::exception& e) { - throw foundry_local::FoundryLocalException( - "Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); - } - } - - std::vector GetCachedModelsInternal(foundry_local::Internal::IFoundryLocalCore* core, - foundry_local::ILogger& logger) { - std::string raw = core->call("get_cached_models", logger); - - try { - auto parsed = nlohmann::json::parse(raw); - return parsed.get>(); - } - catch (const nlohmann::json::exception& e) { - throw foundry_local::FoundryLocalException( - "Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); - } - } - - std::vector CollectVariantsByIds( - std::unordered_map& modelIdToModelVariant, - std::vector ids) { - std::vector out; - out.reserve(ids.size()); - - for (const auto& id : ids) { - auto it = modelIdToModelVariant.find(id); - if (it != modelIdToModelVariant.end()) { - out.emplace_back(&it->second); - } - } - return out; - } - -} // namespace - -namespace foundry_local { -inline static void* RequireProc(HMODULE mod, const char* name) { - if (void* p = ::GetProcAddress(mod, name)) - return p; - throw std::runtime_error(std::string("GetProcAddress failed for ") + name); - } - - struct Core : foundry_local::Internal::IFoundryLocalCore { - using ResponseHandle = std::unique_ptr; - - Core() = default; - ~Core() = default; - - void loadEmbedded() { loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); } - - void unload() { - module_.reset(); - execCmd_ = nullptr; - execCbCmd_ = nullptr; - freeResCmd_ = nullptr; - } - std::string call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, - void* callback = nullptr, void* data = nullptr) const override { - if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { - throw FoundryLocalException("Core is not loaded. Cannot call command: " + std::string(command), logger); - } - - RequestBuffer request{}; - request.Command = command.empty() ? nullptr : command.data(); - request.CommandLength = static_cast(command.size()); - - if (dataArgument && !dataArgument->empty()) { - request.Data = dataArgument->data(); - request.DataLength = static_cast(dataArgument->size()); - } - - ResponseBuffer response{}; - auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { - if (fn) - fn(buf); - }; - std::unique_ptr responseGuard(&response, safeDeleter); - - using CallbackFn = void (*)(void*, int32_t, void*); - - if (callback != nullptr) { - auto cb = reinterpret_cast(callback); - execCbCmd_(&request, &response, reinterpret_cast(cb), data); - } - else { - execCmd_(&request, &response); - } - - std::string result; - if (response.Error && response.ErrorLength > 0) { - std::string err(static_cast(response.Error), response.ErrorLength); - throw FoundryLocalException(std::string("Command failed [").append(command).append("]: ").append(err), - logger); - } - - if (response.Data && response.DataLength > 0) { - result.assign(static_cast(response.Data), response.DataLength); - } - - return result; - } - - private: - wil::unique_hmodule module_; - execute_command_fn execCmd_{}; - execute_command_with_callback_fn execCbCmd_{}; - free_response_fn freeResCmd_{}; - - void loadFromPath(const std::filesystem::path& path) { - wil::unique_hmodule m(::LoadLibraryW(path.c_str())); - if (!m) - throw std::runtime_error("LoadLibraryW failed"); - - execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); - execCbCmd_ = reinterpret_cast( - RequireProc(m.get(), "execute_command_with_callback")); - freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); - - module_ = std::move(m); - } - }; - - /// - /// OpenAIAudioClient - /// - - OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger) - : core_(core), modelId_(modelId), logger_(logger) {} - - AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { - nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; - CoreInteropRequest req("audio_transcribe"); - req.AddParam("OpenAICreateRequest", openAiReq.dump()); - - std::string json = req.ToJson(); - - AudioCreateTranscriptionResponse response; - response.text = core_->call(req.Command(), *logger_, &json); - - return response; - } - - void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, - const StreamCallback& onChunk) const { - nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; - CoreInteropRequest req("audio_transcribe"); - req.AddParam("OpenAICreateRequest", openAiReq.dump()); - - std::string json = req.ToJson(); - - struct State { - const StreamCallback* cb; - std::exception_ptr exception; - } state{&onChunk, nullptr}; - - auto streamCallback = [](void* data, int32_t len, void* user) { - if (!data || len <= 0) - return; - - auto* st = static_cast(user); - if (st->exception) - return; - - try { - std::string text(static_cast(data), static_cast(len)); - AudioCreateTranscriptionResponse chunk; - chunk.text = std::move(text); - (*(st->cb))(chunk); - } - catch (...) { - st->exception = std::current_exception(); - } - }; - - core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), - reinterpret_cast(&state)); - - if (state.exception) { - std::rethrow_exception(state.exception); - } - } - - std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { - if (created == 0) - return {}; - std::time_t t = static_cast(created); - std::tm tm{}; -#ifdef _WIN32 - gmtime_s(&tm, &t); -#else - gmtime_r(&t, &tm); -#endif - char buf[32]; - std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); - return buf; - } - - /// - /// OpenAIChatClient - /// - - OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger) - : core_(core), modelId_(modelId), logger_(logger) {} - - std::string OpenAIChatClient::BuildChatRequestJson(gsl::span messages, - gsl::span tools, const ChatSettings& settings, - bool stream) const { - nlohmann::json jMessages = nlohmann::json::array(); - for (const auto& msg : messages) { - nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; - if (msg.tool_call_id) - jMsg["tool_call_id"] = *msg.tool_call_id; - jMessages.push_back(std::move(jMsg)); - } - - nlohmann::json req = {{"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream}}; - - if (!tools.empty()) { - nlohmann::json jTools = nlohmann::json::array(); - for (const auto& tool : tools) { - nlohmann::json jTool; - to_json(jTool, tool); - jTools.push_back(std::move(jTool)); - } - req["tools"] = std::move(jTools); - } - - if (settings.tool_choice) - req["tool_choice"] = tool_choice_to_string(*settings.tool_choice); - if (settings.top_k) - req["metadata"] = {{"top_k", *settings.top_k}}; - if (settings.frequency_penalty) - req["frequency_penalty"] = *settings.frequency_penalty; - if (settings.presence_penalty) - req["presence_penalty"] = *settings.presence_penalty; - if (settings.max_tokens) - req["max_completion_tokens"] = *settings.max_tokens; - if (settings.n) - req["n"] = *settings.n; - if (settings.temperature) - req["temperature"] = *settings.temperature; - if (settings.top_p) - req["top_p"] = *settings.top_p; - if (settings.random_seed) - req["seed"] = *settings.random_seed; - - return req.dump(); - } - - ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, - const ChatSettings& settings) const { - return CompleteChat(messages, {}, settings); - } - - ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, - gsl::span tools, - const ChatSettings& settings) const { - std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); - - CoreInteropRequest req("chat_completions"); - req.AddParam("OpenAICreateRequest", openAiReqJson); - - std::string json = req.ToJson(); - std::string rawResult = core_->call(req.Command(), *logger_, &json); - - return nlohmann::json::parse(rawResult).get(); - } - - void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, - const StreamCallback& onChunk) const { - CompleteChatStreaming(messages, {}, settings, onChunk); - } - - void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, gsl::span tools, - const ChatSettings& settings, const StreamCallback& onChunk) const { - std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); - - CoreInteropRequest req("chat_completions"); - req.AddParam("OpenAICreateRequest", openAiReqJson); - std::string json = req.ToJson(); - - struct State { - const StreamCallback* cb; - std::exception_ptr exception; - } state{&onChunk, nullptr}; - - auto streamCallback = [](void* data, int32_t len, void* user) { - if (!data || len <= 0) - return; - - auto* st = static_cast(user); - if (st->exception) - return; - - std::string s(static_cast(data), static_cast(len)); - - try { - auto parsed = nlohmann::json::parse(s).get(); - - (*(st->cb))(parsed); - } - catch (const nlohmann::json::exception& e) { - st->exception = std::make_exception_ptr( - FoundryLocalException(std::string("Error while parsing streaming chat chunk: ") + e.what())); - } - catch (...) { - st->exception = std::current_exception(); - } - }; - - core_->call(req.Command(), *logger_, &json, reinterpret_cast(+streamCallback), - reinterpret_cast(&state)); - - if (state.exception) { - std::rethrow_exception(state.exception); - } - } - - /// - /// ModelVariant - /// - - ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, - gsl::not_null logger) - : core_(core), info_(std::move(info)), logger_(logger) {} - - const ModelInfo& ModelVariant::GetInfo() const { - return info_; - } - - void ModelVariant::RemoveFromCache() { - try { - CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); - cachedPath_.clear(); - } - catch (const std::exception& ex) { - throw FoundryLocalException("Error removing model from cache [" + info_.name + "]: " + ex.what(), *logger_); - } - } - - void ModelVariant::Unload() { - try { - CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); - } - catch (const std::exception& ex) { - throw FoundryLocalException("Error unloading model [" + info_.name + "]: " + ex.what(), *logger_); - } - } - - bool ModelVariant::IsLoaded() const { - std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); - for (const auto& id : loadedModelIds) { - if (id == info_.id) { - return true; - } - } - - return false; - } - - bool ModelVariant::IsCached() const { - auto cachedModels = GetCachedModelsInternal(core_, *logger_); - for (const auto& id : cachedModels) { - if (id == info_.id) { - return true; - } - } - return false; - } - - void ModelVariant::Download(DownloadProgressCallback onProgress) { - if (IsCached()) { - logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); - return; - } - - if (onProgress) { - struct ProgressState { - DownloadProgressCallback* cb; - ILogger* logger; - } state{&onProgress, logger_}; - - auto nativeCallback = [](void* data, int32_t len, void* user) { - if (!data || len <= 0) - return; - auto* st = static_cast(user); - std::string perc(static_cast(data), static_cast((std::min)(4, static_cast(len)))); - try { - float value = std::stof(perc); - (*(st->cb))(value); - } - catch (...) { - st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); - } - }; - - CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, - reinterpret_cast(+nativeCallback), reinterpret_cast(&state)); - } - else { - CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); - } - } - - void ModelVariant::Load() { - CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); - } - - const std::filesystem::path& ModelVariant::GetPath() const { - if (cachedPath_.empty()) { - cachedPath_ = - std::filesystem::path(CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_)); - } - return cachedPath_; - } - - const std::string& ModelVariant::GetId() const noexcept { - return info_.id; - } - - const std::string& ModelVariant::GetAlias() const noexcept { - return info_.alias; - } - - uint32_t ModelVariant::GetVersion() const noexcept { - return info_.version; - } - - IModel::CoreAccess ModelVariant::GetCoreAccess() const { - return {core_, info_.name, logger_}; - } - - OpenAIAudioClient::OpenAIAudioClient(const IModel& model) - : OpenAIAudioClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { - if (!model.IsLoaded()) { - throw FoundryLocalException("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", - *model.GetCoreAccess().logger); - } - } - - OpenAIAudioClient ModelVariant::GetAudioClient() const { - return OpenAIAudioClient(*this); - } - - OpenAIChatClient::OpenAIChatClient(const IModel& model) - : OpenAIChatClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { - if (!model.IsLoaded()) { - throw FoundryLocalException("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", - *model.GetCoreAccess().logger); - } - } - - OpenAIChatClient ModelVariant::GetChatClient() const { - return OpenAIChatClient(*this); - } - - /// - /// Model - /// - Model::Model(gsl::not_null core, gsl::not_null logger) - : core_(core), logger_(logger) {} - - ModelVariant& Model::SelectedVariant() { - if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { - throw FoundryLocalException("Model has no selected variant", *logger_); - } - return variants_[*selectedVariantIndex_]; - } - - const ModelVariant& Model::SelectedVariant() const { - if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { - throw FoundryLocalException("Model has no selected variant", *logger_); - } - return variants_[*selectedVariantIndex_]; - } - - gsl::span Model::GetAllModelVariants() const { - return variants_; - } - - const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { - const auto& targetName = variant.GetInfo().name; - - for (const auto& v : variants_) { - if (v.GetInfo().name == targetName) { - return v; - } - } - - throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", - *logger_); - } - - const std::string& Model::GetId() const { - return SelectedVariant().GetId(); - } - - const std::string& Model::GetAlias() const { - return SelectedVariant().GetAlias(); - } - - void Model::SelectVariant(const ModelVariant& variant) const { - auto it = std::find_if(variants_.begin(), variants_.end(), - [&](const ModelVariant& v) { return &v == &variant; }); - - if (it == variants_.end()) { - throw FoundryLocalException("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", - *logger_); - } - - selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); - } - - IModel::CoreAccess Model::GetCoreAccess() const { - return SelectedVariant().GetCoreAccess(); - } - - /// - /// Catalog - /// - - Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) - : core_(injected), logger_(logger) { - try { - name_ = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); - } - catch (const std::exception& ex) { - throw FoundryLocalException(std::string("Error getting catalog name: ") + ex.what(), *logger_); - } - } - - std::vector Catalog::GetLoadedModels() const { - return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); - } - - std::vector Catalog::GetCachedModels() const { - return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); - } - - Model* Catalog::GetModel(std::string_view modelId) const { - auto it = byAlias_.find(std::string(modelId)); - if (it != byAlias_.end()) { - return &it->second; - } - return nullptr; - } - - std::vector Catalog::ListModels() const { - UpdateModels(); - - std::vector out; - out.reserve(byAlias_.size()); - for (auto& kv : byAlias_) - out.emplace_back(&kv.second); - - return out; - } - - void Catalog::UpdateModels() const { - using clock = std::chrono::steady_clock; - - // TODO: make this configurable - constexpr auto kRefreshInterval = std::chrono::hours(6); - - const auto now = clock::now(); - if (lastFetch_.time_since_epoch() != clock::duration::zero() && (now - lastFetch_) < kRefreshInterval) { - return; - } - - const std::string raw = core_->call("get_model_list", *logger_); - const auto arr = nlohmann::json::parse(raw); - - byAlias_.clear(); - modelIdToModelVariant_.clear(); - - for (const auto& j : arr) { - const std::string alias = j.at("alias").get(); - if (alias.rfind("openai-", 0) == 0) - continue; - - auto it = byAlias_.find(alias); - if (it == byAlias_.end()) { - Model m(core_, logger_); - it = byAlias_.emplace(alias, std::move(m)).first; - } - - ModelInfo modelVariantInfo; - from_json(j, modelVariantInfo); - std::string variantId = modelVariantInfo.id; - ModelVariant modelVariant(core_, modelVariantInfo, logger_); - modelIdToModelVariant_.emplace(variantId, modelVariant); - - it->second.variants_.emplace_back(std::move(modelVariant)); - } - - // Auto-select the first variant for each model. - for (auto& [alias, model] : byAlias_) { - if (!model.variants_.empty()) { - model.selectedVariantIndex_ = 0; - } - } - - lastFetch_ = now; - } - - ModelVariant* Catalog::GetModelVariant(std::string_view id) const { - auto it = modelIdToModelVariant_.find(std::string(id)); - if (it != modelIdToModelVariant_.end()) { - return &it->second; - } - return nullptr; - } - - /// - /// FoundryLocalManager - /// - - FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) - : config_(std::move(configuration)), core_(std::make_unique()), - logger_(logger ? logger : &defaultLogger_) { - static_cast(core_.get())->loadEmbedded(); - Initialize(); - catalog_ = Catalog::Create(core_.get(), logger_); - } - - FoundryLocalManager::FoundryLocalManager(FoundryLocalManager&& other) noexcept - : config_(std::move(other.config_)), core_(std::move(other.core_)), catalog_(std::move(other.catalog_)), - logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), urls_(std::move(other.urls_)) { - other.logger_ = &other.defaultLogger_; - } - - FoundryLocalManager& FoundryLocalManager::operator=(FoundryLocalManager&& other) noexcept { - if (this != &other) { - config_ = std::move(other.config_); - core_ = std::move(other.core_); - catalog_ = std::move(other.catalog_); - logger_ = other.OwnsLogger() ? &defaultLogger_ : other.logger_; - urls_ = std::move(other.urls_); - other.logger_ = &other.defaultLogger_; - } - return *this; - } - - FoundryLocalManager::~FoundryLocalManager() { - // Unload all loaded models before tearing down. - if (catalog_) { - try { - auto loadedModels = catalog_->GetLoadedModels(); - for (auto* variant : loadedModels) { - try { - variant->Unload(); - } - catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, - std::string("Error unloading model during destruction: ") + ex.what()); - } - } - } - catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, - std::string("Error retrieving loaded models during destruction: ") + ex.what()); - } - } - - if (!urls_.empty()) { - try { - StopWebService(); - } - catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, - std::string("Error stopping web service during destruction: ") + ex.what()); - } - } - } - - const Catalog& FoundryLocalManager::GetCatalog() const { - return *catalog_; - } - - Catalog& FoundryLocalManager::GetCatalog() { - return *catalog_; - } - - void FoundryLocalManager::StartWebService() { - if (!config_.web) { - throw FoundryLocalException("Web service configuration was not provided.", *logger_); - } - - try { - std::string raw = core_->call("start_service", *logger_); - auto arr = nlohmann::json::parse(raw); - urls_ = arr.get>(); - } - catch (const std::exception& ex) { - throw FoundryLocalException(std::string("Error starting web service: ") + ex.what(), *logger_); - } - } - - void FoundryLocalManager::StopWebService() { - if (!config_.web) { - throw FoundryLocalException("Web service configuration was not provided.", *logger_); - } - - try { - core_->call("stop_service", *logger_); - urls_.clear(); - } - catch (const std::exception& ex) { - throw FoundryLocalException(std::string("Error stopping web service: ") + ex.what(), *logger_); - } - } - - gsl::span FoundryLocalManager::GetUrls() const noexcept { - return urls_; - } - - void FoundryLocalManager::EnsureEpsDownloaded() const { - try { - core_->call("ensure_eps_downloaded", *logger_); - } - catch (const std::exception& ex) { - throw FoundryLocalException(std::string("Error ensuring execution providers downloaded: ") + ex.what(), - *logger_); - } - } - - void FoundryLocalManager::Initialize() { - config_.Validate(); - - try { - CoreInteropRequest initReq("initialize"); - initReq.AddParam("AppName", config_.app_name); - initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); - - if (config_.app_data_dir) { - initReq.AddParam("AppDataDir", config_.app_data_dir->string()); - } - if (config_.logs_dir) { - initReq.AddParam("LogsDir", config_.logs_dir->string()); - } - if (config_.web && config_.web->urls) { - initReq.AddParam("WebServiceUrls", *config_.web->urls); - } - if (config_.additional_settings) { - for (const auto& [key, value] : *config_.additional_settings) { - if (!key.empty()) { - initReq.AddParam(key, value); - } - } - } - - std::string initJson = initReq.ToJson(); - core_->call(initReq.Command(), *logger_, &initJson); - - if (config_.model_cache_dir) { - std::string current = core_->call("get_cache_directory", *logger_); - - if (current != config_.model_cache_dir->string()) { - CoreInteropRequest setReq("set_cache_directory"); - setReq.AddParam("Directory", config_.model_cache_dir->string()); - std::string setJson = setReq.ToJson(); - core_->call(setReq.Command(), *logger_, &setJson); - - logger_->Log(LogLevel::Information, - std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); - } - else { - logger_->Log(LogLevel::Information, - std::string("Model cache directory already set to: ") + current); - } - } - } - catch (const std::exception& ex) { - throw FoundryLocalException(std::string("FoundryLocalManager::Initialize failed: ") + ex.what(), *logger_); - } - } - -} // namespace foundry_local diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index c5ee81f2..0cbc3d68 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -3,17 +3,34 @@ #pragma once +#include #include #include #include "logger.h" namespace foundry_local { + + /// Native callback signature used by the core DLL interop. + /// Parameters: (data, dataLength, userData). + using NativeCallbackFn = void(*)(void*, int32_t, void*); + + /// Value returned by IFoundryLocalCore::call(). + /// On success, `data` contains the response payload and `error` is empty. + /// On failure, `error` contains the error message from the core layer. + struct CoreResponse { + std::string data; + std::string error; + + bool HasError() const noexcept { return !error.empty(); } + }; + namespace Internal { struct IFoundryLocalCore { virtual ~IFoundryLocalCore() = default; - virtual std::string call(std::string_view command, ILogger& logger, - const std::string* dataArgument = nullptr, void* callback = nullptr, + virtual CoreResponse call(std::string_view command, ILogger& logger, + const std::string* dataArgument = nullptr, + NativeCallbackFn callback = nullptr, void* data = nullptr) const = 0; virtual void unload() = 0; }; diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp new file mode 100644 index 00000000..f383abf6 --- /dev/null +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -0,0 +1,178 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "core.h" +#include "logger.h" + +namespace foundry_local { + +FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) + : config_(std::move(configuration)), core_(std::make_unique()), + logger_(logger ? logger : &defaultLogger_) { + static_cast(core_.get())->loadEmbedded(); + Initialize(); + catalog_ = Catalog::Create(core_.get(), logger_); +} + +FoundryLocalManager::FoundryLocalManager(FoundryLocalManager&& other) noexcept + : config_(std::move(other.config_)), core_(std::move(other.core_)), catalog_(std::move(other.catalog_)), + logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), urls_(std::move(other.urls_)) { + other.logger_ = &other.defaultLogger_; +} + +FoundryLocalManager& FoundryLocalManager::operator=(FoundryLocalManager&& other) noexcept { + if (this != &other) { + config_ = std::move(other.config_); + core_ = std::move(other.core_); + catalog_ = std::move(other.catalog_); + logger_ = other.OwnsLogger() ? &defaultLogger_ : other.logger_; + urls_ = std::move(other.urls_); + other.logger_ = &other.defaultLogger_; + } + return *this; +} + +FoundryLocalManager::~FoundryLocalManager() { + // Unload all loaded models before tearing down. + if (catalog_) { + try { + auto loadedModels = catalog_->GetLoadedModels(); + for (auto* variant : loadedModels) { + try { + variant->Unload(); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error unloading model during destruction: ") + ex.what()); + } + } + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error retrieving loaded models during destruction: ") + ex.what()); + } + } + + if (!urls_.empty()) { + try { + StopWebService(); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error stopping web service during destruction: ") + ex.what()); + } + } +} + +const Catalog& FoundryLocalManager::GetCatalog() const { + return *catalog_; +} + +Catalog& FoundryLocalManager::GetCatalog() { + return *catalog_; +} + +void FoundryLocalManager::StartWebService() { + if (!config_.web) { + throw Exception("Web service configuration was not provided.", *logger_); + } + + auto response = core_->call("start_service", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error starting web service: ") + response.error, *logger_); + } + auto arr = nlohmann::json::parse(response.data); + urls_ = arr.get>(); +} + +void FoundryLocalManager::StopWebService() { + if (!config_.web) { + throw Exception("Web service configuration was not provided.", *logger_); + } + + auto response = core_->call("stop_service", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error stopping web service: ") + response.error, *logger_); + } + urls_.clear(); +} + +gsl::span FoundryLocalManager::GetUrls() const noexcept { + return urls_; +} + +void FoundryLocalManager::EnsureEpsDownloaded() const { + auto response = core_->call("ensure_eps_downloaded", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error ensuring execution providers downloaded: ") + response.error, + *logger_); + } +} + +void FoundryLocalManager::Initialize() { + config_.Validate(); + + CoreInteropRequest initReq("initialize"); + initReq.AddParam("AppName", config_.app_name); + initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); + + if (config_.app_data_dir) { + initReq.AddParam("AppDataDir", config_.app_data_dir->string()); + } + if (config_.logs_dir) { + initReq.AddParam("LogsDir", config_.logs_dir->string()); + } + if (config_.web && config_.web->urls) { + initReq.AddParam("WebServiceUrls", *config_.web->urls); + } + if (config_.additional_settings) { + for (const auto& [key, value] : *config_.additional_settings) { + if (!key.empty()) { + initReq.AddParam(key, value); + } + } + } + + std::string initJson = initReq.ToJson(); + auto initResponse = core_->call(initReq.Command(), *logger_, &initJson); + if (initResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + initResponse.error, *logger_); + } + + if (config_.model_cache_dir) { + auto cacheResponse = core_->call("get_cache_directory", *logger_); + if (cacheResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + cacheResponse.error, *logger_); + } + + if (cacheResponse.data != config_.model_cache_dir->string()) { + CoreInteropRequest setReq("set_cache_directory"); + setReq.AddParam("Directory", config_.model_cache_dir->string()); + std::string setJson = setReq.ToJson(); + auto setResponse = core_->call(setReq.Command(), *logger_, &setJson); + if (setResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + setResponse.error, *logger_); + } + + logger_->Log(LogLevel::Information, + std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); + } + else { + logger_->Log(LogLevel::Information, + std::string("Model cache directory already set to: ") + cacheResponse.data); + } + } +} + +} // namespace foundry_local diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp new file mode 100644 index 00000000..d017387b --- /dev/null +++ b/sdk/cpp/src/model.cpp @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_helpers.h" +#include "logger.h" + +namespace foundry_local { + +using namespace detail; + +/// ModelVariant + +ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) + : core_(core), info_(std::move(info)), logger_(logger) {} + +const ModelInfo& ModelVariant::GetInfo() const { + return info_; +} + +void ModelVariant::RemoveFromCache() { + auto response = CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error removing model from cache [" + info_.name + "]: " + response.error, *logger_); + } + cachedPath_.clear(); +} + +void ModelVariant::Unload() { + auto response = CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error unloading model [" + info_.name + "]: " + response.error, *logger_); + } +} + +bool ModelVariant::IsLoaded() const { + std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); + for (const auto& id : loadedModelIds) { + if (id == info_.id) { + return true; + } + } + + return false; +} + +bool ModelVariant::IsCached() const { + auto cachedModels = GetCachedModelsInternal(core_, *logger_); + for (const auto& id : cachedModels) { + if (id == info_.id) { + return true; + } + } + return false; +} + +void ModelVariant::Download(DownloadProgressCallback onProgress) { + if (IsCached()) { + logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); + return; + } + + if (onProgress) { + struct ProgressState { + DownloadProgressCallback* cb; + ILogger* logger; + } state{&onProgress, logger_}; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + auto* st = static_cast(user); + std::string perc(static_cast(data), static_cast((std::min)(4, static_cast(len)))); + try { + float value = std::stof(perc); + (*(st->cb))(value); + } + catch (...) { + st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + } + }; + + auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, + +nativeCallback, &state); + if (response.HasError()) { + throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); + } + } + else { + auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); + } + } +} + +void ModelVariant::Load() { + auto response = CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error loading model [" + info_.name + "]: " + response.error, *logger_); + } +} + +const std::filesystem::path& ModelVariant::GetPath() const { + if (cachedPath_.empty()) { + auto response = CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error getting model path [" + info_.name + "]: " + response.error, *logger_); + } + cachedPath_ = std::filesystem::path(response.data); + } + return cachedPath_; +} + +const std::string& ModelVariant::GetId() const noexcept { + return info_.id; +} + +const std::string& ModelVariant::GetAlias() const noexcept { + return info_.alias; +} + +uint32_t ModelVariant::GetVersion() const noexcept { + return info_.version; +} + +IModel::CoreAccess ModelVariant::GetCoreAccess() const { + return {core_, info_.name, logger_}; +} + +OpenAIAudioClient ModelVariant::GetAudioClient() const { + return OpenAIAudioClient(*this); +} + +OpenAIChatClient ModelVariant::GetChatClient() const { + return OpenAIChatClient(*this); +} + +/// Model + +Model::Model(gsl::not_null core, gsl::not_null logger) + : core_(core), logger_(logger) {} + +ModelVariant& Model::SelectedVariant() { + if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + throw Exception("Model has no selected variant", *logger_); + } + return variants_[*selectedVariantIndex_]; +} + +const ModelVariant& Model::SelectedVariant() const { + if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + throw Exception("Model has no selected variant", *logger_); + } + return variants_[*selectedVariantIndex_]; +} + +gsl::span Model::GetAllModelVariants() const { + return variants_; +} + +const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { + const auto& targetName = variant.GetInfo().name; + + for (const auto& v : variants_) { + if (v.GetInfo().name == targetName) { + return v; + } + } + + throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", + *logger_); +} + +const std::string& Model::GetId() const { + return SelectedVariant().GetId(); +} + +const std::string& Model::GetAlias() const { + return SelectedVariant().GetAlias(); +} + +void Model::SelectVariant(const ModelVariant& variant) const { + auto it = std::find_if(variants_.begin(), variants_.end(), + [&](const ModelVariant& v) { return &v == &variant; }); + + if (it == variants_.end()) { + throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", + *logger_); + } + + selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); +} + +IModel::CoreAccess Model::GetCoreAccess() const { + return SelectedVariant().GetCoreAccess(); +} + +} // namespace foundry_local diff --git a/sdk/cpp/src/openai_audio_client.cpp b/sdk/cpp/src/openai_audio_client.cpp new file mode 100644 index 00000000..75c0110b --- /dev/null +++ b/sdk/cpp/src/openai_audio_client.cpp @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "core_helpers.h" +#include "logger.h" + +namespace foundry_local { + +OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} + +AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + auto coreResponse = core_->call(req.Command(), *logger_, &json); + if (coreResponse.HasError()) { + throw Exception("Audio transcription failed: " + coreResponse.error, *logger_); + } + + AudioCreateTranscriptionResponse response; + response.text = std::move(coreResponse.data); + + return response; +} + +void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, + const StreamCallback& onChunk) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + detail::CallWithStreamingCallback( + core_, req.Command(), json, *logger_, + [&onChunk](const std::string& text) { + AudioCreateTranscriptionResponse chunk; + chunk.text = text; + onChunk(chunk); + }, + "Streaming audio transcription failed: "); +} + +OpenAIAudioClient::OpenAIAudioClient(const IModel& model) + : OpenAIAudioClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } +} + +} // namespace foundry_local diff --git a/sdk/cpp/src/openai_chat_client.cpp b/sdk/cpp/src/openai_chat_client.cpp new file mode 100644 index 00000000..3f0d95cd --- /dev/null +++ b/sdk/cpp/src/openai_chat_client.cpp @@ -0,0 +1,139 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "core_helpers.h" +#include "parser.h" +#include "logger.h" + +namespace foundry_local { + +std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { + if (created == 0) + return {}; + std::time_t t = static_cast(created); + std::tm tm{}; +#ifdef _WIN32 + gmtime_s(&tm, &t); +#else + gmtime_r(&t, &tm); +#endif + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); + return buf; +} + +OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} + +std::string OpenAIChatClient::BuildChatRequestJson(gsl::span messages, + gsl::span tools, + const ChatSettings& settings, bool stream) const { + nlohmann::json jMessages = nlohmann::json::array(); + for (const auto& msg : messages) { + nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; + if (msg.tool_call_id) + jMsg["tool_call_id"] = *msg.tool_call_id; + jMessages.push_back(std::move(jMsg)); + } + + nlohmann::json req = {{"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream}}; + + if (!tools.empty()) { + nlohmann::json jTools = nlohmann::json::array(); + for (const auto& tool : tools) { + nlohmann::json jTool; + to_json(jTool, tool); + jTools.push_back(std::move(jTool)); + } + req["tools"] = std::move(jTools); + } + + if (settings.tool_choice) + req["tool_choice"] = ParsingUtils::tool_choice_to_string(*settings.tool_choice); + if (settings.top_k) + req["metadata"] = {{"top_k", *settings.top_k}}; + if (settings.frequency_penalty) + req["frequency_penalty"] = *settings.frequency_penalty; + if (settings.presence_penalty) + req["presence_penalty"] = *settings.presence_penalty; + if (settings.max_tokens) + req["max_completion_tokens"] = *settings.max_tokens; + if (settings.n) + req["n"] = *settings.n; + if (settings.temperature) + req["temperature"] = *settings.temperature; + if (settings.top_p) + req["top_p"] = *settings.top_p; + if (settings.random_seed) + req["seed"] = *settings.random_seed; + + return req.dump(); +} + +ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + const ChatSettings& settings) const { + return CompleteChat(messages, {}, settings); +} + +ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + auto response = core_->call(req.Command(), *logger_, &json); + if (response.HasError()) { + throw Exception("Chat completion failed: " + response.error, *logger_); + } + + return nlohmann::json::parse(response.data).get(); +} + +void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const { + CompleteChatStreaming(messages, {}, settings, onChunk); +} + +void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, + gsl::span tools, const ChatSettings& settings, + const StreamCallback& onChunk) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + std::string json = req.ToJson(); + + detail::CallWithStreamingCallback( + core_, req.Command(), json, *logger_, + [&onChunk](const std::string& chunk) { + auto parsed = nlohmann::json::parse(chunk).get(); + onChunk(parsed); + }, + "Streaming chat completion failed: "); +} + +OpenAIChatClient::OpenAIChatClient(const IModel& model) + : OpenAIChatClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } +} + +} // namespace foundry_local diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 7930b7cb..7d28392e 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -8,104 +8,118 @@ #include namespace foundry_local { - inline DeviceType parse_device_type(std::string_view v) { - if (v == "CPU") { - return DeviceType::CPU; - } - if (v == "NPU") { - return DeviceType::NPU; - } - if (v == "GPU") { - return DeviceType::GPU; - } - return DeviceType::Invalid; - } - inline FinishReason parse_finish_reason(std::string_view v) { - if (v == "stop") - return FinishReason::Stop; - if (v == "length") - return FinishReason::Length; - if (v == "tool_calls") - return FinishReason::ToolCalls; - if (v == "content_filter") - return FinishReason::ContentFilter; - return FinishReason::None; - } - - // ---------- Helpers ---------- - inline std::string get_string_or_empty(const nlohmann::json& j, const char* key) { - auto it = j.find(key); - std::string out = ""; - if (it != j.end() && it->is_string()) { - out = it->get(); + class ParsingUtils { + public: + static DeviceType parse_device_type(std::string_view v) { + if (v == "CPU") { + return DeviceType::CPU; + } + if (v == "NPU") { + return DeviceType::NPU; + } + if (v == "GPU") { + return DeviceType::GPU; + } + return DeviceType::Invalid; } - return out; - } - inline void from_json(const nlohmann::json& j, Runtime& r) { - std::string deviceType; - std::string executionProvider; - j.at("deviceType").get_to(deviceType); - j.at("executionProvider").get_to(r.execution_provider); - - r.device_type = parse_device_type(std::move(deviceType)); - } + static FinishReason parse_finish_reason(std::string_view v) { + if (v == "stop") + return FinishReason::Stop; + if (v == "length") + return FinishReason::Length; + if (v == "tool_calls") + return FinishReason::ToolCalls; + if (v == "content_filter") + return FinishReason::ContentFilter; + return FinishReason::None; + } - inline void from_json(const nlohmann::json& j, PromptTemplate& p) { - p.system = get_string_or_empty(j, "system"); - p.user = get_string_or_empty(j, "user"); - p.assistant = get_string_or_empty(j, "assistant"); - p.prompt = get_string_or_empty(j, "prompt"); - } + static std::string get_string_or_empty(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + std::string out = ""; + if (it != j.end() && it->is_string()) { + out = it->get(); + } + return out; + } - inline std::optional get_opt_string(const nlohmann::json& j, const char* key) { - auto it = j.find(key); - if (it == j.end() || it->is_null()) { + static std::optional get_opt_string(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_string()) { + return it->get(); + } return std::nullopt; } - if (it->is_string()) { - return it->get(); - } - return std::nullopt; - } - inline std::optional get_opt_int(const nlohmann::json& j, const char* key) { - auto it = j.find(key); - if (it == j.end() || it->is_null()) { + static std::optional get_opt_int(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } return std::nullopt; } - if (it->is_number_integer()) { - return it->get(); - } - return std::nullopt; - } - inline std::optional get_opt_i64(const nlohmann::json& j, const char* key) { - auto it = j.find(key); - if (it == j.end() || it->is_null()) { + static std::optional get_opt_i64(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } return std::nullopt; } - if (it->is_number_integer()) { - return it->get(); - } - return std::nullopt; - } - inline std::optional get_opt_bool(const nlohmann::json& j, const char* key) { - auto it = j.find(key); - if (it == j.end() || it->is_null()) { + static std::optional get_opt_bool(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_boolean()) { + return it->get(); + } return std::nullopt; } - if (it->is_boolean()) { - return it->get(); + + static std::string tool_choice_to_string(ToolChoiceKind kind) { + switch (kind) { + case ToolChoiceKind::Auto: return "auto"; + case ToolChoiceKind::None: return "none"; + case ToolChoiceKind::Required: return "required"; + } + return "auto"; } - return std::nullopt; + }; + + // ---------- from_json / to_json (ADL overloads for nlohmann::json) ---------- + + inline void from_json(const nlohmann::json& j, Runtime& r) { + std::string deviceType; + std::string executionProvider; + j.at("deviceType").get_to(deviceType); + j.at("executionProvider").get_to(r.execution_provider); + + r.device_type = ParsingUtils::parse_device_type(std::move(deviceType)); + } + + inline void from_json(const nlohmann::json& j, PromptTemplate& p) { + p.system = ParsingUtils::get_string_or_empty(j, "system"); + p.user = ParsingUtils::get_string_or_empty(j, "user"); + p.assistant = ParsingUtils::get_string_or_empty(j, "assistant"); + p.prompt = ParsingUtils::get_string_or_empty(j, "prompt"); } inline void from_json(const nlohmann::json& j, Parameter& p) { j.at("name").get_to(p.name); - p.value = get_opt_string(j, "value"); + p.value = ParsingUtils::get_opt_string(j, "value"); } inline void from_json(const nlohmann::json& j, ModelSettings& ms) { @@ -124,18 +138,18 @@ namespace foundry_local { j.at("uri").get_to(m.uri); j.at("modelType").get_to(m.model_type); - m.display_name = get_opt_string(j, "displayName"); - m.publisher = get_opt_string(j, "publisher"); - m.license = get_opt_string(j, "license"); - m.license_description = get_opt_string(j, "licenseDescription"); - m.task = get_opt_string(j, "task"); + m.display_name = ParsingUtils::get_opt_string(j, "displayName"); + m.publisher = ParsingUtils::get_opt_string(j, "publisher"); + m.license = ParsingUtils::get_opt_string(j, "license"); + m.license_description = ParsingUtils::get_opt_string(j, "licenseDescription"); + m.task = ParsingUtils::get_opt_string(j, "task"); if (auto it = j.find("fileSizeMb"); it != j.end() && !it->is_null() && it->is_number_integer()) { auto v = it->get(); m.file_size_mb = (v >= 0) ? static_cast(v) : 0u; } - m.supports_tool_calling = get_opt_bool(j, "supportsToolCalling"); - m.max_output_tokens = get_opt_i64(j, "maxOutputTokens"); - m.min_fl_version = get_opt_string(j, "minFLVersion"); + m.supports_tool_calling = ParsingUtils::get_opt_bool(j, "supportsToolCalling"); + m.max_output_tokens = ParsingUtils::get_opt_i64(j, "maxOutputTokens"); + m.min_fl_version = ParsingUtils::get_opt_string(j, "minFLVersion"); if (auto it = j.find("cached"); it != j.end() && it->is_boolean()) { m.cached = it->get(); @@ -214,7 +228,7 @@ namespace foundry_local { // ---------- Tool calling: from_json (deserialization from responses) ---------- inline void from_json(const nlohmann::json& j, FunctionCall& fc) { - fc.name = get_string_or_empty(j, "name"); + fc.name = ParsingUtils::get_string_or_empty(j, "name"); if (j.contains("arguments")) { const auto& args = j.at("arguments"); if (args.is_string()) @@ -225,8 +239,8 @@ namespace foundry_local { } inline void from_json(const nlohmann::json& j, ToolCall& tc) { - tc.id = get_string_or_empty(j, "id"); - tc.type = get_string_or_empty(j, "type"); + tc.id = ParsingUtils::get_string_or_empty(j, "id"); + tc.type = ParsingUtils::get_string_or_empty(j, "type"); if (j.contains("function") && j.at("function").is_object()) tc.function_call = j.at("function").get(); } @@ -237,7 +251,7 @@ namespace foundry_local { if (j.contains("content") && !j.at("content").is_null()) j.at("content").get_to(m.content); - m.tool_call_id = get_opt_string(j, "tool_call_id"); + m.tool_call_id = ParsingUtils::get_opt_string(j, "tool_call_id"); m.tool_calls.clear(); if (j.contains("tool_calls") && j.at("tool_calls").is_array()) { @@ -252,7 +266,7 @@ namespace foundry_local { if (j.contains("index")) j.at("index").get_to(c.index); if (j.contains("finish_reason") && !j.at("finish_reason").is_null()) - c.finish_reason = parse_finish_reason(j.at("finish_reason").get()); + c.finish_reason = ParsingUtils::parse_finish_reason(j.at("finish_reason").get()); if (j.contains("message") && !j.at("message").is_null()) c.message = j.at("message").get(); @@ -264,7 +278,7 @@ namespace foundry_local { inline void from_json(const nlohmann::json& j, ChatCompletionCreateResponse& r) { if (j.contains("created")) j.at("created").get_to(r.created); - r.id = get_string_or_empty(j, "id"); + r.id = ParsingUtils::get_string_or_empty(j, "id"); if (j.contains("IsDelta")) j.at("IsDelta").get_to(r.is_delta); if (j.contains("Successful")) @@ -278,15 +292,4 @@ namespace foundry_local { } } - // ---------- Tool choice helpers ---------- - - inline std::string tool_choice_to_string(ToolChoiceKind kind) { - switch (kind) { - case ToolChoiceKind::Auto: return "auto"; - case ToolChoiceKind::None: return "none"; - case ToolChoiceKind::Required: return "required"; - } - return "auto"; - } - } // namespace foundry_local \ No newline at end of file diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp index 5ed41e5d..fb7af1b3 100644 --- a/sdk/cpp/test/catalog_test.cpp +++ b/sdk/cpp/test/catalog_test.cpp @@ -45,7 +45,7 @@ TEST_F(CatalogTest, GetName) { TEST_F(CatalogTest, Create_ThrowsOnCoreError) { core_.OnCallThrow("get_catalog_name", "catalog error"); - EXPECT_THROW(MockObjectFactory::CreateCatalog(&core_, &logger_), FoundryLocalException); + EXPECT_THROW(MockObjectFactory::CreateCatalog(&core_, &logger_), Exception); } TEST_F(CatalogTest, ListModels_Empty) { @@ -85,12 +85,11 @@ TEST_F(CatalogTest, ListModels_DifferentAliases) { EXPECT_EQ(2u, models.size()); } -TEST_F(CatalogTest, ListModels_FiltersOpenAIPrefix) { +TEST_F(CatalogTest, ListModels_IncludesOpenAIPrefix) { core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "my-model"}, {"openai-model", "openai-stuff"}})); auto catalog = MakeCatalog(); auto models = catalog->ListModels(); - ASSERT_EQ(1u, models.size()); - EXPECT_EQ("my-model", models[0]->GetAlias()); + ASSERT_EQ(2u, models.size()); } TEST_F(CatalogTest, GetModel_Found) { @@ -308,13 +307,12 @@ TEST_F(FileBasedCatalogTest, CoreErrorOnModelList) { EXPECT_ANY_THROW(catalog->ListModels()); } -TEST_F(FileBasedCatalogTest, MixedOpenAIAndLocal_FiltersOpenAIPrefix) { +TEST_F(FileBasedCatalogTest, MixedOpenAIAndLocal_IncludesAll) { auto core = FileBackedCore::FromModelList(TestDataPath("mixed_openai_and_local.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); auto models = catalog->ListModels(); - ASSERT_EQ(1u, models.size()); - EXPECT_EQ("phi-4", models[0]->GetAlias()); + ASSERT_EQ(3u, models.size()); } TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel) { diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index bfe6bffb..0ddb0c02 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -127,13 +127,12 @@ nlohmann::json chunk1 = { {"choices", {{{"index", 0}, {"finish_reason", "stop"}, {"delta", {{"content", " world"}}}}}}}; core_.OnCall("chat_completions", - [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { - auto cb = reinterpret_cast(callback); std::string s1 = chunk1.dump(); std::string s2 = chunk2.dump(); - cb(s1.data(), static_cast(s1.size()), userData); - cb(s2.data(), static_cast(s2.size()), userData); + callback(s1.data(), static_cast(s1.size()), userData); + callback(s2.data(), static_cast(s2.size()), userData); } return ""; }); @@ -167,11 +166,10 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}}}}; core_.OnCall("chat_completions", - [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { - auto cb = reinterpret_cast(callback); std::string s = chunk.dump(); - cb(s.data(), static_cast(s.size()), userData); + callback(s.data(), static_cast(s.size()), userData); } return ""; }); @@ -192,7 +190,7 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) TEST_F(OpenAIChatClientTest, Constructor_ThrowsIfNotLoaded) { core_.OnCall("list_loaded_models", R"([])"); auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(OpenAIChatClient client(variant), FoundryLocalException); + EXPECT_THROW(OpenAIChatClient client(variant), Exception); } TEST_F(OpenAIChatClientTest, GetModelId) { @@ -405,13 +403,12 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; core_.OnCall("chat_completions", - [&](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { - auto cb = reinterpret_cast(callback); std::string s1 = chunk1.dump(); std::string s2 = chunk2.dump(); - cb(s1.data(), static_cast(s1.size()), userData); - cb(s2.data(), static_cast(s2.size()), userData); + callback(s1.data(), static_cast(s1.size()), userData); + callback(s2.data(), static_cast(s2.size()), userData); } return ""; }); @@ -485,13 +482,12 @@ OpenAIAudioClient client(variant); TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { core_.OnCall("audio_transcribe", - [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { - auto cb = reinterpret_cast(callback); std::string text1 = "Hello "; std::string text2 = "world!"; - cb(text1.data(), static_cast(text1.size()), userData); - cb(text2.data(), static_cast(text2.size()), userData); + callback(text1.data(), static_cast(text1.size()), userData); + callback(text2.data(), static_cast(text2.size()), userData); } return ""; }); @@ -511,11 +507,10 @@ TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { core_.OnCall("audio_transcribe", - [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { if (callback && userData) { - auto cb = reinterpret_cast(callback); std::string text = "test"; - cb(text.data(), static_cast(text.size()), userData); + callback(text.data(), static_cast(text.size()), userData); } return ""; }); diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h index c2b4f4e6..136a7ff4 100644 --- a/sdk/cpp/test/mock_core.h +++ b/sdk/cpp/test/mock_core.h @@ -20,27 +20,22 @@ namespace foundry_local::Testing { /// Register expected command -> response mappings before use. class MockCore final : public Internal::IFoundryLocalCore { public: - using CallbackFn = void (*)(void*, int32_t, void*); - /// Handler signature: (command, dataArgument, callback, userData) -> response string. using Handler = std::function; + NativeCallbackFn callback, void* userData)>; /// Register a fixed response for a command. void OnCall(std::string command, std::string response) { - handlers_[std::move(command)] = [r = std::move(response)](std::string_view, const std::string*, void*, - void*) { return r; }; + handlers_[std::move(command)] = [r = std::move(response)](std::string_view, const std::string*, + NativeCallbackFn, void*) { return r; }; } /// Register a custom handler for a command. void OnCall(std::string command, Handler handler) { handlers_[std::move(command)] = std::move(handler); } - /// Register a handler that throws for a command. + /// Register a handler that returns an error for a command. void OnCallThrow(std::string command, std::string errorMessage) { - handlers_[std::move(command)] = [msg = std::move(errorMessage)](std::string_view, const std::string*, void*, - void*) -> std::string { - throw std::runtime_error(msg); - }; + errorResponses_[std::move(command)] = std::move(errorMessage); } /// Returns the number of times a command was called. @@ -60,8 +55,8 @@ namespace foundry_local::Testing { } // IFoundryLocalCore implementation - std::string call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, - void* callback = nullptr, void* data = nullptr) const override { + CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, + NativeCallbackFn callback = nullptr, void* data = nullptr) const override { std::string cmd(command); const_cast(this)->callCounts_[cmd]++; @@ -69,18 +64,28 @@ namespace foundry_local::Testing { const_cast(this)->lastDataArgs_[cmd] = *dataArgument; } + auto errIt = errorResponses_.find(cmd); + if (errIt != errorResponses_.end()) { + CoreResponse resp; + resp.error = errIt->second; + return resp; + } + auto it = handlers_.find(cmd); if (it == handlers_.end()) { throw std::runtime_error("MockCore: no handler registered for command '" + cmd + "'"); } - return it->second(command, dataArgument, callback, data); + CoreResponse resp; + resp.data = it->second(command, dataArgument, callback, data); + return resp; } void unload() override {} private: std::unordered_map handlers_; + std::unordered_map errorResponses_; std::unordered_map callCounts_; std::unordered_map lastDataArgs_; }; @@ -113,31 +118,33 @@ namespace foundry_local::Testing { return FileBackedCore(modelListPath, cachedModelsPath, loadedModelsPath); } - std::string call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, - void* /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { + CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, + NativeCallbackFn /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { - if (command == "get_catalog_name") - return "TestCatalog"; + CoreResponse resp; + + if (command == "get_catalog_name") { + resp.data = "TestCatalog"; + return resp; + } if (command == "get_model_list") { - if (modelListPath_.empty()) - return "[]"; - return ReadFile(modelListPath_); + resp.data = modelListPath_.empty() ? "[]" : ReadFile(modelListPath_); + return resp; } if (command == "get_cached_models") { - if (cachedModelsPath_.empty()) - return "[]"; - return ReadFile(cachedModelsPath_); + resp.data = cachedModelsPath_.empty() ? "[]" : ReadFile(cachedModelsPath_); + return resp; } if (command == "list_loaded_models") { - if (loadedModelsPath_.empty()) - return "[]"; - return ReadFile(loadedModelsPath_); + resp.data = loadedModelsPath_.empty() ? "[]" : ReadFile(loadedModelsPath_); + return resp; } - return "{}"; + resp.data = "{}"; + return resp; } void unload() override {} diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index 1da0d2a1..b6b8aaa4 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -99,7 +99,7 @@ TEST_F(ModelVariantTest, Unload_CallsCore) { TEST_F(ModelVariantTest, Unload_ThrowsOnError) { core_.OnCallThrow("unload_model", "unload failed"); auto variant = MakeVariant("test-model"); - EXPECT_THROW(variant.Unload(), FoundryLocalException); + EXPECT_THROW(variant.Unload(), Exception); } TEST_F(ModelVariantTest, Download_NoCallback) { @@ -113,15 +113,14 @@ variant.Download(); TEST_F(ModelVariantTest, Download_WithCallback) { core_.OnCall("get_cached_models", R"([])"); core_.OnCall("download_model", - [](std::string_view, const std::string*, void* callback, void* userData) -> std::string { - // Simulate calling the progress callback - if (callback && userData) { - auto cb = reinterpret_cast(callback); - std::string progress = "50"; - cb(progress.data(), static_cast(progress.size()), userData); - } - return ""; - }); +[](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + // Simulate calling the progress callback + if (callback && userData) { + std::string progress = "50"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; +}); auto variant = MakeVariant("test-model"); float lastProgress = -1.0f; @@ -139,7 +138,7 @@ TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { TEST_F(ModelVariantTest, RemoveFromCache_ThrowsOnError) { core_.OnCallThrow("remove_cached_model", "remove failed"); auto variant = MakeVariant("test-model"); - EXPECT_THROW(variant.RemoveFromCache(), FoundryLocalException); + EXPECT_THROW(variant.RemoveFromCache(), Exception); } TEST_F(ModelVariantTest, GetPath_CallsCore) { @@ -180,7 +179,7 @@ class ModelTest : public ::testing::Test { TEST_F(ModelTest, SelectedVariant_ThrowsWhenEmpty) { auto model = MakeModel(); - EXPECT_THROW(model.GetId(), FoundryLocalException); + EXPECT_THROW(model.GetId(), Exception); } TEST_F(ModelTest, AddVariant_AndSelect) { @@ -219,7 +218,7 @@ TEST_F(ModelTest, SelectVariant_NotFound_Throws) { Factory::SetSelectedVariantIndex(model, 0); auto external = MakeVariant("external", "alias", 1); - EXPECT_THROW(model.SelectVariant(external), FoundryLocalException); + EXPECT_THROW(model.SelectVariant(external), Exception); } TEST_F(ModelTest, GetLatestVariant) { diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp index 38d5c992..a6b077ab 100644 --- a/sdk/cpp/test/parser_and_types_test.cpp +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -24,95 +24,95 @@ class ParserTest : public ::testing::Test { }; TEST_F(ParserTest, ParseDeviceType_CPU) { - EXPECT_EQ(DeviceType::CPU, parse_device_type("CPU")); + EXPECT_EQ(DeviceType::CPU, ParsingUtils::parse_device_type("CPU")); } TEST_F(ParserTest, ParseDeviceType_GPU) { - EXPECT_EQ(DeviceType::GPU, parse_device_type("GPU")); + EXPECT_EQ(DeviceType::GPU, ParsingUtils::parse_device_type("GPU")); } TEST_F(ParserTest, ParseDeviceType_NPU) { - EXPECT_EQ(DeviceType::NPU, parse_device_type("NPU")); + EXPECT_EQ(DeviceType::NPU, ParsingUtils::parse_device_type("NPU")); } TEST_F(ParserTest, ParseDeviceType_Unknown) { - EXPECT_EQ(DeviceType::Invalid, parse_device_type("FPGA")); + EXPECT_EQ(DeviceType::Invalid, ParsingUtils::parse_device_type("FPGA")); } TEST_F(ParserTest, ParseFinishReason_Stop) { - EXPECT_EQ(FinishReason::Stop, parse_finish_reason("stop")); + EXPECT_EQ(FinishReason::Stop, ParsingUtils::parse_finish_reason("stop")); } TEST_F(ParserTest, ParseFinishReason_Length) { - EXPECT_EQ(FinishReason::Length, parse_finish_reason("length")); + EXPECT_EQ(FinishReason::Length, ParsingUtils::parse_finish_reason("length")); } TEST_F(ParserTest, ParseFinishReason_ToolCalls) { - EXPECT_EQ(FinishReason::ToolCalls, parse_finish_reason("tool_calls")); + EXPECT_EQ(FinishReason::ToolCalls, ParsingUtils::parse_finish_reason("tool_calls")); } TEST_F(ParserTest, ParseFinishReason_ContentFilter) { - EXPECT_EQ(FinishReason::ContentFilter, parse_finish_reason("content_filter")); + EXPECT_EQ(FinishReason::ContentFilter, ParsingUtils::parse_finish_reason("content_filter")); } TEST_F(ParserTest, ParseFinishReason_None) { - EXPECT_EQ(FinishReason::None, parse_finish_reason("unknown_value")); + EXPECT_EQ(FinishReason::None, ParsingUtils::parse_finish_reason("unknown_value")); } TEST_F(ParserTest, GetStringOrEmpty_Present) { nlohmann::json j = {{"key", "value"}}; - EXPECT_EQ("value", get_string_or_empty(j, "key")); + EXPECT_EQ("value", ParsingUtils::get_string_or_empty(j, "key")); } TEST_F(ParserTest, GetStringOrEmpty_Missing) { nlohmann::json j = {{"other", "value"}}; - EXPECT_EQ("", get_string_or_empty(j, "key")); + EXPECT_EQ("", ParsingUtils::get_string_or_empty(j, "key")); } TEST_F(ParserTest, GetStringOrEmpty_NonString) { nlohmann::json j = {{"key", 42}}; - EXPECT_EQ("", get_string_or_empty(j, "key")); + EXPECT_EQ("", ParsingUtils::get_string_or_empty(j, "key")); } TEST_F(ParserTest, GetOptString_Present) { nlohmann::json j = {{"key", "hello"}}; - auto result = get_opt_string(j, "key"); + auto result = ParsingUtils::get_opt_string(j, "key"); ASSERT_TRUE(result.has_value()); EXPECT_EQ("hello", *result); } TEST_F(ParserTest, GetOptString_Null) { nlohmann::json j = {{"key", nullptr}}; - EXPECT_FALSE(get_opt_string(j, "key").has_value()); + EXPECT_FALSE(ParsingUtils::get_opt_string(j, "key").has_value()); } TEST_F(ParserTest, GetOptString_Missing) { nlohmann::json j = {{"other", "v"}}; - EXPECT_FALSE(get_opt_string(j, "key").has_value()); + EXPECT_FALSE(ParsingUtils::get_opt_string(j, "key").has_value()); } TEST_F(ParserTest, GetOptInt_Present) { nlohmann::json j = {{"key", 42}}; - auto result = get_opt_int(j, "key"); + auto result = ParsingUtils::get_opt_int(j, "key"); ASSERT_TRUE(result.has_value()); EXPECT_EQ(42, *result); } TEST_F(ParserTest, GetOptInt_Missing) { nlohmann::json j = {}; - EXPECT_FALSE(get_opt_int(j, "key").has_value()); + EXPECT_FALSE(ParsingUtils::get_opt_int(j, "key").has_value()); } TEST_F(ParserTest, GetOptBool_Present) { nlohmann::json j = {{"key", true}}; - auto result = get_opt_bool(j, "key"); + auto result = ParsingUtils::get_opt_bool(j, "key"); ASSERT_TRUE(result.has_value()); EXPECT_TRUE(*result); } TEST_F(ParserTest, GetOptBool_Missing) { nlohmann::json j = {}; - EXPECT_FALSE(get_opt_bool(j, "key").has_value()); + EXPECT_FALSE(ParsingUtils::get_opt_bool(j, "key").has_value()); } TEST_F(ParserTest, ParseRuntime) { @@ -304,9 +304,9 @@ TEST_F(ParserTest, SerializeToolDefinition_MinimalFunction) { } TEST_F(ParserTest, ToolChoiceToString) { - EXPECT_EQ("auto", tool_choice_to_string(ToolChoiceKind::Auto)); - EXPECT_EQ("none", tool_choice_to_string(ToolChoiceKind::None)); - EXPECT_EQ("required", tool_choice_to_string(ToolChoiceKind::Required)); + EXPECT_EQ("auto", ParsingUtils::tool_choice_to_string(ToolChoiceKind::Auto)); + EXPECT_EQ("none", ParsingUtils::tool_choice_to_string(ToolChoiceKind::None)); + EXPECT_EQ("required", ParsingUtils::tool_choice_to_string(ToolChoiceKind::Required)); } TEST_F(ParserTest, ParseChatChoice_NonStreaming) { @@ -411,17 +411,17 @@ TEST(CoreInteropRequestTest, AddParam_Chaining) { } // ============================================================================= -// FoundryLocalException tests +// Exception tests // ============================================================================= -TEST(FoundryLocalExceptionTest, MessageOnly) { - FoundryLocalException ex("test error"); +TEST(ExceptionTest, MessageOnly) { + Exception ex("test error"); EXPECT_STREQ("test error", ex.what()); } -TEST(FoundryLocalExceptionTest, MessageAndLogger) { +TEST(ExceptionTest, MessageAndLogger) { NullLogger logger; - FoundryLocalException ex("logged error", logger); + Exception ex("logged error", logger); EXPECT_STREQ("logged error", ex.what()); } From 14f5d5fe52fdc6652a82e238f0e6d9f4b4b54422 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 19:02:12 -0700 Subject: [PATCH 11/18] Fixes no 3 --- sdk/cpp/include/model.h | 2 +- sdk/cpp/sample/main.cpp | 2 -- sdk/cpp/src/catalog.cpp | 6 +++- sdk/cpp/src/foundry_local_manager.cpp | 13 +++----- sdk/cpp/src/model.cpp | 13 ++++---- sdk/cpp/test/catalog_test.cpp | 44 ++++++++------------------- sdk/cpp/test/mock_object_factory.h | 4 +-- sdk/cpp/test/model_variant_test.cpp | 40 ++++++++++++------------ 8 files changed, 52 insertions(+), 72 deletions(-) diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index a5008ff8..7786b923 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -194,7 +194,7 @@ namespace foundry_local { gsl::not_null core_; std::vector variants_; - mutable std::optional selectedVariantIndex_; + mutable const ModelVariant* selectedVariant_ = nullptr; gsl::not_null logger_; friend class Catalog; diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index e1ed84be..bd0e1879 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -85,7 +85,6 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { std::cout << "\n=== Example 2: Non-Streaming Chat ===\n"; auto& catalog = manager.GetCatalog(); - auto models = catalog.ListModels(); auto* model = catalog.GetModel(alias); if (!model) { @@ -181,7 +180,6 @@ void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, con model->Load(); - const auto& selectedVariant = model->GetAllModelVariants()[0]; OpenAIAudioClient audio(*model); std::cout << "Transcribing: " << audioPath << "\n"; diff --git a/sdk/cpp/src/catalog.cpp b/sdk/cpp/src/catalog.cpp index 0de167f9..3ff9df0d 100644 --- a/sdk/cpp/src/catalog.cpp +++ b/sdk/cpp/src/catalog.cpp @@ -30,14 +30,17 @@ Catalog::Catalog(gsl::not_null injected, gsl::not_ } std::vector Catalog::GetLoadedModels() const { + UpdateModels(); return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); } std::vector Catalog::GetCachedModels() const { + UpdateModels(); return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); } Model* Catalog::GetModel(std::string_view modelId) const { + UpdateModels(); auto it = byAlias_.find(std::string(modelId)); if (it != byAlias_.end()) { return &it->second; @@ -97,7 +100,7 @@ void Catalog::UpdateModels() const { // Auto-select the first variant for each model. for (auto& [alias, model] : byAlias_) { if (!model.variants_.empty()) { - model.selectedVariantIndex_ = 0; + model.selectedVariant_ = &model.variants_.front(); } } @@ -105,6 +108,7 @@ void Catalog::UpdateModels() const { } ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + UpdateModels(); auto it = modelIdToModelVariant_.find(std::string(id)); if (it != modelIdToModelVariant_.end()) { return &it->second; diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp index f383abf6..7ee39253 100644 --- a/sdk/cpp/src/foundry_local_manager.cpp +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -130,6 +130,9 @@ void FoundryLocalManager::Initialize() { if (config_.app_data_dir) { initReq.AddParam("AppDataDir", config_.app_data_dir->string()); } + if (config_.model_cache_dir) { + initReq.AddParam("ModelCacheDir", config_.model_cache_dir->string()); + } if (config_.logs_dir) { initReq.AddParam("LogsDir", config_.logs_dir->string()); } @@ -162,15 +165,9 @@ void FoundryLocalManager::Initialize() { std::string setJson = setReq.ToJson(); auto setResponse = core_->call(setReq.Command(), *logger_, &setJson); if (setResponse.HasError()) { - throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + setResponse.error, *logger_); + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + setResponse.error, + *logger_); } - - logger_->Log(LogLevel::Information, - std::string("Model cache directory updated: ") + config_.model_cache_dir->string()); - } - else { - logger_->Log(LogLevel::Information, - std::string("Model cache directory already set to: ") + cacheResponse.data); } } } diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index d017387b..57e660d0 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -83,7 +83,7 @@ void ModelVariant::Download(DownloadProgressCallback onProgress) { if (!data || len <= 0) return; auto* st = static_cast(user); - std::string perc(static_cast(data), static_cast((std::min)(4, static_cast(len)))); + std::string perc(static_cast(data), static_cast(len)); try { float value = std::stof(perc); (*(st->cb))(value); @@ -155,17 +155,17 @@ Model::Model(gsl::not_null core, gsl::not_null= variants_.size()) { + if (!selectedVariant_) { throw Exception("Model has no selected variant", *logger_); } - return variants_[*selectedVariantIndex_]; + return *const_cast(selectedVariant_); } const ModelVariant& Model::SelectedVariant() const { - if (!selectedVariantIndex_ || *selectedVariantIndex_ >= variants_.size()) { + if (!selectedVariant_) { throw Exception("Model has no selected variant", *logger_); } - return variants_[*selectedVariantIndex_]; + return *selectedVariant_; } gsl::span Model::GetAllModelVariants() const { @@ -176,6 +176,7 @@ const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { const auto& targetName = variant.GetInfo().name; for (const auto& v : variants_) { + // The variants returned by the catalog are sorted by version, so the first match should always be the latest version. if (v.GetInfo().name == targetName) { return v; } @@ -202,7 +203,7 @@ void Model::SelectVariant(const ModelVariant& variant) const { *logger_); } - selectedVariantIndex_ = static_cast(std::distance(variants_.begin(), it)); + selectedVariant_ = &(*it); } IModel::CoreAccess Model::GetCoreAccess() const { diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp index fb7af1b3..af019024 100644 --- a/sdk/cpp/test/catalog_test.cpp +++ b/sdk/cpp/test/catalog_test.cpp @@ -93,39 +93,35 @@ TEST_F(CatalogTest, ListModels_IncludesOpenAIPrefix) { } TEST_F(CatalogTest, GetModel_Found) { - core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); - auto catalog = MakeCatalog(); - catalog->ListModels(); // populate +core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); +auto catalog = MakeCatalog(); - auto* model = catalog->GetModel("my-model"); +auto* model = catalog->GetModel("my-model"); ASSERT_NE(nullptr, model); EXPECT_EQ("my-model", model->GetAlias()); } TEST_F(CatalogTest, GetModel_NotFound) { - core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); - auto catalog = MakeCatalog(); - catalog->ListModels(); // populate +core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); +auto catalog = MakeCatalog(); - EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); +EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); } TEST_F(CatalogTest, GetModelVariant_Found) { - core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); - auto catalog = MakeCatalog(); - catalog->ListModels(); // populate +core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); +auto catalog = MakeCatalog(); - auto* variant = catalog->GetModelVariant("model-1:1"); +auto* variant = catalog->GetModelVariant("model-1:1"); ASSERT_NE(nullptr, variant); EXPECT_EQ("model-1:1", variant->GetId()); } TEST_F(CatalogTest, GetModelVariant_NotFound) { - core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); - auto catalog = MakeCatalog(); - catalog->ListModels(); +core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); +auto catalog = MakeCatalog(); - EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent:1")); +EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent:1")); } TEST_F(CatalogTest, GetLoadedModels) { @@ -133,7 +129,6 @@ TEST_F(CatalogTest, GetLoadedModels) { core_.OnCall("list_loaded_models", R"(["model-1:1"])"); auto catalog = MakeCatalog(); - catalog->ListModels(); // populate auto loaded = catalog->GetLoadedModels(); ASSERT_EQ(1u, loaded.size()); @@ -145,7 +140,6 @@ TEST_F(CatalogTest, GetCachedModels) { core_.OnCall("get_cached_models", R"(["model-1:1", "model-2:1"])"); auto catalog = MakeCatalog(); - catalog->ListModels(); // populate auto cached = catalog->GetCachedModels(); EXPECT_EQ(2u, cached.size()); @@ -200,8 +194,6 @@ TEST_F(FileBasedCatalogTest, RealModelsList_VariantDetails) { auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate - const auto* gpuVariant = catalog->GetModelVariant("Phi-4-generic-gpu:1"); ASSERT_NE(nullptr, gpuVariant); @@ -235,8 +227,6 @@ TEST_F(FileBasedCatalogTest, RealModelsList_CpuVariantDetails) { auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate - const auto* cpuVariant = catalog->GetModelVariant("Phi-4-generic-cpu:1"); ASSERT_NE(nullptr, cpuVariant); @@ -286,8 +276,6 @@ TEST_F(FileBasedCatalogTest, CachedModels) { FileBackedCore::FromBoth(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate internal maps - auto cached = catalog->GetCachedModels(); ASSERT_EQ(2u, cached.size()); @@ -329,8 +317,6 @@ TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel_CachedSubset) { TestDataPath("single_cached_model.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate - auto cached = catalog->GetCachedModels(); ASSERT_EQ(1u, cached.size()); EXPECT_EQ("multi-v1-cpu", cached[0]->GetInfo().name); @@ -340,8 +326,6 @@ TEST_F(FileBasedCatalogTest, GetModelByAlias) { auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate - const auto* model = catalog->GetModel("phi-4"); ASSERT_NE(nullptr, model); EXPECT_EQ("phi-4", model->GetAlias()); @@ -355,8 +339,6 @@ TEST_F(FileBasedCatalogTest, GetModelVariant_NotInCatalog) { auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate - EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent-variant-id")); } @@ -365,8 +347,6 @@ TEST_F(FileBasedCatalogTest, LoadedModels) { TestDataPath("valid_loaded_models.json")); auto catalog = Factory::CreateCatalog(&core, &logger_); - catalog->ListModels(); // populate - auto loaded = catalog->GetLoadedModels(); ASSERT_EQ(1u, loaded.size()); EXPECT_EQ("Phi-4-generic-gpu", loaded[0]->GetInfo().name); diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h index 6a3d3d14..86331ab3 100644 --- a/sdk/cpp/test/mock_object_factory.h +++ b/sdk/cpp/test/mock_object_factory.h @@ -35,8 +35,8 @@ namespace foundry_local::Testing { model.variants_.push_back(std::move(variant)); } - /// Set the selected variant index on a Model. - static void SetSelectedVariantIndex(Model& model, size_t index) { model.selectedVariantIndex_ = index; } + /// Set the selected variant on a Model. + static void SelectFirstVariant(Model& model) { model.selectedVariant_ = &model.variants_.front(); } /// Helper to build a minimal ModelInfo with defaults. static ModelInfo MakeModelInfo(std::string name, std::string alias = "", uint32_t version = 1) { diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index b6b8aaa4..ac21fe77 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -168,11 +168,11 @@ class ModelTest : public ::testing::Test { return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); } - /// Helper: create a Model with one variant and selectedVariantIndex_=0. + /// Helper: create a Model with one variant and selectedVariant_ set. Model MakeModelWithVariant(const std::string& name = "test-model", const std::string& alias = "test-alias") { auto model = MakeModel(); Factory::AddVariantToModel(model, MakeVariant(name, alias, 1)); - Factory::SetSelectedVariantIndex(model, 0); + Factory::SelectFirstVariant(model); return model; } }; @@ -183,29 +183,29 @@ TEST_F(ModelTest, SelectedVariant_ThrowsWhenEmpty) { } TEST_F(ModelTest, AddVariant_AndSelect) { - auto model = MakeModel(); - Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); - Factory::SetSelectedVariantIndex(model, 0); +auto model = MakeModel(); +Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); +Factory::SelectFirstVariant(model); EXPECT_EQ("v1:1", model.GetId()); EXPECT_EQ("alias", model.GetAlias()); } TEST_F(ModelTest, GetAllModelVariants) { - auto model = MakeModel(); - Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); - Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); - Factory::SetSelectedVariantIndex(model, 0); +auto model = MakeModel(); +Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); +Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); +Factory::SelectFirstVariant(model); auto variants = model.GetAllModelVariants(); EXPECT_EQ(2u, variants.size()); } TEST_F(ModelTest, SelectVariant) { - auto model = MakeModel(); - Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); - Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); - Factory::SetSelectedVariantIndex(model, 0); +auto model = MakeModel(); +Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); +Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); +Factory::SelectFirstVariant(model); const auto& v2 = model.GetAllModelVariants()[1]; model.SelectVariant(v2); @@ -213,19 +213,19 @@ TEST_F(ModelTest, SelectVariant) { } TEST_F(ModelTest, SelectVariant_NotFound_Throws) { - auto model = MakeModel(); - Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); - Factory::SetSelectedVariantIndex(model, 0); +auto model = MakeModel(); +Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); +Factory::SelectFirstVariant(model); auto external = MakeVariant("external", "alias", 1); EXPECT_THROW(model.SelectVariant(external), Exception); } TEST_F(ModelTest, GetLatestVariant) { - auto model = MakeModel(); - Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); - Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); - Factory::SetSelectedVariantIndex(model, 0); +auto model = MakeModel(); +Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); +Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); +Factory::SelectFirstVariant(model); const auto& first = model.GetAllModelVariants()[0]; const auto& latest = model.GetLatestVersion(first); From 0a904409075a1b864a400536bfcb002629dc6ca2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 31 Mar 2026 20:00:00 -0700 Subject: [PATCH 12/18] vcpkg --- sdk/cpp/CMakeLists.txt | 50 +----- sdk/cpp/CMakePresets.json | 10 +- sdk/cpp/test/parser_and_types_test.cpp | 171 +------------------ sdk/cpp/triplets/x64-windows-static-md.cmake | 3 + sdk/cpp/vcpkg-configuration.json | 6 + sdk/cpp/vcpkg.json | 10 ++ 6 files changed, 33 insertions(+), 217 deletions(-) create mode 100644 sdk/cpp/triplets/x64-windows-static-md.cmake create mode 100644 sdk/cpp/vcpkg-configuration.json create mode 100644 sdk/cpp/vcpkg.json diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index f35f2ee9..dd08b76e 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -36,51 +36,13 @@ set(CMAKE_CXX_EXTENSIONS OFF) # Optional: target Windows 10+ APIs (adjust if you need older) add_compile_definitions(_WIN32_WINNT=0x0A00 WINVER=0x0A00) -include(FetchContent) - -# ----------------------------- -# nlohmann_json (clean CMake target) -# ----------------------------- -FetchContent_Declare( - nlohmann_json - GIT_REPOSITORY https://github.com/nlohmann/json.git - GIT_TAG v3.12.0 -) -FetchContent_MakeAvailable(nlohmann_json) - -# ----------------------------- -# WIL (download headers only; DO NOT run WIL's CMake) -# This avoids NuGet/test requirements and missing wil::wil targets. -# ----------------------------- -FetchContent_Declare( - wil_src - GIT_REPOSITORY https://github.com/microsoft/wil.git - GIT_TAG v1.0.250325.1 -) -FetchContent_Populate(wil_src) - # ----------------------------- -# Microsoft GSL (Guidelines Support Library) -# Provides gsl::span for C++17 (std::span is C++20) +# Dependencies (installed via vcpkg) # ----------------------------- -FetchContent_Declare( - gsl - GIT_REPOSITORY https://github.com/microsoft/GSL.git - GIT_TAG v4.0.0 -) -FetchContent_MakeAvailable(gsl) - -# ----------------------------- -# Google Test (for unit tests) -# ----------------------------- -FetchContent_Declare( - googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.14.0 -) -# Prevent GoogleTest from overriding our compiler/linker options on Windows -set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) -FetchContent_MakeAvailable(googletest) +find_package(nlohmann_json CONFIG REQUIRED) +find_package(wil CONFIG REQUIRED) +find_package(Microsoft.GSL CONFIG REQUIRED) +find_package(GTest CONFIG REQUIRED) # ----------------------------- # SDK library (STATIC) @@ -97,7 +59,6 @@ add_library(CppSdk STATIC target_include_directories(CppSdk PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include - ${wil_src_SOURCE_DIR}/include PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src ) @@ -106,6 +67,7 @@ target_link_libraries(CppSdk PUBLIC nlohmann_json::nlohmann_json Microsoft.GSL::GSL + WIL::WIL ) # ----------------------------- diff --git a/sdk/cpp/CMakePresets.json b/sdk/cpp/CMakePresets.json index f9ab249d..ddead1b2 100644 --- a/sdk/cpp/CMakePresets.json +++ b/sdk/cpp/CMakePresets.json @@ -7,9 +7,11 @@ "generator": "Ninja", "binaryDir": "${sourceDir}/out/build/${presetName}", "installDir": "${sourceDir}/out/install/${presetName}", + "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", "cacheVariables": { "CMAKE_C_COMPILER": "cl.exe", - "CMAKE_CXX_COMPILER": "cl.exe" + "CMAKE_CXX_COMPILER": "cl.exe", + "VCPKG_OVERLAY_TRIPLETS": "${sourceDir}/triplets" }, "condition": { "type": "equals", @@ -26,7 +28,8 @@ "strategy": "external" }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Debug" + "CMAKE_BUILD_TYPE": "Debug", + "VCPKG_TARGET_TRIPLET": "x64-windows-static-md" } }, { @@ -38,7 +41,8 @@ "strategy": "external" }, "cacheVariables": { - "CMAKE_BUILD_TYPE": "Release" + "CMAKE_BUILD_TYPE": "Release", + "VCPKG_TARGET_TRIPLET": "x64-windows-static-md" } }, { diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp index a6b077ab..00cac8da 100644 --- a/sdk/cpp/test/parser_and_types_test.cpp +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -423,173 +423,4 @@ TEST(ExceptionTest, MessageAndLogger) { NullLogger logger; Exception ex("logged error", logger); EXPECT_STREQ("logged error", ex.what()); -} - -// ============================================================================= -// File-based parser tests (read JSON from testdata/) -// ============================================================================= - -class FileBasedParserTest : public ::testing::Test { -protected: - static std::string TestDataPath(const std::string& filename) { return "testdata/" + filename; } - - static nlohmann::json LoadJsonArray(const std::string& filename) { - std::string raw = Testing::ReadFile(TestDataPath(filename)); - return nlohmann::json::parse(raw); - } -}; - -TEST_F(FileBasedParserTest, AllFields_RequiredFields) { - auto arr = LoadJsonArray("model_all_fields.json"); - ModelInfo info = arr.at(0).get(); - EXPECT_EQ("model-all-fields", info.id); - EXPECT_EQ("model-all-fields", info.name); - EXPECT_EQ(3u, info.version); - EXPECT_EQ("full-model", info.alias); - EXPECT_EQ("onnx", info.provider_type); - EXPECT_EQ("https://example.com/full-model", info.uri); - EXPECT_EQ("text", info.model_type); - EXPECT_TRUE(info.cached); - EXPECT_EQ(1710000000, info.created_at_unix); -} - -TEST_F(FileBasedParserTest, AllFields_OptionalStrings) { - auto arr = LoadJsonArray("model_all_fields.json"); - ModelInfo info = arr.at(0).get(); - - ASSERT_TRUE(info.display_name.has_value()); - EXPECT_EQ("Full Model Display Name", *info.display_name); - ASSERT_TRUE(info.publisher.has_value()); - EXPECT_EQ("TestPublisher", *info.publisher); - ASSERT_TRUE(info.license.has_value()); - EXPECT_EQ("Apache-2.0", *info.license); - ASSERT_TRUE(info.license_description.has_value()); - EXPECT_EQ("Permissive open source license", *info.license_description); - ASSERT_TRUE(info.task.has_value()); - EXPECT_EQ("text-generation", *info.task); - ASSERT_TRUE(info.min_fl_version.has_value()); - EXPECT_EQ("1.0.0", *info.min_fl_version); -} - -TEST_F(FileBasedParserTest, AllFields_NumericOptionals) { - auto arr = LoadJsonArray("model_all_fields.json"); - ModelInfo info = arr.at(0).get(); - - ASSERT_TRUE(info.file_size_mb.has_value()); - EXPECT_EQ(16384u, *info.file_size_mb); - ASSERT_TRUE(info.supports_tool_calling.has_value()); - EXPECT_TRUE(*info.supports_tool_calling); - ASSERT_TRUE(info.max_output_tokens.has_value()); - EXPECT_EQ(8192, *info.max_output_tokens); -} - -TEST_F(FileBasedParserTest, AllFields_Runtime) { - auto arr = LoadJsonArray("model_all_fields.json"); - ModelInfo info = arr.at(0).get(); - - ASSERT_TRUE(info.runtime.has_value()); - EXPECT_EQ(DeviceType::NPU, info.runtime->device_type); - EXPECT_EQ("QNN", info.runtime->execution_provider); -} - -TEST_F(FileBasedParserTest, AllFields_PromptTemplate) { - auto arr = LoadJsonArray("model_all_fields.json"); - ModelInfo info = arr.at(0).get(); - - ASSERT_TRUE(info.prompt_template.has_value()); - EXPECT_EQ("<|system|>\n", info.prompt_template->system); - EXPECT_EQ("<|user|>\n", info.prompt_template->user); - EXPECT_EQ("<|assistant|>\n", info.prompt_template->assistant); - EXPECT_EQ("<|endoftext|>", info.prompt_template->prompt); -} - -TEST_F(FileBasedParserTest, AllFields_ModelSettings) { - auto arr = LoadJsonArray("model_all_fields.json"); - ModelInfo info = arr.at(0).get(); - - ASSERT_TRUE(info.model_settings.has_value()); - ASSERT_EQ(3u, info.model_settings->parameters.size()); - EXPECT_EQ("temperature", info.model_settings->parameters[0].name); - ASSERT_TRUE(info.model_settings->parameters[0].value.has_value()); - EXPECT_EQ("0.7", *info.model_settings->parameters[0].value); - EXPECT_EQ("top_p", info.model_settings->parameters[1].name); - ASSERT_TRUE(info.model_settings->parameters[1].value.has_value()); - EXPECT_EQ("0.9", *info.model_settings->parameters[1].value); - EXPECT_EQ("max_tokens", info.model_settings->parameters[2].name); - EXPECT_FALSE(info.model_settings->parameters[2].value.has_value()); -} - -TEST_F(FileBasedParserTest, MinimalFields_RequiredOnly) { - auto arr = LoadJsonArray("model_minimal_fields.json"); - ModelInfo info = arr.at(0).get(); - - EXPECT_EQ("minimal-model", info.id); - EXPECT_EQ("minimal-model", info.name); - EXPECT_EQ(1u, info.version); - EXPECT_EQ("minimal", info.alias); - EXPECT_EQ("onnx", info.provider_type); - EXPECT_EQ("text", info.model_type); - EXPECT_FALSE(info.cached); - EXPECT_EQ(0, info.created_at_unix); -} - -TEST_F(FileBasedParserTest, MinimalFields_AllOptionalsAbsent) { - auto arr = LoadJsonArray("model_minimal_fields.json"); - ModelInfo info = arr.at(0).get(); - - EXPECT_FALSE(info.display_name.has_value()); - EXPECT_FALSE(info.publisher.has_value()); - EXPECT_FALSE(info.license.has_value()); - EXPECT_FALSE(info.license_description.has_value()); - EXPECT_FALSE(info.task.has_value()); - EXPECT_FALSE(info.file_size_mb.has_value()); - EXPECT_FALSE(info.supports_tool_calling.has_value()); - EXPECT_FALSE(info.max_output_tokens.has_value()); - EXPECT_FALSE(info.min_fl_version.has_value()); - EXPECT_FALSE(info.runtime.has_value()); - EXPECT_FALSE(info.prompt_template.has_value()); - EXPECT_FALSE(info.model_settings.has_value()); -} - -TEST_F(FileBasedParserTest, NullOptionals_AllOptionalsAbsent) { - auto arr = LoadJsonArray("model_null_optionals.json"); - ModelInfo info = arr.at(0).get(); - - EXPECT_EQ("model-null-optionals", info.id); - EXPECT_EQ("null-opts", info.alias); - - // All explicitly-null fields should parse as absent - EXPECT_FALSE(info.display_name.has_value()); - EXPECT_FALSE(info.publisher.has_value()); - EXPECT_FALSE(info.license.has_value()); - EXPECT_FALSE(info.license_description.has_value()); - EXPECT_FALSE(info.task.has_value()); - EXPECT_FALSE(info.file_size_mb.has_value()); - EXPECT_FALSE(info.supports_tool_calling.has_value()); - EXPECT_FALSE(info.max_output_tokens.has_value()); - EXPECT_FALSE(info.min_fl_version.has_value()); - EXPECT_FALSE(info.runtime.has_value()); - EXPECT_FALSE(info.prompt_template.has_value()); - EXPECT_FALSE(info.model_settings.has_value()); -} - -TEST_F(FileBasedParserTest, RealModelsList_ParseAllEntries) { - auto arr = LoadJsonArray("real_models_list.json"); - ASSERT_EQ(4u, arr.size()); - - for (const auto& j : arr) { - EXPECT_NO_THROW({ - auto info = j.get(); - EXPECT_FALSE(info.id.empty()); - EXPECT_FALSE(info.name.empty()); - EXPECT_FALSE(info.alias.empty()); - }); - } -} - -TEST_F(FileBasedParserTest, MalformedJson_Throws) { - EXPECT_ANY_THROW({ - std::string raw = Testing::ReadFile(TestDataPath("malformed_models_list.json")); - nlohmann::json::parse(raw); - }); -} +} \ No newline at end of file diff --git a/sdk/cpp/triplets/x64-windows-static-md.cmake b/sdk/cpp/triplets/x64-windows-static-md.cmake new file mode 100644 index 00000000..63d6cde2 --- /dev/null +++ b/sdk/cpp/triplets/x64-windows-static-md.cmake @@ -0,0 +1,3 @@ +set(VCPKG_TARGET_ARCHITECTURE x64) +set(VCPKG_CRT_LINKAGE dynamic) +set(VCPKG_LIBRARY_LINKAGE static) diff --git a/sdk/cpp/vcpkg-configuration.json b/sdk/cpp/vcpkg-configuration.json new file mode 100644 index 00000000..a5253fb7 --- /dev/null +++ b/sdk/cpp/vcpkg-configuration.json @@ -0,0 +1,6 @@ +{ + "default-registry": { + "kind": "builtin", + "baseline": "a9f0cd0345fb29cd227d802f1fd1917c28f8e5a3" + } +} diff --git a/sdk/cpp/vcpkg.json b/sdk/cpp/vcpkg.json new file mode 100644 index 00000000..ec08c349 --- /dev/null +++ b/sdk/cpp/vcpkg.json @@ -0,0 +1,10 @@ +{ + "name": "cppsdk", + "version-string": "0.1.0", + "dependencies": [ + "nlohmann-json", + "wil", + "ms-gsl", + "gtest" + ] +} From 444958b9bda7b6a733c85e330d10c8f8439776ee Mon Sep 17 00:00:00 2001 From: Your Name Date: Wed, 1 Apr 2026 09:34:38 -0700 Subject: [PATCH 13/18] E2e tests --- sdk/cpp/CMakeLists.txt | 26 + sdk/cpp/client_test.cpp | 722 ++++++++++++++++++++++++ sdk/cpp/include/foundry_local_manager.h | 34 +- sdk/cpp/sample/main.cpp | 5 +- sdk/cpp/src/foundry_local_manager.cpp | 49 +- 5 files changed, 812 insertions(+), 24 deletions(-) create mode 100644 sdk/cpp/client_test.cpp diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index dd08b76e..080815c4 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -117,6 +117,32 @@ gtest_discover_tests(CppSdkTests WORKING_DIRECTORY $ ) +# ----------------------------- +# End-to-end tests (separate executable, requires Core DLL) +# Exercises the full public API against the real catalog. +# Tests that need model download are DISABLED by default; +# run with --gtest_also_run_disabled_tests locally. +# ----------------------------- +add_executable(CppSdkE2ETests + test/e2e_test.cpp +) + +target_include_directories(CppSdkE2ETests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test + ${CMAKE_CURRENT_SOURCE_DIR}/src +) + +target_link_libraries(CppSdkE2ETests + PRIVATE + CppSdk + GTest::gtest_main +) + +gtest_discover_tests(CppSdkE2ETests + WORKING_DIRECTORY $ +) + # Make Visual Studio start/debug this target by default set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY VS_STARTUP_PROJECT CppSdkSample) diff --git a/sdk/cpp/client_test.cpp b/sdk/cpp/client_test.cpp new file mode 100644 index 00000000..2a09ac77 --- /dev/null +++ b/sdk/cpp/client_test.cpp @@ -0,0 +1,722 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace foundry_local; +using namespace foundry_local::Testing; + +using Factory = MockObjectFactory; + +class OpenAIChatClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeChatResponseJson(const std::string& content = "Hello!") { + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-test"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", content}}}}}}}; + return resp.dump(); + } + + ModelVariant MakeLoadedVariant(const std::string& name = "chat-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(OpenAIChatClientTest, CompleteChat_BasicResponse) { + core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Say hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + EXPECT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("Hello world!", response.choices[0].message->content); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_WithSettings) { +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + +auto variant = MakeLoadedVariant(); +OpenAIChatClient client(variant); + +std::vector messages = {{"user", "test", {}}}; +ChatSettings settings; +settings.temperature = 0.7f; +settings.max_tokens = 100; + settings.top_p = 0.9f; + settings.frequency_penalty = 0.5f; + settings.presence_penalty = 0.3f; + settings.n = 2; + settings.random_seed = 42; + settings.top_k = 10; + + auto response = client.CompleteChat(messages, settings); + + // Verify the request JSON contains the settings + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_NEAR(0.7f, openAiReq["temperature"].get(), 0.001f); + EXPECT_EQ(100, openAiReq["max_completion_tokens"].get()); + EXPECT_NEAR(0.9f, openAiReq["top_p"].get(), 0.001f); + EXPECT_NEAR(0.5f, openAiReq["frequency_penalty"].get(), 0.001f); + EXPECT_NEAR(0.3f, openAiReq["presence_penalty"].get(), 0.001f); + EXPECT_EQ(2, openAiReq["n"].get()); + EXPECT_EQ(42, openAiReq["seed"].get()); + EXPECT_EQ(10, openAiReq["metadata"]["top_k"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_RequestFormat) { +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + +auto variant = MakeLoadedVariant(); +OpenAIChatClient client(variant); + +std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; +ChatSettings settings; +auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_EQ("chat-model", openAiReq["model"].get()); + EXPECT_FALSE(openAiReq["stream"].get()); + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_EQ("system", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming) { +nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hello"}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", {{{"index", 0}, {"finish_reason", "stop"}, {"delta", {{"content", " world"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + callback(s1.data(), static_cast(s1.size()), userData); + callback(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + std::vector chunks; + client.CompleteChatStreaming(messages, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_TRUE(chunks[0].is_delta); + ASSERT_TRUE(chunks[0].choices[0].delta.has_value()); + EXPECT_EQ("Hello", chunks[0].choices[0].delta->content); + EXPECT_EQ(" world", chunks[1].choices[0].delta->content); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { + nlohmann::json chunk = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string s = chunk.dump(); + callback(s.data(), static_cast(s.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChatStreaming( + messages, settings, + [](const ChatCompletionCreateResponse&) { throw std::runtime_error("callback error"); }), + std::runtime_error); +} + +TEST_F(OpenAIChatClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(OpenAIChatClient client(variant), Exception); +} + +TEST_F(OpenAIChatClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + EXPECT_EQ("chat-model", client.GetModelId()); +} + +// ---------- Tool calling tests ---------- + +TEST_F(OpenAIChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + +auto variant = MakeLoadedVariant(); +OpenAIChatClient client(variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", + "A tool for multiplying two numbers.", + PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} + } + } + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + // Verify the request JSON contains tools and tool_choice + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_TRUE(openAiReq.contains("tools")); + ASSERT_TRUE(openAiReq["tools"].is_array()); + EXPECT_EQ(1u, openAiReq["tools"].size()); + EXPECT_EQ("function", openAiReq["tools"][0]["type"].get()); + EXPECT_EQ("multiply_numbers", openAiReq["tools"][0]["function"]["name"].get()); + EXPECT_EQ("A tool for multiplying two numbers.", openAiReq["tools"][0]["function"]["description"].get()); + EXPECT_EQ("object", openAiReq["tools"][0]["function"]["parameters"]["type"].get()); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"].contains("properties")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("first")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("second")); + + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { +core_.OnCall("chat_completions", MakeChatResponseJson()); +core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + +auto variant = MakeLoadedVariant(); +OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_FALSE(openAiReq.contains("tools")); + EXPECT_FALSE(openAiReq.contains("tool_choice")); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallResponse_Parsed) { + // Simulate a response with tool calls from the model + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", resp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_TRUE(response.choices[0].message.has_value()); + + const auto& msg = *response.choices[0].message; + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_1", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("multiply_numbers", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"first\": 7, \"second\": 6}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceAuto) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Auto; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceNone) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::None; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("none", openAiReq["tool_choice"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "42"; + toolMsg.tool_call_id = "call_1"; + + std::vector messages = { + {"user", "What is 7 * 6?", {}}, + std::move(toolMsg) + }; + ChatSettings settings; + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_FALSE(openAiReq["messages"][0].contains("tool_call_id")); + EXPECT_EQ("call_1", openAiReq["messages"][1]["tool_call_id"].get()); + EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", nullptr}, + {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"delta", + {{"content", ""}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + callback(s1.data(), static_cast(s1.size()), userData); + callback(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + + std::vector tools = {{ + "function", + FunctionDefinition{"multiply", "Multiply numbers."} + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + std::vector chunks; + client.CompleteChatStreaming(messages, tools, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ(FinishReason::ToolCalls, chunks[1].choices[0].finish_reason); + ASSERT_TRUE(chunks[1].choices[0].delta.has_value()); + ASSERT_EQ(1u, chunks[1].choices[0].delta->tool_calls.size()); + EXPECT_EQ("multiply", chunks[1].choices[0].delta->tool_calls[0].function_call->name); + + // Verify tools were included in the request + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_TRUE(openAiReq.contains("tools")); + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +class OpenAIAudioClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeLoadedVariant(const std::string& name = "audio-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(OpenAIAudioClientTest, TranscribeAudio) { +core_.OnCall("audio_transcribe", "Hello world transcribed text"); +core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + +auto variant = MakeLoadedVariant(); +OpenAIAudioClient client(variant); + auto response = client.TranscribeAudio("test.wav"); + + EXPECT_EQ("Hello world transcribed text", response.text); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudio_RequestFormat) { +core_.OnCall("audio_transcribe", "text"); +core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + +auto variant = MakeLoadedVariant(); +OpenAIAudioClient client(variant); + client.TranscribeAudio("audio.wav"); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("audio-model", openAiReq["Model"].get()); + EXPECT_EQ("audio.wav", openAiReq["FileName"].get()); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string text1 = "Hello "; + std::string text2 = "world!"; + callback(text1.data(), static_cast(text1.size()), userData); + callback(text2.data(), static_cast(text2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + std::vector chunks; + client.TranscribeAudioStreaming( + "test.wav", [&](const AudioCreateTranscriptionResponse& chunk) { chunks.push_back(chunk.text); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ("Hello ", chunks[0]); + EXPECT_EQ("world!", chunks[1]); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string text = "test"; + callback(text.data(), static_cast(text.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW( + client.TranscribeAudioStreaming( + "test.wav", [](const AudioCreateTranscriptionResponse&) { throw std::runtime_error("streaming error"); }), + std::runtime_error); +} + +TEST_F(OpenAIAudioClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(OpenAIAudioClient client(variant), FoundryLocalException); +} + +TEST_F(OpenAIAudioClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + EXPECT_EQ("audio-model", client.GetModelId()); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudio_CoreError_Throws) { + core_.OnCallThrow("audio_transcribe", "transcription failed"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW(client.TranscribeAudio("test.wav"), Exception); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_CoreError_Throws) { + core_.OnCallThrow("audio_transcribe", "streaming transcription failed"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW( + client.TranscribeAudioStreaming("test.wav", [](const AudioCreateTranscriptionResponse&) {}), + Exception); +} + +// ===================================================================== +// Multi-turn conversation tests +// ===================================================================== + +TEST_F(OpenAIChatClientTest, CompleteChat_MultiTurn) { + // First turn: user asks a question + core_.OnCall("chat_completions", MakeChatResponseJson("42")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = { + {"user", "What is 7 * 6?", {}} + }; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("42", response.choices[0].message->content); + + // Second turn: add assistant response + user follow-up + messages.push_back({"assistant", response.choices[0].message->content, {}}); + messages.push_back({"user", "Is that a real number?", {}}); + + core_.OnCall("chat_completions", MakeChatResponseJson("Yes")); + auto response2 = client.CompleteChat(messages, settings); + + ASSERT_TRUE(response2.successful); + EXPECT_EQ("Yes", response2.choices[0].message->content); + + // Verify the second request contained all 3 messages + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_EQ(3u, openAiReq["messages"].size()); + EXPECT_EQ("user", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("assistant", openAiReq["messages"][1]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][2]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_CoreError_Throws) { + core_.OnCallThrow("chat_completions", "inference failed"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChat(messages, settings), Exception); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_CoreError_Throws) { + core_.OnCallThrow("chat_completions", "streaming failed"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + + EXPECT_THROW( + client.CompleteChatStreaming(messages, settings, + [](const ChatCompletionCreateResponse&) {}), + Exception); +} + +// ===================================================================== +// Full tool-call round-trip +// ===================================================================== + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { + // Step 1: model returns a tool call + nlohmann::json toolCallResp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", toolCallResp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = { + {"system", "You are a helpful AI assistant.", {}}, + {"user", "What is 7 multiplied by 6?", {}} + }; + + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", + "A tool for multiplying two numbers.", + PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} + } + } + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_EQ(1u, response.choices[0].message->tool_calls.size()); + EXPECT_EQ("multiply_numbers", response.choices[0].message->tool_calls[0].function_call->name); + + // Step 2: send tool response back, model continues with the answer + messages.push_back({"assistant", response.choices[0].message->content, {}}); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "7 x 6 = 42."; + toolMsg.tool_call_id = "call_1"; + messages.push_back(std::move(toolMsg)); + + messages.push_back({"system", "Respond only with the answer generated by the tool.", {}}); + + core_.OnCall("chat_completions", MakeChatResponseJson("42")); + settings.tool_choice = ToolChoiceKind::Auto; + + auto response2 = client.CompleteChat(messages, tools, settings); + + ASSERT_TRUE(response2.successful); + EXPECT_EQ("42", response2.choices[0].message->content); + + // Verify the second request contained tool response message + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + // 5 messages: system, user, assistant (tool_call), tool, system (continue) + ASSERT_EQ(5u, openAiReq["messages"].size()); + EXPECT_EQ("tool", openAiReq["messages"][3]["role"].get()); + EXPECT_EQ("call_1", openAiReq["messages"][3]["tool_call_id"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h index 00621b24..9ff5eda3 100644 --- a/sdk/cpp/include/foundry_local_manager.h +++ b/sdk/cpp/include/foundry_local_manager.h @@ -23,11 +23,26 @@ namespace foundry_local { public: FoundryLocalManager(const FoundryLocalManager&) = delete; FoundryLocalManager& operator=(const FoundryLocalManager&) = delete; - FoundryLocalManager(FoundryLocalManager&& other) noexcept; - FoundryLocalManager& operator=(FoundryLocalManager&& other) noexcept; + FoundryLocalManager(FoundryLocalManager&&) = delete; + FoundryLocalManager& operator=(FoundryLocalManager&&) = delete; - explicit FoundryLocalManager(Configuration configuration, ILogger* logger = nullptr); - ~FoundryLocalManager(); + /// Create the FoundryLocalManager singleton instance. + /// Throws if an instance has already been created. Call Destroy() first to release the current instance. + /// @param configuration Configuration to use. + /// @param logger Optional application logger. Pass nullptr to suppress log output. + static void Create(Configuration configuration, ILogger* logger = nullptr); + + /// Get the singleton instance. + /// Throws if Create() has not been called. + static FoundryLocalManager& Instance(); + + /// Returns true if the singleton instance has been created. + static bool IsInitialized() noexcept; + + /// Destroy the singleton instance, performing deterministic cleanup. + /// Unloads all loaded models and stops the web service if running. + /// After this call, Create() may be called again. + static void Destroy() noexcept; const Catalog& GetCatalog() const; Catalog& GetCatalog(); @@ -49,12 +64,19 @@ namespace foundry_local { void EnsureEpsDownloaded() const; private: - bool OwnsLogger() const noexcept { return logger_ == &defaultLogger_; } + explicit FoundryLocalManager(Configuration configuration, ILogger* logger); + ~FoundryLocalManager(); - Configuration config_; + struct Deleter { + void operator()(FoundryLocalManager* p) const noexcept { delete p; } + }; void Initialize(); + void Cleanup() noexcept; + static std::unique_ptr instance_; + + Configuration config_; NullLogger defaultLogger_; std::unique_ptr core_; std::unique_ptr catalog_; diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index bd0e1879..434eeba9 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -331,7 +331,8 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) int main() { try { StdLogger logger; - FoundryLocalManager manager({"SampleApp"}, &logger); + FoundryLocalManager::Create({"SampleApp"}, &logger); + auto& manager = FoundryLocalManager::Instance(); // 1. Browse the full catalog BrowseCatalog(manager); @@ -348,10 +349,12 @@ int main() { // 5. Tool calling (define tools, let the model call them, feed results back) ChatWithToolCalling(manager, "phi-3.5-mini"); + FoundryLocalManager::Destroy(); return 0; } catch (const std::exception& ex) { std::cerr << "Fatal: " << ex.what() << std::endl; + FoundryLocalManager::Destroy(); return 1; } } diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp index 7ee39253..aca77be8 100644 --- a/sdk/cpp/src/foundry_local_manager.cpp +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -17,6 +17,35 @@ namespace foundry_local { +std::unique_ptr FoundryLocalManager::instance_; + +void FoundryLocalManager::Create(Configuration configuration, ILogger* logger) { + if (instance_) { + NullLogger fallback; + ILogger& log = logger ? *logger : fallback; + throw Exception("FoundryLocalManager has already been created. Call Destroy() first.", log); + } + + // Use a local to ensure full initialization before assigning to the static instance. + std::unique_ptr manager(new FoundryLocalManager(std::move(configuration), logger)); + instance_ = std::move(manager); +} + +FoundryLocalManager& FoundryLocalManager::Instance() { + if (!instance_) { + throw Exception("FoundryLocalManager has not been created. Call Create() first."); + } + return *instance_; +} + +bool FoundryLocalManager::IsInitialized() noexcept { + return instance_ != nullptr; +} + +void FoundryLocalManager::Destroy() noexcept { + instance_.reset(); +} + FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) : config_(std::move(configuration)), core_(std::make_unique()), logger_(logger ? logger : &defaultLogger_) { @@ -25,25 +54,11 @@ FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* l catalog_ = Catalog::Create(core_.get(), logger_); } -FoundryLocalManager::FoundryLocalManager(FoundryLocalManager&& other) noexcept - : config_(std::move(other.config_)), core_(std::move(other.core_)), catalog_(std::move(other.catalog_)), - logger_(other.OwnsLogger() ? &defaultLogger_ : other.logger_), urls_(std::move(other.urls_)) { - other.logger_ = &other.defaultLogger_; -} - -FoundryLocalManager& FoundryLocalManager::operator=(FoundryLocalManager&& other) noexcept { - if (this != &other) { - config_ = std::move(other.config_); - core_ = std::move(other.core_); - catalog_ = std::move(other.catalog_); - logger_ = other.OwnsLogger() ? &defaultLogger_ : other.logger_; - urls_ = std::move(other.urls_); - other.logger_ = &other.defaultLogger_; - } - return *this; +FoundryLocalManager::~FoundryLocalManager() { + Cleanup(); } -FoundryLocalManager::~FoundryLocalManager() { +void FoundryLocalManager::Cleanup() noexcept { // Unload all loaded models before tearing down. if (catalog_) { try { From f7694815170c5ba5a43a3f8d2572d48ab1d3c283 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 2 Apr 2026 11:47:17 -0700 Subject: [PATCH 14/18] e2e tests file separate --- sdk/cpp/test/client_test.cpp | 183 +++++++++++ sdk/cpp/test/e2e_test.cpp | 574 +++++++++++++++++++++++++++++++++++ 2 files changed, 757 insertions(+) create mode 100644 sdk/cpp/test/e2e_test.cpp diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 0ddb0c02..2a09ac77 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -537,3 +537,186 @@ TEST_F(OpenAIAudioClientTest, GetModelId) { OpenAIAudioClient client(variant); EXPECT_EQ("audio-model", client.GetModelId()); } + +TEST_F(OpenAIAudioClientTest, TranscribeAudio_CoreError_Throws) { + core_.OnCallThrow("audio_transcribe", "transcription failed"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW(client.TranscribeAudio("test.wav"), Exception); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_CoreError_Throws) { + core_.OnCallThrow("audio_transcribe", "streaming transcription failed"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW( + client.TranscribeAudioStreaming("test.wav", [](const AudioCreateTranscriptionResponse&) {}), + Exception); +} + +// ===================================================================== +// Multi-turn conversation tests +// ===================================================================== + +TEST_F(OpenAIChatClientTest, CompleteChat_MultiTurn) { + // First turn: user asks a question + core_.OnCall("chat_completions", MakeChatResponseJson("42")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = { + {"user", "What is 7 * 6?", {}} + }; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("42", response.choices[0].message->content); + + // Second turn: add assistant response + user follow-up + messages.push_back({"assistant", response.choices[0].message->content, {}}); + messages.push_back({"user", "Is that a real number?", {}}); + + core_.OnCall("chat_completions", MakeChatResponseJson("Yes")); + auto response2 = client.CompleteChat(messages, settings); + + ASSERT_TRUE(response2.successful); + EXPECT_EQ("Yes", response2.choices[0].message->content); + + // Verify the second request contained all 3 messages + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_EQ(3u, openAiReq["messages"].size()); + EXPECT_EQ("user", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("assistant", openAiReq["messages"][1]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][2]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_CoreError_Throws) { + core_.OnCallThrow("chat_completions", "inference failed"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChat(messages, settings), Exception); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_CoreError_Throws) { + core_.OnCallThrow("chat_completions", "streaming failed"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + + EXPECT_THROW( + client.CompleteChatStreaming(messages, settings, + [](const ChatCompletionCreateResponse&) {}), + Exception); +} + +// ===================================================================== +// Full tool-call round-trip +// ===================================================================== + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { + // Step 1: model returns a tool call + nlohmann::json toolCallResp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", toolCallResp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = { + {"system", "You are a helpful AI assistant.", {}}, + {"user", "What is 7 multiplied by 6?", {}} + }; + + std::vector tools = {{ + "function", + FunctionDefinition{ + "multiply_numbers", + "A tool for multiplying two numbers.", + PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}} + }, + std::vector{"first", "second"} + } + } + }}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_EQ(1u, response.choices[0].message->tool_calls.size()); + EXPECT_EQ("multiply_numbers", response.choices[0].message->tool_calls[0].function_call->name); + + // Step 2: send tool response back, model continues with the answer + messages.push_back({"assistant", response.choices[0].message->content, {}}); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "7 x 6 = 42."; + toolMsg.tool_call_id = "call_1"; + messages.push_back(std::move(toolMsg)); + + messages.push_back({"system", "Respond only with the answer generated by the tool.", {}}); + + core_.OnCall("chat_completions", MakeChatResponseJson("42")); + settings.tool_choice = ToolChoiceKind::Auto; + + auto response2 = client.CompleteChat(messages, tools, settings); + + ASSERT_TRUE(response2.successful); + EXPECT_EQ("42", response2.choices[0].message->content); + + // Verify the second request contained tool response message + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + // 5 messages: system, user, assistant (tool_call), tool, system (continue) + ASSERT_EQ(5u, openAiReq["messages"].size()); + EXPECT_EQ("tool", openAiReq["messages"][3]["role"].get()); + EXPECT_EQ("call_1", openAiReq["messages"][3]["tool_call_id"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} diff --git a/sdk/cpp/test/e2e_test.cpp b/sdk/cpp/test/e2e_test.cpp new file mode 100644 index 00000000..8b506d23 --- /dev/null +++ b/sdk/cpp/test/e2e_test.cpp @@ -0,0 +1,574 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// End-to-end tests that exercise the public API with the real Core DLL. +// Tests marked DISABLED_ are skipped in CI (no Core DLL / no network). +// Run locally with: --gtest_also_run_disabled_tests + +#include + +#include "foundry_local.h" + +#include +#include +#include +#include + +using namespace foundry_local; + +// --------------------------------------------------------------------------- +// Helper: detect CI environment (mirrors C# SkipInCI logic) +// --------------------------------------------------------------------------- +static bool IsRunningInCI() { + auto check = [](const char* var) -> bool { + const char* val = std::getenv(var); + if (!val) return false; + std::string s(val); + for (auto& c : s) c = static_cast(std::tolower(static_cast(c))); + return s == "true" || s == "1"; + }; + return check("TF_BUILD") || check("GITHUB_ACTIONS") || check("CI"); +} + +// --------------------------------------------------------------------------- +// Fixture: creates a real FoundryLocalManager with the Core DLL. +// All tests in this fixture require the native DLLs next to the test binary. +// --------------------------------------------------------------------------- +class EndToEndTest : public ::testing::Test { +protected: + static void SetUpTestSuite() { + Configuration config("CppSdkE2ETest"); + config.log_level = LogLevel::Information; + try { + FoundryLocalManager::Create(std::move(config)); + } + catch (const std::exception& ex) { + std::cerr << "[E2E] Failed to create FoundryLocalManager: " << ex.what() << "\n"; + GTEST_SKIP() << "Core DLL not available: " << ex.what(); + } + } + + static void TearDownTestSuite() { + FoundryLocalManager::Destroy(); + } + + void SetUp() override { + if (!FoundryLocalManager::IsInitialized()) { + GTEST_SKIP() << "FoundryLocalManager not available (Core DLL missing?)"; + } + } + + static bool IsAudioModel(const std::string& alias) { + return alias.find("whisper") != std::string::npos; + } + + /// Find a chat-capable model, preferring cached, then known small models, then any. + /// Selects the CPU variant when available to avoid GPU/EP dependency issues. + static Model* FindChatModel(Catalog& catalog) { + Model* target = nullptr; + + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + if (!IsAudioModel(variant->GetAlias())) { + target = catalog.GetModel(variant->GetAlias()); + if (target) break; + } + } + + if (!target) { + for (const auto& alias : {"qwen2.5-0.5b", "qwen2.5-coder-0.5b", "phi-4-mini"}) { + target = catalog.GetModel(alias); + if (target) break; + } + } + + if (!target) { + auto models = catalog.ListModels(); + for (auto* model : models) { + if (!IsAudioModel(model->GetAlias())) { + target = model; + break; + } + } + } + + if (target) { + for (const auto& variant : target->GetAllModelVariants()) { + if (variant.GetInfo().runtime.has_value() && + variant.GetInfo().runtime->device_type == DeviceType::CPU) { + target->SelectVariant(variant); + break; + } + } + } + + return target; + } + + /// Find an audio model, preferring cached. + static Model* FindAudioModel(Catalog& catalog) { + Model* target = nullptr; + + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + if (IsAudioModel(variant->GetAlias())) { + target = catalog.GetModel(variant->GetAlias()); + if (target) break; + } + } + + if (!target) { + for (const auto& alias : {"whisper-small", "whisper-tiny"}) { + target = catalog.GetModel(alias); + if (target) break; + } + } + + return target; + } +}; + +// =========================================================================== +// Catalog tests (no model download required) +// =========================================================================== + +TEST_F(EndToEndTest, BrowseCatalog_ListsModels) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + EXPECT_FALSE(catalog.GetName().empty()); + + auto models = catalog.ListModels(); + EXPECT_GT(models.size(), 0u) << "Catalog should have at least one model"; + + for (const auto* model : models) { + EXPECT_FALSE(model->GetAlias().empty()); + EXPECT_FALSE(model->GetAllModelVariants().empty()); + + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + EXPECT_FALSE(info.id.empty()); + EXPECT_FALSE(info.name.empty()); + EXPECT_FALSE(info.alias.empty()); + EXPECT_FALSE(info.provider_type.empty()); + EXPECT_FALSE(info.model_type.empty()); + } + } +} + +TEST_F(EndToEndTest, GetCachedModels_Succeeds) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + EXPECT_FALSE(variant->GetId().empty()); + EXPECT_TRUE(variant->IsCached()); + } +} + +TEST_F(EndToEndTest, GetLoadedModels_Succeeds) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto loaded = catalog.GetLoadedModels(); + for (auto* variant : loaded) { + EXPECT_FALSE(variant->GetId().empty()); + EXPECT_TRUE(variant->IsLoaded()); + } +} + +TEST_F(EndToEndTest, GetModel_NotFound_ReturnsNull) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* model = catalog.GetModel("this-model-does-not-exist-12345"); + EXPECT_EQ(model, nullptr); +} + +TEST_F(EndToEndTest, GetModelVariant_NotFound_ReturnsNull) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* variant = catalog.GetModelVariant("nonexistent-model:999"); + EXPECT_EQ(variant, nullptr); +} + +TEST_F(EndToEndTest, GetModelVariant_Found) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto models = catalog.ListModels(); + if (models.empty()) { + GTEST_SKIP() << "No models in catalog"; + } + + const auto& firstVariant = models[0]->GetAllModelVariants()[0]; + auto* found = catalog.GetModelVariant(firstVariant.GetId()); + ASSERT_NE(nullptr, found); + EXPECT_EQ(firstVariant.GetId(), found->GetId()); +} + +TEST_F(EndToEndTest, ModelVariantInfo_HasRequiredFields) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto models = catalog.ListModels(); + if (models.empty()) { + GTEST_SKIP() << "No models in catalog"; + } + + for (const auto* model : models) { + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + EXPECT_FALSE(info.id.empty()); + EXPECT_FALSE(info.name.empty()); + EXPECT_GT(info.version, 0u); + EXPECT_FALSE(info.alias.empty()); + EXPECT_FALSE(info.uri.empty()); + } + } +} + +TEST_F(EndToEndTest, ModelVariant_SelectVariant) { +auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto models = catalog.ListModels(); + + // Find a model with multiple variants + Model* multiVariantModel = nullptr; + for (auto* model : models) { + if (model->GetAllModelVariants().size() > 1) { + multiVariantModel = model; + break; + } + } + + if (!multiVariantModel) { + GTEST_SKIP() << "No model with multiple variants found"; + } + + const auto& variants = multiVariantModel->GetAllModelVariants(); + const auto& secondVariant = variants[1]; + multiVariantModel->SelectVariant(secondVariant); + EXPECT_EQ(secondVariant.GetId(), multiVariantModel->GetId()); + + // Select back the first variant + multiVariantModel->SelectVariant(variants[0]); + EXPECT_EQ(variants[0].GetId(), multiVariantModel->GetId()); +} + +// =========================================================================== +// EnsureEpsDownloaded (no model download, but may download EPs) +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_EnsureEpsDownloaded_Succeeds) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (may require network)"; + } + + EXPECT_NO_THROW(FoundryLocalManager::Instance().EnsureEpsDownloaded()); +} + +// =========================================================================== +// Web service tests +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_WebService_StartAndStop) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI"; + } + + auto& manager = FoundryLocalManager::Instance(); + + // GetUrls should be empty before starting + EXPECT_TRUE(manager.GetUrls().empty()); + + // StartWebService without web config should throw + // Note: the manager was created without web config, so this verifies the guard. + EXPECT_THROW(manager.StartWebService(), Exception); +} + +// =========================================================================== +// Download, load, chat (non-streaming), unload +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_DownloadLoadChatUnload) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + std::cout << "[E2E] Using model: " << target->GetAlias() + << " variant: " << target->GetId() << "\n"; + + // Download (no-op if already cached) + bool progressCallbackInvoked = false; + target->Download([&](float pct) { + progressCallbackInvoked = true; + std::cout << "\r[E2E] Download: " << pct << "% " << std::flush; + }); + std::cout << "\n"; + + EXPECT_TRUE(target->IsCached()); + + // Load + target->Load(); + EXPECT_TRUE(target->IsLoaded()); + + // Verify it appears in loaded models + auto loaded = catalog.GetLoadedModels(); + bool foundInLoaded = false; + for (auto* v : loaded) { + if (v->GetId() == target->GetId()) { + foundInLoaded = true; + break; + } + } + EXPECT_TRUE(foundInLoaded) << "Model should appear in GetLoadedModels() after Load()"; + + // Chat (non-streaming) + OpenAIChatClient client(*target); + + std::vector messages = {{"user", "Say hello in one word.", {}}}; + ChatSettings settings; + settings.max_tokens = 32; + auto response = client.CompleteChat(messages, settings); + EXPECT_TRUE(response.successful); + ASSERT_FALSE(response.choices.empty()); + ASSERT_TRUE(response.choices[0].message.has_value()); + EXPECT_FALSE(response.choices[0].message->content.empty()); + EXPECT_EQ(FinishReason::Stop, response.choices[0].finish_reason); + std::cout << "[E2E] Response: " << response.choices[0].message->content << "\n"; + + // Unload + target->Unload(); + EXPECT_FALSE(target->IsLoaded()); +} + +// =========================================================================== +// Streaming chat +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_StreamingChat) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + std::cout << "[E2E] Streaming with model: " << target->GetAlias() << "\n"; + + OpenAIChatClient client(*target); + + std::vector messages = {{"user", "Count from 1 to 5.", {}}}; + ChatSettings settings; + settings.max_tokens = 64; + settings.temperature = 0.0f; + + std::vector chunks; + std::string fullContent; + client.CompleteChatStreaming(messages, settings, + [&](const ChatCompletionCreateResponse& chunk) { + chunks.push_back(chunk); + if (!chunk.choices.empty() && chunk.choices[0].delta.has_value() && + !chunk.choices[0].delta->content.empty()) { + fullContent += chunk.choices[0].delta->content; + } + }); + + EXPECT_GT(chunks.size(), 0u) << "Should have received at least one streaming chunk"; + EXPECT_FALSE(fullContent.empty()) << "Accumulated streaming content should not be empty"; + std::cout << "[E2E] Streaming response: " << fullContent << "\n"; + + // Last chunk should have a stop finish reason + ASSERT_FALSE(chunks.empty()); + const auto& lastChunk = chunks.back(); + if (!lastChunk.choices.empty()) { + EXPECT_EQ(FinishReason::Stop, lastChunk.choices[0].finish_reason); + } + + target->Unload(); +} + +// =========================================================================== +// Chat with tool calling +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_ChatWithToolCalling) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + // Check if the selected variant supports tool calling + bool supportsCalling = false; + for (const auto& v : target->GetAllModelVariants()) { + if (v.GetInfo().supports_tool_calling.has_value() && *v.GetInfo().supports_tool_calling) { + supportsCalling = true; + break; + } + } + if (!supportsCalling) { + GTEST_SKIP() << "Model does not support tool calling"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + std::cout << "[E2E] Tool calling with model: " << target->GetAlias() << "\n"; + + OpenAIChatClient client(*target); + + std::vector tools = {{ + "function", + FunctionDefinition{ + "get_weather", + "Get the current weather for a city.", + PropertyDefinition{ + "object", + std::nullopt, + std::unordered_map{ + {"city", PropertyDefinition{"string", "The city name"}} + }, + std::vector{"city"} + } + } + }}; + + std::vector messages = { + {"system", "You are a helpful assistant. Use the provided tools when asked about weather."}, + {"user", "What is the weather in Seattle?"} + }; + + ChatSettings settings; + settings.temperature = 0.0f; + settings.max_tokens = 256; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + EXPECT_TRUE(response.successful); + ASSERT_FALSE(response.choices.empty()); + + const auto& choice = response.choices[0]; + // With tool_choice = Required, the model should produce a tool call + if (choice.finish_reason == FinishReason::ToolCalls) { + ASSERT_TRUE(choice.message.has_value()); + ASSERT_FALSE(choice.message->tool_calls.empty()); + const auto& tc = choice.message->tool_calls[0]; + EXPECT_FALSE(tc.id.empty()); + ASSERT_TRUE(tc.function_call.has_value()); + EXPECT_EQ("get_weather", tc.function_call->name); + EXPECT_FALSE(tc.function_call->arguments.empty()); + std::cout << "[E2E] Tool call: " << tc.function_call->name + << " args: " << tc.function_call->arguments << "\n"; + } + + target->Unload(); +} + +// =========================================================================== +// Audio transcription +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_AudioTranscription) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download + audio file)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindAudioModel(catalog); + if (!target) { + GTEST_SKIP() << "No audio model found in catalog"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + std::cout << "[E2E] Audio model: " << target->GetAlias() << "\n"; + + OpenAIAudioClient client(*target); + + // Note: this test requires a valid audio file to be present. + // Skip if no test audio file is available. + const char* audioPath = std::getenv("FL_TEST_AUDIO_PATH"); + if (!audioPath) { + target->Unload(); + GTEST_SKIP() << "Set FL_TEST_AUDIO_PATH env var to a .wav file to run audio tests"; + } + + auto result = client.TranscribeAudio(audioPath); + EXPECT_FALSE(result.text.empty()); + std::cout << "[E2E] Transcription: " << result.text << "\n"; + + target->Unload(); +} + +TEST_F(EndToEndTest, DISABLED_AudioTranscriptionStreaming) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download + audio file)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindAudioModel(catalog); + if (!target) { + GTEST_SKIP() << "No audio model found in catalog"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + const char* audioPath = std::getenv("FL_TEST_AUDIO_PATH"); + if (!audioPath) { + target->Unload(); + GTEST_SKIP() << "Set FL_TEST_AUDIO_PATH env var to a .wav file to run audio tests"; + } + + OpenAIAudioClient client(*target); + + std::string fullText; + int chunkCount = 0; + client.TranscribeAudioStreaming(audioPath, + [&](const AudioCreateTranscriptionResponse& chunk) { + fullText += chunk.text; + chunkCount++; + }); + + EXPECT_GT(chunkCount, 0) << "Should have received at least one streaming chunk"; + EXPECT_FALSE(fullText.empty()); + std::cout << "[E2E] Streaming transcription (" << chunkCount << " chunks): " << fullText << "\n"; + + target->Unload(); +} + +// =========================================================================== +// RemoveFromCache +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_DownloadAndRemoveFromCache) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + target->Download(); + EXPECT_TRUE(target->IsCached()); + + // RemoveFromCache should succeed without throwing. + EXPECT_NO_THROW(target->RemoveFromCache()); + + std::cout << "[E2E] RemoveFromCache completed for: " << target->GetAlias() + << " (IsCached=" << (target->IsCached() ? "true" : "false") << ")\n"; +} From d7ca1f3cb6124282c4a9c07c9152ccba297ebc14 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 2 Apr 2026 12:23:03 -0700 Subject: [PATCH 15/18] fmt --- sdk/cpp/include/catalog.h | 14 +- sdk/cpp/include/log_level.h | 2 +- sdk/cpp/include/openai/openai_audio_client.h | 6 +- sdk/cpp/include/openai/openai_chat_client.h | 12 +- sdk/cpp/include/openai/openai_tool_types.h | 2 +- sdk/cpp/sample/main.cpp | 47 ++- sdk/cpp/src/catalog.cpp | 171 ++++++----- sdk/cpp/src/core.h | 152 +++++----- sdk/cpp/src/core_helpers.h | 202 ++++++------- sdk/cpp/src/flcore_native.h | 6 +- sdk/cpp/src/foundry_local_internal_core.h | 9 +- sdk/cpp/src/foundry_local_manager.cpp | 271 +++++++++--------- sdk/cpp/src/model.cpp | 285 +++++++++---------- sdk/cpp/src/openai_audio_client.cpp | 75 ++--- sdk/cpp/src/openai_chat_client.cpp | 208 +++++++------- sdk/cpp/src/parser.h | 9 +- sdk/cpp/test/catalog_test.cpp | 24 +- sdk/cpp/test/client_test.cpp | 187 +++++------- sdk/cpp/test/e2e_test.cpp | 95 +++---- sdk/cpp/test/mock_core.h | 6 +- sdk/cpp/test/model_variant_test.cpp | 64 ++--- sdk/cpp/test/parser_and_types_test.cpp | 39 +-- 22 files changed, 931 insertions(+), 955 deletions(-) diff --git a/sdk/cpp/include/catalog.h b/sdk/cpp/include/catalog.h index 2cb05a1c..d57137b7 100644 --- a/sdk/cpp/include/catalog.h +++ b/sdk/cpp/include/catalog.h @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -46,12 +47,17 @@ namespace foundry_local { ModelVariant* GetModelVariant(std::string_view modelVariantId) const; private: - void UpdateModels() const; + struct CatalogState { + std::unordered_map byAlias; + std::unordered_map modelIdToModelVariant; + std::chrono::steady_clock::time_point lastFetch{}; + }; - mutable std::chrono::steady_clock::time_point lastFetch_{}; + void UpdateModels() const; + std::shared_ptr GetState() const; - mutable std::unordered_map byAlias_; - mutable std::unordered_map modelIdToModelVariant_; + mutable std::mutex mutex_; + mutable std::shared_ptr state_; explicit Catalog(gsl::not_null injected, gsl::not_null logger); diff --git a/sdk/cpp/include/log_level.h b/sdk/cpp/include/log_level.h index 887189ec..75dfe667 100644 --- a/sdk/cpp/include/log_level.h +++ b/sdk/cpp/include/log_level.h @@ -7,7 +7,7 @@ namespace foundry_local { -enum class LogLevel { + enum class LogLevel { Verbose, Debug, Information, diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h index 876f93bd..79acccc9 100644 --- a/sdk/cpp/include/openai/openai_audio_client.h +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -15,8 +15,8 @@ namespace foundry_local::Internal { } namespace foundry_local { -class ILogger; -class IModel; + class ILogger; + class IModel; struct AudioCreateTranscriptionResponse { std::string text; @@ -36,7 +36,7 @@ class IModel; private: OpenAIAudioClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger); + gsl::not_null logger); std::string modelId_; gsl::not_null core_; diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h index 49ce0fe8..788a820d 100644 --- a/sdk/cpp/include/openai/openai_chat_client.h +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -21,11 +21,11 @@ namespace foundry_local::Internal { } namespace foundry_local { -class ILogger; -class IModel; + class ILogger; + class IModel; -/// Reason the model stopped generating tokens. -enum class FinishReason { + /// Reason the model stopped generating tokens. + enum class FinishReason { None, Stop, Length, @@ -36,7 +36,7 @@ enum class FinishReason { struct ChatMessage { std::string role; std::string content; - std::optional tool_call_id; ///< For role="tool" responses + std::optional tool_call_id; ///< For role="tool" responses std::vector tool_calls; }; @@ -104,7 +104,7 @@ enum class FinishReason { private: OpenAIChatClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger); + gsl::not_null logger); std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, const ChatSettings& settings, bool stream) const; diff --git a/sdk/cpp/include/openai/openai_tool_types.h b/sdk/cpp/include/openai/openai_tool_types.h index 4130d2b7..105bc49e 100644 --- a/sdk/cpp/include/openai/openai_tool_types.h +++ b/sdk/cpp/include/openai/openai_tool_types.h @@ -34,7 +34,7 @@ namespace foundry_local { /// A parsed function call returned by the model. struct FunctionCall { std::string name; - std::string arguments; ///< JSON string of the arguments + std::string arguments; ///< JSON string of the arguments }; /// A tool call returned by the model in a chat completion response. diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 434eeba9..0e6b92e4 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -66,7 +66,7 @@ void BrowseCatalog(FoundryLocalManager& manager) { std::cout << " device=" << (info.runtime->device_type == DeviceType::GPU ? "GPU" : info.runtime->device_type == DeviceType::NPU ? "NPU" - : "CPU") + : "CPU") << " ep=" << info.runtime->execution_provider; } if (info.file_size_mb) @@ -99,7 +99,8 @@ void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { if (model->IsLoaded()) { std::cout << "Model is loaded and ready for inference.\n"; - } else { + } + else { std::cerr << "Failed to load model.\n"; return; } @@ -130,7 +131,6 @@ void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { auto& catalog = manager.GetCatalog(); - auto* model = catalog.GetModel(alias); if (!model) { std::cerr << "Model '" << alias << "' not found in catalog.\n"; @@ -233,34 +233,29 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) // ── Step 1: Define tools ────────────────────────────────────────────── // Each tool describes a function the model can call. The PropertyDefinition // mirrors a JSON Schema so the model knows what arguments are expected. - std::vector tools = {{ - "function", - FunctionDefinition{ - "multiply_numbers", // function name - "Multiply two integers and return the result.", // description - PropertyDefinition{ - "object", // top-level schema type - std::nullopt, // no top-level description - std::unordered_map{ - {"first", PropertyDefinition{"integer", "The first number"}}, - {"second", PropertyDefinition{"integer", "The second number"}} - }, - std::vector{"first", "second"} // both params are required - } - } - }}; + std::vector tools = { + {"function", + FunctionDefinition{"multiply_numbers", // function name + "Multiply two integers and return the result.", // description + PropertyDefinition{ + "object", // top-level schema type + std::nullopt, // no top-level description + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}}}, + std::vector{"first", "second"} // both params are required + }}}}; // ── Step 2: Send the first request ──────────────────────────────────── // tool_choice = Required forces the model to always produce a tool call. // In production you'd typically use Auto so the model decides on its own. std::vector messages = { {"system", "You are a helpful AI assistant. Use the provided tools when appropriate."}, - {"user", "What is 7 multiplied by 6?"} - }; + {"user", "What is 7 multiplied by 6?"}}; ChatSettings settings; settings.temperature = 0.0f; - settings.max_tokens = 500; + settings.max_tokens = 500; settings.tool_choice = ToolChoiceKind::Required; std::cout << "Sending chat request with tool definitions...\n"; @@ -276,9 +271,8 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) const auto& firstChoice = response.choices[0]; // The model signals it wants to call a tool via finish_reason == ToolCalls. - if (firstChoice.finish_reason == FinishReason::ToolCalls && - firstChoice.message && !firstChoice.message->tool_calls.empty()) - { + if (firstChoice.finish_reason == FinishReason::ToolCalls && firstChoice.message && + !firstChoice.message->tool_calls.empty()) { const auto& tc = firstChoice.message->tool_calls[0]; std::cout << "Model requested tool call:\n" << " function : " << (tc.function_call ? tc.function_call->name : "(none)") << "\n" @@ -293,7 +287,8 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) // For brevity we hard-code the expected result here. toolResult = "7 x 6 = 42."; std::cout << " result : " << toolResult << "\n"; - } else { + } + else { toolResult = "Unknown tool."; } diff --git a/sdk/cpp/src/catalog.cpp b/sdk/cpp/src/catalog.cpp index 3ff9df0d..836cb285 100644 --- a/sdk/cpp/src/catalog.cpp +++ b/sdk/cpp/src/catalog.cpp @@ -18,102 +18,123 @@ namespace foundry_local { -using namespace detail; + using namespace detail; -Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) - : core_(injected), logger_(logger) { - auto response = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); - if (response.HasError()) { - throw Exception(std::string("Error getting catalog name: ") + response.error, *logger_); + Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) + : state_(std::make_shared()), core_(injected), logger_(logger) { + auto response = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); + if (response.HasError()) { + throw Exception(std::string("Error getting catalog name: ") + response.error, *logger_); + } + name_ = std::move(response.data); } - name_ = std::move(response.data); -} - -std::vector Catalog::GetLoadedModels() const { - UpdateModels(); - return CollectVariantsByIds(modelIdToModelVariant_, GetLoadedModelsInternal(core_, *logger_)); -} - -std::vector Catalog::GetCachedModels() const { - UpdateModels(); - return CollectVariantsByIds(modelIdToModelVariant_, GetCachedModelsInternal(core_, *logger_)); -} - -Model* Catalog::GetModel(std::string_view modelId) const { - UpdateModels(); - auto it = byAlias_.find(std::string(modelId)); - if (it != byAlias_.end()) { - return &it->second; + + std::shared_ptr Catalog::GetState() const { + std::lock_guard lock(mutex_); + return state_; } - return nullptr; -} -std::vector Catalog::ListModels() const { - UpdateModels(); + std::vector Catalog::GetLoadedModels() const { + UpdateModels(); + auto state = GetState(); + return CollectVariantsByIds(state->modelIdToModelVariant, GetLoadedModelsInternal(core_, *logger_)); + } - std::vector out; - out.reserve(byAlias_.size()); - for (auto& kv : byAlias_) - out.emplace_back(&kv.second); + std::vector Catalog::GetCachedModels() const { + UpdateModels(); + auto state = GetState(); + return CollectVariantsByIds(state->modelIdToModelVariant, GetCachedModelsInternal(core_, *logger_)); + } - return out; -} + Model* Catalog::GetModel(std::string_view modelId) const { + UpdateModels(); + auto state = GetState(); + auto it = state->byAlias.find(std::string(modelId)); + if (it != state->byAlias.end()) { + return const_cast(&it->second); + } + return nullptr; + } -void Catalog::UpdateModels() const { - using clock = std::chrono::steady_clock; + std::vector Catalog::ListModels() const { + UpdateModels(); + auto state = GetState(); - // TODO: make this configurable - constexpr auto kRefreshInterval = std::chrono::hours(6); + std::vector out; + out.reserve(state->byAlias.size()); + for (auto& kv : state->byAlias) + out.emplace_back(const_cast(&kv.second)); - const auto now = clock::now(); - if (lastFetch_.time_since_epoch() != clock::duration::zero() && (now - lastFetch_) < kRefreshInterval) { - return; + return out; } - const auto response = core_->call("get_model_list", *logger_); - if (response.HasError()) { - throw Exception(std::string("Error getting model list: ") + response.error, *logger_); - } - const auto arr = nlohmann::json::parse(response.data); + void Catalog::UpdateModels() const { + using clock = std::chrono::steady_clock; - byAlias_.clear(); - modelIdToModelVariant_.clear(); + // TODO: make this configurable + constexpr auto kRefreshInterval = std::chrono::hours(6); - for (const auto& j : arr) { - const std::string alias = j.at("alias").get(); + const auto now = clock::now(); + { + auto current = GetState(); + if (current->lastFetch.time_since_epoch() != clock::duration::zero() && + (now - current->lastFetch) < kRefreshInterval) { + return; + } + } - auto it = byAlias_.find(alias); - if (it == byAlias_.end()) { - Model m(core_, logger_); - it = byAlias_.emplace(alias, std::move(m)).first; + // Fetch outside the lock so the core call doesn't block readers. + const auto response = core_->call("get_model_list", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error getting model list: ") + response.error, *logger_); } + const auto arr = nlohmann::json::parse(response.data); - ModelInfo modelVariantInfo; - from_json(j, modelVariantInfo); - std::string variantId = modelVariantInfo.id; - ModelVariant modelVariant(core_, modelVariantInfo, logger_); - modelIdToModelVariant_.emplace(variantId, modelVariant); + // Build the new state locally no reader can see partial data. + auto newState = std::make_shared(); - it->second.variants_.emplace_back(std::move(modelVariant)); - } + for (const auto& j : arr) { + const std::string alias = j.at("alias").get(); - // Auto-select the first variant for each model. - for (auto& [alias, model] : byAlias_) { - if (!model.variants_.empty()) { - model.selectedVariant_ = &model.variants_.front(); + auto it = newState->byAlias.find(alias); + if (it == newState->byAlias.end()) { + Model m(core_, logger_); + it = newState->byAlias.emplace(alias, std::move(m)).first; + } + + ModelInfo modelVariantInfo; + from_json(j, modelVariantInfo); + std::string variantId = modelVariantInfo.id; + ModelVariant modelVariant(core_, modelVariantInfo, logger_); + newState->modelIdToModelVariant.emplace(variantId, modelVariant); + + it->second.variants_.emplace_back(std::move(modelVariant)); + } + + // Auto-select the first variant for each model. + for (auto& [alias, model] : newState->byAlias) { + if (!model.variants_.empty()) { + model.selectedVariant_ = &model.variants_.front(); + } } - } - lastFetch_ = now; -} + newState->lastFetch = now; -ModelVariant* Catalog::GetModelVariant(std::string_view id) const { - UpdateModels(); - auto it = modelIdToModelVariant_.find(std::string(id)); - if (it != modelIdToModelVariant_.end()) { - return &it->second; + // Atomic swap — readers that already hold the old shared_ptr keep it alive. + { + std::lock_guard lock(mutex_); + state_ = std::move(newState); + } + } + + ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + UpdateModels(); + auto state = GetState(); + auto it = state->modelIdToModelVariant.find(std::string(id)); + if (it != state->modelIdToModelVariant.end()) { + return const_cast(&it->second); + } + return nullptr; } - return nullptr; -} } // namespace foundry_local diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h index d0f3b682..c7f73d5d 100644 --- a/sdk/cpp/src/core.h +++ b/sdk/cpp/src/core.h @@ -21,94 +21,94 @@ namespace foundry_local { -namespace { -inline std::filesystem::path getExecutableDir() { - auto exePath = wil::GetModuleFileNameW(nullptr); - return std::filesystem::path(exePath.get()).parent_path(); -} - -inline void* RequireProc(HMODULE mod, const char* name) { - if (void* p = ::GetProcAddress(mod, name)) - return p; - throw std::runtime_error(std::string("GetProcAddress failed for ") + name); -} -} // namespace - -struct Core : Internal::IFoundryLocalCore { - using ResponseHandle = std::unique_ptr; - - Core() = default; - ~Core() = default; - - void loadEmbedded() { loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); } - - void unload() override { - module_.reset(); - execCmd_ = nullptr; - execCbCmd_ = nullptr; - freeResCmd_ = nullptr; - } - - CoreResponse call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, - NativeCallbackFn callback = nullptr, void* data = nullptr) const override { - if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { - throw Exception("Core is not loaded. Cannot call command: " + std::string(command), logger); + namespace { + inline std::filesystem::path getExecutableDir() { + auto exePath = wil::GetModuleFileNameW(nullptr); + return std::filesystem::path(exePath.get()).parent_path(); } - RequestBuffer request{}; - request.Command = command.empty() ? nullptr : command.data(); - request.CommandLength = static_cast(command.size()); - - if (dataArgument && !dataArgument->empty()) { - request.Data = dataArgument->data(); - request.DataLength = static_cast(dataArgument->size()); + inline void* RequireProc(HMODULE mod, const char* name) { + if (void* p = ::GetProcAddress(mod, name)) + return p; + throw std::runtime_error(std::string("GetProcAddress failed for ") + name); } + } // namespace - ResponseBuffer response{}; - auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { - if (fn) - fn(buf); - }; - std::unique_ptr responseGuard(&response, safeDeleter); + struct Core : Internal::IFoundryLocalCore { + using ResponseHandle = std::unique_ptr; - if (callback != nullptr) { - execCbCmd_(&request, &response, reinterpret_cast(callback), data); - } - else { - execCmd_(&request, &response); - } + Core() = default; + ~Core() = default; - CoreResponse result; - if (response.Error && response.ErrorLength > 0) { - result.error.assign(static_cast(response.Error), response.ErrorLength); - return result; - } + void loadEmbedded() { loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); } - if (response.Data && response.DataLength > 0) { - result.data.assign(static_cast(response.Data), response.DataLength); + void unload() override { + module_.reset(); + execCmd_ = nullptr; + execCbCmd_ = nullptr; + freeResCmd_ = nullptr; } - return result; - } + CoreResponse call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, + NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { + throw Exception("Core is not loaded. Cannot call command: " + std::string(command), logger); + } + + RequestBuffer request{}; + request.Command = command.empty() ? nullptr : command.data(); + request.CommandLength = static_cast(command.size()); + + if (dataArgument && !dataArgument->empty()) { + request.Data = dataArgument->data(); + request.DataLength = static_cast(dataArgument->size()); + } + + ResponseBuffer response{}; + auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { + if (fn) + fn(buf); + }; + std::unique_ptr responseGuard(&response, safeDeleter); + + if (callback != nullptr) { + execCbCmd_(&request, &response, reinterpret_cast(callback), data); + } + else { + execCmd_(&request, &response); + } + + CoreResponse result; + if (response.Error && response.ErrorLength > 0) { + result.error.assign(static_cast(response.Error), response.ErrorLength); + return result; + } + + if (response.Data && response.DataLength > 0) { + result.data.assign(static_cast(response.Data), response.DataLength); + } -private: - wil::unique_hmodule module_; - execute_command_fn execCmd_{}; - execute_command_with_callback_fn execCbCmd_{}; - free_response_fn freeResCmd_{}; + return result; + } - void loadFromPath(const std::filesystem::path& path) { - wil::unique_hmodule m(::LoadLibraryW(path.c_str())); - if (!m) - throw std::runtime_error("LoadLibraryW failed"); + private: + wil::unique_hmodule module_; + execute_command_fn execCmd_{}; + execute_command_with_callback_fn execCbCmd_{}; + free_response_fn freeResCmd_{}; - execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); - execCbCmd_ = reinterpret_cast( - RequireProc(m.get(), "execute_command_with_callback")); - freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); + void loadFromPath(const std::filesystem::path& path) { + wil::unique_hmodule m(::LoadLibraryW(path.c_str())); + if (!m) + throw std::runtime_error("LoadLibraryW failed"); - module_ = std::move(m); - } -}; + execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); + execCbCmd_ = reinterpret_cast( + RequireProc(m.get(), "execute_command_with_callback")); + freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); + + module_ = std::move(m); + } + }; } // namespace foundry_local diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h index d35d87b4..76d52ed3 100644 --- a/sdk/cpp/src/core_helpers.h +++ b/sdk/cpp/src/core_helpers.h @@ -22,125 +22,125 @@ namespace foundry_local::detail { -// Wrap Params: { ... } into a request object -inline nlohmann::json MakeParams(nlohmann::json params) { - return nlohmann::json{{"Params", std::move(params)}}; -} - -// Most common: Params { "Model": } -inline nlohmann::json MakeModelParams(std::string_view model) { - return MakeParams(nlohmann::json{{"Model", std::string(model)}}); -} - -// Serialize + call -inline CoreResponse CallWithJson(Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& requestJson, ILogger& logger) { - std::string payload = requestJson.dump(); - return core->call(command, logger, &payload); -} - -// Serialize + call with native callback -inline CoreResponse CallWithJsonAndCallback(Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& requestJson, ILogger& logger, - NativeCallbackFn callback, void* userData) { - std::string payload = requestJson.dump(); - return core->call(command, logger, &payload, callback, userData); -} - -// Serialize + call with a streaming chunk handler. -// Wraps the caller-supplied onChunk with the native callback boilerplate -// (null/length checks, exception capture, rethrow after the call). -// The errorContext string is used to prefix any core-layer error message. -inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, - const std::string& payload, ILogger& logger, - const std::function& onChunk, - std::string_view errorContext) { - struct State { - const std::function* cb; - std::exception_ptr exception; - } state{&onChunk, nullptr}; - - auto nativeCallback = [](void* data, int32_t len, void* user) { - if (!data || len <= 0) - return; - - auto* st = static_cast(user); - if (st->exception) - return; - - try { - std::string chunk(static_cast(data), static_cast(len)); - (*(st->cb))(chunk); - } - catch (...) { - st->exception = std::current_exception(); - } - }; + // Wrap Params: { ... } into a request object + inline nlohmann::json MakeParams(nlohmann::json params) { + return nlohmann::json{{"Params", std::move(params)}}; + } - auto response = core->call(command, logger, &payload, +nativeCallback, &state); - if (response.HasError()) { - throw Exception(std::string(errorContext) + response.error, logger); + // Most common: Params { "Model": } + inline nlohmann::json MakeModelParams(std::string_view model) { + return MakeParams(nlohmann::json{{"Model", std::string(model)}}); } - if (state.exception) { - std::rethrow_exception(state.exception); + // Serialize + call + inline CoreResponse CallWithJson(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, ILogger& logger) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload); } - return response; -} + // Serialize + call with native callback + inline CoreResponse CallWithJsonAndCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, ILogger& logger, + NativeCallbackFn callback, void* userData) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload, callback, userData); + } -// Overload: allow Params object directly -inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, - const nlohmann::json& params, ILogger& logger) { - return CallWithJson(core, command, MakeParams(params), logger); -} + // Serialize + call with a streaming chunk handler. + // Wraps the caller-supplied onChunk with the native callback boilerplate + // (null/length checks, exception capture, rethrow after the call). + // The errorContext string is used to prefix any core-layer error message. + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext) { + struct State { + const std::function* cb; + std::exception_ptr exception; + } state{&onChunk, nullptr}; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + try { + std::string chunk(static_cast(data), static_cast(len)); + (*(st->cb))(chunk); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + auto response = core->call(command, logger, &payload, +nativeCallback, &state); + if (response.HasError()) { + throw Exception(std::string(errorContext) + response.error, logger); + } -// Overload: no payload -inline CoreResponse CallNoArgs(Internal::IFoundryLocalCore* core, std::string_view command, ILogger& logger) { - return core->call(command, logger, nullptr); -} + if (state.exception) { + std::rethrow_exception(state.exception); + } -inline std::vector GetLoadedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { - auto response = core->call("list_loaded_models", logger); - if (response.HasError()) { - throw Exception("Failed to get loaded models: " + response.error, logger); + return response; } - try { - auto parsed = nlohmann::json::parse(response.data); - return parsed.get>(); - } - catch (const nlohmann::json::exception& e) { - throw Exception("Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); + + // Overload: allow Params object directly + inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& params, ILogger& logger) { + return CallWithJson(core, command, MakeParams(params), logger); } -} -inline std::vector GetCachedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { - auto response = core->call("get_cached_models", logger); - if (response.HasError()) { - throw Exception("Failed to get cached models: " + response.error, logger); + // Overload: no payload + inline CoreResponse CallNoArgs(Internal::IFoundryLocalCore* core, std::string_view command, ILogger& logger) { + return core->call(command, logger, nullptr); } - try { - auto parsed = nlohmann::json::parse(response.data); - return parsed.get>(); + inline std::vector GetLoadedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { + auto response = core->call("list_loaded_models", logger); + if (response.HasError()) { + throw Exception("Failed to get loaded models: " + response.error, logger); + } + try { + auto parsed = nlohmann::json::parse(response.data); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw Exception("Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); + } } - catch (const nlohmann::json::exception& e) { - throw Exception("Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); + + inline std::vector GetCachedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { + auto response = core->call("get_cached_models", logger); + if (response.HasError()) { + throw Exception("Failed to get cached models: " + response.error, logger); + } + + try { + auto parsed = nlohmann::json::parse(response.data); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw Exception("Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); + } } -} -inline std::vector CollectVariantsByIds( - std::unordered_map& modelIdToModelVariant, std::vector ids) { - std::vector out; - out.reserve(ids.size()); + inline std::vector CollectVariantsByIds( + const std::unordered_map& modelIdToModelVariant, std::vector ids) { + std::vector out; + out.reserve(ids.size()); - for (const auto& id : ids) { - auto it = modelIdToModelVariant.find(id); - if (it != modelIdToModelVariant.end()) { - out.emplace_back(&it->second); + for (const auto& id : ids) { + auto it = modelIdToModelVariant.find(id); + if (it != modelIdToModelVariant.end()) { + out.emplace_back(const_cast(&it->second)); + } } + return out; } - return out; -} } // namespace foundry_local::detail diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index dffcdad0..c62ca192 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -4,7 +4,8 @@ #pragma once #include -extern "C" { +extern "C" +{ // Layout must match C# structs exactly #pragma pack(push, 8) struct RequestBuffer { @@ -26,7 +27,8 @@ extern "C" { // Exported function pointer types using execute_command_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*); - using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, void* /*userData*/); + using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, + void* /*userData*/); using free_response_fn = void(__cdecl*)(ResponseBuffer*); static_assert(std::is_standard_layout::value, "RequestBuffer must be standard layout"); diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h index 0cbc3d68..1e5af79d 100644 --- a/sdk/cpp/src/foundry_local_internal_core.h +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -12,7 +12,7 @@ namespace foundry_local { /// Native callback signature used by the core DLL interop. /// Parameters: (data, dataLength, userData). - using NativeCallbackFn = void(*)(void*, int32_t, void*); + using NativeCallbackFn = void (*)(void*, int32_t, void*); /// Value returned by IFoundryLocalCore::call(). /// On success, `data` contains the response payload and `error` is empty. @@ -24,14 +24,13 @@ namespace foundry_local { bool HasError() const noexcept { return !error.empty(); } }; -namespace Internal { + namespace Internal { struct IFoundryLocalCore { virtual ~IFoundryLocalCore() = default; virtual CoreResponse call(std::string_view command, ILogger& logger, - const std::string* dataArgument = nullptr, - NativeCallbackFn callback = nullptr, - void* data = nullptr) const = 0; + const std::string* dataArgument = nullptr, NativeCallbackFn callback = nullptr, + void* data = nullptr) const = 0; virtual void unload() = 0; }; diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp index aca77be8..d1ab35bb 100644 --- a/sdk/cpp/src/foundry_local_manager.cpp +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -17,174 +17,175 @@ namespace foundry_local { -std::unique_ptr FoundryLocalManager::instance_; + std::unique_ptr FoundryLocalManager::instance_; -void FoundryLocalManager::Create(Configuration configuration, ILogger* logger) { - if (instance_) { - NullLogger fallback; - ILogger& log = logger ? *logger : fallback; - throw Exception("FoundryLocalManager has already been created. Call Destroy() first.", log); + void FoundryLocalManager::Create(Configuration configuration, ILogger* logger) { + if (instance_) { + NullLogger fallback; + ILogger& log = logger ? *logger : fallback; + throw Exception("FoundryLocalManager has already been created. Call Destroy() first.", log); + } + + // Use a local to ensure full initialization before assigning to the static instance. + std::unique_ptr manager( + new FoundryLocalManager(std::move(configuration), logger)); + instance_ = std::move(manager); } - // Use a local to ensure full initialization before assigning to the static instance. - std::unique_ptr manager(new FoundryLocalManager(std::move(configuration), logger)); - instance_ = std::move(manager); -} + FoundryLocalManager& FoundryLocalManager::Instance() { + if (!instance_) { + throw Exception("FoundryLocalManager has not been created. Call Create() first."); + } + return *instance_; + } -FoundryLocalManager& FoundryLocalManager::Instance() { - if (!instance_) { - throw Exception("FoundryLocalManager has not been created. Call Create() first."); + bool FoundryLocalManager::IsInitialized() noexcept { + return instance_ != nullptr; } - return *instance_; -} - -bool FoundryLocalManager::IsInitialized() noexcept { - return instance_ != nullptr; -} - -void FoundryLocalManager::Destroy() noexcept { - instance_.reset(); -} - -FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) - : config_(std::move(configuration)), core_(std::make_unique()), - logger_(logger ? logger : &defaultLogger_) { - static_cast(core_.get())->loadEmbedded(); - Initialize(); - catalog_ = Catalog::Create(core_.get(), logger_); -} - -FoundryLocalManager::~FoundryLocalManager() { - Cleanup(); -} - -void FoundryLocalManager::Cleanup() noexcept { - // Unload all loaded models before tearing down. - if (catalog_) { - try { - auto loadedModels = catalog_->GetLoadedModels(); - for (auto* variant : loadedModels) { - try { - variant->Unload(); - } - catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, - std::string("Error unloading model during destruction: ") + ex.what()); + + void FoundryLocalManager::Destroy() noexcept { + instance_.reset(); + } + + FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) + : config_(std::move(configuration)), core_(std::make_unique()), + logger_(logger ? logger : &defaultLogger_) { + static_cast(core_.get())->loadEmbedded(); + Initialize(); + catalog_ = Catalog::Create(core_.get(), logger_); + } + + FoundryLocalManager::~FoundryLocalManager() { + Cleanup(); + } + + void FoundryLocalManager::Cleanup() noexcept { + // Unload all loaded models before tearing down. + if (catalog_) { + try { + auto loadedModels = catalog_->GetLoadedModels(); + for (auto* variant : loadedModels) { + try { + variant->Unload(); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error unloading model during destruction: ") + ex.what()); + } } } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error retrieving loaded models during destruction: ") + ex.what()); + } } - catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, - std::string("Error retrieving loaded models during destruction: ") + ex.what()); + + if (!urls_.empty()) { + try { + StopWebService(); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error stopping web service during destruction: ") + ex.what()); + } } } - if (!urls_.empty()) { - try { - StopWebService(); - } - catch (const std::exception& ex) { - logger_->Log(LogLevel::Warning, - std::string("Error stopping web service during destruction: ") + ex.what()); - } + const Catalog& FoundryLocalManager::GetCatalog() const { + return *catalog_; } -} -const Catalog& FoundryLocalManager::GetCatalog() const { - return *catalog_; -} + Catalog& FoundryLocalManager::GetCatalog() { + return *catalog_; + } -Catalog& FoundryLocalManager::GetCatalog() { - return *catalog_; -} + void FoundryLocalManager::StartWebService() { + if (!config_.web) { + throw Exception("Web service configuration was not provided.", *logger_); + } -void FoundryLocalManager::StartWebService() { - if (!config_.web) { - throw Exception("Web service configuration was not provided.", *logger_); + auto response = core_->call("start_service", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error starting web service: ") + response.error, *logger_); + } + auto arr = nlohmann::json::parse(response.data); + urls_ = arr.get>(); } - auto response = core_->call("start_service", *logger_); - if (response.HasError()) { - throw Exception(std::string("Error starting web service: ") + response.error, *logger_); - } - auto arr = nlohmann::json::parse(response.data); - urls_ = arr.get>(); -} + void FoundryLocalManager::StopWebService() { + if (!config_.web) { + throw Exception("Web service configuration was not provided.", *logger_); + } -void FoundryLocalManager::StopWebService() { - if (!config_.web) { - throw Exception("Web service configuration was not provided.", *logger_); + auto response = core_->call("stop_service", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error stopping web service: ") + response.error, *logger_); + } + urls_.clear(); } - auto response = core_->call("stop_service", *logger_); - if (response.HasError()) { - throw Exception(std::string("Error stopping web service: ") + response.error, *logger_); + gsl::span FoundryLocalManager::GetUrls() const noexcept { + return urls_; } - urls_.clear(); -} - -gsl::span FoundryLocalManager::GetUrls() const noexcept { - return urls_; -} - -void FoundryLocalManager::EnsureEpsDownloaded() const { - auto response = core_->call("ensure_eps_downloaded", *logger_); - if (response.HasError()) { - throw Exception(std::string("Error ensuring execution providers downloaded: ") + response.error, - *logger_); + + void FoundryLocalManager::EnsureEpsDownloaded() const { + auto response = core_->call("ensure_eps_downloaded", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error ensuring execution providers downloaded: ") + response.error, *logger_); + } } -} -void FoundryLocalManager::Initialize() { - config_.Validate(); + void FoundryLocalManager::Initialize() { + config_.Validate(); - CoreInteropRequest initReq("initialize"); - initReq.AddParam("AppName", config_.app_name); - initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); + CoreInteropRequest initReq("initialize"); + initReq.AddParam("AppName", config_.app_name); + initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); - if (config_.app_data_dir) { - initReq.AddParam("AppDataDir", config_.app_data_dir->string()); - } - if (config_.model_cache_dir) { - initReq.AddParam("ModelCacheDir", config_.model_cache_dir->string()); - } - if (config_.logs_dir) { - initReq.AddParam("LogsDir", config_.logs_dir->string()); - } - if (config_.web && config_.web->urls) { - initReq.AddParam("WebServiceUrls", *config_.web->urls); - } - if (config_.additional_settings) { - for (const auto& [key, value] : *config_.additional_settings) { - if (!key.empty()) { - initReq.AddParam(key, value); + if (config_.app_data_dir) { + initReq.AddParam("AppDataDir", config_.app_data_dir->string()); + } + if (config_.model_cache_dir) { + initReq.AddParam("ModelCacheDir", config_.model_cache_dir->string()); + } + if (config_.logs_dir) { + initReq.AddParam("LogsDir", config_.logs_dir->string()); + } + if (config_.web && config_.web->urls) { + initReq.AddParam("WebServiceUrls", *config_.web->urls); + } + if (config_.additional_settings) { + for (const auto& [key, value] : *config_.additional_settings) { + if (!key.empty()) { + initReq.AddParam(key, value); + } } } - } - std::string initJson = initReq.ToJson(); - auto initResponse = core_->call(initReq.Command(), *logger_, &initJson); - if (initResponse.HasError()) { - throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + initResponse.error, *logger_); - } - - if (config_.model_cache_dir) { - auto cacheResponse = core_->call("get_cache_directory", *logger_); - if (cacheResponse.HasError()) { - throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + cacheResponse.error, *logger_); + std::string initJson = initReq.ToJson(); + auto initResponse = core_->call(initReq.Command(), *logger_, &initJson); + if (initResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + initResponse.error, *logger_); } - if (cacheResponse.data != config_.model_cache_dir->string()) { - CoreInteropRequest setReq("set_cache_directory"); - setReq.AddParam("Directory", config_.model_cache_dir->string()); - std::string setJson = setReq.ToJson(); - auto setResponse = core_->call(setReq.Command(), *logger_, &setJson); - if (setResponse.HasError()) { - throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + setResponse.error, + if (config_.model_cache_dir) { + auto cacheResponse = core_->call("get_cache_directory", *logger_); + if (cacheResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + cacheResponse.error, *logger_); } + + if (cacheResponse.data != config_.model_cache_dir->string()) { + CoreInteropRequest setReq("set_cache_directory"); + setReq.AddParam("Directory", config_.model_cache_dir->string()); + std::string setJson = setReq.ToJson(); + auto setResponse = core_->call(setReq.Command(), *logger_, &setJson); + if (setResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + setResponse.error, + *logger_); + } + } } } -} } // namespace foundry_local diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index 57e660d0..17b43021 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -19,195 +19,194 @@ namespace foundry_local { -using namespace detail; + using namespace detail; -/// ModelVariant + /// ModelVariant -ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, - gsl::not_null logger) - : core_(core), info_(std::move(info)), logger_(logger) {} + ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) + : core_(core), info_(std::move(info)), logger_(logger) {} -const ModelInfo& ModelVariant::GetInfo() const { - return info_; -} - -void ModelVariant::RemoveFromCache() { - auto response = CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); - if (response.HasError()) { - throw Exception("Error removing model from cache [" + info_.name + "]: " + response.error, *logger_); + const ModelInfo& ModelVariant::GetInfo() const { + return info_; } - cachedPath_.clear(); -} -void ModelVariant::Unload() { - auto response = CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); - if (response.HasError()) { - throw Exception("Error unloading model [" + info_.name + "]: " + response.error, *logger_); + void ModelVariant::RemoveFromCache() { + auto response = CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error removing model from cache [" + info_.name + "]: " + response.error, *logger_); + } + cachedPath_.clear(); } -} -bool ModelVariant::IsLoaded() const { - std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); - for (const auto& id : loadedModelIds) { - if (id == info_.id) { - return true; + void ModelVariant::Unload() { + auto response = CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error unloading model [" + info_.name + "]: " + response.error, *logger_); } } - return false; -} + bool ModelVariant::IsLoaded() const { + std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); + for (const auto& id : loadedModelIds) { + if (id == info_.id) { + return true; + } + } + + return false; + } -bool ModelVariant::IsCached() const { - auto cachedModels = GetCachedModelsInternal(core_, *logger_); - for (const auto& id : cachedModels) { - if (id == info_.id) { - return true; + bool ModelVariant::IsCached() const { + auto cachedModels = GetCachedModelsInternal(core_, *logger_); + for (const auto& id : cachedModels) { + if (id == info_.id) { + return true; + } } + return false; } - return false; -} - -void ModelVariant::Download(DownloadProgressCallback onProgress) { - if (IsCached()) { - logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); - return; - } - - if (onProgress) { - struct ProgressState { - DownloadProgressCallback* cb; - ILogger* logger; - } state{&onProgress, logger_}; - - auto nativeCallback = [](void* data, int32_t len, void* user) { - if (!data || len <= 0) - return; - auto* st = static_cast(user); - std::string perc(static_cast(data), static_cast(len)); - try { - float value = std::stof(perc); - (*(st->cb))(value); + + void ModelVariant::Download(DownloadProgressCallback onProgress) { + if (IsCached()) { + logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); + return; + } + + if (onProgress) { + struct ProgressState { + DownloadProgressCallback* cb; + ILogger* logger; + } state{&onProgress, logger_}; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + auto* st = static_cast(user); + std::string perc(static_cast(data), static_cast(len)); + try { + float value = std::stof(perc); + (*(st->cb))(value); + } + catch (...) { + st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + } + }; + + auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, + +nativeCallback, &state); + if (response.HasError()) { + throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); } - catch (...) { - st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + } + else { + auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); } - }; - - auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, - +nativeCallback, &state); - if (response.HasError()) { - throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); } } - else { - auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); + + void ModelVariant::Load() { + auto response = CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); if (response.HasError()) { - throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); + throw Exception("Error loading model [" + info_.name + "]: " + response.error, *logger_); } } -} - -void ModelVariant::Load() { - auto response = CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); - if (response.HasError()) { - throw Exception("Error loading model [" + info_.name + "]: " + response.error, *logger_); - } -} -const std::filesystem::path& ModelVariant::GetPath() const { - if (cachedPath_.empty()) { - auto response = CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_); - if (response.HasError()) { - throw Exception("Error getting model path [" + info_.name + "]: " + response.error, *logger_); + const std::filesystem::path& ModelVariant::GetPath() const { + if (cachedPath_.empty()) { + auto response = CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error getting model path [" + info_.name + "]: " + response.error, *logger_); + } + cachedPath_ = std::filesystem::path(response.data); } - cachedPath_ = std::filesystem::path(response.data); + return cachedPath_; } - return cachedPath_; -} -const std::string& ModelVariant::GetId() const noexcept { - return info_.id; -} + const std::string& ModelVariant::GetId() const noexcept { + return info_.id; + } -const std::string& ModelVariant::GetAlias() const noexcept { - return info_.alias; -} + const std::string& ModelVariant::GetAlias() const noexcept { + return info_.alias; + } -uint32_t ModelVariant::GetVersion() const noexcept { - return info_.version; -} + uint32_t ModelVariant::GetVersion() const noexcept { + return info_.version; + } -IModel::CoreAccess ModelVariant::GetCoreAccess() const { - return {core_, info_.name, logger_}; -} + IModel::CoreAccess ModelVariant::GetCoreAccess() const { + return {core_, info_.name, logger_}; + } -OpenAIAudioClient ModelVariant::GetAudioClient() const { - return OpenAIAudioClient(*this); -} + OpenAIAudioClient ModelVariant::GetAudioClient() const { + return OpenAIAudioClient(*this); + } -OpenAIChatClient ModelVariant::GetChatClient() const { - return OpenAIChatClient(*this); -} + OpenAIChatClient ModelVariant::GetChatClient() const { + return OpenAIChatClient(*this); + } -/// Model + /// Model -Model::Model(gsl::not_null core, gsl::not_null logger) - : core_(core), logger_(logger) {} + Model::Model(gsl::not_null core, gsl::not_null logger) + : core_(core), logger_(logger) {} -ModelVariant& Model::SelectedVariant() { - if (!selectedVariant_) { - throw Exception("Model has no selected variant", *logger_); + ModelVariant& Model::SelectedVariant() { + if (!selectedVariant_) { + throw Exception("Model has no selected variant", *logger_); + } + return *const_cast(selectedVariant_); } - return *const_cast(selectedVariant_); -} -const ModelVariant& Model::SelectedVariant() const { - if (!selectedVariant_) { - throw Exception("Model has no selected variant", *logger_); + const ModelVariant& Model::SelectedVariant() const { + if (!selectedVariant_) { + throw Exception("Model has no selected variant", *logger_); + } + return *selectedVariant_; } - return *selectedVariant_; -} -gsl::span Model::GetAllModelVariants() const { - return variants_; -} + gsl::span Model::GetAllModelVariants() const { + return variants_; + } -const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { - const auto& targetName = variant.GetInfo().name; + const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { + const auto& targetName = variant.GetInfo().name; - for (const auto& v : variants_) { - // The variants returned by the catalog are sorted by version, so the first match should always be the latest version. - if (v.GetInfo().name == targetName) { - return v; + for (const auto& v : variants_) { + // The variants returned by the catalog are sorted by version, so the first match should always be the + // latest version. + if (v.GetInfo().name == targetName) { + return v; + } } + + throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); } - throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", - *logger_); -} + const std::string& Model::GetId() const { + return SelectedVariant().GetId(); + } -const std::string& Model::GetId() const { - return SelectedVariant().GetId(); -} + const std::string& Model::GetAlias() const { + return SelectedVariant().GetAlias(); + } -const std::string& Model::GetAlias() const { - return SelectedVariant().GetAlias(); -} + void Model::SelectVariant(const ModelVariant& variant) const { + auto it = + std::find_if(variants_.begin(), variants_.end(), [&](const ModelVariant& v) { return &v == &variant; }); -void Model::SelectVariant(const ModelVariant& variant) const { - auto it = std::find_if(variants_.begin(), variants_.end(), - [&](const ModelVariant& v) { return &v == &variant; }); + if (it == variants_.end()) { + throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); + } - if (it == variants_.end()) { - throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", - *logger_); + selectedVariant_ = &(*it); } - selectedVariant_ = &(*it); -} - -IModel::CoreAccess Model::GetCoreAccess() const { - return SelectedVariant().GetCoreAccess(); -} + IModel::CoreAccess Model::GetCoreAccess() const { + return SelectedVariant().GetCoreAccess(); + } } // namespace foundry_local diff --git a/sdk/cpp/src/openai_audio_client.cpp b/sdk/cpp/src/openai_audio_client.cpp index 75c0110b..d4409d1f 100644 --- a/sdk/cpp/src/openai_audio_client.cpp +++ b/sdk/cpp/src/openai_audio_client.cpp @@ -18,52 +18,53 @@ namespace foundry_local { -OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger) - : core_(core), modelId_(modelId), logger_(logger) {} + OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} -AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio(const std::filesystem::path& audioFilePath) const { - nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; - CoreInteropRequest req("audio_transcribe"); - req.AddParam("OpenAICreateRequest", openAiReq.dump()); + AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio( + const std::filesystem::path& audioFilePath) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); - std::string json = req.ToJson(); + std::string json = req.ToJson(); - auto coreResponse = core_->call(req.Command(), *logger_, &json); - if (coreResponse.HasError()) { - throw Exception("Audio transcription failed: " + coreResponse.error, *logger_); - } + auto coreResponse = core_->call(req.Command(), *logger_, &json); + if (coreResponse.HasError()) { + throw Exception("Audio transcription failed: " + coreResponse.error, *logger_); + } - AudioCreateTranscriptionResponse response; - response.text = std::move(coreResponse.data); + AudioCreateTranscriptionResponse response; + response.text = std::move(coreResponse.data); - return response; -} + return response; + } -void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, - const StreamCallback& onChunk) const { - nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; - CoreInteropRequest req("audio_transcribe"); - req.AddParam("OpenAICreateRequest", openAiReq.dump()); + void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, + const StreamCallback& onChunk) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); - std::string json = req.ToJson(); + std::string json = req.ToJson(); - detail::CallWithStreamingCallback( - core_, req.Command(), json, *logger_, - [&onChunk](const std::string& text) { - AudioCreateTranscriptionResponse chunk; - chunk.text = text; - onChunk(chunk); - }, - "Streaming audio transcription failed: "); -} + detail::CallWithStreamingCallback( + core_, req.Command(), json, *logger_, + [&onChunk](const std::string& text) { + AudioCreateTranscriptionResponse chunk; + chunk.text = text; + onChunk(chunk); + }, + "Streaming audio transcription failed: "); + } -OpenAIAudioClient::OpenAIAudioClient(const IModel& model) - : OpenAIAudioClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { - if (!model.IsLoaded()) { - throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", - *model.GetCoreAccess().logger); + OpenAIAudioClient::OpenAIAudioClient(const IModel& model) + : OpenAIAudioClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } } -} } // namespace foundry_local diff --git a/sdk/cpp/src/openai_chat_client.cpp b/sdk/cpp/src/openai_chat_client.cpp index 3f0d95cd..84dc1a24 100644 --- a/sdk/cpp/src/openai_chat_client.cpp +++ b/sdk/cpp/src/openai_chat_client.cpp @@ -19,121 +19,121 @@ namespace foundry_local { -std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { - if (created == 0) - return {}; - std::time_t t = static_cast(created); - std::tm tm{}; + std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { + if (created == 0) + return {}; + std::time_t t = static_cast(created); + std::tm tm{}; #ifdef _WIN32 - gmtime_s(&tm, &t); + gmtime_s(&tm, &t); #else - gmtime_r(&t, &tm); + gmtime_r(&t, &tm); #endif - char buf[32]; - std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); - return buf; -} - -OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, - gsl::not_null logger) - : core_(core), modelId_(modelId), logger_(logger) {} - -std::string OpenAIChatClient::BuildChatRequestJson(gsl::span messages, - gsl::span tools, - const ChatSettings& settings, bool stream) const { - nlohmann::json jMessages = nlohmann::json::array(); - for (const auto& msg : messages) { - nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; - if (msg.tool_call_id) - jMsg["tool_call_id"] = *msg.tool_call_id; - jMessages.push_back(std::move(jMsg)); + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); + return buf; } - nlohmann::json req = {{"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream}}; + OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} + + std::string OpenAIChatClient::BuildChatRequestJson(gsl::span messages, + gsl::span tools, + const ChatSettings& settings, bool stream) const { + nlohmann::json jMessages = nlohmann::json::array(); + for (const auto& msg : messages) { + nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; + if (msg.tool_call_id) + jMsg["tool_call_id"] = *msg.tool_call_id; + jMessages.push_back(std::move(jMsg)); + } + + nlohmann::json req = {{"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream}}; + + if (!tools.empty()) { + nlohmann::json jTools = nlohmann::json::array(); + for (const auto& tool : tools) { + nlohmann::json jTool; + to_json(jTool, tool); + jTools.push_back(std::move(jTool)); + } + req["tools"] = std::move(jTools); + } + + if (settings.tool_choice) + req["tool_choice"] = ParsingUtils::tool_choice_to_string(*settings.tool_choice); + if (settings.top_k) + req["metadata"] = {{"top_k", *settings.top_k}}; + if (settings.frequency_penalty) + req["frequency_penalty"] = *settings.frequency_penalty; + if (settings.presence_penalty) + req["presence_penalty"] = *settings.presence_penalty; + if (settings.max_tokens) + req["max_completion_tokens"] = *settings.max_tokens; + if (settings.n) + req["n"] = *settings.n; + if (settings.temperature) + req["temperature"] = *settings.temperature; + if (settings.top_p) + req["top_p"] = *settings.top_p; + if (settings.random_seed) + req["seed"] = *settings.random_seed; + + return req.dump(); + } + + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + const ChatSettings& settings) const { + return CompleteChat(messages, {}, settings); + } + + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); - if (!tools.empty()) { - nlohmann::json jTools = nlohmann::json::array(); - for (const auto& tool : tools) { - nlohmann::json jTool; - to_json(jTool, tool); - jTools.push_back(std::move(jTool)); + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + auto response = core_->call(req.Command(), *logger_, &json); + if (response.HasError()) { + throw Exception("Chat completion failed: " + response.error, *logger_); } - req["tools"] = std::move(jTools); + + return nlohmann::json::parse(response.data).get(); } - if (settings.tool_choice) - req["tool_choice"] = ParsingUtils::tool_choice_to_string(*settings.tool_choice); - if (settings.top_k) - req["metadata"] = {{"top_k", *settings.top_k}}; - if (settings.frequency_penalty) - req["frequency_penalty"] = *settings.frequency_penalty; - if (settings.presence_penalty) - req["presence_penalty"] = *settings.presence_penalty; - if (settings.max_tokens) - req["max_completion_tokens"] = *settings.max_tokens; - if (settings.n) - req["n"] = *settings.n; - if (settings.temperature) - req["temperature"] = *settings.temperature; - if (settings.top_p) - req["top_p"] = *settings.top_p; - if (settings.random_seed) - req["seed"] = *settings.random_seed; - - return req.dump(); -} - -ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, - const ChatSettings& settings) const { - return CompleteChat(messages, {}, settings); -} - -ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, - gsl::span tools, - const ChatSettings& settings) const { - std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); - - CoreInteropRequest req("chat_completions"); - req.AddParam("OpenAICreateRequest", openAiReqJson); - - std::string json = req.ToJson(); - auto response = core_->call(req.Command(), *logger_, &json); - if (response.HasError()) { - throw Exception("Chat completion failed: " + response.error, *logger_); + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const { + CompleteChatStreaming(messages, {}, settings, onChunk); } - return nlohmann::json::parse(response.data).get(); -} - -void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, - const StreamCallback& onChunk) const { - CompleteChatStreaming(messages, {}, settings, onChunk); -} - -void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, - gsl::span tools, const ChatSettings& settings, - const StreamCallback& onChunk) const { - std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); - - CoreInteropRequest req("chat_completions"); - req.AddParam("OpenAICreateRequest", openAiReqJson); - std::string json = req.ToJson(); - - detail::CallWithStreamingCallback( - core_, req.Command(), json, *logger_, - [&onChunk](const std::string& chunk) { - auto parsed = nlohmann::json::parse(chunk).get(); - onChunk(parsed); - }, - "Streaming chat completion failed: "); -} - -OpenAIChatClient::OpenAIChatClient(const IModel& model) - : OpenAIChatClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { - if (!model.IsLoaded()) { - throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", - *model.GetCoreAccess().logger); + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, + gsl::span tools, const ChatSettings& settings, + const StreamCallback& onChunk) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + std::string json = req.ToJson(); + + detail::CallWithStreamingCallback( + core_, req.Command(), json, *logger_, + [&onChunk](const std::string& chunk) { + auto parsed = nlohmann::json::parse(chunk).get(); + onChunk(parsed); + }, + "Streaming chat completion failed: "); + } + + OpenAIChatClient::OpenAIChatClient(const IModel& model) + : OpenAIChatClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } } -} } // namespace foundry_local diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 7d28392e..6c568374 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -91,9 +91,12 @@ namespace foundry_local { static std::string tool_choice_to_string(ToolChoiceKind kind) { switch (kind) { - case ToolChoiceKind::Auto: return "auto"; - case ToolChoiceKind::None: return "none"; - case ToolChoiceKind::Required: return "required"; + case ToolChoiceKind::Auto: + return "auto"; + case ToolChoiceKind::None: + return "none"; + case ToolChoiceKind::Required: + return "required"; } return "auto"; } diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp index af019024..3f60e0b4 100644 --- a/sdk/cpp/test/catalog_test.cpp +++ b/sdk/cpp/test/catalog_test.cpp @@ -93,35 +93,35 @@ TEST_F(CatalogTest, ListModels_IncludesOpenAIPrefix) { } TEST_F(CatalogTest, GetModel_Found) { -core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); -auto catalog = MakeCatalog(); + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); -auto* model = catalog->GetModel("my-model"); + auto* model = catalog->GetModel("my-model"); ASSERT_NE(nullptr, model); EXPECT_EQ("my-model", model->GetAlias()); } TEST_F(CatalogTest, GetModel_NotFound) { -core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); -auto catalog = MakeCatalog(); + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); -EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); + EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); } TEST_F(CatalogTest, GetModelVariant_Found) { -core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); -auto catalog = MakeCatalog(); + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); -auto* variant = catalog->GetModelVariant("model-1:1"); + auto* variant = catalog->GetModelVariant("model-1:1"); ASSERT_NE(nullptr, variant); EXPECT_EQ("model-1:1", variant->GetId()); } TEST_F(CatalogTest, GetModelVariant_NotFound) { -core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); -auto catalog = MakeCatalog(); + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); -EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent:1")); + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent:1")); } TEST_F(CatalogTest, GetLoadedModels) { diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 2a09ac77..201fd965 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -55,16 +55,16 @@ TEST_F(OpenAIChatClientTest, CompleteChat_BasicResponse) { } TEST_F(OpenAIChatClientTest, CompleteChat_WithSettings) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); -std::vector messages = {{"user", "test", {}}}; -ChatSettings settings; -settings.temperature = 0.7f; -settings.max_tokens = 100; + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 100; settings.top_p = 0.9f; settings.frequency_penalty = 0.5f; settings.presence_penalty = 0.3f; @@ -89,15 +89,15 @@ settings.max_tokens = 100; } TEST_F(OpenAIChatClientTest, CompleteChat_RequestFormat) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); -std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; -ChatSettings settings; -auto response = client.CompleteChat(messages, settings); + std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); @@ -110,7 +110,7 @@ auto response = client.CompleteChat(messages, settings); } TEST_F(OpenAIChatClientTest, CompleteChatStreaming) { -nlohmann::json chunk1 = { + nlohmann::json chunk1 = { {"created", 1700000000}, {"id", "chatcmpl-1"}, {"IsDelta", true}, @@ -203,30 +203,22 @@ TEST_F(OpenAIChatClientTest, GetModelId) { // ---------- Tool calling tests ---------- TEST_F(OpenAIChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); std::vector messages = {{"user", "What is 7 * 6?", {}}}; - std::vector tools = {{ - "function", - FunctionDefinition{ - "multiply_numbers", - "A tool for multiplying two numbers.", - PropertyDefinition{ - "object", - std::nullopt, - std::unordered_map{ - {"first", PropertyDefinition{"integer", "The first number"}}, - {"second", PropertyDefinition{"integer", "The second number"}} - }, - std::vector{"first", "second"} - } - } - }}; + std::vector tools = { + {"function", + FunctionDefinition{"multiply_numbers", "A tool for multiplying two numbers.", + PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}}}, + std::vector{"first", "second"}}}}}; ChatSettings settings; settings.tool_choice = ToolChoiceKind::Required; @@ -242,7 +234,8 @@ OpenAIChatClient client(variant); EXPECT_EQ(1u, openAiReq["tools"].size()); EXPECT_EQ("function", openAiReq["tools"][0]["type"].get()); EXPECT_EQ("multiply_numbers", openAiReq["tools"][0]["function"]["name"].get()); - EXPECT_EQ("A tool for multiplying two numbers.", openAiReq["tools"][0]["function"]["description"].get()); + EXPECT_EQ("A tool for multiplying two numbers.", + openAiReq["tools"][0]["function"]["description"].get()); EXPECT_EQ("object", openAiReq["tools"][0]["function"]["parameters"]["type"].get()); EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"].contains("properties")); EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("first")); @@ -252,11 +245,11 @@ OpenAIChatClient client(variant); } TEST_F(OpenAIChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); std::vector messages = {{"user", "Hello", {}}}; ChatSettings settings; @@ -282,7 +275,8 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallResponse_Parsed) { {"finish_reason", "tool_calls"}, {"message", {{"role", "assistant"}, - {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": " + "6}}]"}, {"tool_calls", {{{"id", "call_1"}, {"type", "function"}, @@ -359,10 +353,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { toolMsg.content = "42"; toolMsg.tool_call_id = "call_1"; - std::vector messages = { - {"user", "What is 7 * 6?", {}}, - std::move(toolMsg) - }; + std::vector messages = {{"user", "What is 7 * 6?", {}}, std::move(toolMsg)}; ChatSettings settings; client.CompleteChat(messages, settings); @@ -383,24 +374,21 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { {"Successful", true}, {"HttpStatusCode", 200}, {"choices", - {{{"index", 0}, - {"finish_reason", nullptr}, - {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; - nlohmann::json chunk2 = { - {"created", 1700000000}, - {"id", "chatcmpl-1"}, - {"IsDelta", true}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, - {"finish_reason", "tool_calls"}, - {"delta", - {{"content", ""}, - {"tool_calls", - {{{"id", "call_1"}, - {"type", "function"}, - {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; + nlohmann::json chunk2 = {{"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"delta", + {{"content", ""}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; core_.OnCall("chat_completions", [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { @@ -419,10 +407,7 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { std::vector messages = {{"user", "test", {}}}; - std::vector tools = {{ - "function", - FunctionDefinition{"multiply", "Multiply numbers."} - }}; + std::vector tools = {{"function", FunctionDefinition{"multiply", "Multiply numbers."}}}; ChatSettings settings; settings.tool_choice = ToolChoiceKind::Required; @@ -456,22 +441,22 @@ class OpenAIAudioClientTest : public ::testing::Test { }; TEST_F(OpenAIAudioClientTest, TranscribeAudio) { -core_.OnCall("audio_transcribe", "Hello world transcribed text"); -core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + core_.OnCall("audio_transcribe", "Hello world transcribed text"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); -auto variant = MakeLoadedVariant(); -OpenAIAudioClient client(variant); + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); auto response = client.TranscribeAudio("test.wav"); EXPECT_EQ("Hello world transcribed text", response.text); } TEST_F(OpenAIAudioClientTest, TranscribeAudio_RequestFormat) { -core_.OnCall("audio_transcribe", "text"); -core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + core_.OnCall("audio_transcribe", "text"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); -auto variant = MakeLoadedVariant(); -OpenAIAudioClient client(variant); + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); client.TranscribeAudio("audio.wav"); auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); @@ -555,9 +540,8 @@ TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_CoreError_Throws) { auto variant = MakeLoadedVariant(); OpenAIAudioClient client(variant); - EXPECT_THROW( - client.TranscribeAudioStreaming("test.wav", [](const AudioCreateTranscriptionResponse&) {}), - Exception); + EXPECT_THROW(client.TranscribeAudioStreaming("test.wav", [](const AudioCreateTranscriptionResponse&) {}), + Exception); } // ===================================================================== @@ -572,9 +556,7 @@ TEST_F(OpenAIChatClientTest, CompleteChat_MultiTurn) { auto variant = MakeLoadedVariant(); OpenAIChatClient client(variant); - std::vector messages = { - {"user", "What is 7 * 6?", {}} - }; + std::vector messages = {{"user", "What is 7 * 6?", {}}}; ChatSettings settings; auto response = client.CompleteChat(messages, settings); @@ -624,10 +606,8 @@ TEST_F(OpenAIChatClientTest, CompleteChatStreaming_CoreError_Throws) { std::vector messages = {{"user", "Hello", {}}}; ChatSettings settings; - EXPECT_THROW( - client.CompleteChatStreaming(messages, settings, - [](const ChatCompletionCreateResponse&) {}), - Exception); + EXPECT_THROW(client.CompleteChatStreaming(messages, settings, [](const ChatCompletionCreateResponse&) {}), + Exception); } // ===================================================================== @@ -647,7 +627,8 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { {"finish_reason", "tool_calls"}, {"message", {{"role", "assistant"}, - {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": " + "6}}]"}, {"tool_calls", {{{"id", "call_1"}, {"type", "function"}, @@ -659,27 +640,17 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { auto variant = MakeLoadedVariant(); OpenAIChatClient client(variant); - std::vector messages = { - {"system", "You are a helpful AI assistant.", {}}, - {"user", "What is 7 multiplied by 6?", {}} - }; - - std::vector tools = {{ - "function", - FunctionDefinition{ - "multiply_numbers", - "A tool for multiplying two numbers.", - PropertyDefinition{ - "object", - std::nullopt, - std::unordered_map{ - {"first", PropertyDefinition{"integer", "The first number"}}, - {"second", PropertyDefinition{"integer", "The second number"}} - }, - std::vector{"first", "second"} - } - } - }}; + std::vector messages = {{"system", "You are a helpful AI assistant.", {}}, + {"user", "What is 7 multiplied by 6?", {}}}; + + std::vector tools = { + {"function", + FunctionDefinition{"multiply_numbers", "A tool for multiplying two numbers.", + PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}}}, + std::vector{"first", "second"}}}}}; ChatSettings settings; settings.tool_choice = ToolChoiceKind::Required; diff --git a/sdk/cpp/test/e2e_test.cpp b/sdk/cpp/test/e2e_test.cpp index 8b506d23..56625902 100644 --- a/sdk/cpp/test/e2e_test.cpp +++ b/sdk/cpp/test/e2e_test.cpp @@ -23,9 +23,11 @@ using namespace foundry_local; static bool IsRunningInCI() { auto check = [](const char* var) -> bool { const char* val = std::getenv(var); - if (!val) return false; + if (!val) + return false; std::string s(val); - for (auto& c : s) c = static_cast(std::tolower(static_cast(c))); + for (auto& c : s) + c = static_cast(std::tolower(static_cast(c))); return s == "true" || s == "1"; }; return check("TF_BUILD") || check("GITHUB_ACTIONS") || check("CI"); @@ -49,9 +51,7 @@ class EndToEndTest : public ::testing::Test { } } - static void TearDownTestSuite() { - FoundryLocalManager::Destroy(); - } + static void TearDownTestSuite() { FoundryLocalManager::Destroy(); } void SetUp() override { if (!FoundryLocalManager::IsInitialized()) { @@ -59,9 +59,7 @@ class EndToEndTest : public ::testing::Test { } } - static bool IsAudioModel(const std::string& alias) { - return alias.find("whisper") != std::string::npos; - } + static bool IsAudioModel(const std::string& alias) { return alias.find("whisper") != std::string::npos; } /// Find a chat-capable model, preferring cached, then known small models, then any. /// Selects the CPU variant when available to avoid GPU/EP dependency issues. @@ -72,14 +70,16 @@ class EndToEndTest : public ::testing::Test { for (auto* variant : cached) { if (!IsAudioModel(variant->GetAlias())) { target = catalog.GetModel(variant->GetAlias()); - if (target) break; + if (target) + break; } } if (!target) { for (const auto& alias : {"qwen2.5-0.5b", "qwen2.5-coder-0.5b", "phi-4-mini"}) { target = catalog.GetModel(alias); - if (target) break; + if (target) + break; } } @@ -114,14 +114,16 @@ class EndToEndTest : public ::testing::Test { for (auto* variant : cached) { if (IsAudioModel(variant->GetAlias())) { target = catalog.GetModel(variant->GetAlias()); - if (target) break; + if (target) + break; } } if (!target) { for (const auto& alias : {"whisper-small", "whisper-tiny"}) { target = catalog.GetModel(alias); - if (target) break; + if (target) + break; } } @@ -134,7 +136,7 @@ class EndToEndTest : public ::testing::Test { // =========================================================================== TEST_F(EndToEndTest, BrowseCatalog_ListsModels) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); EXPECT_FALSE(catalog.GetName().empty()); auto models = catalog.ListModels(); @@ -156,7 +158,7 @@ auto& catalog = FoundryLocalManager::Instance().GetCatalog(); } TEST_F(EndToEndTest, GetCachedModels_Succeeds) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto cached = catalog.GetCachedModels(); for (auto* variant : cached) { EXPECT_FALSE(variant->GetId().empty()); @@ -165,7 +167,7 @@ auto& catalog = FoundryLocalManager::Instance().GetCatalog(); } TEST_F(EndToEndTest, GetLoadedModels_Succeeds) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto loaded = catalog.GetLoadedModels(); for (auto* variant : loaded) { EXPECT_FALSE(variant->GetId().empty()); @@ -174,19 +176,19 @@ auto& catalog = FoundryLocalManager::Instance().GetCatalog(); } TEST_F(EndToEndTest, GetModel_NotFound_ReturnsNull) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto* model = catalog.GetModel("this-model-does-not-exist-12345"); EXPECT_EQ(model, nullptr); } TEST_F(EndToEndTest, GetModelVariant_NotFound_ReturnsNull) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto* variant = catalog.GetModelVariant("nonexistent-model:999"); EXPECT_EQ(variant, nullptr); } TEST_F(EndToEndTest, GetModelVariant_Found) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto models = catalog.ListModels(); if (models.empty()) { GTEST_SKIP() << "No models in catalog"; @@ -199,7 +201,7 @@ auto& catalog = FoundryLocalManager::Instance().GetCatalog(); } TEST_F(EndToEndTest, ModelVariantInfo_HasRequiredFields) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto models = catalog.ListModels(); if (models.empty()) { GTEST_SKIP() << "No models in catalog"; @@ -218,7 +220,7 @@ auto& catalog = FoundryLocalManager::Instance().GetCatalog(); } TEST_F(EndToEndTest, ModelVariant_SelectVariant) { -auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); auto models = catalog.ListModels(); // Find a model with multiple variants @@ -290,8 +292,7 @@ TEST_F(EndToEndTest, DISABLED_DownloadLoadChatUnload) { GTEST_SKIP() << "No chat-capable model found in catalog"; } - std::cout << "[E2E] Using model: " << target->GetAlias() - << " variant: " << target->GetId() << "\n"; + std::cout << "[E2E] Using model: " << target->GetAlias() << " variant: " << target->GetId() << "\n"; // Download (no-op if already cached) bool progressCallbackInvoked = false; @@ -367,14 +368,12 @@ TEST_F(EndToEndTest, DISABLED_StreamingChat) { std::vector chunks; std::string fullContent; - client.CompleteChatStreaming(messages, settings, - [&](const ChatCompletionCreateResponse& chunk) { - chunks.push_back(chunk); - if (!chunk.choices.empty() && chunk.choices[0].delta.has_value() && - !chunk.choices[0].delta->content.empty()) { - fullContent += chunk.choices[0].delta->content; - } - }); + client.CompleteChatStreaming(messages, settings, [&](const ChatCompletionCreateResponse& chunk) { + chunks.push_back(chunk); + if (!chunk.choices.empty() && chunk.choices[0].delta.has_value() && !chunk.choices[0].delta->content.empty()) { + fullContent += chunk.choices[0].delta->content; + } + }); EXPECT_GT(chunks.size(), 0u) << "Should have received at least one streaming chunk"; EXPECT_FALSE(fullContent.empty()) << "Accumulated streaming content should not be empty"; @@ -425,26 +424,16 @@ TEST_F(EndToEndTest, DISABLED_ChatWithToolCalling) { OpenAIChatClient client(*target); - std::vector tools = {{ - "function", - FunctionDefinition{ - "get_weather", - "Get the current weather for a city.", - PropertyDefinition{ - "object", - std::nullopt, - std::unordered_map{ - {"city", PropertyDefinition{"string", "The city name"}} - }, - std::vector{"city"} - } - } - }}; + std::vector tools = { + {"function", FunctionDefinition{"get_weather", "Get the current weather for a city.", + PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"city", PropertyDefinition{"string", "The city name"}}}, + std::vector{"city"}}}}}; std::vector messages = { {"system", "You are a helpful assistant. Use the provided tools when asked about weather."}, - {"user", "What is the weather in Seattle?"} - }; + {"user", "What is the weather in Seattle?"}}; ChatSettings settings; settings.temperature = 0.0f; @@ -465,8 +454,7 @@ TEST_F(EndToEndTest, DISABLED_ChatWithToolCalling) { ASSERT_TRUE(tc.function_call.has_value()); EXPECT_EQ("get_weather", tc.function_call->name); EXPECT_FALSE(tc.function_call->arguments.empty()); - std::cout << "[E2E] Tool call: " << tc.function_call->name - << " args: " << tc.function_call->arguments << "\n"; + std::cout << "[E2E] Tool call: " << tc.function_call->name << " args: " << tc.function_call->arguments << "\n"; } target->Unload(); @@ -535,11 +523,10 @@ TEST_F(EndToEndTest, DISABLED_AudioTranscriptionStreaming) { std::string fullText; int chunkCount = 0; - client.TranscribeAudioStreaming(audioPath, - [&](const AudioCreateTranscriptionResponse& chunk) { - fullText += chunk.text; - chunkCount++; - }); + client.TranscribeAudioStreaming(audioPath, [&](const AudioCreateTranscriptionResponse& chunk) { + fullText += chunk.text; + chunkCount++; + }); EXPECT_GT(chunkCount, 0) << "Should have received at least one streaming chunk"; EXPECT_FALSE(fullText.empty()); diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h index 136a7ff4..f89af91a 100644 --- a/sdk/cpp/test/mock_core.h +++ b/sdk/cpp/test/mock_core.h @@ -22,7 +22,7 @@ namespace foundry_local::Testing { public: /// Handler signature: (command, dataArgument, callback, userData) -> response string. using Handler = std::function; + NativeCallbackFn callback, void* userData)>; /// Register a fixed response for a command. void OnCall(std::string command, std::string response) { @@ -56,7 +56,7 @@ namespace foundry_local::Testing { // IFoundryLocalCore implementation CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, - NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + NativeCallbackFn callback = nullptr, void* data = nullptr) const override { std::string cmd(command); const_cast(this)->callCounts_[cmd]++; @@ -119,7 +119,7 @@ namespace foundry_local::Testing { } CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, - NativeCallbackFn /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { + NativeCallbackFn /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { CoreResponse resp; diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index ac21fe77..060f629c 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -103,24 +103,24 @@ TEST_F(ModelVariantTest, Unload_ThrowsOnError) { } TEST_F(ModelVariantTest, Download_NoCallback) { -core_.OnCall("get_cached_models", R"([])"); -core_.OnCall("download_model", ""); -auto variant = MakeVariant("test-model"); -variant.Download(); + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", ""); + auto variant = MakeVariant("test-model"); + variant.Download(); EXPECT_EQ(1, core_.GetCallCount("download_model")); } TEST_F(ModelVariantTest, Download_WithCallback) { -core_.OnCall("get_cached_models", R"([])"); -core_.OnCall("download_model", -[](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { - // Simulate calling the progress callback - if (callback && userData) { - std::string progress = "50"; - callback(progress.data(), static_cast(progress.size()), userData); - } - return ""; -}); + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + // Simulate calling the progress callback + if (callback && userData) { + std::string progress = "50"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); auto variant = MakeVariant("test-model"); float lastProgress = -1.0f; @@ -183,29 +183,29 @@ TEST_F(ModelTest, SelectedVariant_ThrowsWhenEmpty) { } TEST_F(ModelTest, AddVariant_AndSelect) { -auto model = MakeModel(); -Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); -Factory::SelectFirstVariant(model); + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SelectFirstVariant(model); EXPECT_EQ("v1:1", model.GetId()); EXPECT_EQ("alias", model.GetAlias()); } TEST_F(ModelTest, GetAllModelVariants) { -auto model = MakeModel(); -Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); -Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); -Factory::SelectFirstVariant(model); + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SelectFirstVariant(model); auto variants = model.GetAllModelVariants(); EXPECT_EQ(2u, variants.size()); } TEST_F(ModelTest, SelectVariant) { -auto model = MakeModel(); -Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); -Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); -Factory::SelectFirstVariant(model); + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SelectFirstVariant(model); const auto& v2 = model.GetAllModelVariants()[1]; model.SelectVariant(v2); @@ -213,19 +213,19 @@ Factory::SelectFirstVariant(model); } TEST_F(ModelTest, SelectVariant_NotFound_Throws) { -auto model = MakeModel(); -Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); -Factory::SelectFirstVariant(model); + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SelectFirstVariant(model); auto external = MakeVariant("external", "alias", 1); EXPECT_THROW(model.SelectVariant(external), Exception); } TEST_F(ModelTest, GetLatestVariant) { -auto model = MakeModel(); -Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); -Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); -Factory::SelectFirstVariant(model); + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); + Factory::SelectFirstVariant(model); const auto& first = model.GetAllModelVariants()[0]; const auto& latest = model.GetLatestVersion(first); diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp index 00cac8da..681e912f 100644 --- a/sdk/cpp/test/parser_and_types_test.cpp +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -17,7 +17,7 @@ using namespace foundry_local::Testing; class ParserTest : public ::testing::Test { protected: static nlohmann::json MinimalModelJson() { - return nlohmann::json{{"id", "model-1:1"}, {"name", "model-1"}, {"version", 1}, + return nlohmann::json{{"id", "model-1:1"}, {"name", "model-1"}, {"version", 1}, {"alias", "my-model"}, {"providerType", "onnx"}, {"uri", "https://example.com/model"}, {"modelType", "text"}, {"cached", false}, {"createdAt", 1700000000}}; } @@ -208,13 +208,12 @@ TEST_F(ParserTest, ParseChatMessage) { } TEST_F(ParserTest, ParseChatMessage_WithToolCalls) { - nlohmann::json j = { - {"role", "assistant"}, - {"content", "I'll call a tool."}, - {"tool_calls", - {{{"id", "call_abc123"}, - {"type", "function"}, - {"function", {{"name", "get_weather"}, {"arguments", "{\"city\": \"Seattle\"}"}}}}}}}; + nlohmann::json j = {{"role", "assistant"}, + {"content", "I'll call a tool."}, + {"tool_calls", + {{{"id", "call_abc123"}, + {"type", "function"}, + {"function", {{"name", "get_weather"}, {"arguments", "{\"city\": \"Seattle\"}"}}}}}}}; ChatMessage msg = j.get(); EXPECT_EQ("assistant", msg.role); ASSERT_EQ(1u, msg.tool_calls.size()); @@ -226,10 +225,7 @@ TEST_F(ParserTest, ParseChatMessage_WithToolCalls) { } TEST_F(ParserTest, ParseChatMessage_WithToolCallId) { - nlohmann::json j = { - {"role", "tool"}, - {"content", "72 degrees and sunny"}, - {"tool_call_id", "call_abc123"}}; + nlohmann::json j = {{"role", "tool"}, {"content", "72 degrees and sunny"}, {"tool_call_id", "call_abc123"}}; ChatMessage msg = j.get(); EXPECT_EQ("tool", msg.role); EXPECT_EQ("72 degrees and sunny", msg.content); @@ -252,10 +248,9 @@ TEST_F(ParserTest, ParseFunctionCall_ObjectArguments) { } TEST_F(ParserTest, ParseToolCall) { - nlohmann::json j = { - {"id", "call_1"}, - {"type", "function"}, - {"function", {{"name", "search"}, {"arguments", "{\"query\": \"test\"}"}}}}; + nlohmann::json j = {{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "search"}, {"arguments", "{\"query\": \"test\"}"}}}}; ToolCall tc = j.get(); EXPECT_EQ("call_1", tc.id); EXPECT_EQ("function", tc.type); @@ -268,14 +263,10 @@ TEST_F(ParserTest, SerializeToolDefinition) { tool.type = "function"; tool.function.name = "get_weather"; tool.function.description = "Get the current weather"; - tool.function.parameters = PropertyDefinition{ - "object", - std::nullopt, - std::unordered_map{ - {"location", PropertyDefinition{"string", "The city name"}} - }, - std::vector{"location"} - }; + tool.function.parameters = PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"location", PropertyDefinition{"string", "The city name"}}}, + std::vector{"location"}}; nlohmann::json j; to_json(j, tool); From 3caafd206bfbd8f5d48f5f2e822f7b87ada27ad2 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 2 Apr 2026 14:15:03 -0700 Subject: [PATCH 16/18] nit fixes --- sdk/cpp/CMakeLists.txt | 2 - sdk/cpp/client_test.cpp | 722 ------------------- sdk/cpp/include/model.h | 22 +- sdk/cpp/include/openai/openai_audio_client.h | 3 - sdk/cpp/include/openai/openai_chat_client.h | 3 - sdk/cpp/src/model.cpp | 8 - sdk/cpp/test/e2e_test.cpp | 1 - 7 files changed, 5 insertions(+), 756 deletions(-) delete mode 100644 sdk/cpp/client_test.cpp diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index 080815c4..bc3f4844 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -1,7 +1,5 @@ cmake_minimum_required(VERSION 3.20) -cmake_minimum_required(VERSION 3.20) - # VS hot reload policy (safe-guarded) if (POLICY CMP0141) cmake_policy(SET CMP0141 NEW) diff --git a/sdk/cpp/client_test.cpp b/sdk/cpp/client_test.cpp deleted file mode 100644 index 2a09ac77..00000000 --- a/sdk/cpp/client_test.cpp +++ /dev/null @@ -1,722 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include - -#include "mock_core.h" -#include "mock_object_factory.h" -#include "parser.h" -#include "foundry_local_exception.h" - -#include - -using namespace foundry_local; -using namespace foundry_local::Testing; - -using Factory = MockObjectFactory; - -class OpenAIChatClientTest : public ::testing::Test { -protected: - MockCore core_; - NullLogger logger_; - - std::string MakeChatResponseJson(const std::string& content = "Hello!") { - nlohmann::json resp = { - {"created", 1700000000}, - {"id", "chatcmpl-test"}, - {"IsDelta", false}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", content}}}}}}}; - return resp.dump(); - } - - ModelVariant MakeLoadedVariant(const std::string& name = "chat-model") { - core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); - return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); - } -}; - -TEST_F(OpenAIChatClientTest, CompleteChat_BasicResponse) { - core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "Say hello", {}}}; - ChatSettings settings; - auto response = client.CompleteChat(messages, settings); - - EXPECT_TRUE(response.successful); - ASSERT_EQ(1u, response.choices.size()); - EXPECT_EQ("Hello world!", response.choices[0].message->content); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_WithSettings) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); - -std::vector messages = {{"user", "test", {}}}; -ChatSettings settings; -settings.temperature = 0.7f; -settings.max_tokens = 100; - settings.top_p = 0.9f; - settings.frequency_penalty = 0.5f; - settings.presence_penalty = 0.3f; - settings.n = 2; - settings.random_seed = 42; - settings.top_k = 10; - - auto response = client.CompleteChat(messages, settings); - - // Verify the request JSON contains the settings - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - - EXPECT_NEAR(0.7f, openAiReq["temperature"].get(), 0.001f); - EXPECT_EQ(100, openAiReq["max_completion_tokens"].get()); - EXPECT_NEAR(0.9f, openAiReq["top_p"].get(), 0.001f); - EXPECT_NEAR(0.5f, openAiReq["frequency_penalty"].get(), 0.001f); - EXPECT_NEAR(0.3f, openAiReq["presence_penalty"].get(), 0.001f); - EXPECT_EQ(2, openAiReq["n"].get()); - EXPECT_EQ(42, openAiReq["seed"].get()); - EXPECT_EQ(10, openAiReq["metadata"]["top_k"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_RequestFormat) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); - -std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; -ChatSettings settings; -auto response = client.CompleteChat(messages, settings); - - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - - EXPECT_EQ("chat-model", openAiReq["model"].get()); - EXPECT_FALSE(openAiReq["stream"].get()); - ASSERT_EQ(2u, openAiReq["messages"].size()); - EXPECT_EQ("system", openAiReq["messages"][0]["role"].get()); - EXPECT_EQ("user", openAiReq["messages"][1]["role"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChatStreaming) { -nlohmann::json chunk1 = { - {"created", 1700000000}, - {"id", "chatcmpl-1"}, - {"IsDelta", true}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hello"}}}}}}}; - nlohmann::json chunk2 = { - {"created", 1700000000}, - {"id", "chatcmpl-1"}, - {"IsDelta", true}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", {{{"index", 0}, {"finish_reason", "stop"}, {"delta", {{"content", " world"}}}}}}}; - - core_.OnCall("chat_completions", - [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { - if (callback && userData) { - std::string s1 = chunk1.dump(); - std::string s2 = chunk2.dump(); - callback(s1.data(), static_cast(s1.size()), userData); - callback(s2.data(), static_cast(s2.size()), userData); - } - return ""; - }); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "test", {}}}; - ChatSettings settings; - - std::vector chunks; - client.CompleteChatStreaming(messages, settings, - [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); - - ASSERT_EQ(2u, chunks.size()); - EXPECT_TRUE(chunks[0].is_delta); - ASSERT_TRUE(chunks[0].choices[0].delta.has_value()); - EXPECT_EQ("Hello", chunks[0].choices[0].delta->content); - EXPECT_EQ(" world", chunks[1].choices[0].delta->content); -} - -TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { - nlohmann::json chunk = { - {"created", 1700000000}, - {"id", "chatcmpl-1"}, - {"IsDelta", true}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}}}}; - - core_.OnCall("chat_completions", - [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { - if (callback && userData) { - std::string s = chunk.dump(); - callback(s.data(), static_cast(s.size()), userData); - } - return ""; - }); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "test", {}}}; - ChatSettings settings; - - EXPECT_THROW(client.CompleteChatStreaming( - messages, settings, - [](const ChatCompletionCreateResponse&) { throw std::runtime_error("callback error"); }), - std::runtime_error); -} - -TEST_F(OpenAIChatClientTest, Constructor_ThrowsIfNotLoaded) { - core_.OnCall("list_loaded_models", R"([])"); - auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(OpenAIChatClient client(variant), Exception); -} - -TEST_F(OpenAIChatClientTest, GetModelId) { - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - EXPECT_EQ("chat-model", client.GetModelId()); -} - -// ---------- Tool calling tests ---------- - -TEST_F(OpenAIChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); - - std::vector messages = {{"user", "What is 7 * 6?", {}}}; - - std::vector tools = {{ - "function", - FunctionDefinition{ - "multiply_numbers", - "A tool for multiplying two numbers.", - PropertyDefinition{ - "object", - std::nullopt, - std::unordered_map{ - {"first", PropertyDefinition{"integer", "The first number"}}, - {"second", PropertyDefinition{"integer", "The second number"}} - }, - std::vector{"first", "second"} - } - } - }}; - - ChatSettings settings; - settings.tool_choice = ToolChoiceKind::Required; - - auto response = client.CompleteChat(messages, tools, settings); - - // Verify the request JSON contains tools and tool_choice - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - - ASSERT_TRUE(openAiReq.contains("tools")); - ASSERT_TRUE(openAiReq["tools"].is_array()); - EXPECT_EQ(1u, openAiReq["tools"].size()); - EXPECT_EQ("function", openAiReq["tools"][0]["type"].get()); - EXPECT_EQ("multiply_numbers", openAiReq["tools"][0]["function"]["name"].get()); - EXPECT_EQ("A tool for multiplying two numbers.", openAiReq["tools"][0]["function"]["description"].get()); - EXPECT_EQ("object", openAiReq["tools"][0]["function"]["parameters"]["type"].get()); - EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"].contains("properties")); - EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("first")); - EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("second")); - - EXPECT_EQ("required", openAiReq["tool_choice"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { -core_.OnCall("chat_completions", MakeChatResponseJson()); -core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - -auto variant = MakeLoadedVariant(); -OpenAIChatClient client(variant); - - std::vector messages = {{"user", "Hello", {}}}; - ChatSettings settings; - auto response = client.CompleteChat(messages, settings); - - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - - EXPECT_FALSE(openAiReq.contains("tools")); - EXPECT_FALSE(openAiReq.contains("tool_choice")); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallResponse_Parsed) { - // Simulate a response with tool calls from the model - nlohmann::json resp = { - {"created", 1700000000}, - {"id", "chatcmpl-tool"}, - {"IsDelta", false}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, - {"finish_reason", "tool_calls"}, - {"message", - {{"role", "assistant"}, - {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, - {"tool_calls", - {{{"id", "call_1"}, - {"type", "function"}, - {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; - - core_.OnCall("chat_completions", resp.dump()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "What is 7 * 6?", {}}}; - ChatSettings settings; - auto response = client.CompleteChat(messages, settings); - - ASSERT_EQ(1u, response.choices.size()); - EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); - ASSERT_TRUE(response.choices[0].message.has_value()); - - const auto& msg = *response.choices[0].message; - ASSERT_EQ(1u, msg.tool_calls.size()); - EXPECT_EQ("call_1", msg.tool_calls[0].id); - EXPECT_EQ("function", msg.tool_calls[0].type); - ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); - EXPECT_EQ("multiply_numbers", msg.tool_calls[0].function_call->name); - EXPECT_EQ("{\"first\": 7, \"second\": 6}", msg.tool_calls[0].function_call->arguments); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceAuto) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "test", {}}}; - ChatSettings settings; - settings.tool_choice = ToolChoiceKind::Auto; - - client.CompleteChat(messages, settings); - - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - EXPECT_EQ("auto", openAiReq["tool_choice"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceNone) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "test", {}}}; - ChatSettings settings; - settings.tool_choice = ToolChoiceKind::None; - - client.CompleteChat(messages, settings); - - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - EXPECT_EQ("none", openAiReq["tool_choice"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { - core_.OnCall("chat_completions", MakeChatResponseJson()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - ChatMessage toolMsg; - toolMsg.role = "tool"; - toolMsg.content = "42"; - toolMsg.tool_call_id = "call_1"; - - std::vector messages = { - {"user", "What is 7 * 6?", {}}, - std::move(toolMsg) - }; - ChatSettings settings; - client.CompleteChat(messages, settings); - - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - - ASSERT_EQ(2u, openAiReq["messages"].size()); - EXPECT_FALSE(openAiReq["messages"][0].contains("tool_call_id")); - EXPECT_EQ("call_1", openAiReq["messages"][1]["tool_call_id"].get()); - EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { - nlohmann::json chunk1 = { - {"created", 1700000000}, - {"id", "chatcmpl-1"}, - {"IsDelta", true}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, - {"finish_reason", nullptr}, - {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; - nlohmann::json chunk2 = { - {"created", 1700000000}, - {"id", "chatcmpl-1"}, - {"IsDelta", true}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, - {"finish_reason", "tool_calls"}, - {"delta", - {{"content", ""}, - {"tool_calls", - {{{"id", "call_1"}, - {"type", "function"}, - {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; - - core_.OnCall("chat_completions", - [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { - if (callback && userData) { - std::string s1 = chunk1.dump(); - std::string s2 = chunk2.dump(); - callback(s1.data(), static_cast(s1.size()), userData); - callback(s2.data(), static_cast(s2.size()), userData); - } - return ""; - }); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "test", {}}}; - - std::vector tools = {{ - "function", - FunctionDefinition{"multiply", "Multiply numbers."} - }}; - - ChatSettings settings; - settings.tool_choice = ToolChoiceKind::Required; - - std::vector chunks; - client.CompleteChatStreaming(messages, tools, settings, - [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); - - ASSERT_EQ(2u, chunks.size()); - EXPECT_EQ(FinishReason::ToolCalls, chunks[1].choices[0].finish_reason); - ASSERT_TRUE(chunks[1].choices[0].delta.has_value()); - ASSERT_EQ(1u, chunks[1].choices[0].delta->tool_calls.size()); - EXPECT_EQ("multiply", chunks[1].choices[0].delta->tool_calls[0].function_call->name); - - // Verify tools were included in the request - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - ASSERT_TRUE(openAiReq.contains("tools")); - EXPECT_EQ("required", openAiReq["tool_choice"].get()); -} - -class OpenAIAudioClientTest : public ::testing::Test { -protected: - MockCore core_; - NullLogger logger_; - - ModelVariant MakeLoadedVariant(const std::string& name = "audio-model") { - core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); - return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); - } -}; - -TEST_F(OpenAIAudioClientTest, TranscribeAudio) { -core_.OnCall("audio_transcribe", "Hello world transcribed text"); -core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - -auto variant = MakeLoadedVariant(); -OpenAIAudioClient client(variant); - auto response = client.TranscribeAudio("test.wav"); - - EXPECT_EQ("Hello world transcribed text", response.text); -} - -TEST_F(OpenAIAudioClientTest, TranscribeAudio_RequestFormat) { -core_.OnCall("audio_transcribe", "text"); -core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - -auto variant = MakeLoadedVariant(); -OpenAIAudioClient client(variant); - client.TranscribeAudio("audio.wav"); - - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - EXPECT_EQ("audio-model", openAiReq["Model"].get()); - EXPECT_EQ("audio.wav", openAiReq["FileName"].get()); -} - -TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { - core_.OnCall("audio_transcribe", - [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { - if (callback && userData) { - std::string text1 = "Hello "; - std::string text2 = "world!"; - callback(text1.data(), static_cast(text1.size()), userData); - callback(text2.data(), static_cast(text2.size()), userData); - } - return ""; - }); - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(variant); - - std::vector chunks; - client.TranscribeAudioStreaming( - "test.wav", [&](const AudioCreateTranscriptionResponse& chunk) { chunks.push_back(chunk.text); }); - - ASSERT_EQ(2u, chunks.size()); - EXPECT_EQ("Hello ", chunks[0]); - EXPECT_EQ("world!", chunks[1]); -} - -TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { - core_.OnCall("audio_transcribe", - [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { - if (callback && userData) { - std::string text = "test"; - callback(text.data(), static_cast(text.size()), userData); - } - return ""; - }); - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(variant); - - EXPECT_THROW( - client.TranscribeAudioStreaming( - "test.wav", [](const AudioCreateTranscriptionResponse&) { throw std::runtime_error("streaming error"); }), - std::runtime_error); -} - -TEST_F(OpenAIAudioClientTest, Constructor_ThrowsIfNotLoaded) { - core_.OnCall("list_loaded_models", R"([])"); - auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); - EXPECT_THROW(OpenAIAudioClient client(variant), FoundryLocalException); -} - -TEST_F(OpenAIAudioClientTest, GetModelId) { - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(variant); - EXPECT_EQ("audio-model", client.GetModelId()); -} - -TEST_F(OpenAIAudioClientTest, TranscribeAudio_CoreError_Throws) { - core_.OnCallThrow("audio_transcribe", "transcription failed"); - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(variant); - - EXPECT_THROW(client.TranscribeAudio("test.wav"), Exception); -} - -TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_CoreError_Throws) { - core_.OnCallThrow("audio_transcribe", "streaming transcription failed"); - core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIAudioClient client(variant); - - EXPECT_THROW( - client.TranscribeAudioStreaming("test.wav", [](const AudioCreateTranscriptionResponse&) {}), - Exception); -} - -// ===================================================================== -// Multi-turn conversation tests -// ===================================================================== - -TEST_F(OpenAIChatClientTest, CompleteChat_MultiTurn) { - // First turn: user asks a question - core_.OnCall("chat_completions", MakeChatResponseJson("42")); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = { - {"user", "What is 7 * 6?", {}} - }; - ChatSettings settings; - auto response = client.CompleteChat(messages, settings); - - ASSERT_TRUE(response.successful); - ASSERT_EQ(1u, response.choices.size()); - EXPECT_EQ("42", response.choices[0].message->content); - - // Second turn: add assistant response + user follow-up - messages.push_back({"assistant", response.choices[0].message->content, {}}); - messages.push_back({"user", "Is that a real number?", {}}); - - core_.OnCall("chat_completions", MakeChatResponseJson("Yes")); - auto response2 = client.CompleteChat(messages, settings); - - ASSERT_TRUE(response2.successful); - EXPECT_EQ("Yes", response2.choices[0].message->content); - - // Verify the second request contained all 3 messages - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - ASSERT_EQ(3u, openAiReq["messages"].size()); - EXPECT_EQ("user", openAiReq["messages"][0]["role"].get()); - EXPECT_EQ("assistant", openAiReq["messages"][1]["role"].get()); - EXPECT_EQ("user", openAiReq["messages"][2]["role"].get()); -} - -TEST_F(OpenAIChatClientTest, CompleteChat_CoreError_Throws) { - core_.OnCallThrow("chat_completions", "inference failed"); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "Hello", {}}}; - ChatSettings settings; - - EXPECT_THROW(client.CompleteChat(messages, settings), Exception); -} - -TEST_F(OpenAIChatClientTest, CompleteChatStreaming_CoreError_Throws) { - core_.OnCallThrow("chat_completions", "streaming failed"); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = {{"user", "Hello", {}}}; - ChatSettings settings; - - EXPECT_THROW( - client.CompleteChatStreaming(messages, settings, - [](const ChatCompletionCreateResponse&) {}), - Exception); -} - -// ===================================================================== -// Full tool-call round-trip -// ===================================================================== - -TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { - // Step 1: model returns a tool call - nlohmann::json toolCallResp = { - {"created", 1700000000}, - {"id", "chatcmpl-tool"}, - {"IsDelta", false}, - {"Successful", true}, - {"HttpStatusCode", 200}, - {"choices", - {{{"index", 0}, - {"finish_reason", "tool_calls"}, - {"message", - {{"role", "assistant"}, - {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": 6}}]"}, - {"tool_calls", - {{{"id", "call_1"}, - {"type", "function"}, - {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; - - core_.OnCall("chat_completions", toolCallResp.dump()); - core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); - - auto variant = MakeLoadedVariant(); - OpenAIChatClient client(variant); - - std::vector messages = { - {"system", "You are a helpful AI assistant.", {}}, - {"user", "What is 7 multiplied by 6?", {}} - }; - - std::vector tools = {{ - "function", - FunctionDefinition{ - "multiply_numbers", - "A tool for multiplying two numbers.", - PropertyDefinition{ - "object", - std::nullopt, - std::unordered_map{ - {"first", PropertyDefinition{"integer", "The first number"}}, - {"second", PropertyDefinition{"integer", "The second number"}} - }, - std::vector{"first", "second"} - } - } - }}; - - ChatSettings settings; - settings.tool_choice = ToolChoiceKind::Required; - - auto response = client.CompleteChat(messages, tools, settings); - - ASSERT_EQ(1u, response.choices.size()); - EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); - ASSERT_EQ(1u, response.choices[0].message->tool_calls.size()); - EXPECT_EQ("multiply_numbers", response.choices[0].message->tool_calls[0].function_call->name); - - // Step 2: send tool response back, model continues with the answer - messages.push_back({"assistant", response.choices[0].message->content, {}}); - - ChatMessage toolMsg; - toolMsg.role = "tool"; - toolMsg.content = "7 x 6 = 42."; - toolMsg.tool_call_id = "call_1"; - messages.push_back(std::move(toolMsg)); - - messages.push_back({"system", "Respond only with the answer generated by the tool.", {}}); - - core_.OnCall("chat_completions", MakeChatResponseJson("42")); - settings.tool_choice = ToolChoiceKind::Auto; - - auto response2 = client.CompleteChat(messages, tools, settings); - - ASSERT_TRUE(response2.successful); - EXPECT_EQ("42", response2.choices[0].message->content); - - // Verify the second request contained tool response message - auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); - auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); - - // 5 messages: system, user, assistant (tool_call), tool, system (continue) - ASSERT_EQ(5u, openAiReq["messages"].size()); - EXPECT_EQ("tool", openAiReq["messages"][3]["role"].get()); - EXPECT_EQ("call_1", openAiReq["messages"][3]["tool_call_id"].get()); - EXPECT_EQ("auto", openAiReq["tool_choice"].get()); -} diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h index 7786b923..25f04699 100644 --- a/sdk/cpp/include/model.h +++ b/sdk/cpp/include/model.h @@ -15,8 +15,11 @@ #include #include "logger.h" -#include "openai/openai_chat_client.h" -#include "openai/openai_audio_client.h" + +namespace foundry_local { + class OpenAIChatClient; + class OpenAIAudioClient; +} namespace foundry_local::Internal { struct IFoundryLocalCore; @@ -125,12 +128,6 @@ namespace foundry_local { void Unload() override; void RemoveFromCache() override; - [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] - OpenAIAudioClient GetAudioClient() const; - - [[deprecated("Use OpenAIChatClient(model) constructor instead")]] - OpenAIChatClient GetChatClient() const; - const std::string& GetId() const noexcept override; const std::string& GetAlias() const noexcept override; uint32_t GetVersion() const noexcept; @@ -169,15 +166,6 @@ namespace foundry_local { void Load() override { SelectedVariant().Load(); } void Unload() override { SelectedVariant().Unload(); } void RemoveFromCache() override { SelectedVariant().RemoveFromCache(); } - [[deprecated("Use OpenAIAudioClient(model) constructor instead")]] - OpenAIAudioClient GetAudioClient() const { - return SelectedVariant().GetAudioClient(); - } - - [[deprecated("Use OpenAIChatClient(model) constructor instead")]] - OpenAIChatClient GetChatClient() const { - return SelectedVariant().GetChatClient(); - } const std::string& GetId() const override; const std::string& GetAlias() const override; diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h index 79acccc9..ac1ce719 100644 --- a/sdk/cpp/include/openai/openai_audio_client.h +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -43,7 +43,4 @@ namespace foundry_local { gsl::not_null logger_; }; - /// Backward-compatible alias. - using AudioClient = OpenAIAudioClient; - } // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h index 788a820d..084e55d7 100644 --- a/sdk/cpp/include/openai/openai_chat_client.h +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -114,7 +114,4 @@ namespace foundry_local { gsl::not_null logger_; }; - /// Backward-compatible alias. - using ChatClient = OpenAIChatClient; - } // namespace foundry_local diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index 17b43021..1631bcda 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -141,14 +141,6 @@ namespace foundry_local { return {core_, info_.name, logger_}; } - OpenAIAudioClient ModelVariant::GetAudioClient() const { - return OpenAIAudioClient(*this); - } - - OpenAIChatClient ModelVariant::GetChatClient() const { - return OpenAIChatClient(*this); - } - /// Model Model::Model(gsl::not_null core, gsl::not_null logger) diff --git a/sdk/cpp/test/e2e_test.cpp b/sdk/cpp/test/e2e_test.cpp index 56625902..50f1afbc 100644 --- a/sdk/cpp/test/e2e_test.cpp +++ b/sdk/cpp/test/e2e_test.cpp @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. // // End-to-end tests that exercise the public API with the real Core DLL. From 7e60423c23a3f34f04f513460b00b368724a97c6 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 2 Apr 2026 15:06:30 -0700 Subject: [PATCH 17/18] Copilot fixes 1 --- sdk/cpp/CMakeLists.txt | 127 +++++++++++++++++++----------------- sdk/cpp/include/catalog.h | 2 +- sdk/cpp/src/catalog.cpp | 11 ++-- sdk/cpp/src/core_helpers.h | 4 +- sdk/cpp/src/flcore_native.h | 1 + sdk/cpp/src/parser.h | 1 - sdk/cpp/vcpkg.json | 4 +- 7 files changed, 79 insertions(+), 71 deletions(-) diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt index bc3f4844..7e32b7fb 100644 --- a/sdk/cpp/CMakeLists.txt +++ b/sdk/cpp/CMakeLists.txt @@ -40,7 +40,10 @@ add_compile_definitions(_WIN32_WINNT=0x0A00 WINVER=0x0A00) find_package(nlohmann_json CONFIG REQUIRED) find_package(wil CONFIG REQUIRED) find_package(Microsoft.GSL CONFIG REQUIRED) -find_package(GTest CONFIG REQUIRED) +option(BUILD_TESTING "Build unit and end-to-end tests" ON) +if (BUILD_TESTING) + find_package(GTest CONFIG REQUIRED) +endif() # ----------------------------- # SDK library (STATIC) @@ -80,66 +83,68 @@ target_link_libraries(CppSdkSample PRIVATE CppSdk) # ----------------------------- # Unit tests # ----------------------------- -enable_testing() - -add_executable(CppSdkTests - test/parser_and_types_test.cpp - test/model_variant_test.cpp - test/catalog_test.cpp - test/client_test.cpp -) - -target_include_directories(CppSdkTests - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/test - ${CMAKE_CURRENT_SOURCE_DIR}/src -) - -target_compile_definitions(CppSdkTests PRIVATE FL_TESTS) - -target_link_libraries(CppSdkTests - PRIVATE - CppSdk - GTest::gtest_main -) - -# Copy testdata files next to the test executable so file-based tests can find them. -add_custom_command(TARGET CppSdkTests POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_directory - ${CMAKE_CURRENT_SOURCE_DIR}/test/testdata - $/testdata -) - -include(GoogleTest) -gtest_discover_tests(CppSdkTests - WORKING_DIRECTORY $ -) - -# ----------------------------- -# End-to-end tests (separate executable, requires Core DLL) -# Exercises the full public API against the real catalog. -# Tests that need model download are DISABLED by default; -# run with --gtest_also_run_disabled_tests locally. -# ----------------------------- -add_executable(CppSdkE2ETests - test/e2e_test.cpp -) - -target_include_directories(CppSdkE2ETests - PRIVATE - ${CMAKE_CURRENT_SOURCE_DIR}/test - ${CMAKE_CURRENT_SOURCE_DIR}/src -) - -target_link_libraries(CppSdkE2ETests - PRIVATE - CppSdk - GTest::gtest_main -) - -gtest_discover_tests(CppSdkE2ETests - WORKING_DIRECTORY $ -) +if (BUILD_TESTING) + enable_testing() + + add_executable(CppSdkTests + test/parser_and_types_test.cpp + test/model_variant_test.cpp + test/catalog_test.cpp + test/client_test.cpp + ) + + target_include_directories(CppSdkTests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + + target_compile_definitions(CppSdkTests PRIVATE FL_TESTS) + + target_link_libraries(CppSdkTests + PRIVATE + CppSdk + GTest::gtest_main + ) + + # Copy testdata files next to the test executable so file-based tests can find them. + add_custom_command(TARGET CppSdkTests POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/test/testdata + $/testdata + ) + + include(GoogleTest) + gtest_discover_tests(CppSdkTests + WORKING_DIRECTORY $ + ) + + # ----------------------------- + # End-to-end tests (separate executable, requires Core DLL) + # Exercises the full public API against the real catalog. + # Tests that need model download are DISABLED by default; + # run with --gtest_also_run_disabled_tests locally. + # ----------------------------- + add_executable(CppSdkE2ETests + test/e2e_test.cpp + ) + + target_include_directories(CppSdkE2ETests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + + target_link_libraries(CppSdkE2ETests + PRIVATE + CppSdk + GTest::gtest_main + ) + + gtest_discover_tests(CppSdkE2ETests + WORKING_DIRECTORY $ + ) +endif() # Make Visual Studio start/debug this target by default set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} diff --git a/sdk/cpp/include/catalog.h b/sdk/cpp/include/catalog.h index d57137b7..2e32e3f8 100644 --- a/sdk/cpp/include/catalog.h +++ b/sdk/cpp/include/catalog.h @@ -49,7 +49,7 @@ namespace foundry_local { private: struct CatalogState { std::unordered_map byAlias; - std::unordered_map modelIdToModelVariant; + std::unordered_map modelIdToModelVariant; std::chrono::steady_clock::time_point lastFetch{}; }; diff --git a/sdk/cpp/src/catalog.cpp b/sdk/cpp/src/catalog.cpp index 836cb285..b67c1a30 100644 --- a/sdk/cpp/src/catalog.cpp +++ b/sdk/cpp/src/catalog.cpp @@ -104,15 +104,16 @@ namespace foundry_local { ModelInfo modelVariantInfo; from_json(j, modelVariantInfo); - std::string variantId = modelVariantInfo.id; ModelVariant modelVariant(core_, modelVariantInfo, logger_); - newState->modelIdToModelVariant.emplace(variantId, modelVariant); - it->second.variants_.emplace_back(std::move(modelVariant)); } - // Auto-select the first variant for each model. + // Build the lookup map from pointers into the owning Model::variants_ vectors, + // and auto-select the first variant for each model. for (auto& [alias, model] : newState->byAlias) { + for (auto& variant : model.variants_) { + newState->modelIdToModelVariant.emplace(variant.GetId(), &variant); + } if (!model.variants_.empty()) { model.selectedVariant_ = &model.variants_.front(); } @@ -132,7 +133,7 @@ namespace foundry_local { auto state = GetState(); auto it = state->modelIdToModelVariant.find(std::string(id)); if (it != state->modelIdToModelVariant.end()) { - return const_cast(&it->second); + return it->second; } return nullptr; } diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h index 76d52ed3..67bd40d5 100644 --- a/sdk/cpp/src/core_helpers.h +++ b/sdk/cpp/src/core_helpers.h @@ -130,14 +130,14 @@ namespace foundry_local::detail { } inline std::vector CollectVariantsByIds( - const std::unordered_map& modelIdToModelVariant, std::vector ids) { + const std::unordered_map& modelIdToModelVariant, std::vector ids) { std::vector out; out.reserve(ids.size()); for (const auto& id : ids) { auto it = modelIdToModelVariant.find(id); if (it != modelIdToModelVariant.end()) { - out.emplace_back(const_cast(&it->second)); + out.emplace_back(it->second); } } return out; diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h index c62ca192..b0778116 100644 --- a/sdk/cpp/src/flcore_native.h +++ b/sdk/cpp/src/flcore_native.h @@ -3,6 +3,7 @@ #pragma once #include +#include extern "C" { diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 6c568374..8d12411d 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -106,7 +106,6 @@ namespace foundry_local { inline void from_json(const nlohmann::json& j, Runtime& r) { std::string deviceType; - std::string executionProvider; j.at("deviceType").get_to(deviceType); j.at("executionProvider").get_to(r.execution_provider); diff --git a/sdk/cpp/vcpkg.json b/sdk/cpp/vcpkg.json index ec08c349..c4511497 100644 --- a/sdk/cpp/vcpkg.json +++ b/sdk/cpp/vcpkg.json @@ -4,7 +4,9 @@ "dependencies": [ "nlohmann-json", "wil", - "ms-gsl", + "ms-gsl" + ], + "dev-dependencies": [ "gtest" ] } From 7a8e53f8412f6827e8830acba6cc05602f04a41f Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 2 Apr 2026 19:04:55 -0700 Subject: [PATCH 18/18] Copilot fixes 2 --- sdk/cpp/include/openai/openai_chat_client.h | 1 - sdk/cpp/sample/main.cpp | 12 ++--- sdk/cpp/src/model.cpp | 5 +- sdk/cpp/src/openai_chat_client.cpp | 9 ++++ sdk/cpp/src/parser.h | 15 ++++++ sdk/cpp/test/client_test.cpp | 52 +++++++++++++++++++++ sdk/cpp/test/e2e_test.cpp | 1 + sdk/cpp/test/model_variant_test.cpp | 15 +++++- 8 files changed, 100 insertions(+), 10 deletions(-) diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h index 084e55d7..c16b9481 100644 --- a/sdk/cpp/include/openai/openai_chat_client.h +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #pragma once diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp index 0e6b92e4..7b014c4a 100644 --- a/sdk/cpp/sample/main.cpp +++ b/sdk/cpp/sample/main.cpp @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include #include "foundry_local.h" #include @@ -293,12 +292,13 @@ void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) } // ── Step 5: Feed the tool result back ───────────────────────────── - // Add the assistant's message (including the raw tool_call content) - // and then a "tool" message with the result. - messages.push_back({"tool", toolResult}); + // First, append the assistant message that contains the tool_calls + // so the model sees its own request in the conversation history. + messages.push_back({"assistant", "", std::nullopt, firstChoice.message->tool_calls}); - // Add a follow-up system instruction so the model uses the tool output. - messages.push_back({"system", "Respond only with the answer generated by the tool."}); + // Then add a "tool" message with the result, referencing the + // tool_call_id so the model can match it to the call it made. + messages.push_back({"tool", toolResult, tc.id}); // Switch to Auto so the model can answer without calling tools again. settings.tool_choice = ToolChoiceKind::Auto; diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp index 1631bcda..880e91e1 100644 --- a/sdk/cpp/src/model.cpp +++ b/sdk/cpp/src/model.cpp @@ -187,8 +187,9 @@ namespace foundry_local { } void Model::SelectVariant(const ModelVariant& variant) const { - auto it = - std::find_if(variants_.begin(), variants_.end(), [&](const ModelVariant& v) { return &v == &variant; }); + const auto& targetId = variant.GetId(); + auto it = std::find_if(variants_.begin(), variants_.end(), + [&](const ModelVariant& v) { return v.GetId() == targetId; }); if (it == variants_.end()) { throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); diff --git a/sdk/cpp/src/openai_chat_client.cpp b/sdk/cpp/src/openai_chat_client.cpp index 84dc1a24..5c19a0ba 100644 --- a/sdk/cpp/src/openai_chat_client.cpp +++ b/sdk/cpp/src/openai_chat_client.cpp @@ -46,6 +46,15 @@ namespace foundry_local { nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; if (msg.tool_call_id) jMsg["tool_call_id"] = *msg.tool_call_id; + if (!msg.tool_calls.empty()) { + nlohmann::json jToolCalls = nlohmann::json::array(); + for (const auto& tc : msg.tool_calls) { + nlohmann::json jtc; + to_json(jtc, tc); + jToolCalls.push_back(std::move(jtc)); + } + jMsg["tool_calls"] = std::move(jToolCalls); + } jMessages.push_back(std::move(jMsg)); } diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h index 8d12411d..3596579c 100644 --- a/sdk/cpp/src/parser.h +++ b/sdk/cpp/src/parser.h @@ -227,6 +227,21 @@ namespace foundry_local { j["function"] = std::move(fj); } + // ---------- Tool calling: to_json for response types (needed for multi-turn serialization) ---------- + + inline void to_json(nlohmann::json& j, const FunctionCall& fc) { + j = nlohmann::json{{"name", fc.name}, {"arguments", fc.arguments}}; + } + + inline void to_json(nlohmann::json& j, const ToolCall& tc) { + j = nlohmann::json{{"id", tc.id}, {"type", tc.type}}; + if (tc.function_call) { + nlohmann::json fj; + to_json(fj, *tc.function_call); + j["function"] = std::move(fj); + } + } + // ---------- Tool calling: from_json (deserialization from responses) ---------- inline void from_json(const nlohmann::json& j, FunctionCall& fc) { diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp index 201fd965..9864ec99 100644 --- a/sdk/cpp/test/client_test.cpp +++ b/sdk/cpp/test/client_test.cpp @@ -366,6 +366,58 @@ TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); } +TEST_F(OpenAIChatClientTest, CompleteChat_AssistantToolCallsSerialized) { + // Multi-turn tool calling: the assistant message with tool_calls must be sent back + // alongside the tool result message for the model to match the tool response. + core_.OnCall("chat_completions", MakeChatResponseJson("The answer is 42.")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + ChatMessage assistantMsg; + assistantMsg.role = "assistant"; + assistantMsg.content = ""; + assistantMsg.tool_calls = { + {"call_1", "function", FunctionCall{"multiply_numbers", "{\"first\": 7, \"second\": 6}"}}}; + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "42"; + toolMsg.tool_call_id = "call_1"; + + std::vector messages = { + {"user", "What is 7 * 6?", {}}, std::move(assistantMsg), std::move(toolMsg)}; + ChatSettings settings; + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_EQ(3u, openAiReq["messages"].size()); + + // User message: no tool_calls + EXPECT_FALSE(openAiReq["messages"][0].contains("tool_calls")); + + // Assistant message: must include tool_calls + const auto& assistantJson = openAiReq["messages"][1]; + EXPECT_EQ("assistant", assistantJson["role"].get()); + ASSERT_TRUE(assistantJson.contains("tool_calls")); + ASSERT_TRUE(assistantJson["tool_calls"].is_array()); + ASSERT_EQ(1u, assistantJson["tool_calls"].size()); + EXPECT_EQ("call_1", assistantJson["tool_calls"][0]["id"].get()); + EXPECT_EQ("function", assistantJson["tool_calls"][0]["type"].get()); + EXPECT_EQ("multiply_numbers", assistantJson["tool_calls"][0]["function"]["name"].get()); + EXPECT_EQ("{\"first\": 7, \"second\": 6}", + assistantJson["tool_calls"][0]["function"]["arguments"].get()); + + // Tool message: must include tool_call_id + const auto& toolJson = openAiReq["messages"][2]; + EXPECT_EQ("tool", toolJson["role"].get()); + EXPECT_EQ("call_1", toolJson["tool_call_id"].get()); + EXPECT_FALSE(toolJson.contains("tool_calls")); +} + TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { nlohmann::json chunk1 = { {"created", 1700000000}, diff --git a/sdk/cpp/test/e2e_test.cpp b/sdk/cpp/test/e2e_test.cpp index 50f1afbc..4bf33348 100644 --- a/sdk/cpp/test/e2e_test.cpp +++ b/sdk/cpp/test/e2e_test.cpp @@ -9,6 +9,7 @@ #include "foundry_local.h" +#include #include #include #include diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp index 060f629c..023112d9 100644 --- a/sdk/cpp/test/model_variant_test.cpp +++ b/sdk/cpp/test/model_variant_test.cpp @@ -217,10 +217,23 @@ TEST_F(ModelTest, SelectVariant_NotFound_Throws) { Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); Factory::SelectFirstVariant(model); - auto external = MakeVariant("external", "alias", 1); + auto external = MakeVariant("external", "ext-alias", 1); EXPECT_THROW(model.SelectVariant(external), Exception); } +TEST_F(ModelTest, SelectVariant_ByIdFromExternalInstance) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SelectFirstVariant(model); + + // Simulate a variant obtained externally (e.g. from Catalog::GetModelVariant) + // with the same id as v2 but a different object instance. + auto externalV2 = MakeVariant("v2", "alias", 2); + model.SelectVariant(externalV2); + EXPECT_EQ("v2:2", model.GetId()); +} + TEST_F(ModelTest, GetLatestVariant) { auto model = MakeModel(); Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1));