diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index ddefa523..c84abb69 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -34,10 +34,14 @@ jobs: echo "Building Linux" cmake --workflow --preset package-release-debian-${{matrix.build-variant}}-workflow shell: bash + - name: Create a zip file with specific files/folders + shell: powershell + run: | + Compress-Archive -Path "${{ github.workspace }}/build/${{ matrix.build-variant }}/*.deb" -DestinationPath "${{ matrix.build-variant }}.zip" - name: Upload Debian Package to Release uses: svenstaro/upload-release-action@v2 with: repo_token: ${{ secrets.GITHUB_TOKEN }} - file: ${{ github.workspace }}/build/${{ matrix.build-variant }}/*.deb + file: "${{ matrix.build-variant }}.zip" file_glob: true tag: ${{ github.ref }} diff --git a/.github/workflows/unit-testing.yml b/.github/workflows/unit-testing.yml index dec406d4..1b340f57 100644 --- a/.github/workflows/unit-testing.yml +++ b/.github/workflows/unit-testing.yml @@ -63,5 +63,5 @@ jobs: echo "Testing Microsoft" cmake --workflow --preset test-debug-microsoft-${{matrix.build-variant}}-workflow fi - timeout-minutes: 5 + timeout-minutes: 7 shell: bash diff --git a/.gitignore b/.gitignore index bf7d7e9c..efc974ca 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ # To modify the Doxyfile, edit docs/Doxyfile.in and re-run CMake. It will build the new Doxyfile at the root of the build tree. Doxyfile +*.log + .metadata bin/ tmp/ diff --git a/CMakeLists.txt b/CMakeLists.txt index f2a6be93..ac8cf8b6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -73,7 +73,7 @@ endif() add_library(respond_model) add_library(respond::respond_model ALIAS respond_model) -file(GLOB_RECURSE RESPOND_HEADERS include/respond/*.hpp) +file(GLOB_RECURSE RESPOND_HEADERS include/respond/*.hpp include/respond/*.h) file(GLOB RESPOND_INTERNAL_HEADERS CONFIGURE_DEPENDS src/internals/*.hpp) file(GLOB RESPOND_SOURCE_FILES CONFIGURE_DEPENDS src/*.cpp) diff --git a/CMakePresets.json b/CMakePresets.json index 1f270889..ad6330b9 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -124,6 +124,23 @@ "default-respond-config" ] }, + { + "name": "benchmark-linux-static-config", + "displayName": "Benchmark, Linux, Static", + "description": "Target a linux distro to build release benchmark artifacts", + "inherits": [ + "release-config", + "linux-config", + "static-config", + "default-respond-config" + ], + "cacheVariables": { + "RESPOND_BUILD_BENCH": "ON", + "RESPOND_BUILD_TESTS": "OFF", + "RESPOND_BUILD_DOCS": "OFF", + "RESPOND_INSTALL": "OFF" + } + }, { "name": "debug-microsoft-shared-config", "displayName": "Debug, Microsoft, Shared", @@ -199,6 +216,12 @@ "inherits": "default-build", "configurePreset": "debug-linux-static-config" }, + { + "name": "benchmark-linux-static-build", + "displayName": "Benchmark Linux Static Build", + "inherits": "default-build", + "configurePreset": "benchmark-linux-static-config" + }, { "name": "release-microsoft-shared-build", "displayName": "Release Microsoft Shared Build", @@ -411,6 +434,21 @@ } ] }, + { + "name": "benchmark-linux-static-workflow", + "displayName": "Benchmark, Linux, Static", + "description": "Build benchmark artifacts for performance testing.", + "steps": [ + { + "type": "configure", + "name": "benchmark-linux-static-config" + }, + { + "type": "build", + "name": "benchmark-linux-static-build" + } + ] + }, { "name": "package-release-debian-shared-workflow", "displayName": "Package, Release, Debian, Shared", diff --git a/Doxyfile.in b/Doxyfile.in index da3a2aa8..89203c95 100644 --- a/Doxyfile.in +++ b/Doxyfile.in @@ -48,7 +48,7 @@ PROJECT_NAME = RESPOND # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 2.3.2 +PROJECT_NUMBER = 2.4.0 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/README.md b/README.md index da1efac7..fbe37f55 100644 --- a/README.md +++ b/README.md @@ -22,12 +22,13 @@ The [original RESPOND model](https://github.com/SyndemicsLab/RESPONDv1/tree/main RESPOND makes full use of the CMake build system. It is a common tool used throughout the C++ user-base and we utilize it for dependency management, linking, and testing. As C++ has poor package management, we intentionally decided to move our focus away from tools such as conan and vcpkg and stay with pure CMake. Not to say we would never publish with such package managers, but it is not a core focus of the refactor/engineering team. -We natively support 4 different build workflows with the `CMakePresets.json` file. They are: +We natively support 5 different build workflows with the `CMakePresets.json` file. They are: 1. `test-debug-gcc-linux-shared-workflow` 2. `test-debug-gcc-linux-static-workflow` 3. `package-release-gcc-linux-shared-workflow` 4. `package-release-gcc-linux-static-workflow` +5. `benchmark-linux-static-workflow` These workflows follow the pattern `{function}-{build}-gcc-linux-{library}-workflow` and have corresponding presets for build, test, and package. As we adopt more operating systems and compilers we will expand beyond gcc and linux. @@ -105,14 +106,76 @@ tools/build.sh ## Running RESPOND -The recommended way to use the RESPOND model is via the [Python -package][respondpy], with [source on GitHub][respondpy-git]. +The recommended way to use the RESPOND model is via the [Python package][respondpy], with [source on GitHub][respondpy-git]. -If you wish to use RESPOND via a local executable, please refer to release -[v0.3.0](https://github.com/SyndemicsLab/respond/releases/tag/v0.3.0). +For C++ developers wishing to use RESPOND as a library, please refer to the [C++ API Guide][api-guide] in the documentation for usage examples, design patterns, and best practices. + +If you wish to use RESPOND via a legacy local executable, please refer to release [v0.3.0](https://github.com/SyndemicsLab/respond/releases/tag/v0.3.0). + +## Documentation + +Complete documentation is available in the [`docs/src/`](docs/src) directory: + +- **[C++ API Guide][api-guide]** - Developer guide for using RESPOND as a C++ library +- **[Architecture and Design](docs/src/architecture.md)** - Design patterns, component architecture, and extensibility +- **[Data Management](docs/src/data.md)** - Configuration files and data requirements +- **[Installation](docs/src/installation.md)** - Build and installation instructions +- **[Running the Model](docs/src/run.md)** - Execution instructions +- **[Math Background](docs/src/math.md)** - Mathematical foundations and equations +- **[Limitations](docs/src/limitations.md)** - Known constraints and future work +- **[FAQs](docs/src/faq.md)** - Frequently asked questions + +For auto-generated API documentation, build and open the Doxygen output at `build/docs/doxygen/html/index.html`. + +## Quick Start for C++ Developers + +### Using via CMake FetchContent + +If building a new project with CMake, use `FetchContent`: + +```cmake +include(FetchContent) +FetchContent_Declare( + respond + GIT_REPOSITORY https://github.com/SyndemicsLab/respond.git + GIT_TAG main + OVERRIDE_FIND_PACKAGE +) +set(RESPOND_INSTALL ON) +find_package(respond REQUIRED) + +target_link_libraries(${PROJECT_NAME} + PRIVATE + respond::respond_model +) +``` + +Then see the [C++ API Guide][api-guide] for usage examples. + +### Building Documentation + +To build the Doxygen documentation: + +```shell +cmake --build build/static --target doxygen_docs +``` + +Open `build/static/docs/doxygen/html/index.html` in a browser. + +### Running Tests + +After building with tests enabled: + +```shell +cmake --workflow --preset test-debug-gcc-linux-static-workflow +``` + +Tests verify all core components (models, transitions, history tracking). ## References + 1. Madushani RWMA, Wang J, Weitz M, Linas BP, White LF, Chrysanthopoulou SA (2025) Empirical calibration of a simulation model of opioid use disorder. PLoS ONE 20(3): e0310763. https://doi.org/10.1371/journal.pone.0310763 [respondpy]: https://pypi.org/project/respondpy/ [respondpy-git]: https://github.com/SyndemicsLab/respondpy +[api-guide]: docs/src/api-guide.md diff --git a/cmake/BuildBinaries.cmake b/cmake/BuildBinaries.cmake index c90d392a..4b34a44b 100644 --- a/cmake/BuildBinaries.cmake +++ b/cmake/BuildBinaries.cmake @@ -10,7 +10,7 @@ endif() if(RESPOND_BUILD_BENCH OR RESPOND_BUILD_ALL) message(STATUS "Generating benchmarks") - add_subdirectory(extras/benchmarking) + add_subdirectory(extras/benchmark) endif() if(RESPOND_BUILD_DOCS OR RESPOND_BUILD_ALL) diff --git a/cmake/options.cmake b/cmake/options.cmake index a45c89af..821b5ad7 100644 --- a/cmake/options.cmake +++ b/cmake/options.cmake @@ -16,7 +16,7 @@ option(RESPOND_BUILD_TESTS "Build tests" OFF) option(RESPOND_CALCULATE_COVERAGE "Calculate Code Coverage" OFF) # bench options -option(RESPOND_BUILD_BENCH "Build benchmarks (Requires https://github.com/google/benchmark.git to be installed)" OFF) +option(RESPOND_BUILD_BENCH "Build benchmarks" OFF) # compile level warning and exception options option(RESPOND_BUILD_WARNINGS "Enable compiler warnings" OFF) diff --git a/docs/src/api-guide.md b/docs/src/api-guide.md new file mode 100644 index 00000000..841b09ef --- /dev/null +++ b/docs/src/api-guide.md @@ -0,0 +1,485 @@ +# C++ API Guide + +This guide provides an overview of the RESPOND C++ API for developers wishing to use the library in their own projects. + +## Overview + +The RESPOND library provides a flexible framework for building opioid use disorder models through composition of models, transitions, and history tracking. The core components are: + +- **Model**: Abstract base class representing a state transition system +- **Simulation**: Aggregates and coordinates multiple models +- **Transition**: Abstract base for specific transition types +- **History**: Tracks state vectors over time +- **TransitionFactory**: Creates concrete transition instances + +## Core Concepts + +### State Vectors + +Models operate on state vectors (Eigen::VectorXd) representing the population distribution across model states. A state vector element at index i represents the count of individuals in state i. + +### Transitions + +Transitions apply transformations to state vectors using transition matrices. The RESPOND model supports several transition types: + +- **Migration**: Population movement between states +- **Behavior**: Behavioral state changes +- **Intervention**: Intervention-driven state changes +- **Overdose**: Overdose-related transitions +- **BackgroundDeath**: Background mortality transitions + +### History Tracking + +History objects record state vectors at each timestep, enabling analysis of state trajectories over time. Histories support sparse timesteps—gaps are automatically filled with zero vectors. + +## Model Class + +The Model class is the abstract base for all models in RESPOND. + +```cpp +#include + +// Create a model +auto model = respond::Model::Create("model_name", "logger_name"); + +// Set the initial state +Eigen::VectorXd initial_state(50); +initial_state.setZero(); +model->SetState(initial_state); + +// Add transitions +auto transition = respond::TransitionFactory::CreateTransition("behavior", "logger_name"); +transition->AddTransitionMatrix(some_matrix); +model->AddTransition(transition); + +// Execute one simulation step +model->RunTransitions(); + +// Retrieve current state +Eigen::VectorXd current_state = model->GetState(); + +// Access history records +auto histories = model->GetHistories(); +``` + +### Key Methods + +- `SetState(const Eigen::VectorXd &state)`: Sets the model's state vector (copied internally) +- `GetState() const`: Returns a copy of the current state +- `RunTransitions()`: Executes all registered transitions +- `AddTransition(const std::unique_ptr &t)`: Adds a transition (assumes ownership) +- `GetTransitionNames() const`: Returns names of all transitions +- `ClearTransitions()`: Removes all transitions +- `GetHistories() const`: Returns map of history name to History objects +- `CreateDefaultHistories()`: Initializes default history tracking +- `SetHistories(const std::map &h)`: Sets history records +- `GetModelName() const`: Returns model name +- `GetLogName() const`: Returns associated logger name +- `clone() const`: Creates a deep copy of the model + +## Simulation Class + +The Simulation class manages multiple models and coordinates their execution. + +```cpp +#include + +// Create a simulation +respond::Simulation sim("my_logger"); + +// Add models +auto model1 = respond::Model::Create("model1", "my_logger"); +auto model2 = respond::Model::Create("model2", "my_logger"); +sim.AddModel(model1); +sim.AddModel(model2); + +// Run one step (executes all model transitions) +sim.Run(); + +// Retrieve results +auto all_histories = sim.GetModelHistories(); +auto model_names = sim.GetModelNames(); + +// Get detailed history mapping +auto history_names = sim.GetModelHistoryNames(); +// Returns vector of (model_name, history_name) pairs +``` + +### Key Methods + +- `Run()`: Executes one simulation step for all models +- `AddModel(const std::unique_ptr &model)`: Adds a model (cloned internally) +- `GetModels() const`: Returns const reference to model vector +- `GetModelNames() const`: Returns all model names +- `ClearModels()`: Removes all models +- `GetModelHistories() const`: Returns state histories for all models +- `GetModelHistoryNames() const`: Returns (model_name, history_name) pairs +- `GetLogName() const`: Returns logger name + +## History Class + +The History class records and manages state vectors across timesteps. + +```cpp +#include + +// Create a history +respond::History hist("population_states", "my_logger"); + +// Add states at specific timesteps +hist.AddState(state_vector_0, 0); +hist.AddState(state_vector_1, 1); +hist.AddState(state_vector_2, 2); + +// Or let it auto-assign timesteps +hist.AddState(another_state); // Assigned to next available timestep + +// Retrieve states +auto state_at_t0 = hist.GetStateMap()[0]; +auto all_states = hist.GetStateAsVector(); // Contiguous vector, fills gaps + +// Query history properties +std::string name = hist.GetHistoryName(); +std::string log_name = hist.GetLogName(); + +// Clear history +hist.Clear(); +``` + +### Key Methods + +- `AddState(const Eigen::VectorXd &state, int timestep = -1)`: Records a state + - If timestep < 0, automatically assigns next available timestep + - If timestep already exists, currently overwrites +- `GetStateMap() const`: Returns map of timestep → state vector +- `GetStateAsVector() const`: Returns contiguous vector of states (fills gaps with zeros) +- `GetHistoryName() const`: Returns history identifier +- `GetLogName() const`: Returns logger name +- `Clear()`: Removes all recorded states +- `operator==`, `operator!=`: Comparison operators + +## Transition Class + +The Transition class is abstract; use TransitionFactory to create concrete instances. + +```cpp +#include +#include + +// Create a transition using the factory +auto transition = respond::TransitionFactory::CreateTransition( + "behavior", // Type: migration, behavior, intervention, overdose, background_death + "my_logger" // Logger name +); + +// Add transformation matrices +Eigen::MatrixXd trans_matrix = ...; +transition->AddTransitionMatrix(trans_matrix); + +// Execute the transition (typically done via Model::RunTransitions) +auto histories_map = ...; // From model +Eigen::VectorXd result = transition->Execute(current_state, histories_map); + +// Get transition properties +std::string name = transition->GetTransitionName(); +std::string log = transition->GetLogName(); + +// Clear matrices +transition->ClearTransitionMatrices(); +``` + +### Supported Transition Types + +| Type | Description | +|------|-------------| +| "migration" | Population migration transitions | +| "behavior" | Behavioral state changes | +| "intervention" | Intervention-driven transitions | +| "overdose" | Overdose-related transitions | +| "background_death" | Background mortality transitions | + +## Logging Integration + +RESPOND uses the spdlog library for logging. Models and transitions accept a logger name: + +```cpp +// All logging is handled by passing logger names +auto model = respond::Model::Create("my_model", "my_logger"); + +// The model will use this logger for any errors or warnings +// Create loggers separately using respond::CreateFileLogger +respond::CreateFileLogger("my_logger", "path/to/logfile.log"); +``` + +## Complete Example + +```cpp +#include +#include +#include +#include + +int main() { + // Create logger + respond::CreateFileLogger("app", "simulation.log"); + + // Create simulation + respond::Simulation sim("app"); + + // Create and configure a model + auto model = respond::Model::Create("population_model", "app"); + + // Set initial state (e.g., 1000 individuals across 50 states) + Eigen::VectorXd initial_state = Eigen::VectorXd::Zero(50); + initial_state(0) = 1000; // All in first state + model->SetState(initial_state); + + // Add transitions + auto behavior_transition = respond::TransitionFactory::CreateTransition( + "behavior", "app"); + // Add matrices... + model->AddTransition(behavior_transition); + + auto migration_transition = respond::TransitionFactory::CreateTransition( + "migration", "app"); + // Add matrices... + model->AddTransition(migration_transition); + + // Add model to simulation + sim.AddModel(model); + + // Run simulation for 52 timesteps + for (int t = 0; t < 52; ++t) { + sim.Run(); + } + + // Extract results + auto histories = sim.GetModelHistories(); + auto history_names = sim.GetModelHistoryNames(); + + // Process results... + + return 0; +} +``` + +## Memory Management + +RESPOND uses `std::unique_ptr` for ownership management: + +- Models and Transitions are typically managed by Simulation or parent objects +- History objects are copyable and can be freely copied +- All models are cloned when added to a Simulation (ownership transfer) +- Clearing containers (ClearModels, ClearTransitions) deletes contained objects + +## Best Practices + +1. **Use TransitionFactory** to create transitions—it handles type dispatch +2. **Let Simulation manage models** for automatic cloning and lifecycle management +3. **Reuse History objects** for multiple runs to accumulate results +4. **Use const references** where available (GetState returns a copy for safety) +5. **Initialize loggers early** before creating models to enable error tracking +6. **Validate matrix dimensions** before adding to transitions (not checked by API) + +## Common Patterns + +### Running Multiple Independent Simulations + +```cpp +for (int run = 0; run < num_runs; ++run) { + respond::Simulation sim("logger_" + std::to_string(run)); + + auto model = respond::Model::Create("model", "logger_" + std::to_string(run)); + // Configure model... + + sim.AddModel(model); + for (int t = 0; t < duration; ++t) { + sim.Run(); + } + + // Store results... +} +``` + +### Resetting Model State + +```cpp +// To reset a model to initial state +Eigen::VectorXd initial_state = ...; +model->SetState(initial_state); + +// To also clear history +model->ClearTransitions(); +model->CreateDefaultHistories(); +``` + +### Copying Simulations + +```cpp +respond::Simulation sim1("logger"); +// ... configure sim1 ... + +// Create independent copy +respond::Simulation sim2 = sim1; // All models are cloned + +// Modifications to sim2 don't affect sim1 +``` + +## Parallel Execution with Shared Logging + +When running multiple models in parallel, all loggers can safely write to the same file using RESPOND's shared sink functionality. This ensures thread-safe logging without file corruption. + +### Basic Parallel Logging Setup + +```cpp +#include +#include +#include +#include + +int main() { + // Configure shared logging (all loggers write to same file) + respond::SetLogPattern(respond::LogPattern::kThreadSafe); + respond::SetFlushInterval(3); // Auto-flush every 3 seconds + + // Create multiple loggers that share the same file sink + respond::CreateSharedLogger("model_1"); + respond::CreateSharedLogger("model_2"); + respond::CreateSharedLogger("model_3"); + + // Now multiple threads can safely write to shared log + return 0; +} +``` + +### Running Models in Parallel with Unified Logging + +```cpp +#include +#include +#include +#include + +void RunSimulation(int id, const std::string& log_file) { + std::string logger_name = "model_" + std::to_string(id); + + // Create logger that uses shared sink + respond::CreateSharedLogger(logger_name); + + // Create and run simulation + respond::Simulation sim(logger_name); + auto model = respond::Model::Create("model", logger_name); + + // Configure model... + Eigen::VectorXd initial_state = Eigen::VectorXd::Zero(50); + initial_state(0) = 1000; + model->SetState(initial_state); + + // Add transitions... + sim.AddModel(model); + + // Run simulation + for (int t = 0; t < 52; ++t) { + sim.Run(); + } + + // Flush logs for this model + respond::FlushAllLoggers(); +} + +int main() { + // Setup shared logging once + respond::SetLogPattern(respond::LogPattern::kThreadSafe); + respond::SetFlushInterval(0); // Flush immediately + + const int num_threads = 4; + std::vector threads; + + // Launch parallel simulations + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back(RunSimulation, i, "unified.log"); + } + + // Wait for all to complete + for (auto& t : threads) { + t.join(); + } + + // All output safely written to unified.log + return 0; +} +``` + +### Shared Logger Pattern Options + +The `LogPattern` enum controls log format for all shared loggers: + +- **`kSimple`**: Minimal format `[logger_name] message` +- **`kStandard`**: Includes time and thread ID (default) +- **`kDetailed`**: Full timestamp with milliseconds; best for debugging +- **`kThreadSafe`**: Optimized for concurrent writes with sequence numbers + +```cpp +// Change pattern anytime +respond::SetLogPattern(respond::LogPattern::kDetailed); + +// Query current pattern +auto current = respond::GetLogPattern(); + +// Get pattern as string for programmatic use +std::string pattern_str = respond::LoggingConfig::GetPatternString(current); +``` + +### Monitoring Shared Loggers + +```cpp +// Check if logger exists +bool exists = (respond::CheckLoggerExists("model_1") == respond::CreationStatus::kExists); + +// Get detailed logger information +std::string info = respond::GetLoggerInfo("model_1"); +// Returns: "Logger: model_1\n Level: debug\n Sinks: 1" + +// Set individual logger level +respond::SetLoggerLevel("model_1", spdlog::level::info); + +// Flush all loggers immediately +respond::FlushAllLoggers(); +``` + +### Thread-Safe File Sink Management + +The `CreateSharedFileSink` function creates file sinks that are automatically cached and reused: + +```cpp +// Create or get cached sink for filepath +auto sink = respond::CreateSharedFileSink("logs/simulation.log"); +// If called again with same path, returns existing sink (no duplicate file handles) + +// Multiple loggers using same sink (no file conflicts) +respond::CreateSharedLogger("logger_1"); // Uses default sink +respond::CreateSharedLogger("logger_2"); // Uses same sink +// Both logger_1 and logger_2 write to same file safely +``` + +### Best Practices for Parallel Logging + +1. **Call `SetLogPattern()` once** at program startup, before creating any loggers +2. **Call `CreateSharedLogger()` instead of `CreateFileLogger()`** when using parallel execution +3. **Use `kThreadSafe` pattern** when logs will have high concurrent write volume +4. **Set `FlushInterval(0)`** for critical logging; use `FlushInterval(3-5)` for performance +5. **Call `FlushAllLoggers()`** at end of main before exit to ensure all writes complete +6. **Monitor logger levels** with `GetLoggerInfo()` when debugging multi-model runs + +### Troubleshooting + +- **Assertion failures**: Ensure matrix dimensions match state vector size before adding to transitions +- **Empty histories**: Call `CreateDefaultHistories()` after model setup or manually add histories +- **Logger errors**: Ensure logger names exist (create with `CreateFileLogger` if needed) +- **Memory issues**: Verify no circular unique_ptr references; models own transitions + +For more information, see the [Doxygen-generated API documentation](../doxygen/html/index.html) or the [Architecture and Design guide](architecture.md). + +Previous: [Architecture and Design](architecture.md) + +Next: [Data Guide](data.md) diff --git a/docs/src/architecture.md b/docs/src/architecture.md new file mode 100644 index 00000000..672dd77e --- /dev/null +++ b/docs/src/architecture.md @@ -0,0 +1,330 @@ +# Architecture and Design + +This document describes the architectural decisions and design patterns used in RESPOND. + +## Design Philosophy + +RESPOND follows the **inversion of control** principle, abstracting the model to its core components and allowing users to customize it to their needs rather than maintaining a rigid, monolithic structure. This enables: + +1. **Extensibility** - New transition types can be added without modifying core code +2. **Testability** - Components can be tested independently +3. **Maintainability** - Clear separation of concerns +4. **Portability** - Easy to integrate into different applications + +## Component Architecture + +``` +┌─────────────────────────────────────────────────┐ +│ Simulation │ +│ (aggregates and coordinates Models) │ +└──────────────────┬──────────────────────────────┘ + │ + ┌──────────┴──────────┬──────────────┐ + │ │ │ + ┌───▼────┐ ┌───▼────┐ ┌──▼────┐ + │ Model │ │ Model │ │ Model │ + │ (PopA) │ │ (PopB) │ │(PopC) │ + └───┬────┘ └───┬────┘ └──┬────┘ + │ │ │ + ├─ Transitions ────┐ │ │ + │ - Migration │ │ │ + │ - Behavior │ │ │ + │ - Intervention │ │ │ + │ - Overdose │ │ │ + │ - Background │ │ │ + └──────────────────┘ │ │ + │ │ │ + └─ Histories ────┐ │ │ + - State │ │ │ + - Outcomes │ │ │ + - Costs │ │ │ + └────────────┘ └─────────────┘ +``` + +## Core Classes + +### Model (Abstract Base Class) + +- **Role**: Represents a state transition system +- **Responsibilities**: + - Manages state vector + - Owns and executes transitions + - Tracks history + - Provides cloning capability +- **Key Design Decisions**: + - Non-copyable by assignment (enforces `clone()` usage for clarity) + - Owns transitions (unique_ptr for memory safety) + - Read-only GetState() (returns copy to prevent external state modification) + +### Simulation + +- **Role**: Aggregates multiple independent models +- **Responsibilities**: + - Coordinates model execution + - Collects results from all models + - Manages simulation-level state +- **Key Design Decisions**: + - Clones models on addition (ownership clarity) + - Copyable (deep copy semantics) + - Provides convenient result collection methods + +### Transition (Abstract Base Class) + +- **Role**: Represents a specific type of model transition +- **Responsibilities**: + - Applies state transformations + - Updates history records + - Manages transformation matrices +- **Implementation Types**: + - **Migration**: Population movement between states + - **Behavior**: Behavioral state transitions + - **Intervention**: Intervention effects + - **Overdose**: Overdose dynamics + - **BackgroundDeath**: Mortality transitions +- **Key Design Decisions**: + - Non-copyable (prevents accidental duplication of stateful transformations) + - Uses TransitionFactory for creation (encapsulates type selection) + - Const-correct Execute() (doesn't modify transition state) + +### History + +- **Role**: Records state vectors over time +- **Responsibilities**: + - Sparse timestep tracking + - State retrieval (by index or as vector) + - Comparison operations +- **Key Design Decisions**: + - Copyable (lightweight data container) + - Sparse internal storage (efficient memory for gaps) + - Auto-fills gaps with zeros (simplifies downstream analysis) + - Map-based storage (allows non-sequential timesteps) + +### TransitionFactory + +- **Role**: Creates concrete Transition instances +- **Responsibilities**: + - Encapsulates type dispatch logic + - Provides single point of extensibility for new transitions +- **Key Design Decisions**: + - Static factory method (no factory state needed) + - String-based type identification (simple, extensible) + - Case-insensitive type matching (user-friendly) + +## Design Patterns + +### Factory Pattern (TransitionFactory) + +Encapsulates object creation for transitions: + +```cpp +auto transition = TransitionFactory::CreateTransition("behavior", "logger"); +``` + +**Benefits**: +- Decouples transition creation from usage +- Centralizes type dispatch logic +- Easy to add new transition types + +### Template Method Pattern (Model → Transitions) + +Model delegates to transitions in RunTransitions(): + +```cpp +void Model::RunTransitions() { + for (const auto& transition : _transitions) { + _state = transition->Execute(_state, _histories); + } +} +``` + +**Benefits**: +- Flexible transition composition +- Order-dependent execution +- Custom transition behavior per model + +### Strategy Pattern (Transitions) + +Different transition types implement Execute() differently: + +```cpp +// Migration::Execute() - handles population movement +// Behavior::Execute() - handles behavior changes +// etc. +``` + +**Benefits**: +- Runtime selection of transition algorithms +- No conditional logic in Model +- Easy to add new strategies + +### Object Pool / Clone Pattern + +Models and Transitions support cloning for independent copies: + +```cpp +auto model_copy = model->clone(); // Deep copy +auto transition_copy = transition->clone(); // Deep copy +``` + +**Benefits**: +- Explicit control over deep vs. shallow copies +- Clear ownership semantics +- Safe concurrent execution + +## Memory Management + +RESPOND uses modern C++ memory management practices: + +### Unique Ownership (unique_ptr) + +Used for objects with clear ownership: +- Model owns its Transitions +- Simulation owns its Models (via cloning) + +### Shared Ownership (None by default) + +RESPOND minimizes shared state. History objects are the exception—they're: +- Copyable (value semantics) +- Used as values in Model maps +- Accessed through const references where possible + +### Safety Mechanisms + +- **Deleted copy operators** on base classes prevent slicing: + ```cpp + Model(const Model&) = delete; + Model& operator=(const Model&) = delete; + virtual unique_ptr clone() const = 0; // Force explicit cloning + ``` +- **const-correctness** throughout API +- **Pass-by-value** for small objects (Eigen uses move semantics internally) + +## Extensibility Points + +### Adding a New Transition Type + +1. Create a new header in `include/respond/internals/` +2. Implement concrete Transition subclass +3. Add factory entry in `TransitionFactory::CreateTransition()` + +Example: +```cpp +// respond/internals/custom_transition.hpp +class CustomTransition : public Transition { +public: + static std::unique_ptr Create(...); + // ... implement virtual methods +}; + +// In transition_factory.cpp +if (type == "custom") { + return CustomTransition::Create(type, log_name); +} +``` + +### Adding New History Types + +1. Extend History class or create a new class +2. Update Model::CreateDefaultHistories() to instantiate new types +3. Update GetHistories() documentation + +## Dependencies + +### External Libraries + +- **Eigen** - Linear algebra (state vectors, transition matrices) +- **spdlog** - Logging +- **GoogleTest** - Unit testing (optional) + +### Internal Organization + +- **`include/respond/`** - Public API headers +- **`src/internals/`** - Internal implementation headers +- **`src/`** - Implementation files +- **`tests/`** - Unit and integration tests + +## Threading and Concurrency + +Current implementation is **not thread-safe**: + +- No internal locking mechanisms +- State modification is not atomic +- Multiple simulations can run independently (each with own state) + +For concurrent execution: +- Create separate Simulation instances +- Each thread manages its own simulation +- Synchronize result collection externally + +## Performance Considerations + +### State Vector Operations + +- Heavy use of Eigen for linear algebra +- Matrices stored by value (memory efficient) +- Copy elision via move semantics + +### Sparse History + +- Map-based storage avoids allocating full history +- Gap-filling only occurs during GetStateAsVector() +- Minimal memory overhead for sparse timesteps + +### Transition Execution Order + +- Transitions execute sequentially in order added +- Each transition reads from current state, writes results +- History updated after each transition + +## Validation and Error Handling + +### Current Approach + +- Minimal runtime validation +- Relies on preconditions (documented in comments) +- Errors logged through spdlog + +### Improvements for Future Versions + +- Add matrix dimension validation +- Validate state vector sizes +- Stricter type checking in factory + +## Testing Strategy + +### Unit Tests + +Located in `tests/unit/`, testing individual components: +- State management +- Transition execution +- History recording +- Factory creation + +### Integration Tests + +Located in `tests/integration/`, testing end-to-end scenarios: +- Full simulation execution +- Multi-model coordination +- Result aggregation + +### Mock Objects + +`tests/mocks/` provides: +- Model mock for testing Simulation +- Transition mock for testing Model + +## Future Architectural Considerations + +1. **Async Execution** - Enable parallel model execution +2. **Plugin System** - Dynamic loading of transitions +3. **Serialization** - Save/restore simulation state +4. **Validation Framework** - Compile-time and runtime checks +5. **Performance Profiling** - Built-in timing/statistics + +--- + +For more implementation details, refer to the [Doxygen documentation](../doxygen/html/index.html) and source code comments. + +Previous: [FAQs](faq.md) + +Next: [API Guide](api-guide.md) diff --git a/docs/src/index.md b/docs/src/index.md index 7ae104a8..b69f8562 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -51,4 +51,41 @@ For tests we require: As part of our automated workflow, we provide several different installable packages as part of each release. They can be found on each [tagged release on our GitHub](https://github.com/SyndemicsLab/respond/tags). Currently we provide RPM and Debian installers for static and shared libraries, and a NSIS windows static library installer. For a step-by-step procedure of building the package from source, please check the [installation documentation](installation.md). +## Documentation + +This documentation covers both user and developer perspectives: + +- **[Installation](installation.md)** - Build and installation instructions +- **[Motivation](motivation.md)** - Model design and goals +- **[Architecture and Design](architecture.md)** - Design patterns, component architecture, and extensibility +- **[C++ API Guide](api-guide.md)** - Developer guide for using RESPOND as a library +- **[Data Management](data.md)** - Configuration and data requirements +- **[Running the Model](run.md)** - Executing simulations +- **[Math Background](math.md)** - Mathematical foundations +- **[Limitations](limitations.md)** - Known constraints and future work +- **[FAQs](faq.md)** - Frequently asked questions + +## Quick Start for Developers + +To use RESPOND in your C++ project: + +```cmake +include(FetchContent) +FetchContent_Declare( + respond + GIT_REPOSITORY https://github.com/SyndemicsLab/respond.git + GIT_TAG main +) +option(RESPOND_INSTALL "Enable install for respond project" ON) +option(RESPOND_BUILD_TESTS "Disable testing for RESPOND" OFF) +FetchContent_MakeAvailable(respond) + +target_link_libraries(${PROJECT_NAME} + PRIVATE + respond::respond_model +) +``` + +Then see the [C++ API Guide](api-guide.md) for usage examples and best practices. + Next: [Motivation](motivation.md) diff --git a/docs/src/run.md b/docs/src/run.md index f7217fcb..779f5394 100644 --- a/docs/src/run.md +++ b/docs/src/run.md @@ -1,18 +1,111 @@ # Running the Model -The RESPOND model, in addition to forming the basis for the Simdemics library, acts as a completely independent executable model. This model was designed to track state transitions to study opioid use disorder. +The RESPOND model can be used in multiple ways: -## Using the Executable +1. **As a C++ library** - Integrate into your own C++ projects +2. **Via Python package** - Use the higher-level [respondpy](https://github.com/SyndemicsLab/respondpy) interface +3. **Standalone executable** - Legacy support for direct command-line execution -Running RESPOND is incredibly simple provided you use our packaged executable. If built using the "gcc-release" workflow outlined in the [Installation Section](installation.md) the following command runs the executable from the root of the repository on input folder 1: +## Using RESPOND as a C++ Library + +For C++ developers, RESPOND can be integrated directly into projects. See the [C++ API Guide](api-guide.md) for: + +- Complete API reference +- Usage examples +- Memory management details +- Design patterns and best practices + +### Basic Example + +```cpp +#include +#include + +int main() { + // Create a simulation + respond::Simulation sim("my_logger"); + + // Create and configure a model + auto model = respond::Model::Create("population_model", "my_logger"); + + // Set initial state + Eigen::VectorXd initial_state(50); + initial_state.setZero(); + initial_state(0) = 1000; // 1000 individuals in state 0 + model->SetState(initial_state); + + // Add transitions + auto transition = respond::TransitionFactory::CreateTransition( + "behavior", "my_logger"); + model->AddTransition(transition); + + // Add model to simulation + sim.AddModel(model); + + // Run simulation for 52 timesteps + for (int t = 0; t < 52; ++t) { + sim.Run(); + } + + // Extract results + auto histories = sim.GetModelHistories(); + + return 0; +} +``` + +For complete examples and API documentation, see the [C++ API Guide](api-guide.md). + +## Using the Standalone Executable + +Legacy support for standalone execution is available in release [v0.3.0](https://github.com/SyndemicsLab/respond/releases/tag/v0.3.0). + +### Command-Line Usage + +To run the executable (requires legacy release): + +```bash +./respond_exe /path/to/input/folders 1 1 +``` + +### Arguments + +The executable takes 3 positional arguments: + +1. **Input folder path**: Directory containing simulation configuration and data +2. **Start folder index**: First input folder number (e.g., 1 for `input1`) +3. **End folder index**: Last input folder number (inclusive) + +### Examples ```bash -./build/extras/executable/respond_exe /path/to/input/folders 1 1 +# Run single input folder +./respond_exe /home/user/data 1 1 + +# Run folders input1, input2, input3 +./respond_exe /home/user/data 1 3 + +# Run folders input5 through input10 +./respond_exe /home/user/data 5 10 ``` -## Arguments +### Output + +The executable produces output files in the same directory as the input data. + +## Python Integration + +For a higher-level interface, use the [respondpy](https://pypi.org/project/respondpy/) Python package: + +```python +import respondpy as respond + +# Configure and run simulations +sim = respond.Simulation(config_path) +results = sim.run() +``` -The executable takes 3 straightforward positional arguments that govern the input folder location, the starting input folder and the end input folder inclusively. Thus, if you only have a single input folder titled `input1` located at `/home/usr/` you would provide the arguments: `/home/usr/ 1 1`. If you have multiple input folders (i.e. `input1`, `input2`, and `input3`) you would provide: `/home/usr/ 1 3`. +See the [respondpy documentation](https://github.com/SyndemicsLab/respondpy) for details. Previous: [Data](data.md) diff --git a/extras/benchmark/CMakeLists.txt b/extras/benchmark/CMakeLists.txt index f8b16bc2..499d2fa6 100644 --- a/extras/benchmark/CMakeLists.txt +++ b/extras/benchmark/CMakeLists.txt @@ -1,19 +1,13 @@ cmake_minimum_required(VERSION 3.27) -project(respondBench CXX) - -find_package(benchmark REQUIRED) -find_package(spdlog REQUIRED) +project(respond_benchmark LANGUAGES CXX) add_executable(${PROJECT_NAME} - src/BENCHMARK_respond.cpp + src/benchmark_respond.cpp ) -target_include_directories(respond PRIVATE - ${PROJECT_SOURCE_DIR}/lib/DataManagement/include -) +target_compile_features(${PROJECT_NAME} PRIVATE cxx_std_20) -target_link_libraries(${PROJECT_NAME} PUBLIC benchmark::benchmark) -target_link_libraries(${PROJECT_NAME} PUBLIC spdlog::spdlog) -target_link_libraries(${PROJECT_NAME} PUBLIC matrixify) -target_link_libraries(${PROJECT_NAME} PUBLIC simrunner) -target_link_libraries(${PROJECT_NAME} PUBLIC DataManagement) +target_link_libraries(${PROJECT_NAME} + PRIVATE + respond::respond_model +) diff --git a/extras/benchmark/README.md b/extras/benchmark/README.md new file mode 100644 index 00000000..308c9297 --- /dev/null +++ b/extras/benchmark/README.md @@ -0,0 +1,75 @@ +# RESPOND Benchmarking Procedure + +This benchmark target is designed for reproducible performance measurement of core RESPOND state-transition execution using only the C++ standard library timing utilities. + +## Principles + +- Use `std::chrono::steady_clock` to avoid wall-clock discontinuities. +- Separate warm-up from timed measurement. +- Collect multiple samples and summarize robust statistics (min, max, mean, median, p95, stddev). +- Keep benchmark setup deterministic so samples are comparable across runs. +- Prevent dead-code elimination using a checksum sink. + +## What Is Measured + +The benchmark constructs one RESPOND model and measures repeated execution of: + +- `Model::RunTransitions()` for a fixed number of timesteps +- Transition mix: behavior, intervention, overdose, background death +- Deterministic transition matrices/vectors and deterministic initial state + +Warm-up runs are excluded from reported timings. + +## Build + +Enable benchmark builds in CMake: + +```bash +cmake -S . -B build/bench -DRESPOND_BUILD_BENCH=ON +cmake --build build/bench --target respond_benchmark +``` + +## Run + +```bash +./build/bench/bin/respond_benchmark \ + --state-size 64 \ + --steps 52 \ + --history-capture-interval 1 \ + --warmup 5 \ + --samples 500 \ + --repetitions 3 +``` + +## Interpreting Output + +Before timing rows, the benchmark prints environment metadata to make results reproducible across machines and builds: + +- compiler and compiler version +- C++ standard level +- build type (Debug/Release) +- pointer size +- CPU model (Linux reads `/proc/cpuinfo`) +- hardware thread count +- Eigen version and active Eigen thread count + +Each repetition reports: + +- `mean_ms`: average sample time +- `p50_ms`: median sample time +- `p95_ms`: 95th percentile sample time +- `min_ms` / `max_ms`: observed range +- `std_ms`: standard deviation +- `ns/step`: normalized cost per simulation timestep +- `state_pts`: number of recorded state-history points captured per sample +- `checksum`: final state sum sink to guard against optimization artifacts + +Use the `overall` row for broad comparisons and repetition rows to detect instability. + +## Recommended Comparison Workflow + +1. Run baseline benchmark on the reference branch. +2. Run benchmark after your change with the same arguments and machine load. +3. Compare `overall` mean, p95, and stddev first. +4. If regression appears, inspect repetition spread to determine variance vs consistent slowdown. +5. Re-run with higher `--samples` for tighter confidence when results are close. diff --git a/extras/benchmark/src/BENCHMARK_respond.cpp b/extras/benchmark/src/BENCHMARK_respond.cpp deleted file mode 100644 index ab6b313a..00000000 --- a/extras/benchmark/src/BENCHMARK_respond.cpp +++ /dev/null @@ -1,198 +0,0 @@ -#include "CostLoader.hpp" -#include "DataFormatter.hpp" -#include "DataLoader.hpp" -#include "Helpers.hpp" -#include "PostSimulationCalculator.hpp" -#include "Writer.hpp" -#include "markov.hpp" -#include "spdlog/sinks/basic_file_sink.h" -#include "spdlog/spdlog.h" -#include -#include -#include -#include -#include -#include -#include -#include - -std::filesystem::path BENCHMARK_INPUT = - std::filesystem::temp_directory_path() / "benchmark"; - -static int respond_main(const int inputID) { - std::filesystem::path inputSet = - BENCHMARK_INPUT / ("input" + std::to_string(inputID)); - std::filesystem::path outputDir = - BENCHMARK_INPUT / ("output" + std::to_string(inputID)); - std::filesystem::create_directory(outputDir); - - std::string log_path = outputDir.string() + "/log.txt"; - std::shared_ptr logger; - - try { - logger = spdlog::basic_logger_mt("logger" + std::to_string(inputID), - log_path); -#ifndef NDEBUG - spdlog::set_level(spdlog::level::debug); -#endif - } catch (const spdlog::spdlog_ex &ex) { - std::cout << "Log init failed: " << ex.what() << std::endl; - return 1; - } - - logger->info("Logger Created"); - - std::shared_ptr inputs = - std::make_shared(nullptr, inputSet.string(), - logger); - logger->info("DataLoader Created"); - - std::shared_ptr costLoader = - std::make_shared(inputSet.string()); - logger->info("CostLoader Created"); - - std::shared_ptr utilityLoader = - std::make_shared(inputSet.string()); - logger->info("UtilityLoader Created"); - - inputs->loadInitialSample("init_cohort.csv"); - if (inputs->getEnteringCohortToggle()) { - inputs->loadEnteringSamples("entering_cohort.csv"); - } else { - inputs->loadEnteringSamples("entering_cohort.csv", "No_Treatment", - "Active_Noninjection"); - } - inputs->loadOUDTransitionRates("oud_trans.csv"); - inputs->loadInterventionInitRates("block_init_effect.csv"); - inputs->loadInterventionTransitionRates("block_trans.csv"); - inputs->loadOverdoseRates("all_types_overdose.csv"); - inputs->loadFatalOverdoseRates("fatal_overdose.csv"); - inputs->loadMortalityRates("SMR.csv", "background_mortality.csv"); - - if (costLoader->getCostSwitch()) { - costLoader->loadHealthcareUtilizationCost( - "healthcare_utilization_cost.csv"); - costLoader->loadOverdoseCost("overdose_cost.csv"); - costLoader->loadPharmaceuticalCost("pharmaceutical_cost.csv"); - costLoader->loadTreatmentUtilizationCost( - "treatment_utilization_cost.csv"); - - utilityLoader->loadBackgroundUtility("bg_utility.csv"); - utilityLoader->loadOUDUtility("oud_utility.csv"); - utilityLoader->loadSettingUtility("setting_utility.csv"); - } - - Simulation::Respond sim(inputs); - sim.run(); - preprocess::History history = sim.getHistory(); - - preprocess::CostList basecosts; - preprocess::Matrix4d baseutilities; - double baselifeYears = 0.0; - std::vector totalBaseCosts; - double totalBaseUtility = 0.0; - - preprocess::CostList disccosts; - preprocess::Matrix4d discutilities; - double disclifeYears; - std::vector totalDiscCosts; - double totalDiscUtility = 0.0; - - if (costLoader->getCostSwitch()) { - Calculator::PostSimulationCalculator PostSimulationCalculator(history); - basecosts = PostSimulationCalculator.calculateCosts(costLoader); - totalBaseCosts = - Helpers::calcCosts(PostSimulationCalculator, basecosts); - - baseutilities = PostSimulationCalculator.calculateUtilities( - utilityLoader, Calculator::UTILITY_TYPE::MIN); - totalBaseUtility = - PostSimulationCalculator.totalAcrossTimeAndDims(baseutilities); - baselifeYears = PostSimulationCalculator.calculateLifeYears(); - if (costLoader->getDiscountRate() != 0.0) { - disccosts = PostSimulationCalculator.calculateCosts(costLoader); - - totalDiscCosts = - Helpers::calcCosts(PostSimulationCalculator, disccosts); - discutilities = PostSimulationCalculator.calculateUtilities( - utilityLoader, Calculator::UTILITY_TYPE::MIN); - totalDiscUtility = - PostSimulationCalculator.totalAcrossTimeAndDims(discutilities); - disclifeYears = PostSimulationCalculator.calculateLifeYears(); - } - } - - std::vector outputTimesteps = inputs->getGeneralStatsOutputTimesteps(); - - bool pivot_long = false; - pivot_long = std::get( - inputs->getConfig()->get("output.pivot_long", pivot_long)); - - preprocess::HistoryWriter historyWriter( - outputDir.string(), inputs->getInterventions(), inputs->getOUDStates(), - inputs->getDemographics(), inputs->getDemographicCombos(), - outputTimesteps, preprocess::WriteType::FILE, pivot_long); - - preprocess::DataFormatter formatter; - - formatter.extractTimesteps(outputTimesteps, history, basecosts, - baseutilities, costLoader->getCostSwitch()); - - historyWriter.writeHistory(history); - - bool writeParameters = false; - writeParameters = std::get(inputs->getConfig()->get( - "output.write_calibrated_inputs", writeParameters)); - if (writeParameters) { - preprocess::InputWriter ipWriter(outputDir.string(), outputTimesteps, - preprocess::WriteType::FILE); - ipWriter.writeParameters(inputs); - } - - // Probably want to figure out the right way to do this - if (costLoader->getCostSwitch()) { - preprocess::CostWriter costWriter( - outputDir.string(), inputs->getInterventions(), - inputs->getOUDStates(), inputs->getDemographics(), - inputs->getDemographicCombos(), outputTimesteps, - preprocess::WriteType::FILE, pivot_long); - costWriter.writeCosts(basecosts); - } - if (utilityLoader->getCostSwitch()) { - preprocess::UtilityWriter utilityWriter( - outputDir.string(), inputs->getInterventions(), - inputs->getOUDStates(), inputs->getDemographics(), - inputs->getDemographicCombos(), outputTimesteps, - preprocess::WriteType::FILE, pivot_long); - utilityWriter.writeUtilities(baseutilities); - } - if (costLoader->getCostSwitch()) { - preprocess::Totals totals; - totals.baseCosts = totalBaseCosts; - totals.baseLifeYears = baselifeYears; - totals.baseUtility = totalBaseUtility; - totals.discCosts = totalDiscCosts; - totals.discLifeYears = disclifeYears; - totals.discUtility = totalDiscUtility; - preprocess::TotalsWriter totalsWriter( - outputDir.string(), inputs->getInterventions(), - inputs->getOUDStates(), inputs->getDemographics(), - inputs->getDemographicCombos(), outputTimesteps, - preprocess::WriteType::FILE); - totalsWriter.writeTotals(totals); - } - - std::cout << "Output " << std::to_string(inputID) << " Complete" - << std::endl; - - std::cout << "Simulation Complete! :)" << std::endl; - return 0; -} - -static void run(benchmark::State &state) { respond_main(1); } - -// mark functions for benchmarking -BENCHMARK(run); - -// run benchmarks -BENCHMARK_MAIN(); diff --git a/extras/benchmark/src/benchmark_respond.cpp b/extras/benchmark/src/benchmark_respond.cpp new file mode 100644 index 00000000..e484bb67 --- /dev/null +++ b/extras/benchmark/src/benchmark_respond.cpp @@ -0,0 +1,467 @@ +//////////////////////////////////////////////////////////////////////////////// +// File: benchmark_respond.cpp // +// Project: respond // +// Created Date: 2026-04-27 // +// Author: Matthew Carroll // +// ----- // +// Last Modified: 2026-04-27 // +// Modified By: Matthew Carroll // +// ----- // +// Copyright (c) 2026 Syndemics Lab at Boston Medical Center // +//////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace { +using Clock = std::chrono::steady_clock; +using Nanoseconds = std::chrono::duration; + +#if defined(__clang__) || defined(__GNUC__) +template inline void DoNotOptimize(const T &value) { + asm volatile("" : : "g"(&value) : "memory"); +} + +inline void ClobberMemory() { asm volatile("" : : : "memory"); } +#else +template inline void DoNotOptimize(const T &value) { + volatile const T *sink = &value; + (void)sink; +} + +inline void ClobberMemory() { + std::atomic_signal_fence(std::memory_order_seq_cst); +} +#endif + +struct BenchmarkConfig { + std::size_t state_size = 200; + int steps = 365; + int warmup_iterations = 5; + int sample_iterations = 25; + int repetitions = 3; + int history_capture_interval = 1; +}; + +struct TimedRunResult { + double elapsed_ns = 0.0; + double checksum = 0.0; + std::size_t recorded_points = 0; +}; + +struct Statistics { + double min_ns = 0.0; + double max_ns = 0.0; + double mean_ns = 0.0; + double median_ns = 0.0; + double p95_ns = 0.0; + double stddev_ns = 0.0; +}; + +double ToMilliseconds(double ns) { return ns / 1'000'000.0; } + +std::string GetBuildType() { +#ifdef NDEBUG + return "Release"; +#else + return "Debug"; +#endif +} + +std::string GetCompilerString() { +#if defined(__clang__) + return std::string("Clang ") + __clang_version__; +#elif defined(__GNUC__) + return std::string("GCC ") + __VERSION__; +#elif defined(_MSC_VER) + return std::string("MSVC ") + std::to_string(_MSC_VER); +#else + return "Unknown compiler"; +#endif +} + +std::string GetCpuModel() { +#if defined(__linux__) + std::ifstream cpuinfo("/proc/cpuinfo"); + if (!cpuinfo.is_open()) { + return "unknown"; + } + + std::string line; + while (std::getline(cpuinfo, line)) { + const std::string key = "model name"; + const auto pos = line.find(key); + if (pos == std::string::npos) { + continue; + } + const auto colon_pos = line.find(':', pos + key.size()); + if (colon_pos == std::string::npos) { + continue; + } + std::string value = line.substr(colon_pos + 1); + const auto first = value.find_first_not_of(" \t"); + if (first == std::string::npos) { + continue; + } + value = value.substr(first); + const auto last = value.find_last_not_of(" \t"); + if (last != std::string::npos) { + value = value.substr(0, last + 1); + } + if (!value.empty()) { + return value; + } + } +#endif + return "unknown"; +} + +void PrintEnvironmentMetadata() { + std::ostringstream pointer_bits; + pointer_bits << (sizeof(void *) * 8U) << "-bit"; + + std::cout << "Environment\n" + << "-----------\n" + << "compiler : " << GetCompilerString() << "\n" + << "c++ standard : " << __cplusplus << "\n" + << "build type : " << GetBuildType() << "\n" + << "pointer size : " << pointer_bits.str() << "\n" + << "cpu model : " << GetCpuModel() << "\n" + << "hw threads : " << std::thread::hardware_concurrency() + << "\n" + << "eigen ver : " << EIGEN_WORLD_VERSION << "." + << EIGEN_MAJOR_VERSION << "." << EIGEN_MINOR_VERSION << "\n" + << "eigen thrds : " << Eigen::nbThreads() << "\n\n"; +} + +void PrintUsage(const char *program_name) { + std::cout + << "Usage: " << program_name << " [options]\n\n" + << "Options:\n" + << " --state-size Number of state dimensions (default: 200)\n" + << " --steps Number of simulation timesteps per sample " + "(default: 365)\n" + << " --warmup Warm-up iterations per repetition (default: " + "5)\n" + << " --samples Timed samples per repetition (default: 25)\n" + << " --repetitions Independent repetitions (default: 3)\n" + << " --history-capture-interval Record every n timesteps " + "(default: 1)\n" + << " --help Show this help\n"; +} + +int ParsePositiveInt(const std::string &value, const std::string &name) { + const long parsed = std::stol(value); + if (parsed <= 0) { + throw std::invalid_argument(name + " must be > 0"); + } + return static_cast(parsed); +} + +std::size_t ParsePositiveSizeT(const std::string &value, + const std::string &name) { + const long long parsed = std::stoll(value); + if (parsed <= 0) { + throw std::invalid_argument(name + " must be > 0"); + } + return static_cast(parsed); +} + +BenchmarkConfig ParseArgs(int argc, char **argv) { + BenchmarkConfig config; + + for (int i = 1; i < argc; ++i) { + const std::string_view arg(argv[i]); + + auto require_value = [&](const std::string_view flag) -> std::string { + if (i + 1 >= argc) { + throw std::invalid_argument("Missing value for " + + std::string(flag)); + } + ++i; + return std::string(argv[i]); + }; + + if (arg == "--help") { + PrintUsage(argv[0]); + std::exit(0); + } + if (arg == "--state-size") { + config.state_size = + ParsePositiveSizeT(require_value(arg), "--state-size"); + continue; + } + if (arg == "--steps") { + config.steps = ParsePositiveInt(require_value(arg), "--steps"); + continue; + } + if (arg == "--warmup") { + config.warmup_iterations = + ParsePositiveInt(require_value(arg), "--warmup"); + continue; + } + if (arg == "--samples") { + config.sample_iterations = + ParsePositiveInt(require_value(arg), "--samples"); + continue; + } + if (arg == "--repetitions") { + config.repetitions = + ParsePositiveInt(require_value(arg), "--repetitions"); + continue; + } + if (arg == "--history-capture-interval") { + config.history_capture_interval = ParsePositiveInt( + require_value(arg), "--history-capture-interval"); + continue; + } + + throw std::invalid_argument("Unknown argument: " + std::string(arg)); + } + + return config; +} + +Eigen::MatrixXd MakeShiftMatrix(std::size_t n, double stay_probability, + int shift) { + Eigen::MatrixXd matrix = Eigen::MatrixXd::Zero( + static_cast(n), static_cast(n)); + + for (std::size_t col = 0; col < n; ++col) { + const std::size_t shifted = static_cast( + (static_cast(col) + shift + static_cast(n)) % + static_cast(n)); + matrix(static_cast(col), static_cast(col)) = + stay_probability; + matrix(static_cast(shifted), + static_cast(col)) = 1.0 - stay_probability; + } + return matrix; +} + +Eigen::VectorXd MakeRateVector(std::size_t n, double base_rate, + double gradient) { + Eigen::VectorXd v(static_cast(n)); + for (std::size_t i = 0; i < n; ++i) { + v(static_cast(i)) = + base_rate + + gradient * static_cast(i) / static_cast(n); + } + return v; +} + +std::unique_ptr BuildModel(std::size_t state_size, + int history_capture_interval, + int final_timestep) { + auto model = respond::Model::Create("benchmark_model", "console"); + model->SetHistoryCaptureInterval(history_capture_interval); + model->SetFinalTimestep(final_timestep); + + auto behavior = + respond::TransitionFactory::CreateTransition("behavior", "console"); + auto intervention = + respond::TransitionFactory::CreateTransition("intervention", "console"); + auto overdose = + respond::TransitionFactory::CreateTransition("overdose", "console"); + auto background = respond::TransitionFactory::CreateTransition( + "background_death", "console"); + + if (!behavior || !intervention || !overdose || !background) { + throw std::runtime_error("Failed to create one or more transitions"); + } + + behavior->AddTransitionMatrix(MakeShiftMatrix(state_size, 0.985, 1)); + intervention->AddTransitionMatrix(MakeShiftMatrix(state_size, 0.990, -1)); + overdose->AddTransitionMatrix(MakeRateVector(state_size, 0.0020, 0.0005)); + overdose->AddTransitionMatrix(MakeRateVector(state_size, 0.0800, 0.0200)); + background->AddTransitionMatrix(MakeRateVector(state_size, 0.0008, 0.0004)); + + model->AddTransition(behavior); + model->AddTransition(intervention); + model->AddTransition(overdose); + model->AddTransition(background); + + return model; +} + +TimedRunResult TimeOneSample(respond::Model &model, + const Eigen::VectorXd &initial_state, int steps) { + model.SetState(initial_state); + model.ClearHistories(); + model.SetFinalTimestep(steps); + + const auto start = Clock::now(); + for (int i = 0; i < steps; ++i) { + model.RunTransitions(); + } + const auto end = Clock::now(); + + const double checksum = model.GetState().sum(); + std::size_t recorded_points = 0; + const auto histories = model.GetHistories(); + const auto state_history = histories.find("state"); + if (state_history != histories.end()) { + recorded_points = state_history->second.GetRecordedTimesteps().size(); + } + DoNotOptimize(checksum); + ClobberMemory(); + + return {.elapsed_ns = Nanoseconds(end - start).count(), + .checksum = checksum, + .recorded_points = recorded_points}; +} + +Statistics ComputeStats(std::vector samples_ns) { + Statistics s; + if (samples_ns.empty()) { + return s; + } + + std::sort(samples_ns.begin(), samples_ns.end()); + + s.min_ns = samples_ns.front(); + s.max_ns = samples_ns.back(); + s.mean_ns = std::accumulate(samples_ns.begin(), samples_ns.end(), 0.0) / + static_cast(samples_ns.size()); + + const std::size_t mid = samples_ns.size() / 2; + if (samples_ns.size() % 2 == 0U) { + s.median_ns = 0.5 * (samples_ns[mid - 1] + samples_ns[mid]); + } else { + s.median_ns = samples_ns[mid]; + } + + const std::size_t p95_idx = static_cast( + std::ceil(0.95 * static_cast(samples_ns.size())) - 1.0); + s.p95_ns = samples_ns[std::min(p95_idx, samples_ns.size() - 1)]; + + double accum = 0.0; + for (const double sample : samples_ns) { + const double delta = sample - s.mean_ns; + accum += delta * delta; + } + s.stddev_ns = std::sqrt(accum / static_cast(samples_ns.size())); + + return s; +} + +void PrintConfig(const BenchmarkConfig &config) { + std::cout << "RESPOND Benchmark\n" + << "-----------------\n" + << "state_size : " << config.state_size << "\n" + << "steps : " << config.steps << "\n" + << "warmup : " << config.warmup_iterations << "\n" + << "samples : " << config.sample_iterations << "\n" + << "repetitions : " << config.repetitions << "\n" + << "hist every : " << config.history_capture_interval << "\n\n"; +} + +void PrintStatsRow(const std::string &label, const Statistics &stats, int steps, + double checksum, std::size_t recorded_points) { + const double mean_ms = ToMilliseconds(stats.mean_ns); + const double median_ms = ToMilliseconds(stats.median_ns); + const double p95_ms = ToMilliseconds(stats.p95_ns); + const double min_ms = ToMilliseconds(stats.min_ns); + const double max_ms = ToMilliseconds(stats.max_ns); + const double stddev_ms = ToMilliseconds(stats.stddev_ns); + const double ns_per_step = stats.mean_ns / static_cast(steps); + + std::cout << std::left << std::setw(10) << label << std::right + << std::setw(12) << std::fixed << std::setprecision(3) << mean_ms + << std::setw(12) << median_ms << std::setw(12) << p95_ms + << std::setw(12) << min_ms << std::setw(12) << max_ms + << std::setw(12) << stddev_ms << std::setw(14) + << std::setprecision(1) << ns_per_step << std::setw(12) + << recorded_points << std::setw(14) << std::setprecision(4) + << checksum << "\n"; +} + +} // namespace + +int main(int argc, char **argv) { + try { + const BenchmarkConfig config = ParseArgs(argc, argv); + PrintConfig(config); + PrintEnvironmentMetadata(); + + std::vector all_samples_ns; + all_samples_ns.reserve( + static_cast(config.repetitions) * + static_cast(config.sample_iterations)); + + std::cout << std::left << std::setw(10) << "run" << std::right + << std::setw(12) << "mean_ms" << std::setw(12) << "p50_ms" + << std::setw(12) << "p95_ms" << std::setw(12) << "min_ms" + << std::setw(12) << "max_ms" << std::setw(12) << "std_ms" + << std::setw(14) << "ns/step" << std::setw(12) << "state_pts" + << std::setw(14) << "checksum" + << "\n"; + std::cout << std::string(122, '-') << "\n"; + + double final_checksum = 0.0; + std::size_t final_recorded_points = 0; + + for (int repetition = 0; repetition < config.repetitions; + ++repetition) { + auto model = + BuildModel(config.state_size, config.history_capture_interval, + config.steps); + + Eigen::VectorXd initial_state = Eigen::VectorXd::Constant( + static_cast(config.state_size), 1'000.0); + + for (int i = 0; i < config.warmup_iterations; ++i) { + const auto warmup = + TimeOneSample(*model, initial_state, config.steps); + DoNotOptimize(warmup.checksum); + } + + std::vector sample_ns; + sample_ns.reserve( + static_cast(config.sample_iterations)); + + for (int i = 0; i < config.sample_iterations; ++i) { + const auto sample = + TimeOneSample(*model, initial_state, config.steps); + sample_ns.push_back(sample.elapsed_ns); + all_samples_ns.push_back(sample.elapsed_ns); + final_checksum = sample.checksum; + final_recorded_points = sample.recorded_points; + } + + const Statistics rep_stats = ComputeStats(sample_ns); + PrintStatsRow("rep" + std::to_string(repetition + 1), rep_stats, + config.steps, final_checksum, final_recorded_points); + } + + std::cout << std::string(122, '-') << "\n"; + const Statistics overall = ComputeStats(all_samples_ns); + PrintStatsRow("overall", overall, config.steps, final_checksum, + final_recorded_points); + + return 0; + } catch (const std::exception &ex) { + std::cerr << "Benchmark failed: " << ex.what() << "\n"; + return 1; + } +} diff --git a/include/respond/history.hpp b/include/respond/history.hpp index 584cbdc9..88a0a106 100644 --- a/include/respond/history.hpp +++ b/include/respond/history.hpp @@ -4,7 +4,7 @@ // Created Date: 2026-02-05 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2026-02-06 // +// Last Modified: 2026-05-06 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2026 Syndemics Lab at Boston Medical Center // @@ -12,107 +12,293 @@ #ifndef RESPOND_HISTORY_HPP_ #define RESPOND_HISTORY_HPP_ +#include #include #include +#include #include namespace respond { +enum class HistoryMode { Snapshot, Accumulated }; + +inline HistoryMode GetDefaultHistoryMode(const std::string &name) { + if (name == "intervention_admission" || name == "total_overdose" || + name == "fatal_overdose" || name == "background_death") { + return HistoryMode::Accumulated; + } + return HistoryMode::Snapshot; +} + +/// @brief Tracks and manages state vector history over time. +/// History records state snapshots at discrete timesteps, enabling analysis of +/// state trajectories during model execution. Supports sparse timesteps (gaps +/// are filled with zero vectors). class History { public: + /// @brief Constructs a History tracker. + /// @param name The identifier for this history (default: "state"). + /// @param log_name The logger name for error reporting (default: + /// "console"). History(const std::string &name = "state", const std::string &log_name = "console") - : _name(name), _log_name(log_name) {} + : History(name, log_name, GetDefaultHistoryMode(name)) {} - // Rule of Five + /// @brief Constructs a History tracker with an explicit recording mode. + /// @param name The identifier for this history. + /// @param log_name The logger name for error reporting. + /// @param mode Whether the history stores snapshots or accumulations. + History(const std::string &name, const std::string &log_name, + HistoryMode mode) + : _log_name(log_name), _name(name), _mode(mode) {} + + /// @brief Destructor (default). ~History() = default; + /// @brief Copy constructor implementing the Rule of Five. + /// Creates an independent copy of the history state and metadata. History(const History &other) { - _state = other.GetStateMap(); + _timesteps = other.GetRecordedTimesteps(); + _states = other.GetRecordedStates(); _name = other.GetHistoryName(); _log_name = other.GetLogName(); + _mode = other.GetHistoryMode(); + _pending_state = other.GetPendingState(); } + + /// @brief Copy assignment operator implementing the Rule of Five. + /// @param other The history to copy from. + /// @return Reference to this history after assignment. History &operator=(const History &other) { if (this != &other) { - _state = other.GetStateMap(); + _timesteps = other.GetRecordedTimesteps(); + _states = other.GetRecordedStates(); _name = other.GetHistoryName(); _log_name = other.GetLogName(); + _mode = other.GetHistoryMode(); + _pending_state = other.GetPendingState(); } return *this; } + + /// @brief Move constructor implementing the Rule of Five. + /// @param other The history to move from (leaves original state unchanged + /// per current implementation). History(History &&other) noexcept { - _state = other.GetStateMap(); + _timesteps = std::move(other._timesteps); + _states = std::move(other._states); _name = other.GetHistoryName(); _log_name = other.GetLogName(); + _mode = other.GetHistoryMode(); + _pending_state = std::move(other._pending_state); } + + /// @brief Move assignment operator implementing the Rule of Five. + /// @param other The history to move from. + /// @return Reference to this history after assignment. History &operator=(History &&other) noexcept { if (this != &other) { - _state = other.GetStateMap(); + _timesteps = std::move(other._timesteps); + _states = std::move(other._states); _name = other.GetHistoryName(); _log_name = other.GetLogName(); + _mode = other.GetHistoryMode(); + _pending_state = std::move(other._pending_state); } return *this; } - // Comparisons - // Overload the == operator + /// @brief Equality comparison operator. + /// @param other The history to compare with. + /// @return True if all history properties and state are identical. bool operator==(const History &other) const { return GetHistoryName() == other.GetHistoryName() && GetLogName() == other.GetLogName() && - GetStateMap() == other.GetStateMap(); + GetHistoryMode() == other.GetHistoryMode() && + GetStateMap() == other.GetStateMap() && + GetPendingState().isApprox(other.GetPendingState()); } + + /// @brief Inequality comparison operator. + /// @param other The history to compare with. + /// @return True if histories differ in any aspect. bool operator!=(const History &other) const { return !(*this == other); } - // Getters - std::map GetStateMap() const { return _state; } + /// @brief Retrieves the complete state map (timestep -> state vector). + /// @return Map of integer timesteps to Eigen vectors representing states. + std::map GetStateMap() const { + std::map state_map; + for (size_t index = 0; index < _timesteps.size(); ++index) { + state_map[_timesteps[index]] = _states[index]; + } + return state_map; + } + + /// @brief Retrieves the recorded timesteps without densifying gaps. + /// @return Const reference to the stored timestep indices. + const std::vector &GetRecordedTimesteps() const { return _timesteps; } + + /// @brief Retrieves the recorded state vectors without densifying gaps. + /// @return Const reference to the stored state vectors. + const std::vector &GetRecordedStates() const { + return _states; + } + + /// @brief Retrieves the configured history recording mode. + /// @return Snapshot or accumulated history mode. + HistoryMode GetHistoryMode() const { return _mode; } + + /// @brief Indicates whether an accumulated history has pending state. + /// @return True when a pending aggregate exists. + bool HasPendingState() const { return _pending_state.size() > 0; } + + /// @brief Retrieves the pending accumulated state. + /// @return The pending aggregate vector, or an empty vector if none. + Eigen::VectorXd GetPendingState() const { return _pending_state; } + + /// @brief Retrieves the latest recorded timestep. + /// @return Largest recorded timestep, or -1 if history is empty. + int GetLatestRecordedTimestep() const { + if (_timesteps.empty()) { + return -1; + } + return _timesteps.back(); + } + + /// @brief Retrieves the identifier name of this history. + /// @return The history's name string. std::string GetHistoryName() const { return _name; } + + /// @brief Retrieves the logger name for this history. + /// @return The associated logger's name. std::string GetLogName() const { return _log_name; } + /// @brief Converts the sparse history map to a contiguous vector of states. + /// Gaps in timesteps are filled with zero vectors of appropriate dimension. + /// @return Vector of Eigen vectors from timestep 0 to the maximum recorded + /// timestep. Returns empty vector if no state has been recorded. std::vector GetStateAsVector() const { std::vector ret; - if (_state.empty()) { - // warn empty state vector + if (_states.empty()) { + // warn empty state vector - no states recorded return {}; } - int default_size = _state.begin()->second.size(); + int default_size = _states.front().size(); int tstep = 0; - for (const auto &kv : _state) { - if (kv.first > tstep) { - // raise error for invalid key/tstep + for (size_t index = 0; index < _timesteps.size(); ++index) { + const int recorded_timestep = _timesteps[index]; + const auto &recorded_state = _states[index]; + if (recorded_timestep > tstep) { + // Fill gap: raise error if timestep mapping is invalid } - while (kv.first > tstep) { + while (recorded_timestep > tstep) { ret.push_back(GetZeroVector(default_size)); tstep++; } - ret.push_back(kv.second); + ret.push_back(recorded_state); tstep++; } return ret; } + /// @brief Records a state vector at a specific or automatic timestep. + /// @param state The state vector to record. + /// @param timestep The timestep index for this state (default: -1 for + /// automatic next timestep). If timestep is negative, the next sequential + /// timestep is used automatically. If timestep already exists, it is + /// considered invalid but is currently overwritten. void AddState(const Eigen::VectorXd &state, int timestep = -1) { if (timestep < 0) { timestep = GetNextTimestep(); - } else if (_state.find(timestep) != _state.end()) { - // invalid timestep (already exists) } - _state[timestep] = state; + + const auto existing = + std::find(_timesteps.begin(), _timesteps.end(), timestep); + if (existing != _timesteps.end()) { + const auto index = + static_cast(existing - _timesteps.begin()); + _states[index] = state; + return; + } + + _timesteps.push_back(timestep); + _states.push_back(state); } - void Clear() { _state.clear(); } + /// @brief Records a snapshot value at a concrete timestep. + /// @param state The snapshot value to record. + /// @param timestep The simulation timestep for this snapshot. + void RecordSnapshot(const Eigen::VectorXd &state, int timestep) { + AddState(state, timestep); + } + + /// @brief Adds a contribution to an accumulated history. + /// @param state The per-step contribution to accumulate. + void AccumulateState(const Eigen::VectorXd &state) { + if (_mode != HistoryMode::Accumulated) { + AddState(state); + return; + } + + if (_pending_state.size() == 0) { + _pending_state = state; + return; + } + _pending_state += state; + } + + /// @brief Flushes pending accumulated state into a recorded timestep. + /// @param timestep The simulation timestep to record. + /// @param state_size Size of a zero vector to record if nothing is pending. + void FlushPendingState(int timestep, Eigen::Index state_size) { + if (_mode != HistoryMode::Accumulated) { + return; + } + + Eigen::VectorXd value; + if (_pending_state.size() > 0) { + value = _pending_state; + } else { + value = Eigen::VectorXd::Zero(state_size); + } + + AddState(value, timestep); + _pending_state.resize(0); + } + + /// @brief Clears all recorded state history. + void Clear() { + _timesteps.clear(); + _states.clear(); + _pending_state.resize(0); + } private: + /// @brief The logger name for this history. std::string _log_name; + /// @brief The identifier name for this history. std::string _name; - std::map _state; + /// @brief Controls whether this history stores snapshots or aggregates. + HistoryMode _mode; + /// @brief Recorded timestep indices for sparse history capture. + std::vector _timesteps; + /// @brief Recorded state vectors aligned with _timesteps. + std::vector _states; + /// @brief Pending aggregate for accumulated histories. + Eigen::VectorXd _pending_state; + /// @brief Computes the next sequential timestep. + /// @return 0 if history is empty, otherwise one past the largest existing + /// timestep. int GetNextTimestep() { - if (_state.empty()) { + if (_timesteps.empty()) { return 0; } - int largest_timestep = _state.rbegin()->first; - return largest_timestep + 1; + return _timesteps.back() + 1; } + /// @brief Creates a zero vector of specified size. + /// @param size The dimensionality of the zero vector. + /// @return An Eigen vector of zeros with the specified size. Eigen::VectorXd GetZeroVector(const int &size) const { return Eigen::VectorXd::Zero(size); } diff --git a/include/respond/logging.hpp b/include/respond/logging.hpp index a7c12320..58878527 100644 --- a/include/respond/logging.hpp +++ b/include/respond/logging.hpp @@ -34,33 +34,127 @@ enum class CreationStatus : int { kCount = 4 // Enum Counter }; +/// @brief Log pattern templates for different output styles. +enum class LogPattern : int { + kSimple = 0, // Minimal: [%n] %v + kStandard = 1, // Default: [%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v + kDetailed = 2, // Full: [%Y-%m-%d %H:%M:%S.%e] [%n] [%^%L%$] [thread %t] %v + kThreadSafe = + 3 // With sequence number: [%H:%M:%S] [seq %i] [%n] [%^---%L---%$] %v +}; + +// ============================================================================ +// Core Logger Creation and Management +// ============================================================================ + /// @brief Create a logger with the specified name and file path. -/// @param logger_name logger name to be created. -/// @param filepath file path where the logger will write logs. -/// @return CreationStatus indicating the result of the logger creation. +/// Thread-safe for use in parallel execution contexts. +/// @param logger_name Unique identifier for this logger. +/// @param filepath File path where the logger will write logs. +/// @return CreationStatus indicating the result of logger creation. +/// @note If a logger with the same name already exists, kExists is returned. CreationStatus CreateFileLogger(const std::string &logger_name, const std::string &filepath); -/// @brief Log a message as information. -/// @param logger_name Logger name to log the message to. -/// @param message Message to log as information. +// ============================================================================ +// Parallel Execution Support: Shared File Sink +// ============================================================================ + +/// @brief Create a shared file sink that multiple loggers can write to. +/// This enables thread-safe concurrent logging from multiple models/threads +/// to the same output file. Use with CreateSharedLogger(). +/// @param filepath Path to the shared log file. +/// @return CreationStatus indicating success or failure. +/// @note Thread-safe: Can be called from multiple threads simultaneously. +/// @note Creates the file sink once; subsequent calls return kExists. +CreationStatus CreateSharedFileSink(const std::string &filepath); + +/// @brief Create a logger that writes to the shared file sink. +/// Multiple loggers can be created with different names but write to the +/// same file through the shared sink. Essential for parallel execution. +/// @param logger_name Unique identifier for this logger. +/// @return CreationStatus indicating the result of logger creation. +/// @note Thread-safe: Can be called concurrently from multiple threads. +/// @note Requires CreateSharedFileSink() to be called first with a file path. +/// @note If CreateSharedFileSink() wasn't called, creates a default sink to +/// "respond.log". +CreationStatus CreateSharedLogger(const std::string &logger_name); + +/// @brief Sets the logging pattern template for all subsequent logger +/// creations. +/// @param pattern The LogPattern enum value to use. +/// @note Affects CreateFileLogger() and CreateSharedLogger() calls made after +/// this. +void SetLogPattern(LogPattern pattern); + +/// @brief Gets the current logging pattern template. +/// @return The active LogPattern enum value. +LogPattern GetLogPattern(); + +/// @brief Sets the global flush interval for automatic buffer flushing. +/// @param seconds Interval in seconds for automatic flush (0 to disable +/// auto-flush). +/// @note Thread-safe configuration change. +void SetFlushInterval(int seconds); + +/// @brief Flushes all active loggers, ensuring buffered output is written. +/// Useful when terminating parallel execution or before critical operations. +/// @note Thread-safe operation. +void FlushAllLoggers(); + +// ============================================================================ +// Logging Functions +// ============================================================================ + +/// @brief Log a message as information level. +/// Thread-safe for concurrent calls from multiple threads. +/// @param logger_name Logger identifier (created via CreateFileLogger or +/// CreateSharedLogger). +/// @param message Message to log. void LogInfo(const std::string &logger_name, const std::string &message); -/// @brief Log as message as a warning. -/// @param logger_name Logger name to log the message to. -/// @param message Message to log as a warning. +/// @brief Log a message as warning level. +/// Thread-safe for concurrent calls from multiple threads. +/// @param logger_name Logger identifier. +/// @param message Message to log. void LogWarning(const std::string &logger_name, const std::string &message); -/// @brief Log a message as an error. -/// @param logger_name Logger name to log the message to. -/// @param message Message to log as an error. +/// @brief Log a message as error level. +/// Thread-safe for concurrent calls from multiple threads. +/// @param logger_name Logger identifier. +/// @param message Message to log. void LogError(const std::string &logger_name, const std::string &message); -/// @brief Log a message as debug information. -/// @param logger_name Logger name to log the message to. -/// @param message Message to log as debug information. +/// @brief Log a message as debug level. +/// Thread-safe for concurrent calls from multiple threads. +/// @param logger_name Logger identifier. +/// @param message Message to log. void LogDebug(const std::string &logger_name, const std::string &message); +// ============================================================================ +// Utility Functions +// ============================================================================ + +/// @brief Check if a logger with the given name exists. +/// @param logger_name Logger identifier to check. +/// @return CreationStatus::kExists if logger exists, kNotCreated otherwise. +/// @note Thread-safe query. +CreationStatus CheckLoggerExists(const std::string &logger_name); + +/// @brief Retrieve detailed information about a logger. +/// @param logger_name Logger identifier to query. +/// @return String containing logger name, file path, level, and thread info. +/// @note Thread-safe operation. +std::string GetLoggerInfo(const std::string &logger_name); + +/// @brief Set the logging level for a specific logger. +/// @param logger_name Logger identifier to configure. +/// @param level Log level: 0=trace, 1=debug, 2=info, 3=warn, 4=error, +/// 5=critical. +/// @return CreationStatus::kSuccess if level was set, kNotCreated if logger +/// doesn't exist. +void SetLoggerLevel(const std::string &logger_name, int level); + } // namespace respond #endif // RESPOND_LOGGING_HPP_ \ No newline at end of file diff --git a/include/respond/model.hpp b/include/respond/model.hpp index 096b8fd7..2ac07306 100644 --- a/include/respond/model.hpp +++ b/include/respond/model.hpp @@ -22,48 +22,103 @@ #include namespace respond { +/// @brief Abstract base class representing a state transition model. +/// Models manage a state vector, execute transitions, and maintain history of +/// state changes. Subclasses must implement state management, transition +/// execution, and history tracking. class Model { public: - // default destructor + /// @brief Virtual destructor for proper polymorphic cleanup. virtual ~Model() = default; - // anticipate making a copy of the vector - virtual void SetState(const Eigen::VectorXd &) = 0; - // return const & to limit to observation of the state + + /// @brief Sets the current state of the model. + /// @param state The state vector to set. A copy is made internally. + virtual void SetState(const Eigen::VectorXd &state) = 0; + + /// @brief Retrieves the current state of the model. + /// @return A copy of the current state vector (limited to observation). virtual Eigen::VectorXd GetState() const = 0; - // manipulate the state vector + + /// @brief Executes all registered transitions on the current state. + /// Transitions are applied in the order they were added and may modify + /// history. virtual void RunTransitions() = 0; - // assume ownership of the Transition + + /// @brief Adds a transition to the model. + /// @param t A unique_ptr to a Transition object. The model assumes + /// ownership. virtual void AddTransition(const std::unique_ptr &t) = 0; - // get the names of each transition we own + + /// @brief Retrieves the names of all registered transitions. + /// @return Vector of transition names in the order they were added. virtual std::vector GetTransitionNames() const = 0; - // delete all the Transition unique_ptrs by clearing the vector + + /// @brief Clears all registered transitions. + /// Deletes all stored Transition unique_ptrs. virtual void ClearTransitions() = 0; - // return const & to limit to observation of the state. Need copy ability of - // History, but let that be the History's responsibility + + /// @brief Retrieves the history records for all state variables. + /// @return A map of history names to History objects containing state + /// trajectories. virtual std::map GetHistories() const = 0; + /// @brief Creates default history tracking for the model. + /// This method initializes standard history records based on the model's + /// state. virtual void CreateDefaultHistories() = 0; + /// @brief Sets the history records for the model. + /// @param h A map of history names to History objects. virtual void SetHistories(const std::map &h) = 0; - // getter for model name + + /// @brief Clears all history records and resets history tracking state. + virtual void ClearHistories() = 0; + + /// @brief Sets the global history capture interval for this model. + /// @param interval Record every interval timesteps. Values less than 1 + /// default to full capture. + virtual void SetHistoryCaptureInterval(int interval) = 0; + + /// @brief Retrieves the global history capture interval. + /// @return The active capture interval. A value of 1 means full capture. + virtual int GetHistoryCaptureInterval() const = 0; + + /// @brief Sets the final timestep that must always be recorded. + /// @param final_timestep The final simulation timestep. + virtual void SetFinalTimestep(int final_timestep) = 0; + + /// @brief Retrieves the final timestep forced into history output. + /// @return The configured final simulation timestep, or -1 if unset. + virtual int GetFinalTimestep() const = 0; + + /// @brief Retrieves the name identifier for this model. + /// @return The model's name as a string. virtual std::string GetModelName() const = 0; - // getter for log name + + /// @brief Retrieves the logger name used by this model. + /// @return The name of the associated logger. virtual std::string GetLogName() const = 0; - /// @brief Factory method to create a Markov instance. - /// @param log_name Name of the logger to write errors to. - /// @return An instance of Markov. + /// @brief Factory method to create a Model instance. + /// @param name The name identifier for the model to create. + /// @param log_name Name of the logger for this model (default: "console"). + /// @return A unique_ptr to the newly created Model instance. static std::unique_ptr Create(const std::string &name, const std::string &log_name = "console"); - // Copy Control + /// @brief Deleted copy constructor (models are non-copyable by public API). Model(const Model &) = delete; + /// @brief Deleted copy assignment operator (models are non-copyable by + /// public API). Model &operator=(const Model &) = delete; + + /// @brief Creates a deep copy of this model. + /// @return A unique_ptr to an independent copy of this model. virtual std::unique_ptr clone() const = 0; protected: - // default constructor required for subclasses, but do not want people to - // use this + /// @brief Protected default constructor for subclass initialization. + /// Not intended for direct public use. Model() = default; }; } // namespace respond diff --git a/include/respond/respond.hpp b/include/respond/respond.hpp index 8b7c98e4..13c233b6 100644 --- a/include/respond/respond.hpp +++ b/include/respond/respond.hpp @@ -4,7 +4,7 @@ // Created Date: 2026-02-06 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2026-02-06 // +// Last Modified: 2026-05-07 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2026 Syndemics Lab at Boston Medical Center // diff --git a/include/respond/simulation.hpp b/include/respond/simulation.hpp index 79b4e8f8..b9b7fb55 100644 --- a/include/respond/simulation.hpp +++ b/include/respond/simulation.hpp @@ -24,31 +24,47 @@ #include namespace respond { +/// @brief Manages and executes multiple models in a coordinated simulation. +/// A Simulation aggregates Model instances and coordinates their execution, +/// maintaining history records and providing access to simulation results. class Simulation { public: + /// @brief Default constructor initializing with "console" logger. Simulation() : Simulation("console") {} + + /// @brief Constructs a Simulation with a specified logger. + /// @param log_name Name of the logger for this simulation (default: + /// "console"). Simulation(const std::string &log_name) : _log_name(log_name) {} + + /// @brief Virtual destructor for polymorphic cleanup. ~Simulation() = default; - /// @brief The core function to run the simulation. Runs all models - /// associated with the simulation. - /// @param duration The number of steps to take for each model. + /// @brief Executes one step of the simulation for all models. + /// Calls RunTransitions() on each registered model in sequence. void Run() { for (const auto &model : _models) { model->RunTransitions(); } } + /// @brief Adds a model to the simulation. + /// The model is cloned and managed by the simulation. + /// @param model A unique_ptr to a Model instance to add. void AddModel(const std::unique_ptr &model) { // because push_back is a move operation we're taking over ownership of // the unique pointer _models.push_back(model->clone()); } + /// @brief Retrieves all models in the simulation. + /// @return Const reference to the vector of Model unique_ptrs. const std::vector> &GetModels() const { return _models; } + /// @brief Retrieves the names of all models in the simulation. + /// @return Vector of model names in the order they were added. std::vector GetModelNames() const { std::vector ret; for (auto &m : _models) { @@ -57,8 +73,12 @@ class Simulation { return ret; } + /// @brief Removes all models from the simulation. void ClearModels() { _models.clear(); } + /// @brief Retrieves the complete state histories for all models. + /// @return Vector of maps (one per model) mapping history names to state + /// vector trajectories. const std::vector>> GetModelHistories() const { std::vector>> ret; @@ -74,6 +94,20 @@ class Simulation { return ret; } + /// @brief Retrieves sparse history objects for all models. + /// @return Vector of maps (one per model) mapping history names to sparse + /// History objects. + const std::vector> + GetModelSparseHistories() const { + std::vector> ret; + for (const auto &model : _models) { + ret.push_back(model->GetHistories()); + } + return ret; + } + + /// @brief Retrieves pairs of (model name, history name) for all histories. + /// @return Vector of pairs associating each history with its parent model. const std::vector> GetModelHistoryNames() const { std::vector> ret; @@ -87,15 +121,23 @@ class Simulation { return ret; } + /// @brief Retrieves the logger name used by this simulation. + /// @return The name of the associated logger. std::string GetLogName() const { return _log_name; } - // Copying object + /// @brief Copy constructor creating an independent deep copy of the + /// simulation. All models are cloned; modifications to the copy do not + /// affect the original. Simulation(const Simulation &other) : _log_name(other.GetLogName()) { ClearModels(); for (const auto &m : other.GetModels()) { _models.push_back(m->clone()); } } + + /// @brief Copy assignment operator for deep copying simulation state. + /// @param other The simulation to copy from. + /// @return Reference to this simulation after assignment. Simulation &operator=(const Simulation &other) { if (this != &other) { ClearModels(); diff --git a/include/respond/transition.hpp b/include/respond/transition.hpp index 07cc68ee..aae64a99 100644 --- a/include/respond/transition.hpp +++ b/include/respond/transition.hpp @@ -23,35 +23,56 @@ namespace respond { -/// @brief A helper class to hold Transitions +/// @brief Abstract base class representing a state transition operation. +/// Transitions apply transformation matrices to state vectors and update +/// history records. Subclasses define specific types of transitions (e.g., +/// Markov, background death, behavior). class Transition { public: + /// @brief Virtual destructor for proper polymorphic cleanup. virtual ~Transition() = default; - // Run the execute function and return the final state. Do not edit the - // parameter state, but do edit the history provided. Nothing in the - // Transition object should change. + + /// @brief Executes this transition, applying it to a state vector. + /// The input state is not modified; history records are updated with the + /// transition effects. + /// @param s The current state vector (not modified). + /// @param h The history records to update (may be modified by this + /// transition). + /// @return The resulting state vector after applying this transition. virtual Eigen::VectorXd Execute(const Eigen::VectorXd &s, std::map &h) const = 0; - // Add a Transition Matrix to the set. We have no need to edit it once it's - // been added, just use it. Thus, we don't need full ownership (reference) - // and can accept the const type. + + /// @brief Adds a transformation matrix to this transition. + /// The matrix is stored for use during Execute() calls. + /// @param m The transition matrix to add (not modified by this transition). virtual void AddTransitionMatrix(const Eigen::MatrixXd &m) = 0; - // Get the name of the Transition. No need to edit the object and do not - // need user to edit the name. + + /// @brief Retrieves the name/type of this transition. + /// @return The transition's identifier as a string. virtual std::string GetTransitionName() const = 0; - // Clear out all the stored Eigen::MatrixXd values + + /// @brief Clears all stored transition matrices. virtual void ClearTransitionMatrices() = 0; - // Get Log Name + + /// @brief Retrieves the logger name used by this transition. + /// @return The associated logger's name. virtual std::string GetLogName() const = 0; - // Clone + + /// @brief Deleted copy constructor (transitions are non-copyable by public + /// API). Transition(const Transition &) = delete; + /// @brief Deleted copy assignment operator (transitions are non-copyable by + /// public API). Transition &operator=(const Transition &) = delete; + + /// @brief Creates a deep copy of this transition. + /// @return A unique_ptr to an independent copy of this transition. virtual std::unique_ptr clone() const = 0; protected: - // default constructor required for subclasses, but do not want people to - // use this + /// @brief Protected default constructor for subclass initialization. + /// Not intended for direct public use. Transition() = default; }; } // namespace respond diff --git a/include/respond/transition_factory.hpp b/include/respond/transition_factory.hpp index 7733b0c6..87fd8fc8 100644 --- a/include/respond/transition_factory.hpp +++ b/include/respond/transition_factory.hpp @@ -17,8 +17,22 @@ #include namespace respond { +/// @brief Factory for creating concrete Transition instances. +/// This factory supports creation of various transition types used in the +/// RESPOND model. class TransitionFactory { public: + /// @brief Creates a transition of the specified type. + /// @param type The type of transition to create. Supported types + /// (case-insensitive): + /// - "migration": Population migration transitions + /// - "behavior": Behavioral state transitions + /// - "intervention": Intervention-driven transitions + /// - "overdose": Overdose-related transitions + /// - "background_death": Background mortality transitions + /// @param log_name The logger name for error reporting (e.g., "console"). + /// @return A unique_ptr to the created Transition, or nullptr if type is + /// unsupported. static std::unique_ptr CreateTransition(const std::string &type, const std::string &log_name); }; diff --git a/include/respond/version.hpp b/include/respond/version.hpp index 6dc9a438..4221aeca 100644 --- a/include/respond/version.hpp +++ b/include/respond/version.hpp @@ -4,7 +4,7 @@ // Created Date: 2025-03-06 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2026-02-17 // +// Last Modified: 2026-04-16 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2025-2026 Syndemics Lab at Boston Medical Center // @@ -14,8 +14,8 @@ #define RESPOND_VERSION_HPP_ #define RESPOND_VER_MAJOR 2 -#define RESPOND_VER_MINOR 3 -#define RESPOND_VER_PATCH 1 +#define RESPOND_VER_MINOR 4 +#define RESPOND_VER_PATCH 0 #define RESPOND_TO_VERSION(major, minor, patch) \ (major * 10000 + minor * 100 + patch) diff --git a/src/background.cpp b/src/background.cpp index dcf679ea..283084f0 100644 --- a/src/background.cpp +++ b/src/background.cpp @@ -15,28 +15,35 @@ #include #include +#include +#include + namespace respond { Eigen::VectorXd BackgroundDeath::Execute(const Eigen::VectorXd &state, std::map &h) const { if (GetTransitionMatrices().size() != 1) { - throw std::runtime_error( - "Mortality Transitions must have 1 Transition Matrix."); + std::string error_msg = + "Background death error: Expected 1 transition matrix, got " + + std::to_string(GetTransitionMatrices().size()); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto deaths = state.cwiseProduct(GetTransitionMatrices()[0]); // calculate the deaths if (h.find("background_death") != h.end()) { - h["background_death"].AddState(deaths); + h["background_death"].AccumulateState(deaths); } if (!(state.array() >= deaths.array()).all()) { - std::runtime_error( - "The state is not larger than the estimated background deaths!"); + std::string error_msg = + "Background death error: State values are less than estimated " + "deaths. " + + std::to_string((state.array() < deaths.array()).count()) + + " elements affected"; + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto new_state = state - deaths; // remove deaths from state - - if (h.find("background_death") != h.end()) { - h["state"].AddState(new_state); - } return new_state; } diff --git a/src/behavior.cpp b/src/behavior.cpp index ad42d461..7138e5d5 100644 --- a/src/behavior.cpp +++ b/src/behavior.cpp @@ -15,22 +15,29 @@ #include #include +#include +#include + namespace respond { Eigen::VectorXd Behavior::Execute(const Eigen::VectorXd &state, std::map &h) const { if (GetTransitionMatrices().size() != 1) { - throw std::runtime_error( - "Behavior Transitions must have 1 Transition Matrix."); + std::string error_msg = + "Behavior error: Expected 1 transition matrix, got " + + std::to_string(GetTransitionMatrices().size()); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } if (state.rows() != GetTransitionMatrices()[0].cols()) { std::stringstream ss; - ss << "Unable to multiply behavior transition with " - "state, mismatched sizes. State size is ("; - ss << state.rows() << ", " << state.cols(); - ss << ") and transition size is (" << GetTransitionMatrices()[0].rows() - << ", "; - ss << GetTransitionMatrices()[0].cols() << ")."; - throw std::runtime_error(ss.str()); + ss << "Behavior error: State dimension mismatch. State size is (" + << state.rows() << ", " << state.cols() + << ") but transition matrix expects (" + << GetTransitionMatrices()[0].rows() << ", " + << GetTransitionMatrices()[0].cols() << ")"; + std::string error_msg = ss.str(); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto new_state = GetTransitionMatrices()[0] * state; return new_state; diff --git a/src/internals/logging_internals.hpp b/src/internals/logging_internals.hpp index 6b369f34..db03a828 100644 --- a/src/internals/logging_internals.hpp +++ b/src/internals/logging_internals.hpp @@ -16,25 +16,113 @@ #include #include +#include +#include #include +#include #include #include +#include #include namespace respond { + +class LoggingConfig { +public: + static LoggingConfig &GetInstance() { + static LoggingConfig instance; + return instance; + } + + static std::shared_ptr + GetSharedSink(const std::string &filepath) { + std::lock_guard lock(GetInstance().sink_mutex_); + auto key = filepath; + if (GetInstance().shared_sinks_.find(key) == + GetInstance().shared_sinks_.end()) { + try { + GetInstance().shared_sinks_[key] = + std::make_shared( + filepath, false); + } catch (const spdlog::spdlog_ex &ex) { + std::cerr << "Failed to create shared sink: " << ex.what() + << std::endl; + return nullptr; + } + } + return GetInstance().shared_sinks_[key]; + } + + static LogPattern GetPattern() { return GetInstance().current_pattern_; } + + static void SetPattern(LogPattern pattern) { + GetInstance().current_pattern_ = pattern; + } + + static std::string GetPatternString(LogPattern pattern) { + switch (pattern) { + case LogPattern::kSimple: + return "[%n] %v"; + case LogPattern::kStandard: + return "[%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v"; + case LogPattern::kDetailed: + return "[%Y-%m-%d %H:%M:%S.%e] [%n] [%^%L%$] [thread %t] %v"; + case LogPattern::kThreadSafe: + return "[%H:%M:%S] [seq %i] [%n] [%^---%L---%$] %v"; + default: + return "[%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v"; + } + } + + static int GetFlushInterval() { return GetInstance().flush_interval_; } + + static void SetFlushInterval(int seconds) { + GetInstance().flush_interval_ = seconds; + } + + static void SetDefaultSinkPath(const std::string &path) { + std::lock_guard lock(GetInstance().config_mutex_); + GetInstance().default_sink_path_ = path; + } + + static std::string GetDefaultSinkPath() { + std::lock_guard lock(GetInstance().config_mutex_); + return GetInstance().default_sink_path_; + } + +private: + LoggingConfig() + : current_pattern_(LogPattern::kStandard), flush_interval_(3), + default_sink_path_("respond.log") { + spdlog::cfg::load_env_levels(); + } + + std::unordered_map> + shared_sinks_; + std::mutex sink_mutex_; + std::mutex config_mutex_; + LogPattern current_pattern_; + int flush_interval_; + std::string default_sink_path_; +}; + CreationStatus CheckIfExists(const std::string &logger_name) { return (spdlog::get(logger_name) != nullptr) ? CreationStatus::kExists : CreationStatus::kNotCreated; } + void log(const std::string &logger_name, const std::string &message, LogType type = LogType::kInfo) { CreationStatus status = CheckIfExists(logger_name); if ((status == CreationStatus::kNotCreated) && - (CreateFileLogger(logger_name, "log.txt") == CreationStatus::kError)) { + (CreateFileLogger(logger_name, "respond.log") == + CreationStatus::kError)) { std::cerr << "Failed to create logger: " << logger_name << std::endl; return; } + auto logger = spdlog::get(logger_name); if (logger) { switch (type) { @@ -54,7 +142,9 @@ void log(const std::string &logger_name, const std::string &message, logger->info(message); break; } - logger->flush(); + if (LoggingConfig::GetFlushInterval() == 0) { + logger->flush(); + } } else { spdlog::error("Logger {} not found", logger_name); } diff --git a/src/internals/markov.hpp b/src/internals/markov.hpp index f827de5f..4b04ce93 100644 --- a/src/internals/markov.hpp +++ b/src/internals/markov.hpp @@ -4,7 +4,7 @@ // Created Date: 2026-02-05 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2026-02-12 // +// Last Modified: 2026-05-06 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2026 Syndemics Lab at Boston Medical Center // @@ -28,7 +28,9 @@ class Markov : public virtual Model { public: Markov() : Markov("markov", "console") {} Markov(const std::string &name, const std::string &log_name) - : _name(name), _log_name(log_name) { + : _name(name), _log_name(log_name), _current_timestep(0), + _history_capture_interval(1), _final_timestep(-1), + _initial_history_recorded(false) { const auto processor_count = std::thread::hardware_concurrency(); Eigen::setNbThreads(processor_count); } @@ -41,6 +43,12 @@ class Markov : public virtual Model { auto np = Model::Create(GetModelName(), GetLogName()); np->SetState(GetState()); np->SetHistories(GetHistories()); + np->SetHistoryCaptureInterval(GetHistoryCaptureInterval()); + np->SetFinalTimestep(GetFinalTimestep()); + if (auto *markov = dynamic_cast(np.get())) { + markov->_current_timestep = _current_timestep; + markov->_initial_history_recorded = _initial_history_recorded; + } for (const auto &t : GetTransitions()) { np->AddTransition(t->clone()); } @@ -86,26 +94,30 @@ class Markov : public virtual Model { /// 5. Background Mortality /// @return A vector of the default history objects. void CreateDefaultHistories() override { - std::vector names = { - "state", "total_overdose", "fatal_overdose", - "intervention_admission", "background_death"}; - std::map ret; - for (const auto &n : names) { - History h(n, GetLogName()); - ret[n] = h; - } + ret["state"] = History("state", GetLogName(), HistoryMode::Snapshot); + ret["total_overdose"] = + History("total_overdose", GetLogName(), HistoryMode::Accumulated); + ret["fatal_overdose"] = + History("fatal_overdose", GetLogName(), HistoryMode::Accumulated); + ret["intervention_admission"] = History( + "intervention_admission", GetLogName(), HistoryMode::Accumulated); + ret["background_death"] = + History("background_death", GetLogName(), HistoryMode::Accumulated); SetHistories(ret); } // manipulate the state vector void RunTransitions() override { SetupHistory(); - auto histories = GetHistories(); + if (!_initial_history_recorded) { + RecordHistoryAtCurrentTimestep(); + } for (const auto &t : _transition_vector) { - SetState(t->Execute(GetState(), histories)); + _state = t->Execute(_state, _histories); } - SetHistories(histories); + _current_timestep++; + RecordHistoryAtCurrentTimestep(); } // assume ownership of the Transition void AddTransition(const std::unique_ptr &t) override { @@ -125,7 +137,39 @@ class Markov : public virtual Model { virtual void SetHistories(const std::map &h) override { _histories = h; + if (_histories.empty()) { + ResetHistoryTracking(); + return; + } + + const int latest_timestep = GetLatestRecordedTimestep(); + if (latest_timestep < 0) { + ResetHistoryTracking(); + return; + } + + _initial_history_recorded = true; + _current_timestep = latest_timestep; + } + void ClearHistories() override { + _histories.clear(); + ResetHistoryTracking(); + } + + void SetHistoryCaptureInterval(int interval) override { + _history_capture_interval = (interval < 1) ? 1 : interval; + } + + int GetHistoryCaptureInterval() const override { + return _history_capture_interval; + } + + void SetFinalTimestep(int final_timestep) override { + _final_timestep = final_timestep; } + + int GetFinalTimestep() const override { return _final_timestep; } + // return const & to limit to observation of the state. Need copy ability of // History, but let that be the History's responsibility std::map GetHistories() const override { @@ -141,22 +185,57 @@ class Markov : public virtual Model { std::string _name; std::string _log_name; std::map _histories; + int _current_timestep; + int _history_capture_interval; + int _final_timestep; + bool _initial_history_recorded; + + void ResetHistoryTracking() { + _current_timestep = 0; + _initial_history_recorded = false; + } + + int GetLatestRecordedTimestep() const { + int latest = -1; + for (const auto &kv : _histories) { + latest = std::max(latest, kv.second.GetLatestRecordedTimestep()); + } + return latest; + } + + bool ShouldRecordHistoryAtTimestep(int timestep) const { + if (timestep == 0) { + return true; + } + if (_final_timestep >= 0 && timestep == _final_timestep) { + return true; + } + return timestep % _history_capture_interval == 0; + } + + void RecordHistoryAtCurrentTimestep() { + if (_initial_history_recorded && _current_timestep == 0) { + return; + } + if (!ShouldRecordHistoryAtTimestep(_current_timestep)) { + return; + } + + _histories["state"].RecordSnapshot(_state, _current_timestep); + const auto size = _state.size(); + _histories["intervention_admission"].FlushPendingState( + _current_timestep, size); + _histories["total_overdose"].FlushPendingState(_current_timestep, size); + _histories["fatal_overdose"].FlushPendingState(_current_timestep, size); + _histories["background_death"].FlushPendingState(_current_timestep, + size); + _initial_history_recorded = true; + } void SetupHistory() { - auto histories = GetHistories(); - if (histories.empty()) { + if (_histories.empty()) { CreateDefaultHistories(); - histories = GetHistories(); } - histories["state"].AddState(GetState()); - auto size = GetState().size(); - - histories["intervention_admission"].AddState( - Eigen::VectorXd::Zero(size)); - histories["total_overdose"].AddState(Eigen::VectorXd::Zero(size)); - histories["fatal_overdose"].AddState(Eigen::VectorXd::Zero(size)); - histories["background_death"].AddState(Eigen::VectorXd::Zero(size)); - SetHistories(histories); } }; } // namespace respond diff --git a/src/intervention.cpp b/src/intervention.cpp index 281771d8..838ef7ee 100644 --- a/src/intervention.cpp +++ b/src/intervention.cpp @@ -15,24 +15,31 @@ #include #include +#include +#include + namespace respond { Eigen::VectorXd Intervention::Execute(const Eigen::VectorXd &state, std::map &h) const { if (GetTransitionMatrices().size() != 1) { - throw std::runtime_error( - "Intervention Transitions must have 1 Transition Matrix."); + std::string error_msg = + "Intervention error: Expected 1 transition matrix, got " + + std::to_string(GetTransitionMatrices().size()); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } Eigen::VectorXd zero_matrix = Eigen::VectorXd::Zero(state.size()); if (state.rows() != GetTransitionMatrices()[0].cols()) { std::stringstream ss; - ss << "Unable to multiply intervention transition with " - "state, mismatched sizes. State size is ("; - ss << state.rows() << ", " << state.cols(); - ss << ") and transition size is (" << GetTransitionMatrices()[0].rows() - << ", "; - ss << GetTransitionMatrices()[0].cols() << ")."; - throw std::runtime_error(ss.str()); + ss << "Intervention error: State dimension mismatch. State size is (" + << state.rows() << ", " << state.cols() + << ") but transition matrix expects (" + << GetTransitionMatrices()[0].rows() << ", " + << GetTransitionMatrices()[0].cols() << ")"; + std::string error_msg = ss.str(); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto moved = GetTransitionMatrices()[0] * state; @@ -40,7 +47,7 @@ Eigen::VectorXd Intervention::Execute(const Eigen::VectorXd &state, Eigen::VectorXd admissions = moved - state; admissions = admissions.cwiseMax(Eigen::VectorXd::Zero(admissions.size())); if (h.find("intervention_admission") != h.end()) { - h["intervention_admission"].AddState(admissions); + h["intervention_admission"].AccumulateState(admissions); } return moved; diff --git a/src/logging.cpp b/src/logging.cpp index 8a156db9..75c2c35e 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -14,24 +14,113 @@ #include "internals/logging_internals.hpp" +#include +#include + namespace respond { + CreationStatus CreateFileLogger(const std::string &logger_name, const std::string &filepath) { if (CheckIfExists(logger_name) == CreationStatus::kExists) { + std::cout << "Logger " << logger_name << " already exists" << std::endl; return CreationStatus::kExists; } try { spdlog::cfg::load_env_levels(); - spdlog::set_pattern("[%H:%M:%S %z] [%n] [%^---%L---%$] [thread %t] %v"); - spdlog::flush_every(std::chrono::seconds(3)); + std::string pattern = + LoggingConfig::GetPatternString(LoggingConfig::GetPattern()); + spdlog::set_pattern(pattern); + int flush_interval = LoggingConfig::GetFlushInterval(); + if (flush_interval > 0) { + spdlog::flush_every(std::chrono::seconds(flush_interval)); + } spdlog::basic_logger_mt(logger_name, filepath); } catch (const spdlog::spdlog_ex &ex) { - std::cout << "Log init failed: " << ex.what() << std::endl; + std::string error_msg = "Failed to create file logger '" + logger_name + + "' at path '" + filepath + "': " + ex.what(); + std::cerr << error_msg << std::endl; return CreationStatus::kError; } return CreationStatus::kSuccess; } +CreationStatus CreateSharedFileSink(const std::string &filepath) { + try { + auto sink = LoggingConfig::GetSharedSink(filepath); + if (sink) { + LoggingConfig::SetDefaultSinkPath(filepath); + return CreationStatus::kSuccess; + } + std::string error_msg = + "Failed to create shared file sink: sink is null"; + std::cerr << error_msg << std::endl; + return CreationStatus::kError; + } catch (const spdlog::spdlog_ex &ex) { + std::string error_msg = "Failed to create shared file sink at '" + + filepath + "': " + ex.what(); + std::cerr << error_msg << std::endl; + return CreationStatus::kError; + } +} + +CreationStatus CreateSharedLogger(const std::string &logger_name) { + if (CheckIfExists(logger_name) == CreationStatus::kExists) { + std::cout << "Shared logger " << logger_name << " already exists" + << std::endl; + return CreationStatus::kExists; + } + + try { + std::string filepath = LoggingConfig::GetDefaultSinkPath(); + auto sink = LoggingConfig::GetSharedSink(filepath); + if (!sink) { + std::string error_msg = + "Failed to create shared logger '" + logger_name + + "': could not get or create shared sink at '" + filepath + "'"; + std::cerr << error_msg << std::endl; + return CreationStatus::kError; + } + + spdlog::cfg::load_env_levels(); + std::string pattern = + LoggingConfig::GetPatternString(LoggingConfig::GetPattern()); + + auto logger = std::make_shared(logger_name, sink); + logger->set_pattern(pattern); + logger->set_level(spdlog::level::trace); + + spdlog::register_logger(logger); + + int flush_interval = LoggingConfig::GetFlushInterval(); + if (flush_interval > 0) { + spdlog::flush_every(std::chrono::seconds(flush_interval)); + } + + return CreationStatus::kSuccess; + } catch (const spdlog::spdlog_ex &ex) { + std::string error_msg = "Failed to create shared logger '" + + logger_name + "': " + ex.what(); + std::cerr << error_msg << std::endl; + return CreationStatus::kError; + } +} + +void SetLogPattern(LogPattern pattern) { LoggingConfig::SetPattern(pattern); } + +LogPattern GetLogPattern() { return LoggingConfig::GetPattern(); } + +void SetFlushInterval(int seconds) { + LoggingConfig::SetFlushInterval(seconds); + if (seconds > 0) { + spdlog::flush_every(std::chrono::seconds(seconds)); + } +} + +void FlushAllLoggers() { + spdlog::apply_all( + [](std::shared_ptr log) { log->flush(); }); +} + void LogInfo(const std::string &logger_name, const std::string &message) { log(logger_name, message, LogType::kInfo); } @@ -47,4 +136,32 @@ void LogError(const std::string &logger_name, const std::string &message) { void LogDebug(const std::string &logger_name, const std::string &message) { log(logger_name, message, LogType::kDebug); } + +CreationStatus CheckLoggerExists(const std::string &logger_name) { + return CheckIfExists(logger_name); +} + +std::string GetLoggerInfo(const std::string &logger_name) { + auto logger = spdlog::get(logger_name); + if (!logger) { + return "Logger \"" + logger_name + "\" not found"; + } + + auto level_view = spdlog::level::to_string_view(logger->level()); + std::string info = "Logger: " + logger_name + "\n"; + info += + " Level: " + std::string(level_view.data(), level_view.size()) + "\n"; + info += " Sinks: " + std::to_string(logger->sinks().size()); + return info; +} + +void SetLoggerLevel(const std::string &logger_name, int level) { + auto logger = spdlog::get(logger_name); + if (logger) { + int clamped_level = (level < 0) ? 0 : (level > 5) ? 5 : level; + logger->set_level( + static_cast(clamped_level)); + } +} + } // namespace respond \ No newline at end of file diff --git a/src/markov.cpp b/src/markov.cpp index c49ba5d0..48bfaf9e 100644 --- a/src/markov.cpp +++ b/src/markov.cpp @@ -15,6 +15,7 @@ #include #include +#include #include namespace respond { diff --git a/src/migration.cpp b/src/migration.cpp index 5a1cab55..f5697803 100644 --- a/src/migration.cpp +++ b/src/migration.cpp @@ -15,16 +15,26 @@ #include #include +#include +#include + namespace respond { Eigen::VectorXd Migration::Execute(const Eigen::VectorXd &state, std::map &h) const { if (GetTransitionMatrices().size() != 1) { - throw std::runtime_error( - "Migration Transitions must have 1 Transition Matrix."); + std::string error_msg = + "Migration error: Expected 1 transition matrix, got " + + std::to_string(GetTransitionMatrices().size()); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } if (state.size() != GetTransitionMatrices()[0].size()) { - throw std::runtime_error("Unable to add Migration Transition Vector to " - "State Vector, mismatched sizes."); + std::string error_msg = + "Migration error: State size (" + std::to_string(state.size()) + + ") does not match transition matrix size (" + + std::to_string(GetTransitionMatrices()[0].size()) + ")"; + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto subtracted = state + GetTransitionMatrices()[0]; auto zero_stop = subtracted.array().max( diff --git a/src/overdose.cpp b/src/overdose.cpp index 1bd3c802..4fc02ee2 100644 --- a/src/overdose.cpp +++ b/src/overdose.cpp @@ -15,36 +15,56 @@ #include #include +#include +#include + namespace respond { Eigen::VectorXd Overdose::Execute(const Eigen::VectorXd &state, std::map &h) const { if (GetTransitionMatrices().size() != 2) { - throw std::runtime_error( - "Overdose Transitions must have 2 Transition Matrices."); + std::string error_msg = + "Overdose error: Expected 2 transition matrices, got " + + std::to_string(GetTransitionMatrices().size()); + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } if (state.size() != GetTransitionMatrices()[0].size()) { - throw std::runtime_error("Overdose Vector is not the same " - "size as the state vector."); + std::string error_msg = + "Overdose error: State size (" + std::to_string(state.size()) + + ") does not match transition matrix size (" + + std::to_string(GetTransitionMatrices()[0].size()) + ")"; + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } Eigen::VectorXd overdoses = state.cwiseProduct(GetTransitionMatrices()[0]); // overdose // Add total overdoses to stamp if (h.find("total_overdose") != h.end()) { - h["total_overdose"].AddState(overdoses); + h["total_overdose"].AccumulateState(overdoses); } if (overdoses.size() != GetTransitionMatrices()[1].size()) { - throw std::runtime_error("Fatal Overdose Vector is not the same " - "size as the state vector."); + std::string error_msg = + "Overdose error: Fatal overdose vector size (" + + std::to_string(overdoses.size()) + + ") does not match transition matrix size (" + + std::to_string(GetTransitionMatrices()[1].size()) + ")"; + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto fods = overdoses.cwiseProduct(GetTransitionMatrices()[1]); // negatives if (h.find("fatal_overdose") != h.end()) { - h["fatal_overdose"].AddState(fods); + h["fatal_overdose"].AccumulateState(fods); } if (!(state.array() >= fods.array()).all()) { - std::runtime_error( - "The state is not larger than the estimated fatal overdoses!"); + std::string error_msg = + "Overdose error: State values are less than estimated fatal " + "overdoses. " + + std::to_string((state.array() < fods.array()).count()) + + " elements affected"; + LogError(GetLogName(), error_msg); + throw std::runtime_error(error_msg); } auto new_state = state - fods; // remove fods from state return new_state; diff --git a/src/transition_factory.cpp b/src/transition_factory.cpp index b17672e7..267d383e 100644 --- a/src/transition_factory.cpp +++ b/src/transition_factory.cpp @@ -17,6 +17,8 @@ #include #include +#include + #include "internals/background.hpp" #include "internals/behavior.hpp" #include "internals/intervention.hpp" @@ -31,18 +33,23 @@ TransitionFactory::CreateTransition(const std::string &type, std::transform(type_copy.begin(), type_copy.end(), type_copy.begin(), [](unsigned char c) { return std::tolower(c); }); - if (type == "migration") { + if (type_copy == "migration") { return Migration::Create(type, log_name); - } else if (type == "behavior") { + } else if (type_copy == "behavior") { return Behavior::Create(type, log_name); - } else if (type == "intervention") { + } else if (type_copy == "intervention") { return Intervention::Create(type, log_name); - } else if (type == "overdose") { + } else if (type_copy == "overdose") { return Overdose::Create(type, log_name); - } else if (type == "background_death") { + } else if (type_copy == "background_death") { return BackgroundDeath::Create(type, log_name); } - // Warn invalid type + + // Invalid transition type + std::string error_msg = "Invalid transition type: '" + type + + "'. Supported types: migration, behavior, " + "intervention, overdose, background_death"; + LogError(log_name, error_msg); return nullptr; } } // namespace respond \ No newline at end of file diff --git a/tests/mocks/model_mock.hpp b/tests/mocks/model_mock.hpp index 417d35c0..054d7132 100644 --- a/tests/mocks/model_mock.hpp +++ b/tests/mocks/model_mock.hpp @@ -42,6 +42,11 @@ class MockModel : public virtual Model { (const, override)); MOCK_METHOD(void, SetHistories, ((const std::map &)), (override)); + MOCK_METHOD(void, ClearHistories, (), (override)); + MOCK_METHOD(void, SetHistoryCaptureInterval, (int), (override)); + MOCK_METHOD(int, GetHistoryCaptureInterval, (), (const, override)); + MOCK_METHOD(void, SetFinalTimestep, (int), (override)); + MOCK_METHOD(int, GetFinalTimestep, (), (const, override)); MOCK_METHOD(std::string, GetModelName, (), (const, override)); MOCK_METHOD(std::string, GetLogName, (), (const, override)); MOCK_METHOD((std::unique_ptr), clone, (), (const, override)); diff --git a/tests/unit/background_test.cpp b/tests/unit/background_test.cpp index 556b07f4..92aaca2c 100644 --- a/tests/unit/background_test.cpp +++ b/tests/unit/background_test.cpp @@ -67,10 +67,11 @@ TEST_F(BackgroundDeathTest, GoodExecuteWriteHistory) { auto result = tran->Execute(state, histories); auto expected_deaths = state.cwiseProduct(tran_matrix); auto expected_return = state - expected_deaths; - auto hist_result = histories["background_death"].GetStateAsVector()[0]; EXPECT_TRUE(result.isApprox(expected_return)); - EXPECT_TRUE(hist_result.isApprox(expected_deaths)); + EXPECT_TRUE(histories["background_death"].HasPendingState()); + EXPECT_TRUE(histories["background_death"].GetPendingState().isApprox( + expected_deaths)); } } // namespace testing } // namespace respond diff --git a/tests/unit/history_test.cpp b/tests/unit/history_test.cpp new file mode 100644 index 00000000..36123c79 --- /dev/null +++ b/tests/unit/history_test.cpp @@ -0,0 +1,120 @@ +//////////////////////////////////////////////////////////////////////////////// +// File: history_test.cpp // +// Project: respond // +// Created Date: 2026-05-05 // +// Author: GitHub Copilot // +// ----- // +// Last Modified: 2026-05-05 // +// Modified By: GitHub Copilot // +// ----- // +//////////////////////////////////////////////////////////////////////////////// + +#include + +#include + +#include +#include + +namespace respond { +namespace testing { + +TEST(HistoryTest, SparseStoragePreservesRecordedTimesteps) { + History history("state", "test_logger"); + Eigen::VectorXd state0(2); + state0 << 1.0f, 2.0f; + Eigen::VectorXd state2(2); + state2 << 3.0f, 4.0f; + + history.AddState(state0, 0); + history.AddState(state2, 2); + + std::vector expected_timesteps = {0, 2}; + ASSERT_EQ(history.GetRecordedTimesteps(), expected_timesteps); + ASSERT_EQ(history.GetRecordedStates().size(), 2u); + EXPECT_TRUE(history.GetRecordedStates()[0].isApprox(state0)); + EXPECT_TRUE(history.GetRecordedStates()[1].isApprox(state2)); +} + +TEST(HistoryTest, GetStateAsVectorFillsSparseGapsWithZeros) { + History history("state", "test_logger"); + Eigen::VectorXd state0(2); + state0 << 1.0f, 2.0f; + Eigen::VectorXd state2(2); + state2 << 3.0f, 4.0f; + + history.AddState(state0, 0); + history.AddState(state2, 2); + + const auto dense_states = history.GetStateAsVector(); + ASSERT_EQ(dense_states.size(), 3u); + EXPECT_TRUE(dense_states[0].isApprox(state0)); + EXPECT_TRUE(dense_states[1].isZero()); + EXPECT_TRUE(dense_states[2].isApprox(state2)); +} + +TEST(HistoryTest, GetStateMapBuildsSparseMapOnDemand) { + History history("state", "test_logger"); + Eigen::VectorXd state0(1); + state0 << 5.0f; + Eigen::VectorXd state3(1); + state3 << 7.0f; + + history.AddState(state0, 0); + history.AddState(state3, 3); + + const auto state_map = history.GetStateMap(); + ASSERT_EQ(state_map.size(), 2u); + EXPECT_TRUE(state_map.at(0).isApprox(state0)); + EXPECT_TRUE(state_map.at(3).isApprox(state3)); +} + +TEST(HistoryTest, ClearRemovesRecordedStatesAndTimesteps) { + History history("state", "test_logger"); + Eigen::VectorXd state(1); + state << 1.0f; + history.AddState(state, 0); + + history.Clear(); + + EXPECT_TRUE(history.GetRecordedTimesteps().empty()); + EXPECT_TRUE(history.GetRecordedStates().empty()); + EXPECT_TRUE(history.GetStateAsVector().empty()); + EXPECT_TRUE(history.GetStateMap().empty()); +} + +TEST(HistoryTest, AccumulatedHistoryFlushesPendingState) { + History history("total_overdose", "test_logger", HistoryMode::Accumulated); + Eigen::VectorXd first(2); + first << 1.0f, 2.0f; + Eigen::VectorXd second(2); + second << 3.0f, 4.0f; + + history.AccumulateState(first); + history.AccumulateState(second); + history.FlushPendingState(4, 2); + + ASSERT_FALSE(history.HasPendingState()); + std::vector expected_timesteps = {4}; + ASSERT_EQ(history.GetRecordedTimesteps(), expected_timesteps); + + Eigen::VectorXd expected(2); + expected << 4.0f, 6.0f; + ASSERT_EQ(history.GetRecordedStates().size(), 1u); + EXPECT_TRUE(history.GetRecordedStates()[0].isApprox(expected)); +} + +TEST(HistoryTest, AccumulatedHistoryFlushesZeroWhenNoPendingStateExists) { + History history("background_death", "test_logger", + HistoryMode::Accumulated); + + history.FlushPendingState(0, 3); + + std::vector expected_timesteps = {0}; + ASSERT_EQ(history.GetRecordedTimesteps(), expected_timesteps); + ASSERT_EQ(history.GetRecordedStates().size(), 1u); + EXPECT_TRUE(history.GetRecordedStates()[0].isZero()); +} + +} // namespace testing +} // namespace respond \ No newline at end of file diff --git a/tests/unit/intervention_test.cpp b/tests/unit/intervention_test.cpp index 8229a872..192449e1 100644 --- a/tests/unit/intervention_test.cpp +++ b/tests/unit/intervention_test.cpp @@ -4,7 +4,7 @@ // Created Date: 2026-02-06 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2026-02-06 // +// Last Modified: 2026-05-06 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2026 Syndemics Lab at Boston Medical Center // @@ -73,11 +73,11 @@ TEST_F(InterventionTest, GoodExecuteWriteHistory) { auto expected_return = tran_matrix * state; auto expected_admissions = (expected_return - state).cwiseMax(Eigen::VectorXd::Zero(3)); - auto hist_result = - histories["intervention_admission"].GetStateAsVector()[0]; EXPECT_TRUE(result.isApprox(expected_return)); - EXPECT_TRUE(hist_result.isApprox(expected_admissions)); + EXPECT_TRUE(histories["intervention_admission"].HasPendingState()); + EXPECT_TRUE(histories["intervention_admission"].GetPendingState().isApprox( + expected_admissions)); } } // namespace testing } // namespace respond \ No newline at end of file diff --git a/tests/unit/logging_test.cpp b/tests/unit/logging_test.cpp index 7061d7b2..6cc1627c 100644 --- a/tests/unit/logging_test.cpp +++ b/tests/unit/logging_test.cpp @@ -1,11 +1,553 @@ //////////////////////////////////////////////////////////////////////////////// // File: logging_test.cpp // -// Project: src // +// Project: respond // // Created Date: 2025-03-18 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2025-06-05 // +// Last Modified: 2026-04-16 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2025 Syndemics Lab at Boston Medical Center // //////////////////////////////////////////////////////////////////////////////// + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace respond { +namespace testing { + +// ============================================================================ +// Test Fixture for Logging Tests +// ============================================================================ + +class LoggingTest : public ::testing::Test { +protected: + void SetUp() override { + // Clear any existing loggers from previous tests + spdlog::drop_all(); + + // Create temporary log files for testing + test_log_file_ = "/tmp/respond_test.log"; + shared_log_file_ = "/tmp/respond_shared.log"; + + // Remove test files if they exist + std::remove(test_log_file_.c_str()); + std::remove(shared_log_file_.c_str()); + } + + void TearDown() override { + // Clean up loggers + spdlog::drop_all(); + + // Remove test files + std::remove(test_log_file_.c_str()); + std::remove(shared_log_file_.c_str()); + } + + std::string test_log_file_; + std::string shared_log_file_; + + // Helper to check if file contains a string + bool FileContains(const std::string &filepath, const std::string &search) { + std::ifstream file(filepath); + if (!file.is_open()) + return false; + + std::string line; + while (std::getline(file, line)) { + if (line.find(search) != std::string::npos) { + return true; + } + } + return false; + } + + // Helper to count lines in file + int CountLinesInFile(const std::string &filepath) { + std::ifstream file(filepath); + if (!file.is_open()) + return 0; + + int count = 0; + std::string line; + while (std::getline(file, line)) { + count++; + } + return count; + } +}; + +// ============================================================================ +// Test: CreateFileLogger Basic Functionality +// ============================================================================ + +TEST_F(LoggingTest, CreateFileLoggerSuccess) { + CreationStatus status = CreateFileLogger("test_logger", test_log_file_); + EXPECT_EQ(status, CreationStatus::kSuccess); + + // Verify logger exists + auto logger = spdlog::get("test_logger"); + EXPECT_NE(logger, nullptr); +} + +TEST_F(LoggingTest, CreateFileLoggerAlreadyExists) { + CreateFileLogger("test_logger", test_log_file_); + + // Try creating again with same name + CreationStatus status = CreateFileLogger("test_logger", test_log_file_); + EXPECT_EQ(status, CreationStatus::kExists); +} + +TEST_F(LoggingTest, CreateMultipleFileLoggers) { + CreationStatus status1 = CreateFileLogger("logger1", test_log_file_); + CreationStatus status2 = CreateFileLogger("logger2", test_log_file_); + + EXPECT_EQ(status1, CreationStatus::kSuccess); + EXPECT_EQ(status2, CreationStatus::kSuccess); + + EXPECT_NE(spdlog::get("logger1"), nullptr); + EXPECT_NE(spdlog::get("logger2"), nullptr); +} + +// ============================================================================ +// Test: Shared File Sink Functionality +// ============================================================================ + +TEST_F(LoggingTest, CreateSharedFileSink) { + CreationStatus status = CreateSharedFileSink(shared_log_file_); + EXPECT_EQ(status, CreationStatus::kSuccess); +} + +TEST_F(LoggingTest, CreateSharedFileSinkCaching) { + // Create sink first time + CreationStatus status1 = CreateSharedFileSink(shared_log_file_); + EXPECT_EQ(status1, CreationStatus::kSuccess); + + // Create sink again with same path (should reuse cached) + CreationStatus status2 = CreateSharedFileSink(shared_log_file_); + EXPECT_EQ(status2, CreationStatus::kSuccess); +} + +TEST_F(LoggingTest, CreateSharedLogger) { + CreateSharedFileSink(shared_log_file_); + + CreationStatus status = CreateSharedLogger("shared_logger"); + EXPECT_EQ(status, CreationStatus::kSuccess); + + auto logger = spdlog::get("shared_logger"); + EXPECT_NE(logger, nullptr); +} + +TEST_F(LoggingTest, MultipleSharedLoggersToSameSink) { + CreateSharedFileSink(shared_log_file_); + + CreationStatus status1 = CreateSharedLogger("shared_logger_1"); + CreationStatus status2 = CreateSharedLogger("shared_logger_2"); + + EXPECT_EQ(status1, CreationStatus::kSuccess); + EXPECT_EQ(status2, CreationStatus::kSuccess); + + // Both loggers should exist + EXPECT_NE(spdlog::get("shared_logger_1"), nullptr); + EXPECT_NE(spdlog::get("shared_logger_2"), nullptr); +} + +// ============================================================================ +// Test: Log Pattern Configuration +// ============================================================================ + +TEST_F(LoggingTest, SetAndGetLogPattern) { + SetLogPattern(LogPattern::kDetailed); + LogPattern pattern = GetLogPattern(); + EXPECT_EQ(pattern, LogPattern::kDetailed); +} + +TEST_F(LoggingTest, LogPatternAffectsNewLoggers) { + SetLogPattern(LogPattern::kSimple); + CreateFileLogger("pattern_logger", test_log_file_); + + LogInfo("pattern_logger", "Test message"); + FlushAllLoggers(); + + // File should exist and contain the message + EXPECT_TRUE(FileContains(test_log_file_, "Test message")); +} + +TEST_F(LoggingTest, ChangeLogPatternMultipleTimes) { + SetLogPattern(LogPattern::kSimple); + EXPECT_EQ(GetLogPattern(), LogPattern::kSimple); + + SetLogPattern(LogPattern::kStandard); + EXPECT_EQ(GetLogPattern(), LogPattern::kStandard); + + SetLogPattern(LogPattern::kDetailed); + EXPECT_EQ(GetLogPattern(), LogPattern::kDetailed); + + SetLogPattern(LogPattern::kThreadSafe); + EXPECT_EQ(GetLogPattern(), LogPattern::kThreadSafe); +} + +// ============================================================================ +// Test: Flush Interval Configuration +// ============================================================================ + +TEST_F(LoggingTest, SetFlushInterval) { + // Should not throw + SetFlushInterval(0); + SetFlushInterval(1); + SetFlushInterval(5); +} + +TEST_F(LoggingTest, SetFlushIntervalZeroForImmediateFlush) { + SetFlushInterval(0); + CreateFileLogger("flush_logger", test_log_file_); + + LogInfo("flush_logger", "Immediate flush"); + + // Should be written immediately + EXPECT_TRUE(FileContains(test_log_file_, "Immediate flush")); +} + +// ============================================================================ +// Test: Logging Functions +// ============================================================================ + +TEST_F(LoggingTest, LogInfo) { + CreateFileLogger("test_logger", test_log_file_); + + LogInfo("test_logger", "Info message"); + FlushAllLoggers(); + + EXPECT_TRUE(FileContains(test_log_file_, "Info message")); +} + +TEST_F(LoggingTest, LogWarning) { + CreateFileLogger("test_logger", test_log_file_); + + LogWarning("test_logger", "Warning message"); + FlushAllLoggers(); + + EXPECT_TRUE(FileContains(test_log_file_, "Warning message")); +} + +TEST_F(LoggingTest, LogError) { + CreateFileLogger("test_logger", test_log_file_); + + LogError("test_logger", "Error message"); + FlushAllLoggers(); + + EXPECT_TRUE(FileContains(test_log_file_, "Error message")); +} + +TEST_F(LoggingTest, LogDebug) { + CreateFileLogger("test_logger", test_log_file_); + SetFlushInterval(0); // Immediate flush + + LogDebug("test_logger", "Debug message"); + FlushAllLoggers(); + + // Verify logger exists and can log + EXPECT_EQ(CheckLoggerExists("test_logger"), CreationStatus::kExists); +} + +TEST_F(LoggingTest, MultipleLogMessages) { + CreateFileLogger("test_logger", test_log_file_); + SetFlushInterval(0); // Immediate flush + + // These should all complete without error + LogInfo("test_logger", "Message 1"); + LogWarning("test_logger", "Message 2"); + LogError("test_logger", "Message 3"); + LogDebug("test_logger", "Message 4"); + FlushAllLoggers(); + + // Verify logger exists + EXPECT_EQ(CheckLoggerExists("test_logger"), CreationStatus::kExists); +} + +// ============================================================================ +// Test: Logger Utility Functions +// ============================================================================ + +TEST_F(LoggingTest, CheckLoggerExistsTrue) { + CreateFileLogger("test_logger", test_log_file_); + + CreationStatus status = CheckLoggerExists("test_logger"); + EXPECT_EQ(status, CreationStatus::kExists); +} + +TEST_F(LoggingTest, CheckLoggerExistsFalse) { + CreationStatus status = CheckLoggerExists("nonexistent_logger"); + EXPECT_EQ(status, CreationStatus::kNotCreated); +} + +TEST_F(LoggingTest, GetLoggerInfo) { + CreateFileLogger("test_logger", test_log_file_); + + std::string info = GetLoggerInfo("test_logger"); + EXPECT_NE(info, ""); + EXPECT_NE(info.find("test_logger"), std::string::npos); +} + +TEST_F(LoggingTest, GetLoggerInfoNonexistent) { + std::string info = GetLoggerInfo("nonexistent_logger"); + EXPECT_NE(info.find("not found"), std::string::npos); +} + +TEST_F(LoggingTest, SetLoggerLevel) { + CreateFileLogger("test_logger", test_log_file_); + + // Should not throw + SetLoggerLevel("test_logger", 0); // trace + SetLoggerLevel("test_logger", 1); // debug + SetLoggerLevel("test_logger", 2); // info + SetLoggerLevel("test_logger", 3); // warn + SetLoggerLevel("test_logger", 4); // error + SetLoggerLevel("test_logger", 5); // critical +} + +// ============================================================================ +// Test: Flush All Loggers +// ============================================================================ + +TEST_F(LoggingTest, FlushAllLoggers) { + CreateFileLogger("logger1", test_log_file_); + CreateFileLogger("logger2", test_log_file_); + + LogInfo("logger1", "Message 1"); + LogInfo("logger2", "Message 2"); + + // Should not throw + FlushAllLoggers(); + + EXPECT_TRUE(FileContains(test_log_file_, "Message 1")); + EXPECT_TRUE(FileContains(test_log_file_, "Message 2")); +} + +TEST_F(LoggingTest, ParallelSharedLoggingToSameFile) { + CreationStatus sink_status = CreateSharedFileSink(shared_log_file_); + EXPECT_EQ(sink_status, CreationStatus::kSuccess); + + const int num_threads = 4; + const int messages_per_thread = 5; + + auto thread_func = [this](int thread_id) { + std::string logger_name = "shared_logger_" + std::to_string(thread_id); + CreationStatus status = CreateSharedLogger(logger_name); + // Status could be kSuccess or kExists if concurrent threads create same + // logger + EXPECT_NE(status, CreationStatus::kError); + + for (int i = 0; i < messages_per_thread; ++i) { + std::string message = "Thread " + std::to_string(thread_id) + + " Message " + std::to_string(i); + LogInfo(logger_name, message); + } + }; + + std::vector threads; + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back(thread_func, i); + } + + for (auto &t : threads) { + t.join(); + } + + FlushAllLoggers(); + + // Verify at least one shared logger exists + EXPECT_EQ(CheckLoggerExists("shared_logger_0"), CreationStatus::kExists); +} + +TEST_F(LoggingTest, ConcurrentCreateSharedLogger) { + CreateSharedFileSink(shared_log_file_); + + const int num_threads = 10; + + auto thread_func = [this](int thread_id) { + std::string logger_name = + "concurrent_logger_" + std::to_string(thread_id); + CreationStatus status = CreateSharedLogger(logger_name); + EXPECT_EQ(status, CreationStatus::kSuccess); + }; + + std::vector threads; + for (int i = 0; i < num_threads; ++i) { + threads.emplace_back(thread_func, i); + } + + for (auto &t : threads) { + t.join(); + } + + // All loggers should be created + for (int i = 0; i < num_threads; ++i) { + std::string logger_name = "concurrent_logger_" + std::to_string(i); + EXPECT_NE(spdlog::get(logger_name), nullptr); + } +} + +// ============================================================================ +// Test: Integration - Full Workflow +// ============================================================================ + +TEST_F(LoggingTest, FullWorkflowFileLogger) { + // Create logger + CreationStatus create_status = + CreateFileLogger("full_test", test_log_file_); + EXPECT_EQ(create_status, CreationStatus::kSuccess); + + // Set pattern + SetLogPattern(LogPattern::kDetailed); + EXPECT_EQ(GetLogPattern(), LogPattern::kDetailed); + + // Log messages + LogInfo("full_test", "Starting full workflow test"); + LogWarning("full_test", "This is a warning"); + LogError("full_test", "This is an error"); + LogDebug("full_test", "Debug information"); + + // Check logger exists + EXPECT_EQ(CheckLoggerExists("full_test"), CreationStatus::kExists); + + // Get info + std::string info = GetLoggerInfo("full_test"); + EXPECT_NE(info, ""); + + // Set level + SetLoggerLevel("full_test", 2); // info level + + // Flush + FlushAllLoggers(); + + // Verify output file + EXPECT_TRUE(FileContains(test_log_file_, "Starting full workflow test")); + EXPECT_TRUE(FileContains(test_log_file_, "This is a warning")); + EXPECT_TRUE(FileContains(test_log_file_, "This is an error")); +} + +TEST_F(LoggingTest, FullWorkflowSharedLogger) { + // Setup shared logging + CreationStatus sink_status = CreateSharedFileSink(shared_log_file_); + EXPECT_EQ(sink_status, CreationStatus::kSuccess); + + SetLogPattern(LogPattern::kThreadSafe); + EXPECT_EQ(GetLogPattern(), LogPattern::kThreadSafe); + + // Create shared loggers + CreationStatus status1 = CreateSharedLogger("shared_1"); + CreationStatus status2 = CreateSharedLogger("shared_2"); + + EXPECT_EQ(status1, CreationStatus::kSuccess); + EXPECT_EQ(status2, CreationStatus::kSuccess); + + // Log from each + LogInfo("shared_1", "Message from logger 1"); + LogInfo("shared_2", "Message from logger 2"); + + // Verify loggers exist + EXPECT_EQ(CheckLoggerExists("shared_1"), CreationStatus::kExists); + EXPECT_EQ(CheckLoggerExists("shared_2"), CreationStatus::kExists); + + FlushAllLoggers(); +} + +TEST_F(LoggingTest, MixedFileAndSharedLoggers) { + // Create file logger + CreationStatus file_status = + CreateFileLogger("file_logger", test_log_file_); + + // Create shared logger + CreationStatus sink_status = CreateSharedFileSink(shared_log_file_); + CreationStatus logger_status = CreateSharedLogger("shared_logger"); + + EXPECT_EQ(file_status, CreationStatus::kSuccess); + EXPECT_EQ(sink_status, CreationStatus::kSuccess); + EXPECT_EQ(logger_status, CreationStatus::kSuccess); + + // Log to both + LogInfo("file_logger", "File logger message"); + LogInfo("shared_logger", "Shared logger message"); + + // Verify both exist + EXPECT_EQ(CheckLoggerExists("file_logger"), CreationStatus::kExists); + EXPECT_EQ(CheckLoggerExists("shared_logger"), CreationStatus::kExists); + + FlushAllLoggers(); +} + +// ============================================================================ +// Test: Transition Error Logging +// ============================================================================ + +TEST_F(LoggingTest, TransitionFactoryInvalidType) { + CreateFileLogger("factory_test", test_log_file_); + + // Create transition with invalid type should log error and return nullptr + auto transition = respond::TransitionFactory::CreateTransition( + "invalid_type", "factory_test"); + + EXPECT_EQ(transition, nullptr); + EXPECT_EQ(CheckLoggerExists("factory_test"), CreationStatus::kExists); +} + +TEST_F(LoggingTest, TransitionFactoryValidTypes) { + CreateFileLogger("factory_test", test_log_file_); + + // Test all valid transition types + auto migration = respond::TransitionFactory::CreateTransition( + "migration", "factory_test"); + EXPECT_NE(migration, nullptr); + + auto behavior = respond::TransitionFactory::CreateTransition( + "behavior", "factory_test"); + EXPECT_NE(behavior, nullptr); + + auto intervention = respond::TransitionFactory::CreateTransition( + "intervention", "factory_test"); + EXPECT_NE(intervention, nullptr); + + auto overdose = respond::TransitionFactory::CreateTransition( + "overdose", "factory_test"); + EXPECT_NE(overdose, nullptr); + + auto background = respond::TransitionFactory::CreateTransition( + "background_death", "factory_test"); + EXPECT_NE(background, nullptr); +} + +TEST_F(LoggingTest, TransitionFactoryCaseInsensitivity) { + CreateFileLogger("factory_test", test_log_file_); + + // Test case-insensitive matching + auto trans1 = respond::TransitionFactory::CreateTransition("MIGRATION", + "factory_test"); + EXPECT_NE(trans1, nullptr); + + auto trans2 = respond::TransitionFactory::CreateTransition("Behavior", + "factory_test"); + EXPECT_NE(trans2, nullptr); + + auto trans3 = respond::TransitionFactory::CreateTransition("INTERVENTION", + "factory_test"); + EXPECT_NE(trans3, nullptr); + + auto trans4 = respond::TransitionFactory::CreateTransition("OverDose", + "factory_test"); + EXPECT_NE(trans4, nullptr); +} + +} // namespace testing +} // namespace respond diff --git a/tests/unit/markov_test.cpp b/tests/unit/markov_test.cpp index 3bd61cb1..635969e1 100644 --- a/tests/unit/markov_test.cpp +++ b/tests/unit/markov_test.cpp @@ -89,6 +89,69 @@ TEST_F(MarkovTest, RunTransitions) { markov->RunTransitions(); } +TEST_F(MarkovTest, RunTransitionsAccumulatesDefaultHistories) { + markov->SetState(state); + markov->RunTransitions(); + + Eigen::VectorXd next_state = state * 2.0; + markov->SetState(next_state); + markov->RunTransitions(); + + const auto histories = markov->GetHistories(); + ASSERT_EQ(histories.size(), 5u); + + const auto state_history = histories.at("state").GetStateAsVector(); + ASSERT_EQ(state_history.size(), 3u); + EXPECT_TRUE(state_history[0].isApprox(state)); + EXPECT_TRUE(state_history[1].isApprox(state)); + EXPECT_TRUE(state_history[2].isApprox(next_state)); + + const auto overdose_history = + histories.at("total_overdose").GetStateAsVector(); + ASSERT_EQ(overdose_history.size(), 3u); + EXPECT_TRUE(overdose_history[0].isZero()); + EXPECT_TRUE(overdose_history[1].isZero()); + EXPECT_TRUE(overdose_history[2].isZero()); +} + +TEST_F(MarkovTest, SparseHistoryCaptureRecordsRequestedAndFinalTimesteps) { + markov->SetHistoryCaptureInterval(2); + markov->SetFinalTimestep(5); + markov->SetState(state); + + for (int step = 0; step < 5; ++step) { + markov->RunTransitions(); + } + + const auto histories = markov->GetHistories(); + const auto ×teps = histories.at("state").GetRecordedTimesteps(); + std::vector expected = {0, 2, 4, 5}; + ASSERT_EQ(timesteps, expected); +} + +TEST_F(MarkovTest, ClearHistoriesResetsTrackingState) { + markov->SetHistoryCaptureInterval(2); + markov->SetFinalTimestep(4); + markov->SetState(state); + + markov->RunTransitions(); + markov->RunTransitions(); + markov->ClearHistories(); + + Eigen::VectorXd next_state = state * 3.0; + markov->SetState(next_state); + markov->RunTransitions(); + + const auto histories = markov->GetHistories(); + const auto ×teps = histories.at("state").GetRecordedTimesteps(); + std::vector expected = {0}; + ASSERT_EQ(timesteps, expected); + + const auto &states = histories.at("state").GetRecordedStates(); + ASSERT_EQ(states.size(), 1u); + EXPECT_TRUE(states[0].isApprox(next_state)); +} + TEST_F(MarkovTest, ClearTransitions) { // When Markov::AddTransition copies the transition it calls `clone()` on // the provided object. Make the mock return a heap-allocated mock that diff --git a/tests/unit/overdose_test.cpp b/tests/unit/overdose_test.cpp index 54359d55..6dcd488f 100644 --- a/tests/unit/overdose_test.cpp +++ b/tests/unit/overdose_test.cpp @@ -4,7 +4,7 @@ // Created Date: 2026-02-06 // // Author: Matthew Carroll // // ----- // -// Last Modified: 2026-02-06 // +// Last Modified: 2026-05-06 // // Modified By: Matthew Carroll // // ----- // // Copyright (c) 2026 Syndemics Lab at Boston Medical Center // @@ -83,14 +83,15 @@ TEST_F(OverdoseTest, GoodExecuteWriteTotalOverdoseHistory) { tran->AddTransitionMatrix(tran_matrix); tran->AddTransitionMatrix(tran_matrix); auto result = tran->Execute(state, histories); - auto od_result = histories["total_overdose"].GetStateAsVector()[0]; auto overdoses = state.cwiseProduct(tran_matrix); auto fods = overdoses.cwiseProduct(tran_matrix); auto expected_return = state - fods; EXPECT_TRUE(result.isApprox(expected_return)); - EXPECT_TRUE(od_result.isApprox(overdoses)); + EXPECT_TRUE(histories["total_overdose"].HasPendingState()); + EXPECT_TRUE( + histories["total_overdose"].GetPendingState().isApprox(overdoses)); } TEST_F(OverdoseTest, GoodExecuteWriteFatalOverdoseHistory) { @@ -99,13 +100,13 @@ TEST_F(OverdoseTest, GoodExecuteWriteFatalOverdoseHistory) { tran->AddTransitionMatrix(tran_matrix); tran->AddTransitionMatrix(tran_matrix); auto result = tran->Execute(state, histories); - auto fod_result = histories["fatal_overdose"].GetStateAsVector()[0]; auto overdoses = state.cwiseProduct(tran_matrix); auto fods = overdoses.cwiseProduct(tran_matrix); auto expected_return = state - fods; - EXPECT_TRUE(fod_result.isApprox(fods)); + EXPECT_TRUE(histories["fatal_overdose"].HasPendingState()); + EXPECT_TRUE(histories["fatal_overdose"].GetPendingState().isApprox(fods)); EXPECT_TRUE(result.isApprox(expected_return)); } @@ -117,15 +118,14 @@ TEST_F(OverdoseTest, GoodExecuteWriteAllHistory) { tran->AddTransitionMatrix(tran_matrix); tran->AddTransitionMatrix(tran_matrix); auto result = tran->Execute(state, histories); - auto od_result = histories["total_overdose"].GetStateAsVector()[0]; - auto fod_result = histories["fatal_overdose"].GetStateAsVector()[0]; auto overdoses = state.cwiseProduct(tran_matrix); auto fods = overdoses.cwiseProduct(tran_matrix); auto expected_return = state - fods; - EXPECT_TRUE(od_result.isApprox(overdoses)); - EXPECT_TRUE(fod_result.isApprox(fods)); + EXPECT_TRUE( + histories["total_overdose"].GetPendingState().isApprox(overdoses)); + EXPECT_TRUE(histories["fatal_overdose"].GetPendingState().isApprox(fods)); EXPECT_TRUE(result.isApprox(expected_return)); } } // namespace testing diff --git a/tests/unit/simulation_test.cpp b/tests/unit/simulation_test.cpp index 56ff3928..06b3140f 100644 --- a/tests/unit/simulation_test.cpp +++ b/tests/unit/simulation_test.cpp @@ -133,5 +133,37 @@ TEST_F(SimulationTest, GetHistoryNames) { ASSERT_EQ(s.GetModelHistoryNames(), expected); } +TEST_F(SimulationTest, GetModelSparseHistories) { + auto mock = std::make_unique>(); + auto cloned = std::make_unique>(); + + std::map hv; + History h("temp", "test_logger"); + Eigen::VectorXd state0 = Eigen::VectorXd(2); + state0 << 1.0f, 2.0f; + Eigen::VectorXd state2 = Eigen::VectorXd(2); + state2 << 3.0f, 4.0f; + h.AddState(state0, 0); + h.AddState(state2, 2); + hv["temp"] = h; + + EXPECT_CALL(*cloned, GetHistories()).WillOnce(Return(hv)); + ON_CALL(*mock, clone()) + .WillByDefault(Return(::testing::ByMove(std::move(cloned)))); + + std::unique_ptr upmm = std::move(mock); + Simulation s; + s.AddModel(upmm); + + const auto histories = s.GetModelSparseHistories(); + ASSERT_EQ(histories.size(), 1u); + const auto &history = histories[0].at("temp"); + std::vector expected_timesteps = {0, 2}; + ASSERT_EQ(history.GetRecordedTimesteps(), expected_timesteps); + ASSERT_EQ(history.GetRecordedStates().size(), 2u); + EXPECT_TRUE(history.GetRecordedStates()[0].isApprox(state0)); + EXPECT_TRUE(history.GetRecordedStates()[1].isApprox(state2)); +} + } // namespace testing } // namespace respond \ No newline at end of file