diff --git a/sdk/cpp/.clang-format b/sdk/cpp/.clang-format new file mode 100644 index 00000000..751f30aa --- /dev/null +++ b/sdk/cpp/.clang-format @@ -0,0 +1,47 @@ +--- +Language: Cpp +BasedOnStyle: Microsoft + +# Match the existing project style +Standard: c++17 +ColumnLimit: 120 + +# Indentation +IndentWidth: 4 +TabWidth: 4 +UseTab: Never +AccessModifierOffset: -4 +IndentCaseLabels: false +NamespaceIndentation: All + +# Braces +BreakBeforeBraces: Custom +BraceWrapping: + AfterCaseLabel: false + AfterClass: false + AfterControlStatement: Never + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterStruct: false + BeforeCatch: true + BeforeElse: true + IndentBraces: false + +# Alignment +AlignAfterOpenBracket: Align +AlignOperands: Align +AlignTrailingComments: true + +# Includes +SortIncludes: false +IncludeBlocks: Preserve + +# Misc +AllowShortFunctionsOnASingleLine: Inline +AllowShortIfStatementsOnASingleLine: Never +AllowShortLoopsOnASingleLine: false +AllowShortBlocksOnASingleLine: Empty +PointerAlignment: Left +SpaceAfterCStyleCast: false +SpaceBeforeParens: ControlStatements diff --git a/sdk/cpp/CMakeLists.txt b/sdk/cpp/CMakeLists.txt new file mode 100644 index 00000000..7e32b7fb --- /dev/null +++ b/sdk/cpp/CMakeLists.txt @@ -0,0 +1,151 @@ +cmake_minimum_required(VERSION 3.20) + +# VS hot reload policy (safe-guarded) +if (POLICY CMP0141) + cmake_policy(SET CMP0141 NEW) + if (MSVC) + set(CMAKE_MSVC_DEBUG_INFORMATION_FORMAT + "$<$:ProgramDatabase>") + endif() +endif() + +project(CppSdk LANGUAGES CXX) + +# ----------------------------- +# Windows-only + compiler guard +# ----------------------------- +if (NOT WIN32) + message(FATAL_ERROR "CppSdk is Windows-only for now (uses Win32/WIL headers).") +endif() + +# Accept MSVC OR clang-cl (Clang in MSVC compatibility mode). +# VS CMake Open-Folder often uses clang-cl by default. +if (NOT (MSVC OR (CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND CMAKE_CXX_SIMULATE_ID STREQUAL "MSVC"))) + message(STATUS "CMAKE_CXX_COMPILER_ID = ${CMAKE_CXX_COMPILER_ID}") + message(STATUS "CMAKE_CXX_COMPILER = ${CMAKE_CXX_COMPILER}") + message(STATUS "CMAKE_CXX_SIMULATE_ID = ${CMAKE_CXX_SIMULATE_ID}") + message(FATAL_ERROR "Need MSVC or clang-cl (MSVC-compatible toolchain).") +endif() + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_CXX_EXTENSIONS OFF) + +# Optional: target Windows 10+ APIs (adjust if you need older) +add_compile_definitions(_WIN32_WINNT=0x0A00 WINVER=0x0A00) + +# ----------------------------- +# Dependencies (installed via vcpkg) +# ----------------------------- +find_package(nlohmann_json CONFIG REQUIRED) +find_package(wil CONFIG REQUIRED) +find_package(Microsoft.GSL CONFIG REQUIRED) +option(BUILD_TESTING "Build unit and end-to-end tests" ON) +if (BUILD_TESTING) + find_package(GTest CONFIG REQUIRED) +endif() + +# ----------------------------- +# SDK library (STATIC) +# List ONLY .cpp files here. +# ----------------------------- +add_library(CppSdk STATIC + src/model.cpp + src/catalog.cpp + src/openai_chat_client.cpp + src/openai_audio_client.cpp + src/foundry_local_manager.cpp +) + +target_include_directories(CppSdk + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/include + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src +) + +target_link_libraries(CppSdk + PUBLIC + nlohmann_json::nlohmann_json + Microsoft.GSL::GSL + WIL::WIL +) + +# ----------------------------- +# Sample executable +# ----------------------------- +add_executable(CppSdkSample + sample/main.cpp +) + +target_link_libraries(CppSdkSample PRIVATE CppSdk) + +# ----------------------------- +# Unit tests +# ----------------------------- +if (BUILD_TESTING) + enable_testing() + + add_executable(CppSdkTests + test/parser_and_types_test.cpp + test/model_variant_test.cpp + test/catalog_test.cpp + test/client_test.cpp + ) + + target_include_directories(CppSdkTests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + + target_compile_definitions(CppSdkTests PRIVATE FL_TESTS) + + target_link_libraries(CppSdkTests + PRIVATE + CppSdk + GTest::gtest_main + ) + + # Copy testdata files next to the test executable so file-based tests can find them. + add_custom_command(TARGET CppSdkTests POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_directory + ${CMAKE_CURRENT_SOURCE_DIR}/test/testdata + $/testdata + ) + + include(GoogleTest) + gtest_discover_tests(CppSdkTests + WORKING_DIRECTORY $ + ) + + # ----------------------------- + # End-to-end tests (separate executable, requires Core DLL) + # Exercises the full public API against the real catalog. + # Tests that need model download are DISABLED by default; + # run with --gtest_also_run_disabled_tests locally. + # ----------------------------- + add_executable(CppSdkE2ETests + test/e2e_test.cpp + ) + + target_include_directories(CppSdkE2ETests + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/test + ${CMAKE_CURRENT_SOURCE_DIR}/src + ) + + target_link_libraries(CppSdkE2ETests + PRIVATE + CppSdk + GTest::gtest_main + ) + + gtest_discover_tests(CppSdkE2ETests + WORKING_DIRECTORY $ + ) +endif() + +# Make Visual Studio start/debug this target by default +set_property(DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} + PROPERTY VS_STARTUP_PROJECT CppSdkSample) diff --git a/sdk/cpp/CMakePresets.json b/sdk/cpp/CMakePresets.json new file mode 100644 index 00000000..ddead1b2 --- /dev/null +++ b/sdk/cpp/CMakePresets.json @@ -0,0 +1,109 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "windows-base", + "hidden": true, + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "toolchainFile": "$env{VCPKG_ROOT}/scripts/buildsystems/vcpkg.cmake", + "cacheVariables": { + "CMAKE_C_COMPILER": "cl.exe", + "CMAKE_CXX_COMPILER": "cl.exe", + "VCPKG_OVERLAY_TRIPLETS": "${sourceDir}/triplets" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Windows" + } + }, + { + "name": "x64-debug", + "displayName": "MSVC x64 Debug", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug", + "VCPKG_TARGET_TRIPLET": "x64-windows-static-md" + } + }, + { + "name": "x64-release", + "displayName": "MSVC x64 Release", + "inherits": "windows-base", + "architecture": { + "value": "x64", + "strategy": "external" + }, + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "VCPKG_TARGET_TRIPLET": "x64-windows-static-md" + } + }, + { + "name": "linux-debug", + "displayName": "Linux Debug", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Linux" + } + }, + { + "name": "macos-debug", + "displayName": "macOS Debug", + "generator": "Ninja", + "binaryDir": "${sourceDir}/out/build/${presetName}", + "installDir": "${sourceDir}/out/install/${presetName}", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Debug" + }, + "condition": { + "type": "equals", + "lhs": "${hostSystemName}", + "rhs": "Darwin" + } + } + ], + "buildPresets": [ + { + "name": "x64-debug", + "configurePreset": "x64-debug", + "displayName": "MSVC x64 Debug Build" + }, + { + "name": "x64-release", + "configurePreset": "x64-release", + "displayName": "MSVC x64 Release Build" + } + ], + "testPresets": [ + { + "name": "x64-debug", + "configurePreset": "x64-debug", + "displayName": "MSVC x64 Debug Tests", + "output": { + "outputOnFailure": true + } + }, + { + "name": "x64-release", + "configurePreset": "x64-release", + "displayName": "MSVC x64 Release Tests", + "output": { + "outputOnFailure": true + } + } + ] +} diff --git a/sdk/cpp/include/catalog.h b/sdk/cpp/include/catalog.h new file mode 100644 index 00000000..2e32e3f8 --- /dev/null +++ b/sdk/cpp/include/catalog.h @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "model.h" + +namespace foundry_local::Internal { + struct IFoundryLocalCore; +} + +namespace foundry_local { +#ifdef FL_TESTS + namespace Testing { + struct MockObjectFactory; + } +#endif + + class Catalog final { + public: + Catalog(const Catalog&) = delete; + Catalog& operator=(const Catalog&) = delete; + Catalog(Catalog&&) = delete; + Catalog& operator=(Catalog&&) = delete; + + static std::unique_ptr Create(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + const std::string& GetName() const { return name_; } + std::vector ListModels() const; + std::vector GetLoadedModels() const; + std::vector GetCachedModels() const; + + Model* GetModel(std::string_view modelId) const; + ModelVariant* GetModelVariant(std::string_view modelVariantId) const; + + private: + struct CatalogState { + std::unordered_map byAlias; + std::unordered_map modelIdToModelVariant; + std::chrono::steady_clock::time_point lastFetch{}; + }; + + void UpdateModels() const; + std::shared_ptr GetState() const; + + mutable std::mutex mutex_; + mutable std::shared_ptr state_; + + explicit Catalog(gsl::not_null injected, + gsl::not_null logger); + + gsl::not_null core_; + std::string name_; + gsl::not_null logger_; + + friend class FoundryLocalManager; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/configuration.h b/sdk/cpp/include/configuration.h new file mode 100644 index 00000000..21c40473 --- /dev/null +++ b/sdk/cpp/include/configuration.h @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include +#include "log_level.h" + +namespace foundry_local { + + /// Optional configuration for the built-in web service. + struct WebServiceConfig { + // URL/s to bind the web service to. + // Default: 127.0.0.1:0 (random ephemeral port). + // Multiple URLs can be specified as a semicolon-separated list. + std::optional urls; + + // If the web service is running in a separate process, provide its URL here. + std::optional external_url; + }; + + struct Configuration { + // Construct a Configuration with just an application name. + // All other fields use their defaults. + Configuration(std::string name) : app_name(std::move(name)) {} + + // Your application name. MUST be set to a valid name. + std::string app_name; + + // Application data directory. + // Default: {home}/.{appname}, where {home} is the user's home directory and {appname} is the app_name value. + std::optional app_data_dir; + + // Model cache directory. + // Default: {appdata}/cache/models, where {appdata} is the app_data_dir value. + std::optional model_cache_dir; + + // Log directory. + // Default: {appdata}/logs + std::optional logs_dir; + + // Logging level. + // Valid values are: Verbose, Debug, Information, Warning, Error, Fatal. + // Default: LogLevel.Warning + LogLevel log_level = LogLevel::Warning; + + // Optional web service configuration. + std::optional web; + + // Additional settings that Foundry Local Core can consume. + std::optional> additional_settings; + + void Validate() const { + if (app_name.empty()) { + throw std::invalid_argument("Configuration app_name must be set to a valid application name."); + } + + constexpr std::string_view invalidChars = R"(\/:?\"<>|)"; + if (app_name.find_first_of(invalidChars) != std::string::npos) { + throw std::invalid_argument("Configuration app_name value contains invalid characters."); + } + } + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/foundry_local.h b/sdk/cpp/include/foundry_local.h new file mode 100644 index 00000000..c16337e1 --- /dev/null +++ b/sdk/cpp/include/foundry_local.h @@ -0,0 +1,18 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Umbrella header – includes every public header for convenience. +// Consumers may also include individual headers directly. + +#pragma once + +#include "configuration.h" +#include "foundry_local_exception.h" +#include "log_level.h" +#include "logger.h" +#include "model.h" +#include "catalog.h" +#include "foundry_local_manager.h" +#include "openai/openai_tool_types.h" +#include "openai/openai_chat_client.h" +#include "openai/openai_audio_client.h" diff --git a/sdk/cpp/include/foundry_local_exception.h b/sdk/cpp/include/foundry_local_exception.h new file mode 100644 index 00000000..1dba9119 --- /dev/null +++ b/sdk/cpp/include/foundry_local_exception.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "logger.h" + +namespace foundry_local { + + class Exception final : public std::runtime_error { + public: + explicit Exception(std::string message) : std::runtime_error(std::move(message)) {} + + Exception(std::string message, ILogger& logger) : std::runtime_error(std::move(message)) { + logger.Log(LogLevel::Error, what()); + } + }; + + // Backward compatibility alias. + using FoundryLocalException = Exception; + +} // namespace foundry_local diff --git a/sdk/cpp/include/foundry_local_manager.h b/sdk/cpp/include/foundry_local_manager.h new file mode 100644 index 00000000..9ff5eda3 --- /dev/null +++ b/sdk/cpp/include/foundry_local_manager.h @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include + +#include +#include + +#include "configuration.h" +#include "logger.h" +#include "catalog.h" + +namespace foundry_local::Internal { + struct IFoundryLocalCore; +} + +namespace foundry_local { + + class FoundryLocalManager final { + public: + FoundryLocalManager(const FoundryLocalManager&) = delete; + FoundryLocalManager& operator=(const FoundryLocalManager&) = delete; + FoundryLocalManager(FoundryLocalManager&&) = delete; + FoundryLocalManager& operator=(FoundryLocalManager&&) = delete; + + /// Create the FoundryLocalManager singleton instance. + /// Throws if an instance has already been created. Call Destroy() first to release the current instance. + /// @param configuration Configuration to use. + /// @param logger Optional application logger. Pass nullptr to suppress log output. + static void Create(Configuration configuration, ILogger* logger = nullptr); + + /// Get the singleton instance. + /// Throws if Create() has not been called. + static FoundryLocalManager& Instance(); + + /// Returns true if the singleton instance has been created. + static bool IsInitialized() noexcept; + + /// Destroy the singleton instance, performing deterministic cleanup. + /// Unloads all loaded models and stops the web service if running. + /// After this call, Create() may be called again. + static void Destroy() noexcept; + + const Catalog& GetCatalog() const; + Catalog& GetCatalog(); + + /// Start the optional built-in web service. + /// Provides an OpenAI-compatible REST endpoint. + /// After startup, GetUrls() returns the actual bound URL/s. + /// Requires Configuration::Web to be set. + void StartWebService(); + + /// Stop the web service if started. + void StopWebService(); + + /// Returns the bound URL/s after StartWebService(), or empty if not started. + gsl::span GetUrls() const noexcept; + + /// Ensure execution providers are downloaded and registered. + /// Once downloaded, EPs are not re-downloaded unless a new version is available. + void EnsureEpsDownloaded() const; + + private: + explicit FoundryLocalManager(Configuration configuration, ILogger* logger); + ~FoundryLocalManager(); + + struct Deleter { + void operator()(FoundryLocalManager* p) const noexcept { delete p; } + }; + + void Initialize(); + void Cleanup() noexcept; + + static std::unique_ptr instance_; + + Configuration config_; + NullLogger defaultLogger_; + std::unique_ptr core_; + std::unique_ptr catalog_; + ILogger* logger_; + std::vector urls_; + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/log_level.h b/sdk/cpp/include/log_level.h new file mode 100644 index 00000000..75dfe667 --- /dev/null +++ b/sdk/cpp/include/log_level.h @@ -0,0 +1,37 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace foundry_local { + + enum class LogLevel { + Verbose, + Debug, + Information, + Warning, + Error, + Fatal + }; + + inline std::string_view LogLevelToString(LogLevel level) noexcept { + switch (level) { + case LogLevel::Verbose: + return "Verbose"; + case LogLevel::Debug: + return "Debug"; + case LogLevel::Information: + return "Information"; + case LogLevel::Warning: + return "Warning"; + case LogLevel::Error: + return "Error"; + case LogLevel::Fatal: + return "Fatal"; + } + return "Unknown"; + } + +} // namespace foundry_local diff --git a/sdk/cpp/include/logger.h b/sdk/cpp/include/logger.h new file mode 100644 index 00000000..d0b05b4e --- /dev/null +++ b/sdk/cpp/include/logger.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include "log_level.h" + +namespace foundry_local { + class ILogger { + public: + virtual ~ILogger() = default; + virtual void Log(LogLevel level, std::string_view message) noexcept = 0; + }; + + class NullLogger final : public ILogger { + public: + void Log(LogLevel, std::string_view) noexcept override {} + }; +} // namespace foundry_local diff --git a/sdk/cpp/include/model.h b/sdk/cpp/include/model.h new file mode 100644 index 00000000..25f04699 --- /dev/null +++ b/sdk/cpp/include/model.h @@ -0,0 +1,194 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "logger.h" + +namespace foundry_local { + class OpenAIChatClient; + class OpenAIAudioClient; +} + +namespace foundry_local::Internal { + struct IFoundryLocalCore; +} + +namespace foundry_local { +#ifdef FL_TESTS + namespace Testing { + struct MockObjectFactory; + } +#endif + + using DownloadProgressCallback = std::function; + + class IModel { + public: + virtual ~IModel() = default; + + virtual const std::string& GetId() const = 0; + virtual const std::string& GetAlias() const = 0; + virtual bool IsLoaded() const = 0; + virtual bool IsCached() const = 0; + virtual const std::filesystem::path& GetPath() const = 0; + virtual void Download(DownloadProgressCallback onProgress = nullptr) = 0; + virtual void Load() = 0; + virtual void Unload() = 0; + virtual void RemoveFromCache() = 0; + + protected: + struct CoreAccess { + gsl::not_null core; + std::string modelName; + gsl::not_null logger; + }; + + virtual CoreAccess GetCoreAccess() const = 0; + + friend class OpenAIChatClient; + friend class OpenAIAudioClient; + }; + + enum class DeviceType { + Invalid, + CPU, + GPU, + NPU + }; + + struct Runtime { + DeviceType device_type = DeviceType::Invalid; + std::string execution_provider; + }; + + struct PromptTemplate { + std::string system; + std::string user; + std::string assistant; + std::string prompt; + }; + + // Forward declarations + class ModelVariant; + + struct Parameter { + std::string name; + std::optional value; + }; + + struct ModelSettings { + std::vector parameters; + }; + + struct ModelInfo { + std::string id; + std::string name; + uint32_t version = 0; + std::string alias; + std::optional display_name; + std::string provider_type; + std::string uri; + std::string model_type; + std::optional prompt_template; + std::optional publisher; + std::optional model_settings; + std::optional license; + std::optional license_description; + bool cached = false; + std::optional task; + std::optional runtime; + std::optional file_size_mb; + std::optional supports_tool_calling; + std::optional max_output_tokens; + std::optional min_fl_version; + int64_t created_at_unix = 0; + }; + + class ModelVariant final : public IModel { + public: + const ModelInfo& GetInfo() const; + const std::filesystem::path& GetPath() const override; + void Download(DownloadProgressCallback onProgress = nullptr) override; + void Load() override; + + bool IsLoaded() const override; + bool IsCached() const override; + void Unload() override; + void RemoveFromCache() override; + + const std::string& GetId() const noexcept override; + const std::string& GetAlias() const noexcept override; + uint32_t GetVersion() const noexcept; + + protected: + CoreAccess GetCoreAccess() const override; + + private: + static std::string MakeModelParamRequest(std::string_view modelId); + explicit ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger); + + ModelInfo info_; + mutable std::filesystem::path cachedPath_; + gsl::not_null core_; + gsl::not_null logger_; + + friend class Catalog; + friend class Model; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + + class Model final : public IModel { + public: + gsl::span GetAllModelVariants() const; + const ModelVariant& GetLatestVersion(const ModelVariant& variant) const; + + bool IsLoaded() const override { return SelectedVariant().IsLoaded(); } + bool IsCached() const override { return SelectedVariant().IsCached(); } + const std::filesystem::path& GetPath() const override { return SelectedVariant().GetPath(); } + void Download(DownloadProgressCallback onProgress = nullptr) override { + SelectedVariant().Download(std::move(onProgress)); + } + void Load() override { SelectedVariant().Load(); } + void Unload() override { SelectedVariant().Unload(); } + void RemoveFromCache() override { SelectedVariant().RemoveFromCache(); } + + const std::string& GetId() const override; + const std::string& GetAlias() const override; + void SelectVariant(const ModelVariant& variant) const; + + protected: + CoreAccess GetCoreAccess() const override; + + private: + explicit Model(gsl::not_null core, gsl::not_null logger); + ModelVariant& SelectedVariant(); + const ModelVariant& SelectedVariant() const; + + gsl::not_null core_; + + std::vector variants_; + mutable const ModelVariant* selectedVariant_ = nullptr; + gsl::not_null logger_; + + friend class Catalog; +#ifdef FL_TESTS + friend struct Testing::MockObjectFactory; +#endif + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_audio_client.h b/sdk/cpp/include/openai/openai_audio_client.h new file mode 100644 index 00000000..ac1ce719 --- /dev/null +++ b/sdk/cpp/include/openai/openai_audio_client.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +#include + +namespace foundry_local::Internal { + struct IFoundryLocalCore; +} + +namespace foundry_local { + class ILogger; + class IModel; + + struct AudioCreateTranscriptionResponse { + std::string text; + }; + + class OpenAIAudioClient final { + public: + explicit OpenAIAudioClient(const IModel& model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + AudioCreateTranscriptionResponse TranscribeAudio(const std::filesystem::path& audioFilePath) const; + + using StreamCallback = std::function; + void TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, const StreamCallback& onChunk) const; + + private: + OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_chat_client.h b/sdk/cpp/include/openai/openai_chat_client.h new file mode 100644 index 00000000..c16b9481 --- /dev/null +++ b/sdk/cpp/include/openai/openai_chat_client.h @@ -0,0 +1,116 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "openai_tool_types.h" + +namespace foundry_local::Internal { + struct IFoundryLocalCore; +} + +namespace foundry_local { + class ILogger; + class IModel; + + /// Reason the model stopped generating tokens. + enum class FinishReason { + None, + Stop, + Length, + ToolCalls, + ContentFilter + }; + + struct ChatMessage { + std::string role; + std::string content; + std::optional tool_call_id; ///< For role="tool" responses + std::vector tool_calls; + }; + + struct ChatChoice { + int index = 0; + FinishReason finish_reason = FinishReason::None; + + // non-streaming + std::optional message; + + // streaming + std::optional delta; + }; + + struct ChatCompletionCreateResponse { + int64_t created = 0; + std::string id; + + bool is_delta = false; + bool successful = false; + int http_status_code = 0; + + std::vector choices; + + /// Returns the object type string. Derived from is_delta — no allocation. + const char* GetObject() const noexcept { return is_delta ? "chat.completion.chunk" : "chat.completion"; } + + /// Returns the created timestamp as an ISO 8601 string. + /// Computed lazily, only allocates when called. + std::string GetCreatedAtIso() const; + }; + + struct ChatSettings { + std::optional frequency_penalty; + std::optional max_tokens; + std::optional n; + std::optional temperature; + std::optional presence_penalty; + std::optional random_seed; + std::optional top_k; + std::optional top_p; + std::optional tool_choice; + }; + + class OpenAIChatClient final { + public: + explicit OpenAIChatClient(const IModel& model); + + /// Returns the model ID this client was created for. + const std::string& GetModelId() const noexcept { return modelId_; } + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + const ChatSettings& settings) const; + + ChatCompletionCreateResponse CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const; + + using StreamCallback = std::function; + void CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const; + + void CompleteChatStreaming(gsl::span messages, gsl::span tools, + const ChatSettings& settings, const StreamCallback& onChunk) const; + + private: + OpenAIChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger); + + std::string BuildChatRequestJson(gsl::span messages, gsl::span tools, + const ChatSettings& settings, bool stream) const; + + std::string modelId_; + gsl::not_null core_; + gsl::not_null logger_; + }; + +} // namespace foundry_local diff --git a/sdk/cpp/include/openai/openai_tool_types.h b/sdk/cpp/include/openai/openai_tool_types.h new file mode 100644 index 00000000..105bc49e --- /dev/null +++ b/sdk/cpp/include/openai/openai_tool_types.h @@ -0,0 +1,54 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include + +namespace foundry_local { + + /// JSON Schema property definition used to describe tool function parameters. + struct PropertyDefinition { + std::string type; + std::optional description; + std::optional> properties; + std::optional> required; + }; + + /// Describes a function that a model may call. + struct FunctionDefinition { + std::string name; + std::optional description; + std::optional parameters; + }; + + /// A tool definition following the OpenAI tool calling spec. + struct ToolDefinition { + std::string type = "function"; + FunctionDefinition function; + }; + + /// A parsed function call returned by the model. + struct FunctionCall { + std::string name; + std::string arguments; ///< JSON string of the arguments + }; + + /// A tool call returned by the model in a chat completion response. + struct ToolCall { + std::string id; + std::string type; + std::optional function_call; + }; + + /// Controls whether and how the model calls tools. + enum class ToolChoiceKind { + Auto, + None, + Required + }; + +} // namespace foundry_local diff --git a/sdk/cpp/sample/main.cpp b/sdk/cpp/sample/main.cpp new file mode 100644 index 00000000..7b014c4a --- /dev/null +++ b/sdk/cpp/sample/main.cpp @@ -0,0 +1,355 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "foundry_local.h" + +#include +#include +#include + +using namespace foundry_local; + +// --------------------------------------------------------------------------- +// Logger +// --------------------------------------------------------------------------- +class StdLogger final : public ILogger { +public: + void Log(LogLevel level, std::string_view message) noexcept override { + const char* tag = "UNK"; + switch (level) { + case LogLevel::Information: + tag = "INFO"; + break; + case LogLevel::Warning: + tag = "WARN"; + break; + case LogLevel::Error: + tag = "ERROR"; + break; + default: + tag = "DEBUG"; + break; + } + std::cout << "[FoundryLocal][" << tag << "] " << message << "\n"; + } +}; + +// --------------------------------------------------------------------------- +// Example 1 – Browse the catalog +// --------------------------------------------------------------------------- +void BrowseCatalog(FoundryLocalManager& manager) { + std::cout << "\n=== Example 1: Browse Catalog ===\n"; + + auto& catalog = manager.GetCatalog(); + std::cout << "Catalog: " << catalog.GetName() << "\n"; + + auto models = catalog.ListModels(); + std::cout << "Models in catalog: " << models.size() << "\n"; + + for (const auto* model : models) { + std::cout << " - " << model->GetAlias() << " (" << model->GetId() << ")" + << " cached=" << (model->IsCached() ? "yes" : "no") + << " loaded=" << (model->IsLoaded() ? "yes" : "no") << "\n"; + + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + std::cout << " variant: " << info.name << " v" << info.version + << " cached=" << (variant.IsCached() ? "yes" : "no"); + if (info.display_name) + std::cout << " display=\"" << *info.display_name << "\""; + if (info.publisher) + std::cout << " publisher=" << *info.publisher; + if (info.license) + std::cout << " license=" << *info.license; + if (info.runtime) { + std::cout << " device=" + << (info.runtime->device_type == DeviceType::GPU ? "GPU" + : info.runtime->device_type == DeviceType::NPU ? "NPU" + : "CPU") + << " ep=" << info.runtime->execution_provider; + } + if (info.file_size_mb) + std::cout << " size=" << *info.file_size_mb << "MB"; + if (info.supports_tool_calling) + std::cout << " tools=" << (*info.supports_tool_calling ? "yes" : "no"); + std::cout << "\n"; + } + } +} + +// --------------------------------------------------------------------------- +// Example 2 – Download, load, chat (non-streaming), then unload +// --------------------------------------------------------------------------- +void ChatNonStreaming(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 2: Non-Streaming Chat ===\n"; + + auto& catalog = manager.GetCatalog(); + + auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + + if (model->IsLoaded()) { + std::cout << "Model is loaded and ready for inference.\n"; + } + else { + std::cerr << "Failed to load model.\n"; + return; + } + + OpenAIChatClient chat(*model); + + std::vector messages = {{"user", "What is the capital of Croatia?"}}; + + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 128; + + auto response = chat.CompleteChat(messages, settings); + + if (!response.choices.empty() && response.choices[0].message) { + std::cout << "Assistant: " << response.choices[0].message->content << "\n"; + } + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + +// --------------------------------------------------------------------------- +// Example 3 – Streaming chat +// --------------------------------------------------------------------------- +void ChatStreaming(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 3: Streaming Chat ===\n"; + + auto& catalog = manager.GetCatalog(); + + auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Load(); + + OpenAIChatClient chat(*model); + + std::vector messages = {{"user", "Explain quantum computing in three sentences."}}; + + ChatSettings settings; + settings.temperature = 0.9f; + settings.max_tokens = 256; + + std::cout << "Assistant: "; + chat.CompleteChatStreaming(messages, settings, [](const ChatCompletionCreateResponse& chunk) { + if (chunk.choices.empty()) + return; + const auto& choice = chunk.choices[0]; + if (choice.delta && !choice.delta->content.empty()) { + std::cout << choice.delta->content << std::flush; + } + }); + std::cout << "\n"; + + model->Unload(); +} + +// --------------------------------------------------------------------------- +// Example 4 – Audio transcription +// --------------------------------------------------------------------------- +void TranscribeAudio(FoundryLocalManager& manager, const std::string& alias, const std::string& audioPath) { + std::cout << "\n=== Example 4: Audio Transcription ===\n"; + + auto& catalog = manager.GetCatalog(); + + auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + + OpenAIAudioClient audio(*model); + + std::cout << "Transcribing: " << audioPath << "\n"; + auto result = audio.TranscribeAudio(audioPath); + std::cout << "Transcription: " << result.text << "\n"; + + // Streaming alternative: + audio.TranscribeAudioStreaming( + audioPath, [](const AudioCreateTranscriptionResponse& chunk) { std::cout << chunk.text << std::flush; }); + std::cout << "\n"; + + model->Unload(); +} + +// --------------------------------------------------------------------------- +// Example 5 – Tool calling +// --------------------------------------------------------------------------- +// Tool calling lets you define functions that the model can decide to invoke. +// The flow is: +// 1. You describe your tools (functions) as ToolDefinition objects. +// 2. You send a chat request with those tools attached. +// 3. The model may respond with finish_reason = ToolCalls and include +// ToolCall objects in the message, each containing the function name +// and a JSON string of arguments. +// 4. YOUR CODE executes the real function using those arguments. +// 5. You add a message with role = "tool" containing the result, then +// send the conversation back so the model can formulate a final answer. +// +// This lets the model "reach out" to external capabilities (calculators, +// databases, APIs, etc.) while keeping the actual execution in your code. +// --------------------------------------------------------------------------- +void ChatWithToolCalling(FoundryLocalManager& manager, const std::string& alias) { + std::cout << "\n=== Example 5: Tool Calling ===\n"; + + auto& catalog = manager.GetCatalog(); + + auto* model = catalog.GetModel(alias); + if (!model) { + std::cerr << "Model '" << alias << "' not found in catalog.\n"; + return; + } + + model->Download([](float pct) { std::cout << "\rDownloading: " << pct << "% " << std::flush; }); + std::cout << "\n"; + + model->Load(); + std::cout << "Model loaded: " << model->GetAlias() << "\n"; + + OpenAIChatClient chat(*model); + + // ── Step 1: Define tools ────────────────────────────────────────────── + // Each tool describes a function the model can call. The PropertyDefinition + // mirrors a JSON Schema so the model knows what arguments are expected. + std::vector tools = { + {"function", + FunctionDefinition{"multiply_numbers", // function name + "Multiply two integers and return the result.", // description + PropertyDefinition{ + "object", // top-level schema type + std::nullopt, // no top-level description + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}}}, + std::vector{"first", "second"} // both params are required + }}}}; + + // ── Step 2: Send the first request ──────────────────────────────────── + // tool_choice = Required forces the model to always produce a tool call. + // In production you'd typically use Auto so the model decides on its own. + std::vector messages = { + {"system", "You are a helpful AI assistant. Use the provided tools when appropriate."}, + {"user", "What is 7 multiplied by 6?"}}; + + ChatSettings settings; + settings.temperature = 0.0f; + settings.max_tokens = 500; + settings.tool_choice = ToolChoiceKind::Required; + + std::cout << "Sending chat request with tool definitions...\n"; + auto response = chat.CompleteChat(messages, tools, settings); + + // ── Step 3: Inspect the model's tool call ───────────────────────────── + if (response.choices.empty()) { + std::cerr << "No choices returned.\n"; + model->Unload(); + return; + } + + const auto& firstChoice = response.choices[0]; + + // The model signals it wants to call a tool via finish_reason == ToolCalls. + if (firstChoice.finish_reason == FinishReason::ToolCalls && firstChoice.message && + !firstChoice.message->tool_calls.empty()) { + const auto& tc = firstChoice.message->tool_calls[0]; + std::cout << "Model requested tool call:\n" + << " function : " << (tc.function_call ? tc.function_call->name : "(none)") << "\n" + << " arguments: " << (tc.function_call ? tc.function_call->arguments : "{}") << "\n"; + + // ── Step 4: Execute the tool locally ────────────────────────────── + // Parse the arguments JSON and perform the actual computation. + // In a real application this could be a web request, DB query, etc. + std::string toolResult; + if (tc.function_call && tc.function_call->name == "multiply_numbers") { + // The arguments string is JSON, e.g. {"first": 7, "second": 6} + // For brevity we hard-code the expected result here. + toolResult = "7 x 6 = 42."; + std::cout << " result : " << toolResult << "\n"; + } + else { + toolResult = "Unknown tool."; + } + + // ── Step 5: Feed the tool result back ───────────────────────────── + // First, append the assistant message that contains the tool_calls + // so the model sees its own request in the conversation history. + messages.push_back({"assistant", "", std::nullopt, firstChoice.message->tool_calls}); + + // Then add a "tool" message with the result, referencing the + // tool_call_id so the model can match it to the call it made. + messages.push_back({"tool", toolResult, tc.id}); + + // Switch to Auto so the model can answer without calling tools again. + settings.tool_choice = ToolChoiceKind::Auto; + + std::cout << "\nSending tool result back to model...\n"; + auto followUp = chat.CompleteChat(messages, tools, settings); + + if (!followUp.choices.empty() && followUp.choices[0].message) { + std::cout << "Assistant: " << followUp.choices[0].message->content << "\n"; + } + } + else { + // The model answered directly without a tool call. + if (firstChoice.message) + std::cout << "Assistant: " << firstChoice.message->content << "\n"; + } + + model->Unload(); + std::cout << "Model unloaded.\n"; +} + +// --------------------------------------------------------------------------- +// main +// --------------------------------------------------------------------------- +int main() { + try { + StdLogger logger; + FoundryLocalManager::Create({"SampleApp"}, &logger); + auto& manager = FoundryLocalManager::Instance(); + + // 1. Browse the full catalog + BrowseCatalog(manager); + + // 2. Non-streaming chat (change alias to a model in your catalog) + ChatNonStreaming(manager, "phi-3.5-mini"); + + // 3. Streaming chat + ChatStreaming(manager, "phi-3.5-mini"); + + // 4. Audio transcription (uncomment and set a valid alias + wav path) + // TranscribeAudio(manager, "whisper-small", R"(C:\path\to\your\audio.wav)"); + + // 5. Tool calling (define tools, let the model call them, feed results back) + ChatWithToolCalling(manager, "phi-3.5-mini"); + + FoundryLocalManager::Destroy(); + return 0; + } + catch (const std::exception& ex) { + std::cerr << "Fatal: " << ex.what() << std::endl; + FoundryLocalManager::Destroy(); + return 1; + } +} diff --git a/sdk/cpp/src/catalog.cpp b/sdk/cpp/src/catalog.cpp new file mode 100644 index 00000000..b67c1a30 --- /dev/null +++ b/sdk/cpp/src/catalog.cpp @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_helpers.h" +#include "parser.h" +#include "logger.h" + +namespace foundry_local { + + using namespace detail; + + Catalog::Catalog(gsl::not_null injected, gsl::not_null logger) + : state_(std::make_shared()), core_(injected), logger_(logger) { + auto response = core_->call("get_catalog_name", *logger_, /*dataArgument*/ nullptr); + if (response.HasError()) { + throw Exception(std::string("Error getting catalog name: ") + response.error, *logger_); + } + name_ = std::move(response.data); + } + + std::shared_ptr Catalog::GetState() const { + std::lock_guard lock(mutex_); + return state_; + } + + std::vector Catalog::GetLoadedModels() const { + UpdateModels(); + auto state = GetState(); + return CollectVariantsByIds(state->modelIdToModelVariant, GetLoadedModelsInternal(core_, *logger_)); + } + + std::vector Catalog::GetCachedModels() const { + UpdateModels(); + auto state = GetState(); + return CollectVariantsByIds(state->modelIdToModelVariant, GetCachedModelsInternal(core_, *logger_)); + } + + Model* Catalog::GetModel(std::string_view modelId) const { + UpdateModels(); + auto state = GetState(); + auto it = state->byAlias.find(std::string(modelId)); + if (it != state->byAlias.end()) { + return const_cast(&it->second); + } + return nullptr; + } + + std::vector Catalog::ListModels() const { + UpdateModels(); + auto state = GetState(); + + std::vector out; + out.reserve(state->byAlias.size()); + for (auto& kv : state->byAlias) + out.emplace_back(const_cast(&kv.second)); + + return out; + } + + void Catalog::UpdateModels() const { + using clock = std::chrono::steady_clock; + + // TODO: make this configurable + constexpr auto kRefreshInterval = std::chrono::hours(6); + + const auto now = clock::now(); + { + auto current = GetState(); + if (current->lastFetch.time_since_epoch() != clock::duration::zero() && + (now - current->lastFetch) < kRefreshInterval) { + return; + } + } + + // Fetch outside the lock so the core call doesn't block readers. + const auto response = core_->call("get_model_list", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error getting model list: ") + response.error, *logger_); + } + const auto arr = nlohmann::json::parse(response.data); + + // Build the new state locally no reader can see partial data. + auto newState = std::make_shared(); + + for (const auto& j : arr) { + const std::string alias = j.at("alias").get(); + + auto it = newState->byAlias.find(alias); + if (it == newState->byAlias.end()) { + Model m(core_, logger_); + it = newState->byAlias.emplace(alias, std::move(m)).first; + } + + ModelInfo modelVariantInfo; + from_json(j, modelVariantInfo); + ModelVariant modelVariant(core_, modelVariantInfo, logger_); + it->second.variants_.emplace_back(std::move(modelVariant)); + } + + // Build the lookup map from pointers into the owning Model::variants_ vectors, + // and auto-select the first variant for each model. + for (auto& [alias, model] : newState->byAlias) { + for (auto& variant : model.variants_) { + newState->modelIdToModelVariant.emplace(variant.GetId(), &variant); + } + if (!model.variants_.empty()) { + model.selectedVariant_ = &model.variants_.front(); + } + } + + newState->lastFetch = now; + + // Atomic swap — readers that already hold the old shared_ptr keep it alive. + { + std::lock_guard lock(mutex_); + state_ = std::move(newState); + } + } + + ModelVariant* Catalog::GetModelVariant(std::string_view id) const { + UpdateModels(); + auto state = GetState(); + auto it = state->modelIdToModelVariant.find(std::string(id)); + if (it != state->modelIdToModelVariant.end()) { + return it->second; + } + return nullptr; + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/core.h b/sdk/cpp/src/core.h new file mode 100644 index 00000000..c7f73d5d --- /dev/null +++ b/sdk/cpp/src/core.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Core DLL interop – loads Microsoft.AI.Foundry.Local.Core.dll at runtime. +// Internal header, not part of the public API. + +#pragma once + +#include +#include +#include +#include +#include + +#include + +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "flcore_native.h" +#include "logger.h" + +namespace foundry_local { + + namespace { + inline std::filesystem::path getExecutableDir() { + auto exePath = wil::GetModuleFileNameW(nullptr); + return std::filesystem::path(exePath.get()).parent_path(); + } + + inline void* RequireProc(HMODULE mod, const char* name) { + if (void* p = ::GetProcAddress(mod, name)) + return p; + throw std::runtime_error(std::string("GetProcAddress failed for ") + name); + } + } // namespace + + struct Core : Internal::IFoundryLocalCore { + using ResponseHandle = std::unique_ptr; + + Core() = default; + ~Core() = default; + + void loadEmbedded() { loadFromPath(getExecutableDir() / "Microsoft.AI.Foundry.Local.Core.dll"); } + + void unload() override { + module_.reset(); + execCmd_ = nullptr; + execCbCmd_ = nullptr; + freeResCmd_ = nullptr; + } + + CoreResponse call(std::string_view command, ILogger& logger, const std::string* dataArgument = nullptr, + NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + if (!module_ || !execCmd_ || !execCbCmd_ || !freeResCmd_) { + throw Exception("Core is not loaded. Cannot call command: " + std::string(command), logger); + } + + RequestBuffer request{}; + request.Command = command.empty() ? nullptr : command.data(); + request.CommandLength = static_cast(command.size()); + + if (dataArgument && !dataArgument->empty()) { + request.Data = dataArgument->data(); + request.DataLength = static_cast(dataArgument->size()); + } + + ResponseBuffer response{}; + auto safeDeleter = [fn = freeResCmd_](ResponseBuffer* buf) { + if (fn) + fn(buf); + }; + std::unique_ptr responseGuard(&response, safeDeleter); + + if (callback != nullptr) { + execCbCmd_(&request, &response, reinterpret_cast(callback), data); + } + else { + execCmd_(&request, &response); + } + + CoreResponse result; + if (response.Error && response.ErrorLength > 0) { + result.error.assign(static_cast(response.Error), response.ErrorLength); + return result; + } + + if (response.Data && response.DataLength > 0) { + result.data.assign(static_cast(response.Data), response.DataLength); + } + + return result; + } + + private: + wil::unique_hmodule module_; + execute_command_fn execCmd_{}; + execute_command_with_callback_fn execCbCmd_{}; + free_response_fn freeResCmd_{}; + + void loadFromPath(const std::filesystem::path& path) { + wil::unique_hmodule m(::LoadLibraryW(path.c_str())); + if (!m) + throw std::runtime_error("LoadLibraryW failed"); + + execCmd_ = reinterpret_cast(RequireProc(m.get(), "execute_command")); + execCbCmd_ = reinterpret_cast( + RequireProc(m.get(), "execute_command_with_callback")); + freeResCmd_ = reinterpret_cast(RequireProc(m.get(), "free_response")); + + module_ = std::move(m); + } + }; + +} // namespace foundry_local diff --git a/sdk/cpp/src/core_helpers.h b/sdk/cpp/src/core_helpers.h new file mode 100644 index 00000000..67bd40d5 --- /dev/null +++ b/sdk/cpp/src/core_helpers.h @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// Internal helpers shared across implementation files. +// Not part of the public API. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "logger.h" +#include "model.h" + +namespace foundry_local::detail { + + // Wrap Params: { ... } into a request object + inline nlohmann::json MakeParams(nlohmann::json params) { + return nlohmann::json{{"Params", std::move(params)}}; + } + + // Most common: Params { "Model": } + inline nlohmann::json MakeModelParams(std::string_view model) { + return MakeParams(nlohmann::json{{"Model", std::string(model)}}); + } + + // Serialize + call + inline CoreResponse CallWithJson(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, ILogger& logger) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload); + } + + // Serialize + call with native callback + inline CoreResponse CallWithJsonAndCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& requestJson, ILogger& logger, + NativeCallbackFn callback, void* userData) { + std::string payload = requestJson.dump(); + return core->call(command, logger, &payload, callback, userData); + } + + // Serialize + call with a streaming chunk handler. + // Wraps the caller-supplied onChunk with the native callback boilerplate + // (null/length checks, exception capture, rethrow after the call). + // The errorContext string is used to prefix any core-layer error message. + inline CoreResponse CallWithStreamingCallback(Internal::IFoundryLocalCore* core, std::string_view command, + const std::string& payload, ILogger& logger, + const std::function& onChunk, + std::string_view errorContext) { + struct State { + const std::function* cb; + std::exception_ptr exception; + } state{&onChunk, nullptr}; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + + auto* st = static_cast(user); + if (st->exception) + return; + + try { + std::string chunk(static_cast(data), static_cast(len)); + (*(st->cb))(chunk); + } + catch (...) { + st->exception = std::current_exception(); + } + }; + + auto response = core->call(command, logger, &payload, +nativeCallback, &state); + if (response.HasError()) { + throw Exception(std::string(errorContext) + response.error, logger); + } + + if (state.exception) { + std::rethrow_exception(state.exception); + } + + return response; + } + + // Overload: allow Params object directly + inline CoreResponse CallWithParams(Internal::IFoundryLocalCore* core, std::string_view command, + const nlohmann::json& params, ILogger& logger) { + return CallWithJson(core, command, MakeParams(params), logger); + } + + // Overload: no payload + inline CoreResponse CallNoArgs(Internal::IFoundryLocalCore* core, std::string_view command, ILogger& logger) { + return core->call(command, logger, nullptr); + } + + inline std::vector GetLoadedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { + auto response = core->call("list_loaded_models", logger); + if (response.HasError()) { + throw Exception("Failed to get loaded models: " + response.error, logger); + } + try { + auto parsed = nlohmann::json::parse(response.data); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw Exception("Catalog::GetLoadedModelsInternal() JSON error: " + std::string(e.what()), logger); + } + } + + inline std::vector GetCachedModelsInternal(Internal::IFoundryLocalCore* core, ILogger& logger) { + auto response = core->call("get_cached_models", logger); + if (response.HasError()) { + throw Exception("Failed to get cached models: " + response.error, logger); + } + + try { + auto parsed = nlohmann::json::parse(response.data); + return parsed.get>(); + } + catch (const nlohmann::json::exception& e) { + throw Exception("Catalog::GetCachedModelsInternal JSON error: " + std::string(e.what()), logger); + } + } + + inline std::vector CollectVariantsByIds( + const std::unordered_map& modelIdToModelVariant, std::vector ids) { + std::vector out; + out.reserve(ids.size()); + + for (const auto& id : ids) { + auto it = modelIdToModelVariant.find(id); + if (it != modelIdToModelVariant.end()) { + out.emplace_back(it->second); + } + } + return out; + } + +} // namespace foundry_local::detail diff --git a/sdk/cpp/src/core_interop_request.h b/sdk/cpp/src/core_interop_request.h new file mode 100644 index 00000000..67ef1590 --- /dev/null +++ b/sdk/cpp/src/core_interop_request.h @@ -0,0 +1,46 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include + +namespace foundry_local { + + class CoreInteropRequest final { + public: + explicit CoreInteropRequest(std::string command) : command_(std::move(command)) {} + + CoreInteropRequest& AddParam(std::string_view key, std::string_view value) { + params_[std::string(key)] = std::string(value); + return *this; + } + + template CoreInteropRequest& AddParam(std::string_view key, const T& value) { + params_[std::string(key)] = value; + return *this; + } + + CoreInteropRequest& AddJsonParam(std::string_view key, const nlohmann::json& jsonValue) { + params_[std::string(key)] = jsonValue.dump(); + return *this; + } + + std::string ToJson() const { + nlohmann::json wrapper; + if (!params_.empty()) { + wrapper["Params"] = params_; + } + return wrapper.dump(); + } + + const std::string& Command() const noexcept { return command_; } + + private: + std::string command_; + nlohmann::json params_; + }; + +} // namespace foundry_local diff --git a/sdk/cpp/src/flcore_native.h b/sdk/cpp/src/flcore_native.h new file mode 100644 index 00000000..b0778116 --- /dev/null +++ b/sdk/cpp/src/flcore_native.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include + +extern "C" +{ + // Layout must match C# structs exactly +#pragma pack(push, 8) + struct RequestBuffer { + const void* Command; + int32_t CommandLength; + const void* Data; + int32_t DataLength; + }; + + struct ResponseBuffer { + void* Data; + int32_t DataLength; + void* Error; + int32_t ErrorLength; + }; + + // Callback signature: void(*)(void* data, int length, void* userData) + using UserCallbackFn = void(__cdecl*)(void*, int32_t, void*); + + // Exported function pointer types + using execute_command_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*); + using execute_command_with_callback_fn = void(__cdecl*)(RequestBuffer*, ResponseBuffer*, void* /*callback*/, + void* /*userData*/); + using free_response_fn = void(__cdecl*)(ResponseBuffer*); + + static_assert(std::is_standard_layout::value, "RequestBuffer must be standard layout"); + static_assert(std::is_standard_layout::value, "ResponseBuffer must be standard layout"); + +#pragma pack(pop) +} diff --git a/sdk/cpp/src/foundry_local_internal_core.h b/sdk/cpp/src/foundry_local_internal_core.h new file mode 100644 index 00000000..1e5af79d --- /dev/null +++ b/sdk/cpp/src/foundry_local_internal_core.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include "logger.h" + +namespace foundry_local { + + /// Native callback signature used by the core DLL interop. + /// Parameters: (data, dataLength, userData). + using NativeCallbackFn = void (*)(void*, int32_t, void*); + + /// Value returned by IFoundryLocalCore::call(). + /// On success, `data` contains the response payload and `error` is empty. + /// On failure, `error` contains the error message from the core layer. + struct CoreResponse { + std::string data; + std::string error; + + bool HasError() const noexcept { return !error.empty(); } + }; + + namespace Internal { + struct IFoundryLocalCore { + virtual ~IFoundryLocalCore() = default; + + virtual CoreResponse call(std::string_view command, ILogger& logger, + const std::string* dataArgument = nullptr, NativeCallbackFn callback = nullptr, + void* data = nullptr) const = 0; + virtual void unload() = 0; + }; + + } // namespace Internal +} // namespace foundry_local \ No newline at end of file diff --git a/sdk/cpp/src/foundry_local_manager.cpp b/sdk/cpp/src/foundry_local_manager.cpp new file mode 100644 index 00000000..d1ab35bb --- /dev/null +++ b/sdk/cpp/src/foundry_local_manager.cpp @@ -0,0 +1,191 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "core.h" +#include "logger.h" + +namespace foundry_local { + + std::unique_ptr FoundryLocalManager::instance_; + + void FoundryLocalManager::Create(Configuration configuration, ILogger* logger) { + if (instance_) { + NullLogger fallback; + ILogger& log = logger ? *logger : fallback; + throw Exception("FoundryLocalManager has already been created. Call Destroy() first.", log); + } + + // Use a local to ensure full initialization before assigning to the static instance. + std::unique_ptr manager( + new FoundryLocalManager(std::move(configuration), logger)); + instance_ = std::move(manager); + } + + FoundryLocalManager& FoundryLocalManager::Instance() { + if (!instance_) { + throw Exception("FoundryLocalManager has not been created. Call Create() first."); + } + return *instance_; + } + + bool FoundryLocalManager::IsInitialized() noexcept { + return instance_ != nullptr; + } + + void FoundryLocalManager::Destroy() noexcept { + instance_.reset(); + } + + FoundryLocalManager::FoundryLocalManager(Configuration configuration, ILogger* logger) + : config_(std::move(configuration)), core_(std::make_unique()), + logger_(logger ? logger : &defaultLogger_) { + static_cast(core_.get())->loadEmbedded(); + Initialize(); + catalog_ = Catalog::Create(core_.get(), logger_); + } + + FoundryLocalManager::~FoundryLocalManager() { + Cleanup(); + } + + void FoundryLocalManager::Cleanup() noexcept { + // Unload all loaded models before tearing down. + if (catalog_) { + try { + auto loadedModels = catalog_->GetLoadedModels(); + for (auto* variant : loadedModels) { + try { + variant->Unload(); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error unloading model during destruction: ") + ex.what()); + } + } + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error retrieving loaded models during destruction: ") + ex.what()); + } + } + + if (!urls_.empty()) { + try { + StopWebService(); + } + catch (const std::exception& ex) { + logger_->Log(LogLevel::Warning, + std::string("Error stopping web service during destruction: ") + ex.what()); + } + } + } + + const Catalog& FoundryLocalManager::GetCatalog() const { + return *catalog_; + } + + Catalog& FoundryLocalManager::GetCatalog() { + return *catalog_; + } + + void FoundryLocalManager::StartWebService() { + if (!config_.web) { + throw Exception("Web service configuration was not provided.", *logger_); + } + + auto response = core_->call("start_service", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error starting web service: ") + response.error, *logger_); + } + auto arr = nlohmann::json::parse(response.data); + urls_ = arr.get>(); + } + + void FoundryLocalManager::StopWebService() { + if (!config_.web) { + throw Exception("Web service configuration was not provided.", *logger_); + } + + auto response = core_->call("stop_service", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error stopping web service: ") + response.error, *logger_); + } + urls_.clear(); + } + + gsl::span FoundryLocalManager::GetUrls() const noexcept { + return urls_; + } + + void FoundryLocalManager::EnsureEpsDownloaded() const { + auto response = core_->call("ensure_eps_downloaded", *logger_); + if (response.HasError()) { + throw Exception(std::string("Error ensuring execution providers downloaded: ") + response.error, *logger_); + } + } + + void FoundryLocalManager::Initialize() { + config_.Validate(); + + CoreInteropRequest initReq("initialize"); + initReq.AddParam("AppName", config_.app_name); + initReq.AddParam("LogLevel", std::string(LogLevelToString(config_.log_level))); + + if (config_.app_data_dir) { + initReq.AddParam("AppDataDir", config_.app_data_dir->string()); + } + if (config_.model_cache_dir) { + initReq.AddParam("ModelCacheDir", config_.model_cache_dir->string()); + } + if (config_.logs_dir) { + initReq.AddParam("LogsDir", config_.logs_dir->string()); + } + if (config_.web && config_.web->urls) { + initReq.AddParam("WebServiceUrls", *config_.web->urls); + } + if (config_.additional_settings) { + for (const auto& [key, value] : *config_.additional_settings) { + if (!key.empty()) { + initReq.AddParam(key, value); + } + } + } + + std::string initJson = initReq.ToJson(); + auto initResponse = core_->call(initReq.Command(), *logger_, &initJson); + if (initResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + initResponse.error, *logger_); + } + + if (config_.model_cache_dir) { + auto cacheResponse = core_->call("get_cache_directory", *logger_); + if (cacheResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + cacheResponse.error, + *logger_); + } + + if (cacheResponse.data != config_.model_cache_dir->string()) { + CoreInteropRequest setReq("set_cache_directory"); + setReq.AddParam("Directory", config_.model_cache_dir->string()); + std::string setJson = setReq.ToJson(); + auto setResponse = core_->call(setReq.Command(), *logger_, &setJson); + if (setResponse.HasError()) { + throw Exception(std::string("FoundryLocalManager::Initialize failed: ") + setResponse.error, + *logger_); + } + } + } + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/model.cpp b/sdk/cpp/src/model.cpp new file mode 100644 index 00000000..880e91e1 --- /dev/null +++ b/sdk/cpp/src/model.cpp @@ -0,0 +1,205 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_helpers.h" +#include "logger.h" + +namespace foundry_local { + + using namespace detail; + + /// ModelVariant + + ModelVariant::ModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) + : core_(core), info_(std::move(info)), logger_(logger) {} + + const ModelInfo& ModelVariant::GetInfo() const { + return info_; + } + + void ModelVariant::RemoveFromCache() { + auto response = CallWithJson(core_, "remove_cached_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error removing model from cache [" + info_.name + "]: " + response.error, *logger_); + } + cachedPath_.clear(); + } + + void ModelVariant::Unload() { + auto response = CallWithJson(core_, "unload_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error unloading model [" + info_.name + "]: " + response.error, *logger_); + } + } + + bool ModelVariant::IsLoaded() const { + std::vector loadedModelIds = GetLoadedModelsInternal(core_, *logger_); + for (const auto& id : loadedModelIds) { + if (id == info_.id) { + return true; + } + } + + return false; + } + + bool ModelVariant::IsCached() const { + auto cachedModels = GetCachedModelsInternal(core_, *logger_); + for (const auto& id : cachedModels) { + if (id == info_.id) { + return true; + } + } + return false; + } + + void ModelVariant::Download(DownloadProgressCallback onProgress) { + if (IsCached()) { + logger_->Log(LogLevel::Information, "Model '" + info_.name + "' is already cached, skipping download."); + return; + } + + if (onProgress) { + struct ProgressState { + DownloadProgressCallback* cb; + ILogger* logger; + } state{&onProgress, logger_}; + + auto nativeCallback = [](void* data, int32_t len, void* user) { + if (!data || len <= 0) + return; + auto* st = static_cast(user); + std::string perc(static_cast(data), static_cast(len)); + try { + float value = std::stof(perc); + (*(st->cb))(value); + } + catch (...) { + st->logger->Log(LogLevel::Warning, "Failed to parse download progress: " + perc); + } + }; + + auto response = CallWithJsonAndCallback(core_, "download_model", MakeModelParams(info_.name), *logger_, + +nativeCallback, &state); + if (response.HasError()) { + throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); + } + } + else { + auto response = CallWithJson(core_, "download_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error downloading model [" + info_.name + "]: " + response.error, *logger_); + } + } + } + + void ModelVariant::Load() { + auto response = CallWithJson(core_, "load_model", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error loading model [" + info_.name + "]: " + response.error, *logger_); + } + } + + const std::filesystem::path& ModelVariant::GetPath() const { + if (cachedPath_.empty()) { + auto response = CallWithJson(core_, "get_model_path", MakeModelParams(info_.name), *logger_); + if (response.HasError()) { + throw Exception("Error getting model path [" + info_.name + "]: " + response.error, *logger_); + } + cachedPath_ = std::filesystem::path(response.data); + } + return cachedPath_; + } + + const std::string& ModelVariant::GetId() const noexcept { + return info_.id; + } + + const std::string& ModelVariant::GetAlias() const noexcept { + return info_.alias; + } + + uint32_t ModelVariant::GetVersion() const noexcept { + return info_.version; + } + + IModel::CoreAccess ModelVariant::GetCoreAccess() const { + return {core_, info_.name, logger_}; + } + + /// Model + + Model::Model(gsl::not_null core, gsl::not_null logger) + : core_(core), logger_(logger) {} + + ModelVariant& Model::SelectedVariant() { + if (!selectedVariant_) { + throw Exception("Model has no selected variant", *logger_); + } + return *const_cast(selectedVariant_); + } + + const ModelVariant& Model::SelectedVariant() const { + if (!selectedVariant_) { + throw Exception("Model has no selected variant", *logger_); + } + return *selectedVariant_; + } + + gsl::span Model::GetAllModelVariants() const { + return variants_; + } + + const ModelVariant& Model::GetLatestVersion(const ModelVariant& variant) const { + const auto& targetName = variant.GetInfo().name; + + for (const auto& v : variants_) { + // The variants returned by the catalog are sorted by version, so the first match should always be the + // latest version. + if (v.GetInfo().name == targetName) { + return v; + } + } + + throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); + } + + const std::string& Model::GetId() const { + return SelectedVariant().GetId(); + } + + const std::string& Model::GetAlias() const { + return SelectedVariant().GetAlias(); + } + + void Model::SelectVariant(const ModelVariant& variant) const { + const auto& targetId = variant.GetId(); + auto it = std::find_if(variants_.begin(), variants_.end(), + [&](const ModelVariant& v) { return v.GetId() == targetId; }); + + if (it == variants_.end()) { + throw Exception("Model " + GetAlias() + " does not have a " + variant.GetId() + " variant.", *logger_); + } + + selectedVariant_ = &(*it); + } + + IModel::CoreAccess Model::GetCoreAccess() const { + return SelectedVariant().GetCoreAccess(); + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/openai_audio_client.cpp b/sdk/cpp/src/openai_audio_client.cpp new file mode 100644 index 00000000..d4409d1f --- /dev/null +++ b/sdk/cpp/src/openai_audio_client.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "core_helpers.h" +#include "logger.h" + +namespace foundry_local { + + OpenAIAudioClient::OpenAIAudioClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} + + AudioCreateTranscriptionResponse OpenAIAudioClient::TranscribeAudio( + const std::filesystem::path& audioFilePath) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + auto coreResponse = core_->call(req.Command(), *logger_, &json); + if (coreResponse.HasError()) { + throw Exception("Audio transcription failed: " + coreResponse.error, *logger_); + } + + AudioCreateTranscriptionResponse response; + response.text = std::move(coreResponse.data); + + return response; + } + + void OpenAIAudioClient::TranscribeAudioStreaming(const std::filesystem::path& audioFilePath, + const StreamCallback& onChunk) const { + nlohmann::json openAiReq = {{"Model", modelId_}, {"FileName", audioFilePath.string()}}; + CoreInteropRequest req("audio_transcribe"); + req.AddParam("OpenAICreateRequest", openAiReq.dump()); + + std::string json = req.ToJson(); + + detail::CallWithStreamingCallback( + core_, req.Command(), json, *logger_, + [&onChunk](const std::string& text) { + AudioCreateTranscriptionResponse chunk; + chunk.text = text; + onChunk(chunk); + }, + "Streaming audio transcription failed: "); + } + + OpenAIAudioClient::OpenAIAudioClient(const IModel& model) + : OpenAIAudioClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/openai_chat_client.cpp b/sdk/cpp/src/openai_chat_client.cpp new file mode 100644 index 00000000..5c19a0ba --- /dev/null +++ b/sdk/cpp/src/openai_chat_client.cpp @@ -0,0 +1,148 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include + +#include +#include + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" +#include "core_helpers.h" +#include "parser.h" +#include "logger.h" + +namespace foundry_local { + + std::string ChatCompletionCreateResponse::GetCreatedAtIso() const { + if (created == 0) + return {}; + std::time_t t = static_cast(created); + std::tm tm{}; +#ifdef _WIN32 + gmtime_s(&tm, &t); +#else + gmtime_r(&t, &tm); +#endif + char buf[32]; + std::strftime(buf, sizeof(buf), "%Y-%m-%dT%H:%M:%SZ", &tm); + return buf; + } + + OpenAIChatClient::OpenAIChatClient(gsl::not_null core, std::string_view modelId, + gsl::not_null logger) + : core_(core), modelId_(modelId), logger_(logger) {} + + std::string OpenAIChatClient::BuildChatRequestJson(gsl::span messages, + gsl::span tools, + const ChatSettings& settings, bool stream) const { + nlohmann::json jMessages = nlohmann::json::array(); + for (const auto& msg : messages) { + nlohmann::json jMsg = {{"role", msg.role}, {"content", msg.content}}; + if (msg.tool_call_id) + jMsg["tool_call_id"] = *msg.tool_call_id; + if (!msg.tool_calls.empty()) { + nlohmann::json jToolCalls = nlohmann::json::array(); + for (const auto& tc : msg.tool_calls) { + nlohmann::json jtc; + to_json(jtc, tc); + jToolCalls.push_back(std::move(jtc)); + } + jMsg["tool_calls"] = std::move(jToolCalls); + } + jMessages.push_back(std::move(jMsg)); + } + + nlohmann::json req = {{"model", modelId_}, {"messages", std::move(jMessages)}, {"stream", stream}}; + + if (!tools.empty()) { + nlohmann::json jTools = nlohmann::json::array(); + for (const auto& tool : tools) { + nlohmann::json jTool; + to_json(jTool, tool); + jTools.push_back(std::move(jTool)); + } + req["tools"] = std::move(jTools); + } + + if (settings.tool_choice) + req["tool_choice"] = ParsingUtils::tool_choice_to_string(*settings.tool_choice); + if (settings.top_k) + req["metadata"] = {{"top_k", *settings.top_k}}; + if (settings.frequency_penalty) + req["frequency_penalty"] = *settings.frequency_penalty; + if (settings.presence_penalty) + req["presence_penalty"] = *settings.presence_penalty; + if (settings.max_tokens) + req["max_completion_tokens"] = *settings.max_tokens; + if (settings.n) + req["n"] = *settings.n; + if (settings.temperature) + req["temperature"] = *settings.temperature; + if (settings.top_p) + req["top_p"] = *settings.top_p; + if (settings.random_seed) + req["seed"] = *settings.random_seed; + + return req.dump(); + } + + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + const ChatSettings& settings) const { + return CompleteChat(messages, {}, settings); + } + + ChatCompletionCreateResponse OpenAIChatClient::CompleteChat(gsl::span messages, + gsl::span tools, + const ChatSettings& settings) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/false); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + + std::string json = req.ToJson(); + auto response = core_->call(req.Command(), *logger_, &json); + if (response.HasError()) { + throw Exception("Chat completion failed: " + response.error, *logger_); + } + + return nlohmann::json::parse(response.data).get(); + } + + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, const ChatSettings& settings, + const StreamCallback& onChunk) const { + CompleteChatStreaming(messages, {}, settings, onChunk); + } + + void OpenAIChatClient::CompleteChatStreaming(gsl::span messages, + gsl::span tools, const ChatSettings& settings, + const StreamCallback& onChunk) const { + std::string openAiReqJson = BuildChatRequestJson(messages, tools, settings, /*stream=*/true); + + CoreInteropRequest req("chat_completions"); + req.AddParam("OpenAICreateRequest", openAiReqJson); + std::string json = req.ToJson(); + + detail::CallWithStreamingCallback( + core_, req.Command(), json, *logger_, + [&onChunk](const std::string& chunk) { + auto parsed = nlohmann::json::parse(chunk).get(); + onChunk(parsed); + }, + "Streaming chat completion failed: "); + } + + OpenAIChatClient::OpenAIChatClient(const IModel& model) + : OpenAIChatClient(model.GetCoreAccess().core, model.GetCoreAccess().modelName, model.GetCoreAccess().logger) { + if (!model.IsLoaded()) { + throw Exception("Model " + model.GetCoreAccess().modelName + " is not loaded. Call Load() first.", + *model.GetCoreAccess().logger); + } + } + +} // namespace foundry_local diff --git a/sdk/cpp/src/parser.h b/sdk/cpp/src/parser.h new file mode 100644 index 00000000..3596579c --- /dev/null +++ b/sdk/cpp/src/parser.h @@ -0,0 +1,312 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "foundry_local.h" +#include + +namespace foundry_local { + + class ParsingUtils { + public: + static DeviceType parse_device_type(std::string_view v) { + if (v == "CPU") { + return DeviceType::CPU; + } + if (v == "NPU") { + return DeviceType::NPU; + } + if (v == "GPU") { + return DeviceType::GPU; + } + return DeviceType::Invalid; + } + + static FinishReason parse_finish_reason(std::string_view v) { + if (v == "stop") + return FinishReason::Stop; + if (v == "length") + return FinishReason::Length; + if (v == "tool_calls") + return FinishReason::ToolCalls; + if (v == "content_filter") + return FinishReason::ContentFilter; + return FinishReason::None; + } + + static std::string get_string_or_empty(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + std::string out = ""; + if (it != j.end() && it->is_string()) { + out = it->get(); + } + return out; + } + + static std::optional get_opt_string(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_string()) { + return it->get(); + } + return std::nullopt; + } + + static std::optional get_opt_int(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } + return std::nullopt; + } + + static std::optional get_opt_i64(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_number_integer()) { + return it->get(); + } + return std::nullopt; + } + + static std::optional get_opt_bool(const nlohmann::json& j, const char* key) { + auto it = j.find(key); + if (it == j.end() || it->is_null()) { + return std::nullopt; + } + if (it->is_boolean()) { + return it->get(); + } + return std::nullopt; + } + + static std::string tool_choice_to_string(ToolChoiceKind kind) { + switch (kind) { + case ToolChoiceKind::Auto: + return "auto"; + case ToolChoiceKind::None: + return "none"; + case ToolChoiceKind::Required: + return "required"; + } + return "auto"; + } + }; + + // ---------- from_json / to_json (ADL overloads for nlohmann::json) ---------- + + inline void from_json(const nlohmann::json& j, Runtime& r) { + std::string deviceType; + j.at("deviceType").get_to(deviceType); + j.at("executionProvider").get_to(r.execution_provider); + + r.device_type = ParsingUtils::parse_device_type(std::move(deviceType)); + } + + inline void from_json(const nlohmann::json& j, PromptTemplate& p) { + p.system = ParsingUtils::get_string_or_empty(j, "system"); + p.user = ParsingUtils::get_string_or_empty(j, "user"); + p.assistant = ParsingUtils::get_string_or_empty(j, "assistant"); + p.prompt = ParsingUtils::get_string_or_empty(j, "prompt"); + } + + inline void from_json(const nlohmann::json& j, Parameter& p) { + j.at("name").get_to(p.name); + p.value = ParsingUtils::get_opt_string(j, "value"); + } + + inline void from_json(const nlohmann::json& j, ModelSettings& ms) { + ms.parameters.clear(); + if (auto it = j.find("parameters"); it != j.end() && it->is_array()) { + ms.parameters = it->get>(); + } + } + + inline void from_json(const nlohmann::json& j, ModelInfo& m) { + j.at("id").get_to(m.id); + j.at("name").get_to(m.name); + j.at("version").get_to(m.version); + j.at("alias").get_to(m.alias); + j.at("providerType").get_to(m.provider_type); + j.at("uri").get_to(m.uri); + j.at("modelType").get_to(m.model_type); + + m.display_name = ParsingUtils::get_opt_string(j, "displayName"); + m.publisher = ParsingUtils::get_opt_string(j, "publisher"); + m.license = ParsingUtils::get_opt_string(j, "license"); + m.license_description = ParsingUtils::get_opt_string(j, "licenseDescription"); + m.task = ParsingUtils::get_opt_string(j, "task"); + if (auto it = j.find("fileSizeMb"); it != j.end() && !it->is_null() && it->is_number_integer()) { + auto v = it->get(); + m.file_size_mb = (v >= 0) ? static_cast(v) : 0u; + } + m.supports_tool_calling = ParsingUtils::get_opt_bool(j, "supportsToolCalling"); + m.max_output_tokens = ParsingUtils::get_opt_i64(j, "maxOutputTokens"); + m.min_fl_version = ParsingUtils::get_opt_string(j, "minFLVersion"); + + if (auto it = j.find("cached"); it != j.end() && it->is_boolean()) { + m.cached = it->get(); + } + else { + m.cached = false; + } + + if (auto it = j.find("createdAt"); it != j.end() && it->is_number_integer()) { + m.created_at_unix = it->get(); + } + else { + m.created_at_unix = 0; + } + + // nested optional objects + if (auto it = j.find("modelSettings"); it != j.end() && it->is_object()) { + m.model_settings = it->get(); + } + else { + m.model_settings.reset(); + } + + if (auto it = j.find("promptTemplate"); it != j.end() && it->is_object()) { + m.prompt_template = it->get(); + } + else { + m.prompt_template.reset(); + } + + if (auto it = j.find("runtime"); it != j.end() && it->is_object()) { + m.runtime = it->get(); + } + else { + m.runtime.reset(); + } + } + + // ---------- Tool calling: to_json (serialization for requests) ---------- + + inline void to_json(nlohmann::json& j, const PropertyDefinition& pd) { + j = nlohmann::json{{"type", pd.type}}; + if (pd.description) + j["description"] = *pd.description; + if (pd.properties) { + nlohmann::json props = nlohmann::json::object(); + for (const auto& [key, val] : *pd.properties) { + nlohmann::json pj; + to_json(pj, val); + props[key] = std::move(pj); + } + j["properties"] = std::move(props); + } + if (pd.required) + j["required"] = *pd.required; + } + + inline void to_json(nlohmann::json& j, const FunctionDefinition& fd) { + j = nlohmann::json{{"name", fd.name}}; + if (fd.description) + j["description"] = *fd.description; + if (fd.parameters) { + nlohmann::json pj; + to_json(pj, *fd.parameters); + j["parameters"] = std::move(pj); + } + } + + inline void to_json(nlohmann::json& j, const ToolDefinition& td) { + j = nlohmann::json{{"type", td.type}}; + nlohmann::json fj; + to_json(fj, td.function); + j["function"] = std::move(fj); + } + + // ---------- Tool calling: to_json for response types (needed for multi-turn serialization) ---------- + + inline void to_json(nlohmann::json& j, const FunctionCall& fc) { + j = nlohmann::json{{"name", fc.name}, {"arguments", fc.arguments}}; + } + + inline void to_json(nlohmann::json& j, const ToolCall& tc) { + j = nlohmann::json{{"id", tc.id}, {"type", tc.type}}; + if (tc.function_call) { + nlohmann::json fj; + to_json(fj, *tc.function_call); + j["function"] = std::move(fj); + } + } + + // ---------- Tool calling: from_json (deserialization from responses) ---------- + + inline void from_json(const nlohmann::json& j, FunctionCall& fc) { + fc.name = ParsingUtils::get_string_or_empty(j, "name"); + if (j.contains("arguments")) { + const auto& args = j.at("arguments"); + if (args.is_string()) + fc.arguments = args.get(); + else + fc.arguments = args.dump(); + } + } + + inline void from_json(const nlohmann::json& j, ToolCall& tc) { + tc.id = ParsingUtils::get_string_or_empty(j, "id"); + tc.type = ParsingUtils::get_string_or_empty(j, "type"); + if (j.contains("function") && j.at("function").is_object()) + tc.function_call = j.at("function").get(); + } + + inline void from_json(const nlohmann::json& j, ChatMessage& m) { + if (j.contains("role")) + j.at("role").get_to(m.role); + if (j.contains("content") && !j.at("content").is_null()) + j.at("content").get_to(m.content); + + m.tool_call_id = ParsingUtils::get_opt_string(j, "tool_call_id"); + + m.tool_calls.clear(); + if (j.contains("tool_calls") && j.at("tool_calls").is_array()) { + for (const auto& tc : j.at("tool_calls")) { + if (tc.is_object()) + m.tool_calls.push_back(tc.get()); + } + } + } + + inline void from_json(const nlohmann::json& j, ChatChoice& c) { + if (j.contains("index")) + j.at("index").get_to(c.index); + if (j.contains("finish_reason") && !j.at("finish_reason").is_null()) + c.finish_reason = ParsingUtils::parse_finish_reason(j.at("finish_reason").get()); + + if (j.contains("message") && !j.at("message").is_null()) + c.message = j.at("message").get(); + + if (j.contains("delta") && !j.at("delta").is_null()) + c.delta = j.at("delta").get(); + } + + inline void from_json(const nlohmann::json& j, ChatCompletionCreateResponse& r) { + if (j.contains("created")) + j.at("created").get_to(r.created); + r.id = ParsingUtils::get_string_or_empty(j, "id"); + if (j.contains("IsDelta")) + j.at("IsDelta").get_to(r.is_delta); + if (j.contains("Successful")) + j.at("Successful").get_to(r.successful); + if (j.contains("HttpStatusCode")) + j.at("HttpStatusCode").get_to(r.http_status_code); + + r.choices.clear(); + if (j.contains("choices") && j.at("choices").is_array()) { + r.choices = j.at("choices").get>(); + } + } + +} // namespace foundry_local \ No newline at end of file diff --git a/sdk/cpp/test/catalog_test.cpp b/sdk/cpp/test/catalog_test.cpp new file mode 100644 index 00000000..3f60e0b4 --- /dev/null +++ b/sdk/cpp/test/catalog_test.cpp @@ -0,0 +1,353 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace foundry_local; +using namespace foundry_local::Testing; + +using Factory = MockObjectFactory; + +class CatalogTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeModelListJson(const std::vector>& models) { + nlohmann::json arr = nlohmann::json::array(); + for (const auto& [name, alias] : models) { + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson(name, alias))); + } + return arr.dump(); + } + + std::unique_ptr MakeCatalog() { + core_.OnCall("get_catalog_name", "test-catalog"); + return Factory::CreateCatalog(&core_, &logger_); + } +}; + +TEST_F(CatalogTest, GetName) { + auto catalog = MakeCatalog(); + EXPECT_EQ("test-catalog", catalog->GetName()); +} + +TEST_F(CatalogTest, Create_ThrowsOnCoreError) { + core_.OnCallThrow("get_catalog_name", "catalog error"); + EXPECT_THROW(MockObjectFactory::CreateCatalog(&core_, &logger_), Exception); +} + +TEST_F(CatalogTest, ListModels_Empty) { + core_.OnCall("get_model_list", "[]"); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + EXPECT_TRUE(models.empty()); +} + +TEST_F(CatalogTest, ListModels_SingleModel) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ("my-model", models[0]->GetAlias()); +} + +TEST_F(CatalogTest, ListModels_MultipleVariantsSameAlias) { + // Two variants of the same model (same alias, different names) + nlohmann::json arr = nlohmann::json::array(); + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson("model-v1", "my-model", 1))); + arr.push_back(nlohmann::json::parse(Factory::MakeModelInfoJson("model-v2", "my-model", 2))); + core_.OnCall("get_model_list", arr.dump()); + + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + + // Should be grouped into one Model + ASSERT_EQ(1u, models.size()); + EXPECT_EQ(2u, models[0]->GetAllModelVariants().size()); +} + +TEST_F(CatalogTest, ListModels_DifferentAliases) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "alias-a"}, {"model-b", "alias-b"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + EXPECT_EQ(2u, models.size()); +} + +TEST_F(CatalogTest, ListModels_IncludesOpenAIPrefix) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-a", "my-model"}, {"openai-model", "openai-stuff"}})); + auto catalog = MakeCatalog(); + auto models = catalog->ListModels(); + ASSERT_EQ(2u, models.size()); +} + +TEST_F(CatalogTest, GetModel_Found) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + auto* model = catalog->GetModel("my-model"); + ASSERT_NE(nullptr, model); + EXPECT_EQ("my-model", model->GetAlias()); +} + +TEST_F(CatalogTest, GetModel_NotFound) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + EXPECT_EQ(nullptr, catalog->GetModel("nonexistent")); +} + +TEST_F(CatalogTest, GetModelVariant_Found) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + auto* variant = catalog->GetModelVariant("model-1:1"); + ASSERT_NE(nullptr, variant); + EXPECT_EQ("model-1:1", variant->GetId()); +} + +TEST_F(CatalogTest, GetModelVariant_NotFound) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent:1")); +} + +TEST_F(CatalogTest, GetLoadedModels) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); + core_.OnCall("list_loaded_models", R"(["model-1:1"])"); + + auto catalog = MakeCatalog(); + + auto loaded = catalog->GetLoadedModels(); + ASSERT_EQ(1u, loaded.size()); + EXPECT_EQ("model-1:1", loaded[0]->GetId()); +} + +TEST_F(CatalogTest, GetCachedModels) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "alias-1"}, {"model-2", "alias-2"}})); + core_.OnCall("get_cached_models", R"(["model-1:1", "model-2:1"])"); + + auto catalog = MakeCatalog(); + + auto cached = catalog->GetCachedModels(); + EXPECT_EQ(2u, cached.size()); +} + +TEST_F(CatalogTest, ListModels_CachesResults) { + core_.OnCall("get_model_list", MakeModelListJson({{"model-1", "my-model"}})); + auto catalog = MakeCatalog(); + + catalog->ListModels(); + catalog->ListModels(); + + // Should only call get_model_list once due to caching + EXPECT_EQ(1, core_.GetCallCount("get_model_list")); +} + +class FileBasedCatalogTest : public ::testing::Test { +protected: + NullLogger logger_; + + static std::string TestDataPath(const std::string& filename) { return "testdata/" + filename; } +}; + +TEST_F(FileBasedCatalogTest, RealModelsList) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(2u, models.size()); + + int phi_models = 0, mistral_models = 0; + size_t phi_variants = 0, mistral_variants = 0; + + for (const auto* model : models) { + if (model->GetAlias() == "phi-4") { + phi_models++; + phi_variants = model->GetAllModelVariants().size(); + } + else if (model->GetAlias() == "mistral-7b-v0.2") { + mistral_models++; + mistral_variants = model->GetAllModelVariants().size(); + } + } + + EXPECT_EQ(1, phi_models); + EXPECT_EQ(1, mistral_models); + EXPECT_EQ(2u, phi_variants); + EXPECT_EQ(2u, mistral_variants); +} + +TEST_F(FileBasedCatalogTest, RealModelsList_VariantDetails) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + const auto* gpuVariant = catalog->GetModelVariant("Phi-4-generic-gpu:1"); + ASSERT_NE(nullptr, gpuVariant); + + const auto& info = gpuVariant->GetInfo(); + EXPECT_EQ("Phi-4-generic-gpu:1", info.id); + EXPECT_EQ("Phi-4-generic-gpu", info.name); + EXPECT_EQ("phi-4", info.alias); + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("Phi-4 (GPU)", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("Microsoft", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("MIT", *info.license); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::GPU, info.runtime->device_type); + EXPECT_EQ("DML", info.runtime->execution_provider); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(8192u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(4096, *info.max_output_tokens); + ASSERT_TRUE(info.prompt_template.has_value()); + EXPECT_EQ("<|system|>", info.prompt_template->system); + EXPECT_EQ("<|user|>", info.prompt_template->user); + EXPECT_EQ("<|assistant|>", info.prompt_template->assistant); + EXPECT_EQ("<|prompt|>", info.prompt_template->prompt); +} + +TEST_F(FileBasedCatalogTest, RealModelsList_CpuVariantDetails) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + const auto* cpuVariant = catalog->GetModelVariant("Phi-4-generic-cpu:1"); + ASSERT_NE(nullptr, cpuVariant); + + const auto& info = cpuVariant->GetInfo(); + EXPECT_EQ("Phi-4-generic-cpu", info.name); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::CPU, info.runtime->device_type); + EXPECT_EQ("ORT", info.runtime->execution_provider); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(4096u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_FALSE(*info.supports_tool_calling); + EXPECT_FALSE(info.prompt_template.has_value()); +} + +TEST_F(FileBasedCatalogTest, EmptyModelsList) { + auto core = FileBackedCore::FromModelList(TestDataPath("empty_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + EXPECT_TRUE(models.empty()); +} + +TEST_F(FileBasedCatalogTest, MalformedJson) { + auto core = FileBackedCore::FromModelList(TestDataPath("malformed_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_ANY_THROW(catalog->ListModels()); +} + +TEST_F(FileBasedCatalogTest, MissingNameField) { + auto core = FileBackedCore::FromModelList(TestDataPath("missing_name_field_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + try { + catalog->ListModels(); + FAIL() << "Expected exception for missing 'name' field"; + } + catch (const std::exception& e) { + std::string msg = e.what(); + EXPECT_NE(std::string::npos, msg.find("name")) << "Actual: " << msg; + } +} + +TEST_F(FileBasedCatalogTest, CachedModels) { + auto core = + FileBackedCore::FromBoth(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto cached = catalog->GetCachedModels(); + ASSERT_EQ(2u, cached.size()); + + std::vector names; + names.reserve(cached.size()); + for (const auto* mv : cached) + names.push_back(mv->GetInfo().name); + + EXPECT_NE(std::find(names.begin(), names.end(), "Phi-4-generic-gpu"), names.end()); + EXPECT_NE(std::find(names.begin(), names.end(), "Phi-4-generic-cpu"), names.end()); +} + +TEST_F(FileBasedCatalogTest, CoreErrorOnModelList) { + auto core = FileBackedCore::FromModelList("testdata/nonexistent_file.json"); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_ANY_THROW(catalog->ListModels()); +} + +TEST_F(FileBasedCatalogTest, MixedOpenAIAndLocal_IncludesAll) { + auto core = FileBackedCore::FromModelList(TestDataPath("mixed_openai_and_local.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(3u, models.size()); +} + +TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel) { + auto core = FileBackedCore::FromModelList(TestDataPath("three_variants_one_model.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto models = catalog->ListModels(); + ASSERT_EQ(1u, models.size()); + EXPECT_EQ(3u, models[0]->GetAllModelVariants().size()); +} + +TEST_F(FileBasedCatalogTest, ThreeVariantsOneModel_CachedSubset) { + auto core = FileBackedCore::FromBoth(TestDataPath("three_variants_one_model.json"), + TestDataPath("single_cached_model.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto cached = catalog->GetCachedModels(); + ASSERT_EQ(1u, cached.size()); + EXPECT_EQ("multi-v1-cpu", cached[0]->GetInfo().name); +} + +TEST_F(FileBasedCatalogTest, GetModelByAlias) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + const auto* model = catalog->GetModel("phi-4"); + ASSERT_NE(nullptr, model); + EXPECT_EQ("phi-4", model->GetAlias()); + EXPECT_EQ(2u, model->GetAllModelVariants().size()); + + const auto* missing = catalog->GetModel("nonexistent-alias"); + EXPECT_EQ(nullptr, missing); +} + +TEST_F(FileBasedCatalogTest, GetModelVariant_NotInCatalog) { + auto core = FileBackedCore::FromModelList(TestDataPath("real_models_list.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + EXPECT_EQ(nullptr, catalog->GetModelVariant("nonexistent-variant-id")); +} + +TEST_F(FileBasedCatalogTest, LoadedModels) { + auto core = FileBackedCore::FromAll(TestDataPath("real_models_list.json"), TestDataPath("valid_cached_models.json"), + TestDataPath("valid_loaded_models.json")); + auto catalog = Factory::CreateCatalog(&core, &logger_); + + auto loaded = catalog->GetLoadedModels(); + ASSERT_EQ(1u, loaded.size()); + EXPECT_EQ("Phi-4-generic-gpu", loaded[0]->GetInfo().name); +} diff --git a/sdk/cpp/test/client_test.cpp b/sdk/cpp/test/client_test.cpp new file mode 100644 index 00000000..9864ec99 --- /dev/null +++ b/sdk/cpp/test/client_test.cpp @@ -0,0 +1,745 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace foundry_local; +using namespace foundry_local::Testing; + +using Factory = MockObjectFactory; + +class OpenAIChatClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + std::string MakeChatResponseJson(const std::string& content = "Hello!") { + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-test"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", content}}}}}}}; + return resp.dump(); + } + + ModelVariant MakeLoadedVariant(const std::string& name = "chat-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(OpenAIChatClientTest, CompleteChat_BasicResponse) { + core_.OnCall("chat_completions", MakeChatResponseJson("Hello world!")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Say hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + EXPECT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("Hello world!", response.choices[0].message->content); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_WithSettings) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.temperature = 0.7f; + settings.max_tokens = 100; + settings.top_p = 0.9f; + settings.frequency_penalty = 0.5f; + settings.presence_penalty = 0.3f; + settings.n = 2; + settings.random_seed = 42; + settings.top_k = 10; + + auto response = client.CompleteChat(messages, settings); + + // Verify the request JSON contains the settings + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_NEAR(0.7f, openAiReq["temperature"].get(), 0.001f); + EXPECT_EQ(100, openAiReq["max_completion_tokens"].get()); + EXPECT_NEAR(0.9f, openAiReq["top_p"].get(), 0.001f); + EXPECT_NEAR(0.5f, openAiReq["frequency_penalty"].get(), 0.001f); + EXPECT_NEAR(0.3f, openAiReq["presence_penalty"].get(), 0.001f); + EXPECT_EQ(2, openAiReq["n"].get()); + EXPECT_EQ(42, openAiReq["seed"].get()); + EXPECT_EQ(10, openAiReq["metadata"]["top_k"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_RequestFormat) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"system", "You are helpful", {}}, {"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_EQ("chat-model", openAiReq["model"].get()); + EXPECT_FALSE(openAiReq["stream"].get()); + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_EQ("system", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hello"}}}}}}}; + nlohmann::json chunk2 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", {{{"index", 0}, {"finish_reason", "stop"}, {"delta", {{"content", " world"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + callback(s1.data(), static_cast(s1.size()), userData); + callback(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + std::vector chunks; + client.CompleteChatStreaming(messages, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_TRUE(chunks[0].is_delta); + ASSERT_TRUE(chunks[0].choices[0].delta.has_value()); + EXPECT_EQ("Hello", chunks[0].choices[0].delta->content); + EXPECT_EQ(" world", chunks[1].choices[0].delta->content); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_PropagatesCallbackException) { + nlohmann::json chunk = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string s = chunk.dump(); + callback(s.data(), static_cast(s.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChatStreaming( + messages, settings, + [](const ChatCompletionCreateResponse&) { throw std::runtime_error("callback error"); }), + std::runtime_error); +} + +TEST_F(OpenAIChatClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(OpenAIChatClient client(variant), Exception); +} + +TEST_F(OpenAIChatClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + EXPECT_EQ("chat-model", client.GetModelId()); +} + +// ---------- Tool calling tests ---------- + +TEST_F(OpenAIChatClientTest, CompleteChat_WithTools_IncludesToolsInRequest) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + + std::vector tools = { + {"function", + FunctionDefinition{"multiply_numbers", "A tool for multiplying two numbers.", + PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}}}, + std::vector{"first", "second"}}}}}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + // Verify the request JSON contains tools and tool_choice + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_TRUE(openAiReq.contains("tools")); + ASSERT_TRUE(openAiReq["tools"].is_array()); + EXPECT_EQ(1u, openAiReq["tools"].size()); + EXPECT_EQ("function", openAiReq["tools"][0]["type"].get()); + EXPECT_EQ("multiply_numbers", openAiReq["tools"][0]["function"]["name"].get()); + EXPECT_EQ("A tool for multiplying two numbers.", + openAiReq["tools"][0]["function"]["description"].get()); + EXPECT_EQ("object", openAiReq["tools"][0]["function"]["parameters"]["type"].get()); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"].contains("properties")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("first")); + EXPECT_TRUE(openAiReq["tools"][0]["function"]["parameters"]["properties"].contains("second")); + + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_WithoutTools_OmitsToolsField) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + EXPECT_FALSE(openAiReq.contains("tools")); + EXPECT_FALSE(openAiReq.contains("tool_choice")); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallResponse_Parsed) { + // Simulate a response with tool calls from the model + nlohmann::json resp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": " + "6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", resp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_TRUE(response.choices[0].message.has_value()); + + const auto& msg = *response.choices[0].message; + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_1", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("multiply_numbers", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"first\": 7, \"second\": 6}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceAuto) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Auto; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolChoiceNone) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::None; + + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("none", openAiReq["tool_choice"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolMessageWithToolCallId) { + core_.OnCall("chat_completions", MakeChatResponseJson()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "42"; + toolMsg.tool_call_id = "call_1"; + + std::vector messages = {{"user", "What is 7 * 6?", {}}, std::move(toolMsg)}; + ChatSettings settings; + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_EQ(2u, openAiReq["messages"].size()); + EXPECT_FALSE(openAiReq["messages"][0].contains("tool_call_id")); + EXPECT_EQ("call_1", openAiReq["messages"][1]["tool_call_id"].get()); + EXPECT_EQ("tool", openAiReq["messages"][1]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_AssistantToolCallsSerialized) { + // Multi-turn tool calling: the assistant message with tool_calls must be sent back + // alongside the tool result message for the model to match the tool response. + core_.OnCall("chat_completions", MakeChatResponseJson("The answer is 42.")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + ChatMessage assistantMsg; + assistantMsg.role = "assistant"; + assistantMsg.content = ""; + assistantMsg.tool_calls = { + {"call_1", "function", FunctionCall{"multiply_numbers", "{\"first\": 7, \"second\": 6}"}}}; + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "42"; + toolMsg.tool_call_id = "call_1"; + + std::vector messages = { + {"user", "What is 7 * 6?", {}}, std::move(assistantMsg), std::move(toolMsg)}; + ChatSettings settings; + client.CompleteChat(messages, settings); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + ASSERT_EQ(3u, openAiReq["messages"].size()); + + // User message: no tool_calls + EXPECT_FALSE(openAiReq["messages"][0].contains("tool_calls")); + + // Assistant message: must include tool_calls + const auto& assistantJson = openAiReq["messages"][1]; + EXPECT_EQ("assistant", assistantJson["role"].get()); + ASSERT_TRUE(assistantJson.contains("tool_calls")); + ASSERT_TRUE(assistantJson["tool_calls"].is_array()); + ASSERT_EQ(1u, assistantJson["tool_calls"].size()); + EXPECT_EQ("call_1", assistantJson["tool_calls"][0]["id"].get()); + EXPECT_EQ("function", assistantJson["tool_calls"][0]["type"].get()); + EXPECT_EQ("multiply_numbers", assistantJson["tool_calls"][0]["function"]["name"].get()); + EXPECT_EQ("{\"first\": 7, \"second\": 6}", + assistantJson["tool_calls"][0]["function"]["arguments"].get()); + + // Tool message: must include tool_call_id + const auto& toolJson = openAiReq["messages"][2]; + EXPECT_EQ("tool", toolJson["role"].get()); + EXPECT_EQ("call_1", toolJson["tool_call_id"].get()); + EXPECT_FALSE(toolJson.contains("tool_calls")); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_WithTools) { + nlohmann::json chunk1 = { + {"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", ""}}}}}}}; + nlohmann::json chunk2 = {{"created", 1700000000}, + {"id", "chatcmpl-1"}, + {"IsDelta", true}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"delta", + {{"content", ""}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply"}, {"arguments", "{\"a\":1}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", + [&](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string s1 = chunk1.dump(); + std::string s2 = chunk2.dump(); + callback(s1.data(), static_cast(s1.size()), userData); + callback(s2.data(), static_cast(s2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "test", {}}}; + + std::vector tools = {{"function", FunctionDefinition{"multiply", "Multiply numbers."}}}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + std::vector chunks; + client.CompleteChatStreaming(messages, tools, settings, + [&](const ChatCompletionCreateResponse& chunk) { chunks.push_back(chunk); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ(FinishReason::ToolCalls, chunks[1].choices[0].finish_reason); + ASSERT_TRUE(chunks[1].choices[0].delta.has_value()); + ASSERT_EQ(1u, chunks[1].choices[0].delta->tool_calls.size()); + EXPECT_EQ("multiply", chunks[1].choices[0].delta->tool_calls[0].function_call->name); + + // Verify tools were included in the request + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_TRUE(openAiReq.contains("tools")); + EXPECT_EQ("required", openAiReq["tool_choice"].get()); +} + +class OpenAIAudioClientTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeLoadedVariant(const std::string& name = "audio-model") { + core_.OnCall("list_loaded_models", "[\"" + name + ":1\"]"); + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, "alias"), &logger_); + } +}; + +TEST_F(OpenAIAudioClientTest, TranscribeAudio) { + core_.OnCall("audio_transcribe", "Hello world transcribed text"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + auto response = client.TranscribeAudio("test.wav"); + + EXPECT_EQ("Hello world transcribed text", response.text); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudio_RequestFormat) { + core_.OnCall("audio_transcribe", "text"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + client.TranscribeAudio("audio.wav"); + + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("audio_transcribe")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + EXPECT_EQ("audio-model", openAiReq["Model"].get()); + EXPECT_EQ("audio.wav", openAiReq["FileName"].get()); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string text1 = "Hello "; + std::string text2 = "world!"; + callback(text1.data(), static_cast(text1.size()), userData); + callback(text2.data(), static_cast(text2.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + std::vector chunks; + client.TranscribeAudioStreaming( + "test.wav", [&](const AudioCreateTranscriptionResponse& chunk) { chunks.push_back(chunk.text); }); + + ASSERT_EQ(2u, chunks.size()); + EXPECT_EQ("Hello ", chunks[0]); + EXPECT_EQ("world!", chunks[1]); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_PropagatesCallbackException) { + core_.OnCall("audio_transcribe", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + if (callback && userData) { + std::string text = "test"; + callback(text.data(), static_cast(text.size()), userData); + } + return ""; + }); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW( + client.TranscribeAudioStreaming( + "test.wav", [](const AudioCreateTranscriptionResponse&) { throw std::runtime_error("streaming error"); }), + std::runtime_error); +} + +TEST_F(OpenAIAudioClientTest, Constructor_ThrowsIfNotLoaded) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = Factory::CreateModelVariant(&core_, Factory::MakeModelInfo("unloaded-model", "alias"), &logger_); + EXPECT_THROW(OpenAIAudioClient client(variant), FoundryLocalException); +} + +TEST_F(OpenAIAudioClientTest, GetModelId) { + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + EXPECT_EQ("audio-model", client.GetModelId()); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudio_CoreError_Throws) { + core_.OnCallThrow("audio_transcribe", "transcription failed"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW(client.TranscribeAudio("test.wav"), Exception); +} + +TEST_F(OpenAIAudioClientTest, TranscribeAudioStreaming_CoreError_Throws) { + core_.OnCallThrow("audio_transcribe", "streaming transcription failed"); + core_.OnCall("list_loaded_models", R"(["audio-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIAudioClient client(variant); + + EXPECT_THROW(client.TranscribeAudioStreaming("test.wav", [](const AudioCreateTranscriptionResponse&) {}), + Exception); +} + +// ===================================================================== +// Multi-turn conversation tests +// ===================================================================== + +TEST_F(OpenAIChatClientTest, CompleteChat_MultiTurn) { + // First turn: user asks a question + core_.OnCall("chat_completions", MakeChatResponseJson("42")); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "What is 7 * 6?", {}}}; + ChatSettings settings; + auto response = client.CompleteChat(messages, settings); + + ASSERT_TRUE(response.successful); + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ("42", response.choices[0].message->content); + + // Second turn: add assistant response + user follow-up + messages.push_back({"assistant", response.choices[0].message->content, {}}); + messages.push_back({"user", "Is that a real number?", {}}); + + core_.OnCall("chat_completions", MakeChatResponseJson("Yes")); + auto response2 = client.CompleteChat(messages, settings); + + ASSERT_TRUE(response2.successful); + EXPECT_EQ("Yes", response2.choices[0].message->content); + + // Verify the second request contained all 3 messages + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + ASSERT_EQ(3u, openAiReq["messages"].size()); + EXPECT_EQ("user", openAiReq["messages"][0]["role"].get()); + EXPECT_EQ("assistant", openAiReq["messages"][1]["role"].get()); + EXPECT_EQ("user", openAiReq["messages"][2]["role"].get()); +} + +TEST_F(OpenAIChatClientTest, CompleteChat_CoreError_Throws) { + core_.OnCallThrow("chat_completions", "inference failed"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChat(messages, settings), Exception); +} + +TEST_F(OpenAIChatClientTest, CompleteChatStreaming_CoreError_Throws) { + core_.OnCallThrow("chat_completions", "streaming failed"); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"user", "Hello", {}}}; + ChatSettings settings; + + EXPECT_THROW(client.CompleteChatStreaming(messages, settings, [](const ChatCompletionCreateResponse&) {}), + Exception); +} + +// ===================================================================== +// Full tool-call round-trip +// ===================================================================== + +TEST_F(OpenAIChatClientTest, CompleteChat_ToolCallRoundTrip) { + // Step 1: model returns a tool call + nlohmann::json toolCallResp = { + {"created", 1700000000}, + {"id", "chatcmpl-tool"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, + {"finish_reason", "tool_calls"}, + {"message", + {{"role", "assistant"}, + {"content", "[{\"name\": \"multiply_numbers\", \"parameters\": {\"first\": 7, \"second\": " + "6}}]"}, + {"tool_calls", + {{{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "multiply_numbers"}, {"arguments", "{\"first\": 7, \"second\": 6}"}}}}}}}}}}}}; + + core_.OnCall("chat_completions", toolCallResp.dump()); + core_.OnCall("list_loaded_models", R"(["chat-model:1"])"); + + auto variant = MakeLoadedVariant(); + OpenAIChatClient client(variant); + + std::vector messages = {{"system", "You are a helpful AI assistant.", {}}, + {"user", "What is 7 multiplied by 6?", {}}}; + + std::vector tools = { + {"function", + FunctionDefinition{"multiply_numbers", "A tool for multiplying two numbers.", + PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"first", PropertyDefinition{"integer", "The first number"}}, + {"second", PropertyDefinition{"integer", "The second number"}}}, + std::vector{"first", "second"}}}}}; + + ChatSettings settings; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + + ASSERT_EQ(1u, response.choices.size()); + EXPECT_EQ(FinishReason::ToolCalls, response.choices[0].finish_reason); + ASSERT_EQ(1u, response.choices[0].message->tool_calls.size()); + EXPECT_EQ("multiply_numbers", response.choices[0].message->tool_calls[0].function_call->name); + + // Step 2: send tool response back, model continues with the answer + messages.push_back({"assistant", response.choices[0].message->content, {}}); + + ChatMessage toolMsg; + toolMsg.role = "tool"; + toolMsg.content = "7 x 6 = 42."; + toolMsg.tool_call_id = "call_1"; + messages.push_back(std::move(toolMsg)); + + messages.push_back({"system", "Respond only with the answer generated by the tool.", {}}); + + core_.OnCall("chat_completions", MakeChatResponseJson("42")); + settings.tool_choice = ToolChoiceKind::Auto; + + auto response2 = client.CompleteChat(messages, tools, settings); + + ASSERT_TRUE(response2.successful); + EXPECT_EQ("42", response2.choices[0].message->content); + + // Verify the second request contained tool response message + auto requestJson = nlohmann::json::parse(core_.GetLastDataArg("chat_completions")); + auto openAiReq = nlohmann::json::parse(requestJson["Params"]["OpenAICreateRequest"].get()); + + // 5 messages: system, user, assistant (tool_call), tool, system (continue) + ASSERT_EQ(5u, openAiReq["messages"].size()); + EXPECT_EQ("tool", openAiReq["messages"][3]["role"].get()); + EXPECT_EQ("call_1", openAiReq["messages"][3]["tool_call_id"].get()); + EXPECT_EQ("auto", openAiReq["tool_choice"].get()); +} diff --git a/sdk/cpp/test/e2e_test.cpp b/sdk/cpp/test/e2e_test.cpp new file mode 100644 index 00000000..4bf33348 --- /dev/null +++ b/sdk/cpp/test/e2e_test.cpp @@ -0,0 +1,561 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// End-to-end tests that exercise the public API with the real Core DLL. +// Tests marked DISABLED_ are skipped in CI (no Core DLL / no network). +// Run locally with: --gtest_also_run_disabled_tests + +#include + +#include "foundry_local.h" + +#include +#include +#include +#include +#include + +using namespace foundry_local; + +// --------------------------------------------------------------------------- +// Helper: detect CI environment (mirrors C# SkipInCI logic) +// --------------------------------------------------------------------------- +static bool IsRunningInCI() { + auto check = [](const char* var) -> bool { + const char* val = std::getenv(var); + if (!val) + return false; + std::string s(val); + for (auto& c : s) + c = static_cast(std::tolower(static_cast(c))); + return s == "true" || s == "1"; + }; + return check("TF_BUILD") || check("GITHUB_ACTIONS") || check("CI"); +} + +// --------------------------------------------------------------------------- +// Fixture: creates a real FoundryLocalManager with the Core DLL. +// All tests in this fixture require the native DLLs next to the test binary. +// --------------------------------------------------------------------------- +class EndToEndTest : public ::testing::Test { +protected: + static void SetUpTestSuite() { + Configuration config("CppSdkE2ETest"); + config.log_level = LogLevel::Information; + try { + FoundryLocalManager::Create(std::move(config)); + } + catch (const std::exception& ex) { + std::cerr << "[E2E] Failed to create FoundryLocalManager: " << ex.what() << "\n"; + GTEST_SKIP() << "Core DLL not available: " << ex.what(); + } + } + + static void TearDownTestSuite() { FoundryLocalManager::Destroy(); } + + void SetUp() override { + if (!FoundryLocalManager::IsInitialized()) { + GTEST_SKIP() << "FoundryLocalManager not available (Core DLL missing?)"; + } + } + + static bool IsAudioModel(const std::string& alias) { return alias.find("whisper") != std::string::npos; } + + /// Find a chat-capable model, preferring cached, then known small models, then any. + /// Selects the CPU variant when available to avoid GPU/EP dependency issues. + static Model* FindChatModel(Catalog& catalog) { + Model* target = nullptr; + + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + if (!IsAudioModel(variant->GetAlias())) { + target = catalog.GetModel(variant->GetAlias()); + if (target) + break; + } + } + + if (!target) { + for (const auto& alias : {"qwen2.5-0.5b", "qwen2.5-coder-0.5b", "phi-4-mini"}) { + target = catalog.GetModel(alias); + if (target) + break; + } + } + + if (!target) { + auto models = catalog.ListModels(); + for (auto* model : models) { + if (!IsAudioModel(model->GetAlias())) { + target = model; + break; + } + } + } + + if (target) { + for (const auto& variant : target->GetAllModelVariants()) { + if (variant.GetInfo().runtime.has_value() && + variant.GetInfo().runtime->device_type == DeviceType::CPU) { + target->SelectVariant(variant); + break; + } + } + } + + return target; + } + + /// Find an audio model, preferring cached. + static Model* FindAudioModel(Catalog& catalog) { + Model* target = nullptr; + + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + if (IsAudioModel(variant->GetAlias())) { + target = catalog.GetModel(variant->GetAlias()); + if (target) + break; + } + } + + if (!target) { + for (const auto& alias : {"whisper-small", "whisper-tiny"}) { + target = catalog.GetModel(alias); + if (target) + break; + } + } + + return target; + } +}; + +// =========================================================================== +// Catalog tests (no model download required) +// =========================================================================== + +TEST_F(EndToEndTest, BrowseCatalog_ListsModels) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + EXPECT_FALSE(catalog.GetName().empty()); + + auto models = catalog.ListModels(); + EXPECT_GT(models.size(), 0u) << "Catalog should have at least one model"; + + for (const auto* model : models) { + EXPECT_FALSE(model->GetAlias().empty()); + EXPECT_FALSE(model->GetAllModelVariants().empty()); + + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + EXPECT_FALSE(info.id.empty()); + EXPECT_FALSE(info.name.empty()); + EXPECT_FALSE(info.alias.empty()); + EXPECT_FALSE(info.provider_type.empty()); + EXPECT_FALSE(info.model_type.empty()); + } + } +} + +TEST_F(EndToEndTest, GetCachedModels_Succeeds) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto cached = catalog.GetCachedModels(); + for (auto* variant : cached) { + EXPECT_FALSE(variant->GetId().empty()); + EXPECT_TRUE(variant->IsCached()); + } +} + +TEST_F(EndToEndTest, GetLoadedModels_Succeeds) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto loaded = catalog.GetLoadedModels(); + for (auto* variant : loaded) { + EXPECT_FALSE(variant->GetId().empty()); + EXPECT_TRUE(variant->IsLoaded()); + } +} + +TEST_F(EndToEndTest, GetModel_NotFound_ReturnsNull) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* model = catalog.GetModel("this-model-does-not-exist-12345"); + EXPECT_EQ(model, nullptr); +} + +TEST_F(EndToEndTest, GetModelVariant_NotFound_ReturnsNull) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* variant = catalog.GetModelVariant("nonexistent-model:999"); + EXPECT_EQ(variant, nullptr); +} + +TEST_F(EndToEndTest, GetModelVariant_Found) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto models = catalog.ListModels(); + if (models.empty()) { + GTEST_SKIP() << "No models in catalog"; + } + + const auto& firstVariant = models[0]->GetAllModelVariants()[0]; + auto* found = catalog.GetModelVariant(firstVariant.GetId()); + ASSERT_NE(nullptr, found); + EXPECT_EQ(firstVariant.GetId(), found->GetId()); +} + +TEST_F(EndToEndTest, ModelVariantInfo_HasRequiredFields) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto models = catalog.ListModels(); + if (models.empty()) { + GTEST_SKIP() << "No models in catalog"; + } + + for (const auto* model : models) { + for (const auto& variant : model->GetAllModelVariants()) { + const auto& info = variant.GetInfo(); + EXPECT_FALSE(info.id.empty()); + EXPECT_FALSE(info.name.empty()); + EXPECT_GT(info.version, 0u); + EXPECT_FALSE(info.alias.empty()); + EXPECT_FALSE(info.uri.empty()); + } + } +} + +TEST_F(EndToEndTest, ModelVariant_SelectVariant) { + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto models = catalog.ListModels(); + + // Find a model with multiple variants + Model* multiVariantModel = nullptr; + for (auto* model : models) { + if (model->GetAllModelVariants().size() > 1) { + multiVariantModel = model; + break; + } + } + + if (!multiVariantModel) { + GTEST_SKIP() << "No model with multiple variants found"; + } + + const auto& variants = multiVariantModel->GetAllModelVariants(); + const auto& secondVariant = variants[1]; + multiVariantModel->SelectVariant(secondVariant); + EXPECT_EQ(secondVariant.GetId(), multiVariantModel->GetId()); + + // Select back the first variant + multiVariantModel->SelectVariant(variants[0]); + EXPECT_EQ(variants[0].GetId(), multiVariantModel->GetId()); +} + +// =========================================================================== +// EnsureEpsDownloaded (no model download, but may download EPs) +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_EnsureEpsDownloaded_Succeeds) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (may require network)"; + } + + EXPECT_NO_THROW(FoundryLocalManager::Instance().EnsureEpsDownloaded()); +} + +// =========================================================================== +// Web service tests +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_WebService_StartAndStop) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI"; + } + + auto& manager = FoundryLocalManager::Instance(); + + // GetUrls should be empty before starting + EXPECT_TRUE(manager.GetUrls().empty()); + + // StartWebService without web config should throw + // Note: the manager was created without web config, so this verifies the guard. + EXPECT_THROW(manager.StartWebService(), Exception); +} + +// =========================================================================== +// Download, load, chat (non-streaming), unload +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_DownloadLoadChatUnload) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + std::cout << "[E2E] Using model: " << target->GetAlias() << " variant: " << target->GetId() << "\n"; + + // Download (no-op if already cached) + bool progressCallbackInvoked = false; + target->Download([&](float pct) { + progressCallbackInvoked = true; + std::cout << "\r[E2E] Download: " << pct << "% " << std::flush; + }); + std::cout << "\n"; + + EXPECT_TRUE(target->IsCached()); + + // Load + target->Load(); + EXPECT_TRUE(target->IsLoaded()); + + // Verify it appears in loaded models + auto loaded = catalog.GetLoadedModels(); + bool foundInLoaded = false; + for (auto* v : loaded) { + if (v->GetId() == target->GetId()) { + foundInLoaded = true; + break; + } + } + EXPECT_TRUE(foundInLoaded) << "Model should appear in GetLoadedModels() after Load()"; + + // Chat (non-streaming) + OpenAIChatClient client(*target); + + std::vector messages = {{"user", "Say hello in one word.", {}}}; + ChatSettings settings; + settings.max_tokens = 32; + auto response = client.CompleteChat(messages, settings); + EXPECT_TRUE(response.successful); + ASSERT_FALSE(response.choices.empty()); + ASSERT_TRUE(response.choices[0].message.has_value()); + EXPECT_FALSE(response.choices[0].message->content.empty()); + EXPECT_EQ(FinishReason::Stop, response.choices[0].finish_reason); + std::cout << "[E2E] Response: " << response.choices[0].message->content << "\n"; + + // Unload + target->Unload(); + EXPECT_FALSE(target->IsLoaded()); +} + +// =========================================================================== +// Streaming chat +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_StreamingChat) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + std::cout << "[E2E] Streaming with model: " << target->GetAlias() << "\n"; + + OpenAIChatClient client(*target); + + std::vector messages = {{"user", "Count from 1 to 5.", {}}}; + ChatSettings settings; + settings.max_tokens = 64; + settings.temperature = 0.0f; + + std::vector chunks; + std::string fullContent; + client.CompleteChatStreaming(messages, settings, [&](const ChatCompletionCreateResponse& chunk) { + chunks.push_back(chunk); + if (!chunk.choices.empty() && chunk.choices[0].delta.has_value() && !chunk.choices[0].delta->content.empty()) { + fullContent += chunk.choices[0].delta->content; + } + }); + + EXPECT_GT(chunks.size(), 0u) << "Should have received at least one streaming chunk"; + EXPECT_FALSE(fullContent.empty()) << "Accumulated streaming content should not be empty"; + std::cout << "[E2E] Streaming response: " << fullContent << "\n"; + + // Last chunk should have a stop finish reason + ASSERT_FALSE(chunks.empty()); + const auto& lastChunk = chunks.back(); + if (!lastChunk.choices.empty()) { + EXPECT_EQ(FinishReason::Stop, lastChunk.choices[0].finish_reason); + } + + target->Unload(); +} + +// =========================================================================== +// Chat with tool calling +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_ChatWithToolCalling) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + // Check if the selected variant supports tool calling + bool supportsCalling = false; + for (const auto& v : target->GetAllModelVariants()) { + if (v.GetInfo().supports_tool_calling.has_value() && *v.GetInfo().supports_tool_calling) { + supportsCalling = true; + break; + } + } + if (!supportsCalling) { + GTEST_SKIP() << "Model does not support tool calling"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + std::cout << "[E2E] Tool calling with model: " << target->GetAlias() << "\n"; + + OpenAIChatClient client(*target); + + std::vector tools = { + {"function", FunctionDefinition{"get_weather", "Get the current weather for a city.", + PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"city", PropertyDefinition{"string", "The city name"}}}, + std::vector{"city"}}}}}; + + std::vector messages = { + {"system", "You are a helpful assistant. Use the provided tools when asked about weather."}, + {"user", "What is the weather in Seattle?"}}; + + ChatSettings settings; + settings.temperature = 0.0f; + settings.max_tokens = 256; + settings.tool_choice = ToolChoiceKind::Required; + + auto response = client.CompleteChat(messages, tools, settings); + EXPECT_TRUE(response.successful); + ASSERT_FALSE(response.choices.empty()); + + const auto& choice = response.choices[0]; + // With tool_choice = Required, the model should produce a tool call + if (choice.finish_reason == FinishReason::ToolCalls) { + ASSERT_TRUE(choice.message.has_value()); + ASSERT_FALSE(choice.message->tool_calls.empty()); + const auto& tc = choice.message->tool_calls[0]; + EXPECT_FALSE(tc.id.empty()); + ASSERT_TRUE(tc.function_call.has_value()); + EXPECT_EQ("get_weather", tc.function_call->name); + EXPECT_FALSE(tc.function_call->arguments.empty()); + std::cout << "[E2E] Tool call: " << tc.function_call->name << " args: " << tc.function_call->arguments << "\n"; + } + + target->Unload(); +} + +// =========================================================================== +// Audio transcription +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_AudioTranscription) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download + audio file)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindAudioModel(catalog); + if (!target) { + GTEST_SKIP() << "No audio model found in catalog"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + std::cout << "[E2E] Audio model: " << target->GetAlias() << "\n"; + + OpenAIAudioClient client(*target); + + // Note: this test requires a valid audio file to be present. + // Skip if no test audio file is available. + const char* audioPath = std::getenv("FL_TEST_AUDIO_PATH"); + if (!audioPath) { + target->Unload(); + GTEST_SKIP() << "Set FL_TEST_AUDIO_PATH env var to a .wav file to run audio tests"; + } + + auto result = client.TranscribeAudio(audioPath); + EXPECT_FALSE(result.text.empty()); + std::cout << "[E2E] Transcription: " << result.text << "\n"; + + target->Unload(); +} + +TEST_F(EndToEndTest, DISABLED_AudioTranscriptionStreaming) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download + audio file)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindAudioModel(catalog); + if (!target) { + GTEST_SKIP() << "No audio model found in catalog"; + } + + target->Download(); + target->Load(); + ASSERT_TRUE(target->IsLoaded()); + + const char* audioPath = std::getenv("FL_TEST_AUDIO_PATH"); + if (!audioPath) { + target->Unload(); + GTEST_SKIP() << "Set FL_TEST_AUDIO_PATH env var to a .wav file to run audio tests"; + } + + OpenAIAudioClient client(*target); + + std::string fullText; + int chunkCount = 0; + client.TranscribeAudioStreaming(audioPath, [&](const AudioCreateTranscriptionResponse& chunk) { + fullText += chunk.text; + chunkCount++; + }); + + EXPECT_GT(chunkCount, 0) << "Should have received at least one streaming chunk"; + EXPECT_FALSE(fullText.empty()); + std::cout << "[E2E] Streaming transcription (" << chunkCount << " chunks): " << fullText << "\n"; + + target->Unload(); +} + +// =========================================================================== +// RemoveFromCache +// =========================================================================== + +TEST_F(EndToEndTest, DISABLED_DownloadAndRemoveFromCache) { + if (IsRunningInCI()) { + GTEST_SKIP() << "Skipped in CI (requires model download)"; + } + + auto& catalog = FoundryLocalManager::Instance().GetCatalog(); + auto* target = FindChatModel(catalog); + if (!target) { + GTEST_SKIP() << "No chat-capable model found in catalog"; + } + + target->Download(); + EXPECT_TRUE(target->IsCached()); + + // RemoveFromCache should succeed without throwing. + EXPECT_NO_THROW(target->RemoveFromCache()); + + std::cout << "[E2E] RemoveFromCache completed for: " << target->GetAlias() + << " (IsCached=" << (target->IsCached() ? "true" : "false") << ")\n"; +} diff --git a/sdk/cpp/test/mock_core.h b/sdk/cpp/test/mock_core.h new file mode 100644 index 00000000..f89af91a --- /dev/null +++ b/sdk/cpp/test/mock_core.h @@ -0,0 +1,158 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "foundry_local_internal_core.h" +#include "logger.h" + +namespace foundry_local::Testing { + + /// A mock implementation of IFoundryLocalCore for unit testing. + /// Register expected command -> response mappings before use. + class MockCore final : public Internal::IFoundryLocalCore { + public: + /// Handler signature: (command, dataArgument, callback, userData) -> response string. + using Handler = std::function; + + /// Register a fixed response for a command. + void OnCall(std::string command, std::string response) { + handlers_[std::move(command)] = [r = std::move(response)](std::string_view, const std::string*, + NativeCallbackFn, void*) { return r; }; + } + + /// Register a custom handler for a command. + void OnCall(std::string command, Handler handler) { handlers_[std::move(command)] = std::move(handler); } + + /// Register a handler that returns an error for a command. + void OnCallThrow(std::string command, std::string errorMessage) { + errorResponses_[std::move(command)] = std::move(errorMessage); + } + + /// Returns the number of times a command was called. + int GetCallCount(const std::string& command) const { + auto it = callCounts_.find(command); + return it != callCounts_.end() ? it->second : 0; + } + + /// Returns the last data argument passed for a command. + const std::string& GetLastDataArg(const std::string& command) const { + auto it = lastDataArgs_.find(command); + if (it == lastDataArgs_.end()) { + static const std::string empty; + return empty; + } + return it->second; + } + + // IFoundryLocalCore implementation + CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* dataArgument = nullptr, + NativeCallbackFn callback = nullptr, void* data = nullptr) const override { + + std::string cmd(command); + const_cast(this)->callCounts_[cmd]++; + if (dataArgument) { + const_cast(this)->lastDataArgs_[cmd] = *dataArgument; + } + + auto errIt = errorResponses_.find(cmd); + if (errIt != errorResponses_.end()) { + CoreResponse resp; + resp.error = errIt->second; + return resp; + } + + auto it = handlers_.find(cmd); + if (it == handlers_.end()) { + throw std::runtime_error("MockCore: no handler registered for command '" + cmd + "'"); + } + + CoreResponse resp; + resp.data = it->second(command, dataArgument, callback, data); + return resp; + } + + void unload() override {} + + private: + std::unordered_map handlers_; + std::unordered_map errorResponses_; + std::unordered_map callCounts_; + std::unordered_map lastDataArgs_; + }; + + /// Read a file into a string. Throws on failure. + inline std::string ReadFile(const std::string& path) { + std::ifstream in(path, std::ios::in | std::ios::binary); + if (!in) + throw std::runtime_error("Failed to open test data file: " + path); + std::ostringstream contents; + contents << in.rdbuf(); + return contents.str(); + } + + /// A mock core that reads model list, cached models and loaded models from JSON files on disk. + class FileBackedCore final : public Internal::IFoundryLocalCore { + public: + FileBackedCore(std::string modelListPath, std::string cachedModelsPath, std::string loadedModelsPath = "") + : modelListPath_(std::move(modelListPath)), cachedModelsPath_(std::move(cachedModelsPath)), + loadedModelsPath_(std::move(loadedModelsPath)) {} + + static FileBackedCore FromModelList(const std::string& path) { return FileBackedCore(path, ""); } + + static FileBackedCore FromBoth(const std::string& modelListPath, const std::string& cachedModelsPath) { + return FileBackedCore(modelListPath, cachedModelsPath); + } + + static FileBackedCore FromAll(const std::string& modelListPath, const std::string& cachedModelsPath, + const std::string& loadedModelsPath) { + return FileBackedCore(modelListPath, cachedModelsPath, loadedModelsPath); + } + + CoreResponse call(std::string_view command, ILogger& /*logger*/, const std::string* /*dataArgument*/ = nullptr, + NativeCallbackFn /*callback*/ = nullptr, void* /*data*/ = nullptr) const override { + + CoreResponse resp; + + if (command == "get_catalog_name") { + resp.data = "TestCatalog"; + return resp; + } + + if (command == "get_model_list") { + resp.data = modelListPath_.empty() ? "[]" : ReadFile(modelListPath_); + return resp; + } + + if (command == "get_cached_models") { + resp.data = cachedModelsPath_.empty() ? "[]" : ReadFile(cachedModelsPath_); + return resp; + } + + if (command == "list_loaded_models") { + resp.data = loadedModelsPath_.empty() ? "[]" : ReadFile(loadedModelsPath_); + return resp; + } + + resp.data = "{}"; + return resp; + } + + void unload() override {} + + private: + std::string modelListPath_; + std::string cachedModelsPath_; + std::string loadedModelsPath_; + }; + +} // namespace foundry_local::Testing diff --git a/sdk/cpp/test/mock_object_factory.h b/sdk/cpp/test/mock_object_factory.h new file mode 100644 index 00000000..86331ab3 --- /dev/null +++ b/sdk/cpp/test/mock_object_factory.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifndef FL_TESTS +#define FL_TESTS +#endif + +#include "foundry_local.h" +#include "foundry_local_internal_core.h" +#include "logger.h" + +namespace foundry_local::Testing { + + /// Factory to construct private-constructor types for testing. + /// Declared as a friend (Testing::MockObjectFactory) in ModelVariant, Model, and Catalog when FL_TESTS is defined. + struct MockObjectFactory { + static ModelVariant CreateModelVariant(gsl::not_null core, ModelInfo info, + gsl::not_null logger) { + return ModelVariant(core, std::move(info), logger); + } + + static std::unique_ptr CreateCatalog(gsl::not_null core, + gsl::not_null logger) { + return std::unique_ptr(new Catalog(core, logger)); + } + + static Model CreateModel(gsl::not_null core, gsl::not_null logger) { + return Model(core, logger); + } + + /// Push a variant into a Model's internal variant list. + static void AddVariantToModel(Model& model, ModelVariant variant) { + model.variants_.push_back(std::move(variant)); + } + + /// Set the selected variant on a Model. + static void SelectFirstVariant(Model& model) { model.selectedVariant_ = &model.variants_.front(); } + + /// Helper to build a minimal ModelInfo with defaults. + static ModelInfo MakeModelInfo(std::string name, std::string alias = "", uint32_t version = 1) { + ModelInfo info; + info.id = name + ":" + std::to_string(version); + info.name = std::move(name); + info.alias = alias.empty() ? info.name : std::move(alias); + info.version = version; + info.provider_type = "test"; + info.uri = "test://uri"; + info.model_type = "text"; + return info; + } + + /// Helper to build a JSON string representing a model list entry. + static std::string MakeModelInfoJson(const std::string& name, const std::string& alias = "", + uint32_t version = 1, bool cached = false) { + std::string a = alias.empty() ? name : alias; + std::string id = name + ":" + std::to_string(version); + return R"({"id":")" + id + R"(","name":")" + name + R"(","version":)" + std::to_string(version) + + R"(,"alias":")" + a + R"(","providerType":"test","uri":"test://uri","modelType":"text","cached":)" + + (cached ? "true" : "false") + R"(,"createdAt":0})"; + } + }; + +} // namespace foundry_local::Testing diff --git a/sdk/cpp/test/model_variant_test.cpp b/sdk/cpp/test/model_variant_test.cpp new file mode 100644 index 00000000..023112d9 --- /dev/null +++ b/sdk/cpp/test/model_variant_test.cpp @@ -0,0 +1,266 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" + +#include + +using namespace foundry_local; +using namespace foundry_local::Testing; + +using Factory = MockObjectFactory; + +class ModelVariantTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + ModelVariant MakeVariant(std::string name = "test-model", std::string alias = "test-alias", uint32_t version = 1) { + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); + } +}; + +TEST_F(ModelVariantTest, GetInfo) { + auto variant = MakeVariant("my-model", "my-alias", 3); + const auto& info = variant.GetInfo(); + EXPECT_EQ("my-model", info.name); + EXPECT_EQ("my-alias", info.alias); + EXPECT_EQ(3u, info.version); +} + +TEST_F(ModelVariantTest, GetId) { + auto variant = MakeVariant("my-model"); + EXPECT_EQ("my-model:1", variant.GetId()); +} + +TEST_F(ModelVariantTest, GetAlias) { + auto variant = MakeVariant("name", "alias"); + EXPECT_EQ("alias", variant.GetAlias()); +} + +TEST_F(ModelVariantTest, GetVersion) { + auto variant = MakeVariant("name", "alias", 5); + EXPECT_EQ(5u, variant.GetVersion()); +} + +TEST_F(ModelVariantTest, IsLoaded_True) { + core_.OnCall("list_loaded_models", R"(["test-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_TRUE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsLoaded_False) { + core_.OnCall("list_loaded_models", R"(["other-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsLoaded_EmptyList) { + core_.OnCall("list_loaded_models", R"([])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsLoaded()); +} + +TEST_F(ModelVariantTest, IsCached_True) { + core_.OnCall("get_cached_models", R"(["test-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_TRUE(variant.IsCached()); +} + +TEST_F(ModelVariantTest, IsCached_False) { + core_.OnCall("get_cached_models", R"(["other-model:1"])"); + auto variant = MakeVariant("test-model"); + EXPECT_FALSE(variant.IsCached()); +} + +TEST_F(ModelVariantTest, Load_CallsCore) { + core_.OnCall("load_model", ""); + auto variant = MakeVariant("test-model"); + variant.Load(); + EXPECT_EQ(1, core_.GetCallCount("load_model")); + + // Verify the data argument contains the model name + auto parsed = nlohmann::json::parse(core_.GetLastDataArg("load_model")); + EXPECT_EQ("test-model", parsed["Params"]["Model"].get()); +} + +TEST_F(ModelVariantTest, Unload_CallsCore) { + core_.OnCall("unload_model", ""); + auto variant = MakeVariant("test-model"); + variant.Unload(); + EXPECT_EQ(1, core_.GetCallCount("unload_model")); +} + +TEST_F(ModelVariantTest, Unload_ThrowsOnError) { + core_.OnCallThrow("unload_model", "unload failed"); + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.Unload(), Exception); +} + +TEST_F(ModelVariantTest, Download_NoCallback) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", ""); + auto variant = MakeVariant("test-model"); + variant.Download(); + EXPECT_EQ(1, core_.GetCallCount("download_model")); +} + +TEST_F(ModelVariantTest, Download_WithCallback) { + core_.OnCall("get_cached_models", R"([])"); + core_.OnCall("download_model", + [](std::string_view, const std::string*, NativeCallbackFn callback, void* userData) -> std::string { + // Simulate calling the progress callback + if (callback && userData) { + std::string progress = "50"; + callback(progress.data(), static_cast(progress.size()), userData); + } + return ""; + }); + + auto variant = MakeVariant("test-model"); + float lastProgress = -1.0f; + variant.Download([&](float pct) { lastProgress = pct; }); + EXPECT_NEAR(50.0f, lastProgress, 0.01f); +} + +TEST_F(ModelVariantTest, RemoveFromCache_CallsCore) { + core_.OnCall("remove_cached_model", ""); + auto variant = MakeVariant("test-model"); + variant.RemoveFromCache(); + EXPECT_EQ(1, core_.GetCallCount("remove_cached_model")); +} + +TEST_F(ModelVariantTest, RemoveFromCache_ThrowsOnError) { + core_.OnCallThrow("remove_cached_model", "remove failed"); + auto variant = MakeVariant("test-model"); + EXPECT_THROW(variant.RemoveFromCache(), Exception); +} + +TEST_F(ModelVariantTest, GetPath_CallsCore) { + core_.OnCall("get_model_path", R"(C:\models\test)"); + auto variant = MakeVariant("test-model"); + const auto& path = variant.GetPath(); + EXPECT_EQ(std::filesystem::path(R"(C:\models\test)"), path); +} + +TEST_F(ModelVariantTest, GetPath_CachesResult) { + core_.OnCall("get_model_path", R"(C:\models\test)"); + auto variant = MakeVariant("test-model"); + variant.GetPath(); + variant.GetPath(); + // Should only call once due to caching + EXPECT_EQ(1, core_.GetCallCount("get_model_path")); +} + +class ModelTest : public ::testing::Test { +protected: + MockCore core_; + NullLogger logger_; + + Model MakeModel() { return Factory::CreateModel(&core_, &logger_); } + + ModelVariant MakeVariant(std::string name = "test-model", std::string alias = "test-alias", uint32_t version = 1) { + return Factory::CreateModelVariant(&core_, Factory::MakeModelInfo(name, alias, version), &logger_); + } + + /// Helper: create a Model with one variant and selectedVariant_ set. + Model MakeModelWithVariant(const std::string& name = "test-model", const std::string& alias = "test-alias") { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant(name, alias, 1)); + Factory::SelectFirstVariant(model); + return model; + } +}; + +TEST_F(ModelTest, SelectedVariant_ThrowsWhenEmpty) { + auto model = MakeModel(); + EXPECT_THROW(model.GetId(), Exception); +} + +TEST_F(ModelTest, AddVariant_AndSelect) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SelectFirstVariant(model); + + EXPECT_EQ("v1:1", model.GetId()); + EXPECT_EQ("alias", model.GetAlias()); +} + +TEST_F(ModelTest, GetAllModelVariants) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SelectFirstVariant(model); + + auto variants = model.GetAllModelVariants(); + EXPECT_EQ(2u, variants.size()); +} + +TEST_F(ModelTest, SelectVariant) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SelectFirstVariant(model); + + const auto& v2 = model.GetAllModelVariants()[1]; + model.SelectVariant(v2); + EXPECT_EQ("v2:2", model.GetId()); +} + +TEST_F(ModelTest, SelectVariant_NotFound_Throws) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::SelectFirstVariant(model); + + auto external = MakeVariant("external", "ext-alias", 1); + EXPECT_THROW(model.SelectVariant(external), Exception); +} + +TEST_F(ModelTest, SelectVariant_ByIdFromExternalInstance) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("v1", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("v2", "alias", 2)); + Factory::SelectFirstVariant(model); + + // Simulate a variant obtained externally (e.g. from Catalog::GetModelVariant) + // with the same id as v2 but a different object instance. + auto externalV2 = MakeVariant("v2", "alias", 2); + model.SelectVariant(externalV2); + EXPECT_EQ("v2:2", model.GetId()); +} + +TEST_F(ModelTest, GetLatestVariant) { + auto model = MakeModel(); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 1)); + Factory::AddVariantToModel(model, MakeVariant("target-model", "alias", 2)); + Factory::SelectFirstVariant(model); + + const auto& first = model.GetAllModelVariants()[0]; + const auto& latest = model.GetLatestVersion(first); + // Should return the first one with matching name (which is variants_[0]) + EXPECT_EQ(&first, &latest); +} + +TEST_F(ModelTest, DelegationMethods) { + // Test that Model delegates to SelectedVariant + core_.OnCall("list_loaded_models", R"(["test-model:1"])"); + core_.OnCall("get_cached_models", R"(["test-model:1"])"); + core_.OnCall("load_model", ""); + core_.OnCall("unload_model", ""); + core_.OnCall("download_model", ""); + core_.OnCall("get_model_path", R"(C:\test)"); + + auto model = MakeModelWithVariant("test-model", "alias"); + + EXPECT_TRUE(model.IsLoaded()); + EXPECT_TRUE(model.IsCached()); + model.Load(); + model.Unload(); + model.Download(); + EXPECT_EQ(std::filesystem::path(R"(C:\test)"), model.GetPath()); +} diff --git a/sdk/cpp/test/parser_and_types_test.cpp b/sdk/cpp/test/parser_and_types_test.cpp new file mode 100644 index 00000000..681e912f --- /dev/null +++ b/sdk/cpp/test/parser_and_types_test.cpp @@ -0,0 +1,417 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include "mock_core.h" +#include "mock_object_factory.h" +#include "parser.h" +#include "foundry_local_exception.h" +#include "core_interop_request.h" + +#include + +using namespace foundry_local; +using namespace foundry_local::Testing; + +class ParserTest : public ::testing::Test { +protected: + static nlohmann::json MinimalModelJson() { + return nlohmann::json{{"id", "model-1:1"}, {"name", "model-1"}, {"version", 1}, + {"alias", "my-model"}, {"providerType", "onnx"}, {"uri", "https://example.com/model"}, + {"modelType", "text"}, {"cached", false}, {"createdAt", 1700000000}}; + } +}; + +TEST_F(ParserTest, ParseDeviceType_CPU) { + EXPECT_EQ(DeviceType::CPU, ParsingUtils::parse_device_type("CPU")); +} + +TEST_F(ParserTest, ParseDeviceType_GPU) { + EXPECT_EQ(DeviceType::GPU, ParsingUtils::parse_device_type("GPU")); +} + +TEST_F(ParserTest, ParseDeviceType_NPU) { + EXPECT_EQ(DeviceType::NPU, ParsingUtils::parse_device_type("NPU")); +} + +TEST_F(ParserTest, ParseDeviceType_Unknown) { + EXPECT_EQ(DeviceType::Invalid, ParsingUtils::parse_device_type("FPGA")); +} + +TEST_F(ParserTest, ParseFinishReason_Stop) { + EXPECT_EQ(FinishReason::Stop, ParsingUtils::parse_finish_reason("stop")); +} + +TEST_F(ParserTest, ParseFinishReason_Length) { + EXPECT_EQ(FinishReason::Length, ParsingUtils::parse_finish_reason("length")); +} + +TEST_F(ParserTest, ParseFinishReason_ToolCalls) { + EXPECT_EQ(FinishReason::ToolCalls, ParsingUtils::parse_finish_reason("tool_calls")); +} + +TEST_F(ParserTest, ParseFinishReason_ContentFilter) { + EXPECT_EQ(FinishReason::ContentFilter, ParsingUtils::parse_finish_reason("content_filter")); +} + +TEST_F(ParserTest, ParseFinishReason_None) { + EXPECT_EQ(FinishReason::None, ParsingUtils::parse_finish_reason("unknown_value")); +} + +TEST_F(ParserTest, GetStringOrEmpty_Present) { + nlohmann::json j = {{"key", "value"}}; + EXPECT_EQ("value", ParsingUtils::get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetStringOrEmpty_Missing) { + nlohmann::json j = {{"other", "value"}}; + EXPECT_EQ("", ParsingUtils::get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetStringOrEmpty_NonString) { + nlohmann::json j = {{"key", 42}}; + EXPECT_EQ("", ParsingUtils::get_string_or_empty(j, "key")); +} + +TEST_F(ParserTest, GetOptString_Present) { + nlohmann::json j = {{"key", "hello"}}; + auto result = ParsingUtils::get_opt_string(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ("hello", *result); +} + +TEST_F(ParserTest, GetOptString_Null) { + nlohmann::json j = {{"key", nullptr}}; + EXPECT_FALSE(ParsingUtils::get_opt_string(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptString_Missing) { + nlohmann::json j = {{"other", "v"}}; + EXPECT_FALSE(ParsingUtils::get_opt_string(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptInt_Present) { + nlohmann::json j = {{"key", 42}}; + auto result = ParsingUtils::get_opt_int(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(42, *result); +} + +TEST_F(ParserTest, GetOptInt_Missing) { + nlohmann::json j = {}; + EXPECT_FALSE(ParsingUtils::get_opt_int(j, "key").has_value()); +} + +TEST_F(ParserTest, GetOptBool_Present) { + nlohmann::json j = {{"key", true}}; + auto result = ParsingUtils::get_opt_bool(j, "key"); + ASSERT_TRUE(result.has_value()); + EXPECT_TRUE(*result); +} + +TEST_F(ParserTest, GetOptBool_Missing) { + nlohmann::json j = {}; + EXPECT_FALSE(ParsingUtils::get_opt_bool(j, "key").has_value()); +} + +TEST_F(ParserTest, ParseRuntime) { + nlohmann::json j = {{"deviceType", "GPU"}, {"executionProvider", "DML"}}; + Runtime r = j.get(); + EXPECT_EQ(DeviceType::GPU, r.device_type); + EXPECT_EQ("DML", r.execution_provider); +} + +TEST_F(ParserTest, ParsePromptTemplate) { + nlohmann::json j = {{"system", "sys"}, {"user", "usr"}, {"assistant", "asst"}, {"prompt", "p"}}; + PromptTemplate pt = j.get(); + EXPECT_EQ("sys", pt.system); + EXPECT_EQ("usr", pt.user); + EXPECT_EQ("asst", pt.assistant); + EXPECT_EQ("p", pt.prompt); +} + +TEST_F(ParserTest, ParsePromptTemplate_MissingFields) { + nlohmann::json j = {{"system", "sys"}}; + PromptTemplate pt = j.get(); + EXPECT_EQ("sys", pt.system); + EXPECT_EQ("", pt.user); + EXPECT_EQ("", pt.assistant); + EXPECT_EQ("", pt.prompt); +} + +TEST_F(ParserTest, ParseModelInfo_Minimal) { + auto j = MinimalModelJson(); + ModelInfo info = j.get(); + EXPECT_EQ("model-1:1", info.id); + EXPECT_EQ("model-1", info.name); + EXPECT_EQ(1u, info.version); + EXPECT_EQ("my-model", info.alias); + EXPECT_EQ("onnx", info.provider_type); + EXPECT_EQ("https://example.com/model", info.uri); + EXPECT_EQ("text", info.model_type); + EXPECT_FALSE(info.cached); + EXPECT_EQ(1700000000, info.created_at_unix); + EXPECT_FALSE(info.display_name.has_value()); + EXPECT_FALSE(info.publisher.has_value()); + EXPECT_FALSE(info.runtime.has_value()); + EXPECT_FALSE(info.prompt_template.has_value()); + EXPECT_FALSE(info.model_settings.has_value()); +} + +TEST_F(ParserTest, ParseModelInfo_WithOptionals) { + auto j = MinimalModelJson(); + j["displayName"] = "My Model"; + j["publisher"] = "TestPublisher"; + j["license"] = "MIT"; + j["fileSizeMb"] = 512; + j["supportsToolCalling"] = true; + j["maxOutputTokens"] = 4096; + j["runtime"] = {{"deviceType", "CPU"}, {"executionProvider", "ORT"}}; + + ModelInfo info = j.get(); + ASSERT_TRUE(info.display_name.has_value()); + EXPECT_EQ("My Model", *info.display_name); + ASSERT_TRUE(info.publisher.has_value()); + EXPECT_EQ("TestPublisher", *info.publisher); + ASSERT_TRUE(info.license.has_value()); + EXPECT_EQ("MIT", *info.license); + ASSERT_TRUE(info.file_size_mb.has_value()); + EXPECT_EQ(512u, *info.file_size_mb); + ASSERT_TRUE(info.supports_tool_calling.has_value()); + EXPECT_TRUE(*info.supports_tool_calling); + ASSERT_TRUE(info.max_output_tokens.has_value()); + EXPECT_EQ(4096, *info.max_output_tokens); + ASSERT_TRUE(info.runtime.has_value()); + EXPECT_EQ(DeviceType::CPU, info.runtime->device_type); + EXPECT_EQ("ORT", info.runtime->execution_provider); +} + +TEST_F(ParserTest, ParseModelSettings) { + nlohmann::json j = {{"parameters", {{{"name", "p1"}, {"value", "v1"}}, {{"name", "p2"}}}}}; + ModelSettings ms = j.get(); + ASSERT_EQ(2u, ms.parameters.size()); + EXPECT_EQ("p1", ms.parameters[0].name); + ASSERT_TRUE(ms.parameters[0].value.has_value()); + EXPECT_EQ("v1", *ms.parameters[0].value); + EXPECT_EQ("p2", ms.parameters[1].name); + EXPECT_FALSE(ms.parameters[1].value.has_value()); +} + +TEST_F(ParserTest, ParseChatMessage) { + nlohmann::json j = {{"role", "user"}, {"content", "hello"}}; + ChatMessage msg = j.get(); + EXPECT_EQ("user", msg.role); + EXPECT_EQ("hello", msg.content); + EXPECT_TRUE(msg.tool_calls.empty()); + EXPECT_FALSE(msg.tool_call_id.has_value()); +} + +TEST_F(ParserTest, ParseChatMessage_WithToolCalls) { + nlohmann::json j = {{"role", "assistant"}, + {"content", "I'll call a tool."}, + {"tool_calls", + {{{"id", "call_abc123"}, + {"type", "function"}, + {"function", {{"name", "get_weather"}, {"arguments", "{\"city\": \"Seattle\"}"}}}}}}}; + ChatMessage msg = j.get(); + EXPECT_EQ("assistant", msg.role); + ASSERT_EQ(1u, msg.tool_calls.size()); + EXPECT_EQ("call_abc123", msg.tool_calls[0].id); + EXPECT_EQ("function", msg.tool_calls[0].type); + ASSERT_TRUE(msg.tool_calls[0].function_call.has_value()); + EXPECT_EQ("get_weather", msg.tool_calls[0].function_call->name); + EXPECT_EQ("{\"city\": \"Seattle\"}", msg.tool_calls[0].function_call->arguments); +} + +TEST_F(ParserTest, ParseChatMessage_WithToolCallId) { + nlohmann::json j = {{"role", "tool"}, {"content", "72 degrees and sunny"}, {"tool_call_id", "call_abc123"}}; + ChatMessage msg = j.get(); + EXPECT_EQ("tool", msg.role); + EXPECT_EQ("72 degrees and sunny", msg.content); + ASSERT_TRUE(msg.tool_call_id.has_value()); + EXPECT_EQ("call_abc123", *msg.tool_call_id); +} + +TEST_F(ParserTest, ParseFunctionCall) { + nlohmann::json j = {{"name", "multiply"}, {"arguments", "{\"a\": 1, \"b\": 2}"}}; + FunctionCall fc = j.get(); + EXPECT_EQ("multiply", fc.name); + EXPECT_EQ("{\"a\": 1, \"b\": 2}", fc.arguments); +} + +TEST_F(ParserTest, ParseFunctionCall_ObjectArguments) { + nlohmann::json j = {{"name", "add"}, {"arguments", {{"x", 10}}}}; + FunctionCall fc = j.get(); + EXPECT_EQ("add", fc.name); + EXPECT_EQ("{\"x\":10}", fc.arguments); +} + +TEST_F(ParserTest, ParseToolCall) { + nlohmann::json j = {{"id", "call_1"}, + {"type", "function"}, + {"function", {{"name", "search"}, {"arguments", "{\"query\": \"test\"}"}}}}; + ToolCall tc = j.get(); + EXPECT_EQ("call_1", tc.id); + EXPECT_EQ("function", tc.type); + ASSERT_TRUE(tc.function_call.has_value()); + EXPECT_EQ("search", tc.function_call->name); +} + +TEST_F(ParserTest, SerializeToolDefinition) { + ToolDefinition tool; + tool.type = "function"; + tool.function.name = "get_weather"; + tool.function.description = "Get the current weather"; + tool.function.parameters = PropertyDefinition{"object", std::nullopt, + std::unordered_map{ + {"location", PropertyDefinition{"string", "The city name"}}}, + std::vector{"location"}}; + + nlohmann::json j; + to_json(j, tool); + + EXPECT_EQ("function", j["type"].get()); + EXPECT_EQ("get_weather", j["function"]["name"].get()); + EXPECT_EQ("Get the current weather", j["function"]["description"].get()); + EXPECT_EQ("object", j["function"]["parameters"]["type"].get()); + ASSERT_TRUE(j["function"]["parameters"]["properties"].contains("location")); + EXPECT_EQ("string", j["function"]["parameters"]["properties"]["location"]["type"].get()); + ASSERT_EQ(1u, j["function"]["parameters"]["required"].size()); + EXPECT_EQ("location", j["function"]["parameters"]["required"][0].get()); +} + +TEST_F(ParserTest, SerializeToolDefinition_MinimalFunction) { + ToolDefinition tool; + tool.function.name = "noop"; + + nlohmann::json j; + to_json(j, tool); + + EXPECT_EQ("function", j["type"].get()); + EXPECT_EQ("noop", j["function"]["name"].get()); + EXPECT_FALSE(j["function"].contains("description")); + EXPECT_FALSE(j["function"].contains("parameters")); +} + +TEST_F(ParserTest, ToolChoiceToString) { + EXPECT_EQ("auto", ParsingUtils::tool_choice_to_string(ToolChoiceKind::Auto)); + EXPECT_EQ("none", ParsingUtils::tool_choice_to_string(ToolChoiceKind::None)); + EXPECT_EQ("required", ParsingUtils::tool_choice_to_string(ToolChoiceKind::Required)); +} + +TEST_F(ParserTest, ParseChatChoice_NonStreaming) { + nlohmann::json j = { + {"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", "Hi there!"}}}}; + ChatChoice c = j.get(); + EXPECT_EQ(0, c.index); + EXPECT_EQ(FinishReason::Stop, c.finish_reason); + ASSERT_TRUE(c.message.has_value()); + EXPECT_EQ("assistant", c.message->role); + EXPECT_EQ("Hi there!", c.message->content); + EXPECT_FALSE(c.delta.has_value()); +} + +TEST_F(ParserTest, ParseChatChoice_Streaming) { + nlohmann::json j = { + {"index", 0}, {"finish_reason", nullptr}, {"delta", {{"role", "assistant"}, {"content", "Hi"}}}}; + ChatChoice c = j.get(); + EXPECT_EQ(FinishReason::None, c.finish_reason); + EXPECT_FALSE(c.message.has_value()); + ASSERT_TRUE(c.delta.has_value()); + EXPECT_EQ("Hi", c.delta->content); +} + +TEST_F(ParserTest, ParseChatCompletionCreateResponse) { + nlohmann::json j = { + {"created", 1700000000}, + {"id", "chatcmpl-123"}, + {"IsDelta", false}, + {"Successful", true}, + {"HttpStatusCode", 200}, + {"choices", + {{{"index", 0}, {"finish_reason", "stop"}, {"message", {{"role", "assistant"}, {"content", "Hello!"}}}}}}}; + ChatCompletionCreateResponse r = j.get(); + EXPECT_EQ(1700000000, r.created); + EXPECT_EQ("chatcmpl-123", r.id); + EXPECT_FALSE(r.is_delta); + EXPECT_TRUE(r.successful); + EXPECT_EQ(200, r.http_status_code); + ASSERT_EQ(1u, r.choices.size()); + EXPECT_EQ("Hello!", r.choices[0].message->content); +} + +TEST(ChatCompletionCreateResponseTest, GetObject_NonDelta) { + ChatCompletionCreateResponse r; + r.is_delta = false; + EXPECT_STREQ("chat.completion", r.GetObject()); +} + +TEST(ChatCompletionCreateResponseTest, GetObject_Delta) { + ChatCompletionCreateResponse r; + r.is_delta = true; + EXPECT_STREQ("chat.completion.chunk", r.GetObject()); +} + +TEST(ChatCompletionCreateResponseTest, GetCreatedAtIso_Zero) { + ChatCompletionCreateResponse r; + r.created = 0; + EXPECT_EQ("", r.GetCreatedAtIso()); +} + +TEST(ChatCompletionCreateResponseTest, GetCreatedAtIso_ValidTimestamp) { + ChatCompletionCreateResponse r; + r.created = 1700000000; // 2023-11-14T22:13:20Z + std::string iso = r.GetCreatedAtIso(); + EXPECT_FALSE(iso.empty()); + EXPECT_EQ('Z', iso.back()); + EXPECT_NE(std::string::npos, iso.find("2023")); +} + +// ============================================================================= +// CoreInteropRequest tests +// ============================================================================= + +TEST(CoreInteropRequestTest, Command) { + CoreInteropRequest req("test_command"); + EXPECT_EQ("test_command", req.Command()); +} + +TEST(CoreInteropRequestTest, ToJson_NoParams) { + CoreInteropRequest req("cmd"); + std::string json = req.ToJson(); + auto parsed = nlohmann::json::parse(json); + EXPECT_FALSE(parsed.contains("Params")); +} + +TEST(CoreInteropRequestTest, ToJson_WithParams) { + CoreInteropRequest req("cmd"); + req.AddParam("key1", "value1"); + req.AddParam("key2", "value2"); + std::string json = req.ToJson(); + auto parsed = nlohmann::json::parse(json); + ASSERT_TRUE(parsed.contains("Params")); + EXPECT_EQ("value1", parsed["Params"]["key1"].get()); + EXPECT_EQ("value2", parsed["Params"]["key2"].get()); +} + +TEST(CoreInteropRequestTest, AddParam_Chaining) { + CoreInteropRequest req("cmd"); + auto& ref = req.AddParam("a", "1").AddParam("b", "2"); + EXPECT_EQ(&req, &ref); +} + +// ============================================================================= +// Exception tests +// ============================================================================= + +TEST(ExceptionTest, MessageOnly) { + Exception ex("test error"); + EXPECT_STREQ("test error", ex.what()); +} + +TEST(ExceptionTest, MessageAndLogger) { + NullLogger logger; + Exception ex("logged error", logger); + EXPECT_STREQ("logged error", ex.what()); +} \ No newline at end of file diff --git a/sdk/cpp/test/testdata/empty_models_list.json b/sdk/cpp/test/testdata/empty_models_list.json new file mode 100644 index 00000000..fe51488c --- /dev/null +++ b/sdk/cpp/test/testdata/empty_models_list.json @@ -0,0 +1 @@ +[] diff --git a/sdk/cpp/test/testdata/malformed_models_list.json b/sdk/cpp/test/testdata/malformed_models_list.json new file mode 100644 index 00000000..a04360f5 --- /dev/null +++ b/sdk/cpp/test/testdata/malformed_models_list.json @@ -0,0 +1 @@ +{this is not valid json[} diff --git a/sdk/cpp/test/testdata/missing_name_field_models_list.json b/sdk/cpp/test/testdata/missing_name_field_models_list.json new file mode 100644 index 00000000..da1e9465 --- /dev/null +++ b/sdk/cpp/test/testdata/missing_name_field_models_list.json @@ -0,0 +1,12 @@ +[ + { + "id": "model-missing-name:1", + "version": 1, + "alias": "test", + "providerType": "onnx", + "uri": "https://example.com/model", + "modelType": "text", + "cached": false, + "createdAt": 0 + } +] diff --git a/sdk/cpp/test/testdata/mixed_openai_and_local.json b/sdk/cpp/test/testdata/mixed_openai_and_local.json new file mode 100644 index 00000000..9d8de80b --- /dev/null +++ b/sdk/cpp/test/testdata/mixed_openai_and_local.json @@ -0,0 +1,35 @@ +[ + { + "id": "openai-gpt4:1", + "name": "openai-gpt4", + "version": 1, + "alias": "openai-gpt4", + "providerType": "openai", + "uri": "https://example.com/openai-gpt4", + "modelType": "text", + "cached": false, + "createdAt": 0 + }, + { + "id": "openai-whisper:1", + "name": "openai-whisper", + "version": 1, + "alias": "openai-whisper", + "providerType": "openai", + "uri": "https://example.com/openai-whisper", + "modelType": "audio", + "cached": false, + "createdAt": 0 + }, + { + "id": "local-phi-4:1", + "name": "local-phi-4", + "version": 1, + "alias": "phi-4", + "providerType": "onnx", + "uri": "https://example.com/phi-4", + "modelType": "text", + "cached": false, + "createdAt": 1700000000 + } +] diff --git a/sdk/cpp/test/testdata/real_models_list.json b/sdk/cpp/test/testdata/real_models_list.json new file mode 100644 index 00000000..284d3a1a --- /dev/null +++ b/sdk/cpp/test/testdata/real_models_list.json @@ -0,0 +1,88 @@ +[ + { + "id": "Phi-4-generic-gpu:1", + "name": "Phi-4-generic-gpu", + "version": 1, + "alias": "phi-4", + "displayName": "Phi-4 (GPU)", + "providerType": "onnx", + "uri": "https://example.com/phi-4-gpu", + "modelType": "text", + "publisher": "Microsoft", + "license": "MIT", + "fileSizeMb": 8192, + "supportsToolCalling": true, + "maxOutputTokens": 4096, + "cached": false, + "createdAt": 1700000000, + "runtime": { + "deviceType": "GPU", + "executionProvider": "DML" + }, + "promptTemplate": { + "system": "<|system|>", + "user": "<|user|>", + "assistant": "<|assistant|>", + "prompt": "<|prompt|>" + } + }, + { + "id": "Phi-4-generic-cpu:1", + "name": "Phi-4-generic-cpu", + "version": 1, + "alias": "phi-4", + "displayName": "Phi-4 (CPU)", + "providerType": "onnx", + "uri": "https://example.com/phi-4-cpu", + "modelType": "text", + "publisher": "Microsoft", + "license": "MIT", + "fileSizeMb": 4096, + "supportsToolCalling": false, + "maxOutputTokens": 2048, + "cached": false, + "createdAt": 1700000000, + "runtime": { + "deviceType": "CPU", + "executionProvider": "ORT" + } + }, + { + "id": "Mistral-7b-v0.2-generic-gpu:1", + "name": "Mistral-7b-v0.2-generic-gpu", + "version": 1, + "alias": "mistral-7b-v0.2", + "displayName": "Mistral 7B v0.2 (GPU)", + "providerType": "onnx", + "uri": "https://example.com/mistral-gpu", + "modelType": "text", + "publisher": "Mistral AI", + "license": "Apache-2.0", + "fileSizeMb": 14000, + "cached": false, + "createdAt": 1700100000, + "runtime": { + "deviceType": "GPU", + "executionProvider": "DML" + } + }, + { + "id": "Mistral-7b-v0.2-generic-cpu:1", + "name": "Mistral-7b-v0.2-generic-cpu", + "version": 1, + "alias": "mistral-7b-v0.2", + "displayName": "Mistral 7B v0.2 (CPU)", + "providerType": "onnx", + "uri": "https://example.com/mistral-cpu", + "modelType": "text", + "publisher": "Mistral AI", + "license": "Apache-2.0", + "fileSizeMb": 7000, + "cached": false, + "createdAt": 1700100000, + "runtime": { + "deviceType": "CPU", + "executionProvider": "ORT" + } + } +] diff --git a/sdk/cpp/test/testdata/single_cached_model.json b/sdk/cpp/test/testdata/single_cached_model.json new file mode 100644 index 00000000..76efa8e7 --- /dev/null +++ b/sdk/cpp/test/testdata/single_cached_model.json @@ -0,0 +1 @@ +["multi-v1-cpu:1"] diff --git a/sdk/cpp/test/testdata/three_variants_one_model.json b/sdk/cpp/test/testdata/three_variants_one_model.json new file mode 100644 index 00000000..fad0555d --- /dev/null +++ b/sdk/cpp/test/testdata/three_variants_one_model.json @@ -0,0 +1,41 @@ +[ + { + "id": "multi-v1-gpu:1", + "name": "multi-v1-gpu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 GPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-gpu", + "modelType": "text", + "cached": false, + "createdAt": 1700000000, + "runtime": { "deviceType": "GPU", "executionProvider": "DML" } + }, + { + "id": "multi-v1-cpu:1", + "name": "multi-v1-cpu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 CPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-cpu", + "modelType": "text", + "cached": true, + "createdAt": 1700000000, + "runtime": { "deviceType": "CPU", "executionProvider": "ORT" } + }, + { + "id": "multi-v1-npu:1", + "name": "multi-v1-npu", + "version": 1, + "alias": "multi-model", + "displayName": "Multi Model v1 NPU", + "providerType": "onnx", + "uri": "https://example.com/multi-v1-npu", + "modelType": "text", + "cached": false, + "createdAt": 1700000000, + "runtime": { "deviceType": "NPU", "executionProvider": "QNN" } + } +] diff --git a/sdk/cpp/test/testdata/valid_cached_models.json b/sdk/cpp/test/testdata/valid_cached_models.json new file mode 100644 index 00000000..2b144174 --- /dev/null +++ b/sdk/cpp/test/testdata/valid_cached_models.json @@ -0,0 +1 @@ +["Phi-4-generic-gpu:1", "Phi-4-generic-cpu:1"] diff --git a/sdk/cpp/test/testdata/valid_loaded_models.json b/sdk/cpp/test/testdata/valid_loaded_models.json new file mode 100644 index 00000000..4d2ef328 --- /dev/null +++ b/sdk/cpp/test/testdata/valid_loaded_models.json @@ -0,0 +1 @@ +["Phi-4-generic-gpu:1"] diff --git a/sdk/cpp/triplets/x64-windows-static-md.cmake b/sdk/cpp/triplets/x64-windows-static-md.cmake new file mode 100644 index 00000000..63d6cde2 --- /dev/null +++ b/sdk/cpp/triplets/x64-windows-static-md.cmake @@ -0,0 +1,3 @@ +set(VCPKG_TARGET_ARCHITECTURE x64) +set(VCPKG_CRT_LINKAGE dynamic) +set(VCPKG_LIBRARY_LINKAGE static) diff --git a/sdk/cpp/vcpkg-configuration.json b/sdk/cpp/vcpkg-configuration.json new file mode 100644 index 00000000..a5253fb7 --- /dev/null +++ b/sdk/cpp/vcpkg-configuration.json @@ -0,0 +1,6 @@ +{ + "default-registry": { + "kind": "builtin", + "baseline": "a9f0cd0345fb29cd227d802f1fd1917c28f8e5a3" + } +} diff --git a/sdk/cpp/vcpkg.json b/sdk/cpp/vcpkg.json new file mode 100644 index 00000000..c4511497 --- /dev/null +++ b/sdk/cpp/vcpkg.json @@ -0,0 +1,12 @@ +{ + "name": "cppsdk", + "version-string": "0.1.0", + "dependencies": [ + "nlohmann-json", + "wil", + "ms-gsl" + ], + "dev-dependencies": [ + "gtest" + ] +}