diff --git a/src/BUILD b/src/BUILD index df50004d5b..43297bd6ca 100644 --- a/src/BUILD +++ b/src/BUILD @@ -682,6 +682,8 @@ ovms_cc_library( "kfs_backend_impl", "tfs_backend_impl", "anonymous_input_name", + "libovms_servable_name_checker", + "libovms_metric_provider", ] + select({ "//:not_disable_cloud": [ "libovmsazurefilesystem", @@ -1188,7 +1190,8 @@ ovms_cc_library( name = "libovmsstatus", hdrs = ["status.hpp",], srcs = ["status.cpp",], - deps = ["libovmslogging"], + deps = ["libovmslogging", + "@fmtlib",], visibility = ["//visibility:public"], ) @@ -1203,6 +1206,20 @@ ovms_cc_library( visibility = ["//visibility:public",], ) +ovms_cc_library( + name = "libovms_servable_name_checker", + hdrs = ["servable_name_checker.hpp",], + deps = [], + visibility = ["//visibility:public"], +) + +ovms_cc_library( + name = "libovms_metric_provider", + hdrs = ["metric_provider.hpp",], + deps = [], + visibility = ["//visibility:public"], +) + ovms_cc_library( # make ovms_lib dependent, use share doptions name = "libovmsstring_utils", hdrs = ["stringutils.hpp",], @@ -1718,6 +1735,7 @@ ovms_cc_library( hdrs = ["modelversionstatus.hpp",], srcs = ["modelversionstatus.cpp",], deps = [ + "@fmtlib", "libovmslogging", "libovmsmodelversion", ], @@ -1772,12 +1790,12 @@ ovms_cc_library( name = "nodeinfo", hdrs = ["dags/nodeinfo.hpp",], deps = [ + "@fmtlib", "libovms_threadsafequeue", "libovmsmodelversion", "libovms_tensorinfo", "node_library", "libovms_dags_aliases", - "libovmslogging", ], visibility = ["//visibility:public"], ) @@ -1797,7 +1815,7 @@ ovms_cc_library( deps = [ "libovmsstatus", "@com_github_grpc_grpc//:grpc++", - "libovmslogging", + "@fmtlib", ], visibility = ["//visibility:public"], ) diff --git a/src/azurestorage.hpp b/src/azurestorage.hpp index f45b2e8084..3a0c3ae58d 100644 --- a/src/azurestorage.hpp +++ b/src/azurestorage.hpp @@ -21,7 +21,6 @@ #include #include -#include "logging.hpp" #include "status.hpp" #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wall" diff --git a/src/capi_frontend/capi.cpp b/src/capi_frontend/capi.cpp index f6a4b2b159..ab1683f9a1 100644 --- a/src/capi_frontend/capi.cpp +++ b/src/capi_frontend/capi.cpp @@ -28,12 +28,14 @@ #pragma warning(pop) #include "../dags/pipeline.hpp" +#include "../dags/pipeline_factory.hpp" #include "../dags/pipelinedefinition.hpp" #include "../dags/pipelinedefinitionstatus.hpp" #include "../dags/pipelinedefinitionunloadguard.hpp" #include "../execution_context.hpp" #include "../version.hpp" #if (MEDIAPIPE_DISABLE == 0) +#include "../mediapipe_internal/mediapipefactory.hpp" #include "../mediapipe_internal/mediapipegraphdefinition.hpp" #endif #include "../model_service.hpp" @@ -120,7 +122,7 @@ static Status getPipeline(ovms::Server& server, const InferenceRequest* request, if (!status.ok()) { return status; } - return modelManager->createPipeline(pipelinePtr, request->getServableName(), request, response); + return modelManager->getPipelineFactory().create(pipelinePtr, request->getServableName(), request, response, *modelManager); } static Status getPipelineDefinition(Server& server, const std::string& servableName, PipelineDefinition** pipelineDefinition, std::unique_ptr& unloadGuard) { diff --git a/src/capi_frontend/capi_request_utils.hpp b/src/capi_frontend/capi_request_utils.hpp index 7d22e1c7ce..565a17e417 100644 --- a/src/capi_frontend/capi_request_utils.hpp +++ b/src/capi_frontend/capi_request_utils.hpp @@ -19,10 +19,10 @@ #include #include "../ovms.h" // NOLINT +#include "src/logging.hpp" #include "../precision.hpp" #include "inferencerequest.hpp" #include "../shape.hpp" -#include "../logging.hpp" #include "../status.hpp" // TODO move impl @atobisze #include "../extractchoice.hpp" #include "../requesttensorextractor.hpp" diff --git a/src/color_format_configuration.cpp b/src/color_format_configuration.cpp index 0b6640bc88..f1973f53da 100644 --- a/src/color_format_configuration.cpp +++ b/src/color_format_configuration.cpp @@ -18,6 +18,8 @@ #include #include +#include "logging.hpp" + namespace ovms { const char ColorFormatConfiguration::COLOR_FORMAT_DELIMITER = ':'; diff --git a/src/dags/entry_node.hpp b/src/dags/entry_node.hpp index c146b5021f..018aabdde2 100644 --- a/src/dags/entry_node.hpp +++ b/src/dags/entry_node.hpp @@ -21,7 +21,6 @@ #include -#include "../logging.hpp" #include "../ovms.h" // NOLINT #include "../regularovtensorfactory.hpp" #include "../tensorinfo.hpp" diff --git a/src/dags/gatherexitnodeinputhandler.hpp b/src/dags/gatherexitnodeinputhandler.hpp index 9432e3b2ac..14ac2aa522 100644 --- a/src/dags/gatherexitnodeinputhandler.hpp +++ b/src/dags/gatherexitnodeinputhandler.hpp @@ -23,7 +23,6 @@ #include "../capi_frontend/capi_utils.hpp" #include "../capi_frontend/capi_dag_utils.hpp" #include "../kfs_frontend/kfs_utils.hpp" -#include "../logging.hpp" #include "../profiler.hpp" #include "../status.hpp" #include "../tfs_frontend/tfs_utils.hpp" diff --git a/src/dags/nodeinfo.hpp b/src/dags/nodeinfo.hpp index 503c9397df..6a20f31717 100644 --- a/src/dags/nodeinfo.hpp +++ b/src/dags/nodeinfo.hpp @@ -24,11 +24,12 @@ #include #include +#include + #include "../modelversion.hpp" #include "../tensorinfo.hpp" #include "aliases.hpp" #include "node_library.hpp" -#include "../logging.hpp" namespace ovms { diff --git a/src/dags/pipelinedefinition.cpp b/src/dags/pipelinedefinition.cpp index 434859c2ad..31b0ea7c63 100644 --- a/src/dags/pipelinedefinition.cpp +++ b/src/dags/pipelinedefinition.cpp @@ -20,6 +20,7 @@ #include #include "../logging.hpp" +#include "../model.hpp" #include "../model_metric_reporter.hpp" #include "../modelinstance.hpp" #include "../modelinstanceunloadguard.hpp" @@ -68,16 +69,10 @@ PipelineDefinition::PipelineDefinition(const std::string& pipelineName, Status PipelineDefinition::validate(ModelManager& manager) { SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Started validation of pipeline: {}", getName()); ValidationResultNotifier notifier(status, loadedNotify); - if (manager.modelExists(this->pipelineName)) { - SPDLOG_LOGGER_ERROR(modelmanager_logger, "Pipeline name: {} is already occupied by model.", pipelineName); + if (manager.servableExists(this->pipelineName, ServableType::Model | ServableType::Mediapipe)) { + SPDLOG_LOGGER_ERROR(modelmanager_logger, "Pipeline name: {} is already occupied by model or mediapipe graph.", pipelineName); return StatusCode::PIPELINE_NAME_OCCUPIED; } -#if (MEDIAPIPE_DISABLE == 0) - if (manager.getMediapipeFactory().definitionExists(this->pipelineName)) { - SPDLOG_LOGGER_ERROR(modelmanager_logger, "Pipeline name: {} is already occupied by mediapipe graph.", pipelineName); - return StatusCode::PIPELINE_NAME_OCCUPIED; - } -#endif Status validationResult = initializeNodeResources(manager); if (!validationResult.ok()) { return validationResult; diff --git a/src/get_model_metadata_impl.cpp b/src/get_model_metadata_impl.cpp index 195df6da95..73b171040e 100644 --- a/src/get_model_metadata_impl.cpp +++ b/src/get_model_metadata_impl.cpp @@ -17,10 +17,12 @@ #include +#include "dags/pipeline_factory.hpp" #include "dags/pipelinedefinition.hpp" #include "dags/pipelinedefinitionstatus.hpp" #include "dags/pipelinedefinitionunloadguard.hpp" #include "execution_context.hpp" +#include "model.hpp" #include "modelinstance.hpp" #include "modelinstanceunloadguard.hpp" #include "modelmanager.hpp" diff --git a/src/grpc_utils.hpp b/src/grpc_utils.hpp index d9222f68f1..17baed5c3e 100644 --- a/src/grpc_utils.hpp +++ b/src/grpc_utils.hpp @@ -15,9 +15,9 @@ //***************************************************************************** #pragma once #include -#include -#include "logging.hpp" +#include +#include namespace ovms { class Status; diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index afe163e6dc..fc8f46b740 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -38,6 +38,7 @@ #include "config.hpp" #include "dags/pipeline.hpp" +#include "dags/pipeline_factory.hpp" #include "dags/pipelinedefinition.hpp" #include "dags/pipelinedefinitionunloadguard.hpp" #include "execution_context.hpp" @@ -68,6 +69,7 @@ #include "http_payload.hpp" #include "http_frontend/http_client_connection.hpp" #include "http_frontend/http_graph_executor_impl.hpp" +#include "mediapipe_internal/mediapipefactory.hpp" #include "mediapipe_internal/mediapipegraphexecutor.hpp" #endif @@ -1153,7 +1155,7 @@ Status HttpRestApiHandler::processPredictRequest( if (this->modelManager.modelExists(modelName)) { SPDLOG_DEBUG("Found model with name: {}. Searching for requested version...", modelName); status = processSingleModelRequest(modelName, modelVersion, request, requestOrder, responseProto, reporterOut); - } else if (this->modelManager.pipelineDefinitionExists(modelName)) { + } else if (this->modelManager.servableExists(modelName, ServableType::Pipeline)) { SPDLOG_DEBUG("Found pipeline with name: {}", modelName); status = processPipelineRequest(modelName, request, requestOrder, responseProto, reporterOut); } else { @@ -1288,7 +1290,7 @@ Status HttpRestApiHandler::processPipelineRequest(const std::string& modelName, tensorflow::serving::PredictRequest& requestProto = requestParser.getProto(); requestProto.mutable_model_spec()->set_name(modelName); - status = this->modelManager.createPipeline(pipelinePtr, modelName, &requestProto, &responseProto); + status = this->modelManager.getPipelineFactory().create(pipelinePtr, modelName, &requestProto, &responseProto, this->modelManager); if (!status.ok()) { INCREMENT_IF_ENABLED(reporterOut->getInferRequestMetric(executionContext, false)); return status; diff --git a/src/image_gen/imagegen_init.cpp b/src/image_gen/imagegen_init.cpp index a96cbd764c..df9bde7059 100644 --- a/src/image_gen/imagegen_init.cpp +++ b/src/image_gen/imagegen_init.cpp @@ -20,6 +20,8 @@ #include #include +#include + #include "absl/strings/str_replace.h" #include "absl/strings/ascii.h" diff --git a/src/inference_request_common.hpp b/src/inference_request_common.hpp index a0f79f2460..b6405827e5 100644 --- a/src/inference_request_common.hpp +++ b/src/inference_request_common.hpp @@ -22,7 +22,6 @@ #include #include -#include "logging.hpp" #include "shape.hpp" #include "anonymous_input_name.hpp" #include "status.hpp" diff --git a/src/kfs_frontend/kfs_grpc_inference_service.cpp b/src/kfs_frontend/kfs_grpc_inference_service.cpp index e3499caa9f..f129fe008b 100644 --- a/src/kfs_frontend/kfs_grpc_inference_service.cpp +++ b/src/kfs_frontend/kfs_grpc_inference_service.cpp @@ -26,6 +26,7 @@ #include "kfs_utils.hpp" #include "kfs_request_utils.hpp" #include "../dags/pipeline.hpp" +#include "../dags/pipeline_factory.hpp" #include "../dags/pipelinedefinition.hpp" #include "../dags/pipelinedefinitionstatus.hpp" #include "../dags/pipelinedefinitionunloadguard.hpp" @@ -36,11 +37,13 @@ // kfs_graph_executor_impl needs to be included before mediapipegraphexecutor // because it contains functions required by graph execution template #include "kfs_graph_executor_impl.hpp" +#include "../mediapipe_internal/mediapipefactory.hpp" #include "../mediapipe_internal/mediapipegraphdefinition.hpp" #include "../mediapipe_internal/mediapipegraphexecutor.hpp" // clang-format on #endif #include "../metric.hpp" +#include "../model.hpp" #include "../modelinstance.hpp" #include "../deserialization_main.hpp" #include "../inference_executor.hpp" @@ -85,7 +88,7 @@ Status KFSInferenceServiceImpl::getPipeline(const KFSRequest* request, KFSResponse* response, std::unique_ptr& pipelinePtr) { OVMS_PROFILE_FUNCTION(); - return this->modelManager.createPipeline(pipelinePtr, request->model_name(), request, response); + return this->modelManager.getPipelineFactory().create(pipelinePtr, request->model_name(), request, response, this->modelManager); } const std::string PLATFORM = "OpenVINO"; diff --git a/src/kfs_frontend/validation.hpp b/src/kfs_frontend/validation.hpp index 5dab096d64..b88e422b0a 100644 --- a/src/kfs_frontend/validation.hpp +++ b/src/kfs_frontend/validation.hpp @@ -24,7 +24,6 @@ #include "kfs_utils.hpp" #include "../precision.hpp" #include "../predict_request_validation_utils.hpp" -#include "../logging.hpp" #include "../profiler.hpp" #include "../tensorinfo.hpp" #include "../status.hpp" diff --git a/src/llm/apis/openai_completions.cpp b/src/llm/apis/openai_completions.cpp index 6898b51604..64859a9fee 100644 --- a/src/llm/apis/openai_completions.cpp +++ b/src/llm/apis/openai_completions.cpp @@ -25,6 +25,8 @@ #include #include +#include + #include "openai_json_response.hpp" #include "../../logging.hpp" diff --git a/src/logging.cpp b/src/logging.cpp index e89fce9a07..c0974c3a4e 100644 --- a/src/logging.cpp +++ b/src/logging.cpp @@ -15,6 +15,9 @@ //***************************************************************************** #include "logging.hpp" +#include +#include + #if (MEDIAPIPE_DISABLE == 0) #include #endif diff --git a/src/logging.hpp b/src/logging.hpp index 011458fe49..98b842f4d3 100644 --- a/src/logging.hpp +++ b/src/logging.hpp @@ -18,9 +18,6 @@ #include #include -#include -#include -#include #include namespace ovms { diff --git a/src/mediapipe_internal/mediapipefactory.cpp b/src/mediapipe_internal/mediapipefactory.cpp index aa3689ae31..7df68ab004 100644 --- a/src/mediapipe_internal/mediapipefactory.cpp +++ b/src/mediapipe_internal/mediapipefactory.cpp @@ -34,7 +34,8 @@ #pragma warning(pop) #include "../kfs_frontend/kfs_grpc_inference_service.hpp" #include "../logging.hpp" -#include "../modelmanager.hpp" +#include "../metric_provider.hpp" +#include "../servable_name_checker.hpp" #include "../status.hpp" #include "../stringutils.hpp" #pragma warning(push) @@ -62,13 +63,14 @@ MediapipeFactory::MediapipeFactory(PythonBackend* pythonBackend) { Status MediapipeFactory::createDefinition(const std::string& pipelineName, const MediapipeGraphConfig& config, - ModelManager& manager) { + MetricProvider& metrics, + const ServableNameChecker& checker) { if (definitionExists(pipelineName)) { SPDLOG_LOGGER_ERROR(modelmanager_logger, "Mediapipe graph definition: {} is already created", pipelineName); return StatusCode::PIPELINE_DEFINITION_ALREADY_EXIST; } - std::shared_ptr graphDefinition = std::make_shared(pipelineName, config, manager.getMetricRegistry(), &manager.getMetricConfig(), pythonBackend); - auto stat = graphDefinition->validate(manager); + std::shared_ptr graphDefinition = std::make_shared(pipelineName, config, metrics.getMetricRegistry(), &metrics.getMetricConfig(), pythonBackend); + auto stat = graphDefinition->validate(checker); if (stat.getCode() == StatusCode::MEDIAPIPE_GRAPH_NAME_OCCUPIED) { return stat; } @@ -94,19 +96,18 @@ MediapipeGraphDefinition* MediapipeFactory::findDefinitionByName(const std::stri Status MediapipeFactory::reloadDefinition(const std::string& name, const MediapipeGraphConfig& config, - ModelManager& manager) { + const ServableNameChecker& checker) { auto mgd = findDefinitionByName(name); if (mgd == nullptr) { SPDLOG_LOGGER_ERROR(modelmanager_logger, "Requested to reload mediapipe graph definition but it does not exist: {}", name); return StatusCode::INTERNAL_ERROR; } SPDLOG_LOGGER_INFO(modelmanager_logger, "Reloading mediapipe graph: {}", name); - return mgd->reload(manager, config); + return mgd->reload(checker, config); } Status MediapipeFactory::create(std::unique_ptr& pipeline, - const std::string& name, - ModelManager& manager) const { + const std::string& name) const { std::shared_lock lock(definitionsMtx); auto it = definitions.find(name); if (it == definitions.end()) { @@ -117,17 +118,17 @@ Status MediapipeFactory::create(std::unique_ptr& pipelin return definition.create(pipeline); } -void MediapipeFactory::retireOtherThan(std::set&& graphsInConfigFile, ModelManager& manager) { +void MediapipeFactory::retireOtherThan(std::set&& graphsInConfigFile) { std::for_each(definitions.begin(), definitions.end(), - [&graphsInConfigFile, &manager](auto& nameDefinitionPair) { + [&graphsInConfigFile](auto& nameDefinitionPair) { if (graphsInConfigFile.find(nameDefinitionPair.second->getName()) == graphsInConfigFile.end() && nameDefinitionPair.second->getStateCode() != PipelineDefinitionStateCode::RETIRED) { - nameDefinitionPair.second->retire(manager); + nameDefinitionPair.second->retire(); } }); } -Status MediapipeFactory::revalidatePipelines(ModelManager&) { +Status MediapipeFactory::revalidatePipelines() { SPDLOG_LOGGER_WARN(modelmanager_logger, "revalidation of mediapipe graphs not implemented yet"); return StatusCode::OK; } diff --git a/src/mediapipe_internal/mediapipefactory.hpp b/src/mediapipe_internal/mediapipefactory.hpp index e48146b0f0..bb63d22534 100644 --- a/src/mediapipe_internal/mediapipefactory.hpp +++ b/src/mediapipe_internal/mediapipefactory.hpp @@ -36,7 +36,8 @@ namespace ovms { -class ModelManager; +class MetricProvider; +class ServableNameChecker; class Status; class MediapipeGraphConfig; class MediapipeGraphDefinition; @@ -53,30 +54,22 @@ class MediapipeFactory { MediapipeFactory(PythonBackend* pythonBackend = nullptr); Status createDefinition(const std::string& pipelineName, const MediapipeGraphConfig& config, - ModelManager& manager); + MetricProvider& metrics, + const ServableNameChecker& checker); bool definitionExists(const std::string& name) const; -private: - template - Status createInternal(std::unique_ptr& pipeline, - const std::string& name, - const RequestType* request, - ResponseType* response, - ModelManager& manager) const; - public: Status create(std::unique_ptr& pipeline, - const std::string& name, - ModelManager& manager) const; + const std::string& name) const; MediapipeGraphDefinition* findDefinitionByName(const std::string& name) const; Status reloadDefinition(const std::string& pipelineName, const MediapipeGraphConfig& config, - ModelManager& manager); + const ServableNameChecker& checker); - void retireOtherThan(std::set&& pipelinesInConfigFile, ModelManager& manager); - Status revalidatePipelines(ModelManager&); + void retireOtherThan(std::set&& pipelinesInConfigFile); + Status revalidatePipelines(); const std::vector getMediapipePipelinesNames() const; const std::vector getNamesOfAvailableMediapipePipelines() const; ~MediapipeFactory(); diff --git a/src/mediapipe_internal/mediapipegraphdefinition.cpp b/src/mediapipe_internal/mediapipegraphdefinition.cpp index 9047765e75..ad33e96c98 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.cpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.cpp @@ -31,8 +31,8 @@ #include "../deserialization_main.hpp" #include "../metric.hpp" #include "../model_metric_reporter.hpp" -#include "../modelmanager.hpp" #include "../ov_utils.hpp" +#include "../servable_name_checker.hpp" #include "../llm/servable.hpp" #include "../llm/servable_initializer.hpp" #if (PYTHON_DISABLE == 0) @@ -127,14 +127,14 @@ Status MediapipeGraphDefinition::dryInitializeTest() { } return StatusCode::OK; } -Status MediapipeGraphDefinition::validate(ModelManager& manager) { +Status MediapipeGraphDefinition::validate(const ServableNameChecker& checker) { SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Started validation of mediapipe: {}", getName()); if (!this->sidePacketMaps.empty()) { SPDLOG_ERROR("Internal Error: MediaPipe definition is in unexpected state."); return StatusCode::INTERNAL_ERROR; } ValidationResultNotifier notifier(this->status, this->loadedNotify); - if (manager.modelExists(this->getName()) || manager.pipelineDefinitionExists(this->getName())) { + if (checker.servableExists(this->getName(), ServableType::Model | ServableType::Pipeline)) { SPDLOG_LOGGER_ERROR(modelmanager_logger, "Mediapipe graph name: {} is already occupied by model or pipeline.", this->getName()); return StatusCode::MEDIAPIPE_GRAPH_NAME_OCCUPIED; } @@ -332,7 +332,7 @@ Status MediapipeGraphDefinition::setStreamTypes() { return StatusCode::OK; } -Status MediapipeGraphDefinition::reload(ModelManager& manager, const MediapipeGraphConfig& config) { +Status MediapipeGraphDefinition::reload(const ServableNameChecker& checker, const MediapipeGraphConfig& config) { // block creating new unloadGuards this->status.handle(ReloadEvent()); while (requestsHandlesCounter > 0) { @@ -340,10 +340,10 @@ Status MediapipeGraphDefinition::reload(ModelManager& manager, const MediapipeGr } this->mgconfig = config; this->sidePacketMaps.clear(); - return validate(manager); + return validate(checker); } -void MediapipeGraphDefinition::retire(ModelManager& manager) { +void MediapipeGraphDefinition::retire() { this->sidePacketMaps.clear(); this->status.handle(RetireEvent()); } diff --git a/src/mediapipe_internal/mediapipegraphdefinition.hpp b/src/mediapipe_internal/mediapipegraphdefinition.hpp index 14c9e0679f..0ed4d7b752 100644 --- a/src/mediapipe_internal/mediapipegraphdefinition.hpp +++ b/src/mediapipe_internal/mediapipegraphdefinition.hpp @@ -54,7 +54,7 @@ class MediapipeGraphDefinitionUnloadGuard; class MetricConfig; class MetricRegistry; class MediapipeServableMetricReporter; -class ModelManager; +class ServableNameChecker; class MediapipeGraphExecutor; class Status; class PythonBackend; @@ -121,9 +121,9 @@ class MediapipeGraphDefinition { MediapipeServableMetricReporter& getMetricReporter() const { return *this->reporter; } Status create(std::unique_ptr& pipeline); - Status reload(ModelManager& manager, const MediapipeGraphConfig& config); - Status validate(ModelManager& manager); - void retire(ModelManager& manager); + Status reload(const ServableNameChecker& checker, const MediapipeGraphConfig& config); + Status validate(const ServableNameChecker& checker); + void retire(); Status initializeNodes(); bool isReloadRequired(const MediapipeGraphConfig& config) const; diff --git a/src/metric_provider.hpp b/src/metric_provider.hpp new file mode 100644 index 0000000000..36824a8599 --- /dev/null +++ b/src/metric_provider.hpp @@ -0,0 +1,30 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +namespace ovms { + +class MetricConfig; +class MetricRegistry; + +class MetricProvider { +public: + virtual ~MetricProvider() = default; + virtual MetricRegistry* getMetricRegistry() const = 0; + virtual const MetricConfig& getMetricConfig() const = 0; +}; + +} // namespace ovms diff --git a/src/model_service.cpp b/src/model_service.cpp index 658c628f76..8477b1ada3 100644 --- a/src/model_service.cpp +++ b/src/model_service.cpp @@ -33,6 +33,7 @@ #include "tensorflow_serving/apis/model_service.pb.h" #pragma GCC diagnostic pop +#include "dags/pipeline_factory.hpp" #include "dags/pipelinedefinition.hpp" #include "execution_context.hpp" #include "grpc_utils.hpp" @@ -40,6 +41,7 @@ #include "mediapipe_internal/mediapipefactory.hpp" #include "mediapipe_internal/mediapipegraphdefinition.hpp" #endif +#include "model.hpp" #include "modelinstance.hpp" #include "modelmanager.hpp" #include "servablemanagermodule.hpp" diff --git a/src/modelconfig.cpp b/src/modelconfig.cpp index d17e6fa6bf..d4666e83d3 100644 --- a/src/modelconfig.cpp +++ b/src/modelconfig.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -29,6 +30,7 @@ #include "src/port/rapidjson_writer.hpp" #pragma warning(pop) +#include "anonymous_input_name.hpp" #include "filesystem.hpp" #include "json_parser.hpp" #include "logging.hpp" @@ -811,4 +813,31 @@ const std::string ModelConfig::getPath() const { return getLocalPath() + FileSystem::getOsSeparator() + std::to_string(version); } +bool ModelConfig::anyShapeSetToAuto() const { + for (const auto& [name, shapeInfo] : getShapes()) { + if (shapeInfo.shapeMode == AUTO) + return true; + } + return false; +} + +bool ModelConfig::isShapeAuto(const std::string& name) const { + auto it = getShapes().find(name); + if (it == getShapes().end()) { + it = getShapes().find(ANONYMOUS_INPUT_NAME); + } + if (it == getShapes().end()) { + return false; + } + return it->second.shapeMode == Mode::AUTO; +} + +bool ModelConfig::isShapeAnonymous() const { + return getShapes().size() == 1 && getShapes().begin()->first == ANONYMOUS_INPUT_NAME; +} + +bool ModelConfig::isShapeAnonymousFixed() const { + return isShapeAnonymous() && !isShapeAuto(ANONYMOUS_INPUT_NAME); +} + } // namespace ovms diff --git a/src/modelconfig.hpp b/src/modelconfig.hpp index a41463cb1d..1cbb849338 100644 --- a/src/modelconfig.hpp +++ b/src/modelconfig.hpp @@ -15,11 +15,9 @@ //***************************************************************************** #pragma once -#include #include #include #include -#include #include #include #include @@ -30,7 +28,6 @@ #include #pragma warning(pop) -#include "anonymous_input_name.hpp" #include "layout_configuration.hpp" #include "color_format_configuration.hpp" #include "precision_configuration.hpp" @@ -192,7 +189,7 @@ class ModelConfig { /** * @brief Allowed configurable layouts */ - static const std::set configAllowedLayouts; + /** * @brief custom_loader_options config as map @@ -742,13 +739,7 @@ class ModelConfig { * * @return bool */ - bool anyShapeSetToAuto() const { - for (const auto& [name, shapeInfo] : getShapes()) { - if (shapeInfo.shapeMode == AUTO) - return true; - } - return false; - } + bool anyShapeSetToAuto() const; /** * @brief Get the shapes @@ -773,24 +764,9 @@ class ModelConfig { * * @return bool */ - bool isShapeAuto(const std::string& name) const { - auto it = getShapes().find(name); - if (it == getShapes().end()) { - it = getShapes().find(ANONYMOUS_INPUT_NAME); - } - if (it == getShapes().end()) { - return false; - } - return it->second.shapeMode == Mode::AUTO; - } - - bool isShapeAnonymous() const { - return getShapes().size() == 1 && getShapes().begin()->first == ANONYMOUS_INPUT_NAME; - } - - bool isShapeAnonymousFixed() const { - return isShapeAnonymous() && !isShapeAuto(ANONYMOUS_INPUT_NAME); - } + bool isShapeAuto(const std::string& name) const; + bool isShapeAnonymous() const; + bool isShapeAnonymousFixed() const; bool isCloudStored() const { return getLocalPath() != getBasePath(); diff --git a/src/modelinstance.hpp b/src/modelinstance.hpp index a8b4067d3e..a8ce7633ec 100644 --- a/src/modelinstance.hpp +++ b/src/modelinstance.hpp @@ -30,7 +30,6 @@ #include -#include "logging.hpp" #include "model_metric_reporter.hpp" #include "modelchangesubscription.hpp" #include "modelconfig.hpp" diff --git a/src/modelmanager.cpp b/src/modelmanager.cpp index e1588e86f7..58b4ce96e4 100644 --- a/src/modelmanager.cpp +++ b/src/modelmanager.cpp @@ -34,6 +34,7 @@ #endif #include +#include #pragma warning(push) #pragma warning(disable : 6313) #include @@ -64,6 +65,7 @@ #endif #include "metric_config.hpp" #include "metric_registry.hpp" +#include "model.hpp" #include "modelinstance.hpp" // for logging #include "ov_utils.hpp" #include "schema.hpp" @@ -79,8 +81,9 @@ const std::string DEFAULT_MODEL_CACHE_DIRECTORY = "/opt/cache"; #endif ModelManager::ModelManager(const std::string& modelCacheDirectory, MetricRegistry* registry, PythonBackend* pythonBackend) : ieCore(std::make_unique()), + pipelineFactory(std::make_unique()), #if (MEDIAPIPE_DISABLE == 0) - mediapipeFactory(pythonBackend), + mediapipeFactory(std::make_unique(pythonBackend)), #endif waitForModelLoadedTimeoutMs(DEFAULT_WAIT_FOR_MODEL_LOADED_TIMEOUT_MS), modelCacheDirectory(modelCacheDirectory), @@ -511,7 +514,7 @@ Status ModelManager::processMediapipeConfig(const MediapipeGraphConfig& config, MediapipeGraphDefinition* mediapipeGraphDefinition = factory.findDefinitionByName(config.getGraphName()); if (mediapipeGraphDefinition == nullptr) { SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Mediapipe graph:{} was not loaded so far. Triggering load", config.getGraphName()); - auto status = factory.createDefinition(config.getGraphName(), config, *this); + auto status = factory.createDefinition(config.getGraphName(), config, *this, *this); return status; } if (mediapipeGraphDefinition->isReloadRequired(config)) { @@ -692,7 +695,7 @@ Status ModelManager::loadCustomNodeLibrariesConfig(rapidjson::Document& configJs Status ModelManager::loadMediapipeGraphsConfig(std::vector& mediapipesInConfigFile) { if (mediapipesInConfigFile.size() == 0) { SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Configuration file doesn't have mediapipe property."); - mediapipeFactory.retireOtherThan({}, *this); + mediapipeFactory->retireOtherThan({}); return StatusCode::OK; } std::set mediapipesInConfigFileNames; @@ -701,13 +704,13 @@ Status ModelManager::loadMediapipeGraphsConfig(std::vector for (const auto& mediapipeGraphConfig : mediapipesInConfigFile) { mediapipesInConfigFileNames.insert(mediapipeGraphConfig.getGraphName()); } - mediapipeFactory.retireOtherThan(std::move(mediapipesInConfigFileNames), *this); + mediapipeFactory->retireOtherThan(std::move(mediapipesInConfigFileNames)); std::set mediapipesAlreadyLoaded; for (const auto& mediapipeGraphConfig : mediapipesInConfigFile) { if (spdlog::default_logger_raw()->level() <= spdlog::level::debug) { mediapipeGraphConfig.logGraphConfigContent(); } - auto status = processMediapipeConfig(mediapipeGraphConfig, mediapipesAlreadyLoaded, mediapipeFactory); + auto status = processMediapipeConfig(mediapipeGraphConfig, mediapipesAlreadyLoaded, *mediapipeFactory); if (status != StatusCode::OK) { IF_ERROR_NOT_OCCURRED_EARLIER_THEN_SET_FIRST_ERROR(status); } @@ -725,18 +728,18 @@ Status ModelManager::loadPipelinesConfig(rapidjson::Document& configJson) { const auto itrp = configJson.FindMember("pipeline_config_list"); if (itrp == configJson.MemberEnd() || !itrp->value.IsArray()) { SPDLOG_LOGGER_DEBUG(modelmanager_logger, "Configuration file doesn't have pipelines property."); - pipelineFactory.retireOtherThan({}, *this); + pipelineFactory->retireOtherThan({}, *this); return StatusCode::OK; } std::set pipelinesInConfigFile; Status firstErrorStatus = StatusCode::OK; for (const auto& pipelineConfig : itrp->value.GetArray()) { - auto status = processPipelineConfig(configJson, pipelineConfig, pipelinesInConfigFile, pipelineFactory, *this); + auto status = processPipelineConfig(configJson, pipelineConfig, pipelinesInConfigFile, *pipelineFactory, *this); if (status != StatusCode::OK) { IF_ERROR_NOT_OCCURRED_EARLIER_THEN_SET_FIRST_ERROR(status); } } - pipelineFactory.retireOtherThan(std::move(pipelinesInConfigFile), *this); + pipelineFactory->retireOtherThan(std::move(pipelinesInConfigFile), *this); return firstErrorStatus; } @@ -909,18 +912,11 @@ Status ModelManager::loadModels(const rapidjson::Value::MemberIterator& modelsCo modelConfig.setCacheDir(this->modelCacheDirectory); const auto& modelName = modelConfig.getName(); - if (pipelineDefinitionExists(modelName)) { + if (servableExists(modelName, ServableType::Pipeline | ServableType::Mediapipe)) { IF_ERROR_NOT_OCCURRED_EARLIER_THEN_SET_FIRST_ERROR(StatusCode::MODEL_NAME_OCCUPIED); - SPDLOG_LOGGER_ERROR(modelmanager_logger, "Model name: {} is already occupied by pipeline definition.", modelName); + SPDLOG_LOGGER_ERROR(modelmanager_logger, "Model name: {} is already occupied by pipeline or mediapipe graph definition.", modelName); continue; } -#if (MEDIAPIPE_DISABLE == 0) - if (mediapipeFactory.definitionExists(modelName)) { - IF_ERROR_NOT_OCCURRED_EARLIER_THEN_SET_FIRST_ERROR(StatusCode::MODEL_NAME_OCCUPIED); - SPDLOG_LOGGER_ERROR(modelmanager_logger, "Model name: {} is already occupied by mediapipe graph definition.", modelName); - continue; - } -#endif if (modelsInConfigFile.find(modelName) != modelsInConfigFile.end()) { IF_ERROR_NOT_OCCURRED_EARLIER_THEN_SET_FIRST_ERROR(StatusCode::MODEL_NAME_OCCUPIED); SPDLOG_LOGGER_WARN(modelmanager_logger, "Duplicated model names: {} defined in config file. Only first definition will be loaded.", modelName); @@ -1180,7 +1176,7 @@ Status ModelManager::updateConfigurationWithoutConfigFile() { reloadNeeded = true; } } - status = pipelineFactory.revalidatePipelines(*this); + status = pipelineFactory->revalidatePipelines(*this); if (!status.ok()) { IF_ERROR_NOT_OCCURRED_EARLIER_THEN_SET_FIRST_ERROR(status); } @@ -1421,6 +1417,10 @@ Status ModelManager::checkStatefulFlagChange(const std::string& modelName, bool return StatusCode::OK; } +std::shared_ptr ModelManager::modelFactory(const std::string& name, const bool isStateful) { + return std::make_shared(name, isStateful, &this->globalSequencesViewer); +} + std::shared_ptr ModelManager::getModelIfExistCreateElse(const std::string& modelName, const bool isStateful) { std::unique_lock modelsLock(modelsMtx); auto modelIt = models.find(modelName); @@ -1724,7 +1724,7 @@ const std::vector ModelManager::getNamesOfAvailableModels() const { Status ModelManager::createPipeline(std::unique_ptr& graph, const std::string& name) { #if (MEDIAPIPE_DISABLE == 0) - return this->mediapipeFactory.create(graph, name, *this); + return this->mediapipeFactory->create(graph, name); #else SPDLOG_ERROR("Mediapipe support was disabled during build process..."); return StatusCode::INTERNAL_ERROR; @@ -1734,4 +1734,24 @@ Status ModelManager::createPipeline(std::unique_ptr& gra void ModelManager::setRootDirectoryPath(const std::string& configFileFullPath) { FileSystem::setRootDirectoryPath(this->rootDirectoryPath, configFileFullPath); } + +bool ModelManager::servableExists(const std::string& name, ServableType check) const { + if (hasFlag(check, ServableType::Model) && findModelByName(name) != nullptr) { + return true; + } + if (hasFlag(check, ServableType::Pipeline) && pipelineFactory->definitionExists(name)) { + return true; + } +#if (MEDIAPIPE_DISABLE == 0) + if (hasFlag(check, ServableType::Mediapipe) && mediapipeFactory->definitionExists(name)) { + return true; + } +#endif + return false; +} + +const PipelineFactory& ModelManager::getPipelineFactory() const { + return *pipelineFactory; +} + } // namespace ovms diff --git a/src/modelmanager.hpp b/src/modelmanager.hpp index f7e79c76ed..9bfea27572 100644 --- a/src/modelmanager.hpp +++ b/src/modelmanager.hpp @@ -26,23 +26,22 @@ #include #include -#include #pragma warning(push) #pragma warning(disable : 6313) #include #pragma warning(pop) -#include -#include -#include "dags/pipeline_factory.hpp" #include "global_sequences_viewer.hpp" -#if (MEDIAPIPE_DISABLE == 0) -#include "mediapipe_internal/mediapipefactory.hpp" -#endif #include "metric_config.hpp" -#include "model.hpp" +#include "metric_provider.hpp" +#include "modelconfig.hpp" +#include "servable_name_checker.hpp" #include "status.hpp" +namespace ov { +class Core; +} // namespace ov + namespace ovms { const uint32_t DEFAULT_WAIT_FOR_MODEL_LOADED_TIMEOUT_MS = 10000; @@ -54,16 +53,23 @@ struct ModelsSettingsImpl; class CustomLoaderConfig; class CustomNodeLibraryManager; class MetricRegistry; +class Model; class ModelConfig; class FileSystem; +class MediapipeFactory; +class MediapipeGraphConfig; class MediapipeGraphExecutor; +class ModelInstance; +class ModelInstanceUnloadGuard; +class Pipeline; +class PipelineFactory; struct FunctorSequenceCleaner; struct FunctorResourcesCleaner; class PythonBackend; /** * @brief Model manager is managing the list of model topologies enabled for serving and their versions. */ -class ModelManager { +class ModelManager : public ServableNameChecker, public MetricProvider { public: /** * @brief A default constructor is private @@ -84,9 +90,9 @@ class ModelManager { std::map> models; std::unique_ptr ieCore; - PipelineFactory pipelineFactory; + std::unique_ptr pipelineFactory; #if (MEDIAPIPE_DISABLE == 0) - MediapipeFactory mediapipeFactory; + std::unique_ptr mediapipeFactory; #endif std::unique_ptr customNodeLibraryManager; std::vector> resources = {}; @@ -320,13 +326,11 @@ class ModelManager { */ void startCleaner(); - const PipelineFactory& getPipelineFactory() const { - return pipelineFactory; - } + const PipelineFactory& getPipelineFactory() const; #if (MEDIAPIPE_DISABLE == 0) const MediapipeFactory& getMediapipeFactory() const { - return mediapipeFactory; + return *mediapipeFactory; } #endif @@ -363,20 +367,9 @@ class ModelManager { */ const std::shared_ptr findModelInstance(const std::string& name, model_version_t version = 0) const; - template - Status createPipeline(std::unique_ptr& pipeline, - const std::string& name, - const RequestType* request, - ResponseType* response) { - return pipelineFactory.create(pipeline, name, request, response, *this); - } Status createPipeline(std::unique_ptr& graph, const std::string& name); - const bool pipelineDefinitionExists(const std::string& name) const { - return pipelineFactory.definitionExists(name); - } - /** * @brief Starts model manager using provided config file * @@ -397,7 +390,7 @@ class ModelManager { * * @return const std::string& */ - const MetricConfig& getMetricConfig() const { + const MetricConfig& getMetricConfig() const override { return this->metricConfig; } @@ -445,9 +438,7 @@ class ModelManager { * * @return std::shared_ptr */ - virtual std::shared_ptr modelFactory(const std::string& name, const bool isStateful) { - return std::make_shared(name, isStateful, &this->globalSequencesViewer); - } + virtual std::shared_ptr modelFactory(const std::string& name, const bool isStateful); /** * @brief Reads available versions from given filesystem @@ -504,7 +495,8 @@ class ModelManager { */ void cleanupResources(); - MetricRegistry* getMetricRegistry() const { return this->metricRegistry; } + bool servableExists(const std::string& name, ServableType check = ServableType::All) const override; + MetricRegistry* getMetricRegistry() const override { return this->metricRegistry; } }; void cleanerRoutine(uint32_t resourcesCleanupInterval, FunctorResourcesCleaner& functorResourcesCleaner, uint32_t sequenceCleanerInterval, FunctorSequenceCleaner& functorSequenceCleaner, std::future& cleanerExitSignal); diff --git a/src/modelversionstatus.hpp b/src/modelversionstatus.hpp index e46fe4be96..87dded0350 100644 --- a/src/modelversionstatus.hpp +++ b/src/modelversionstatus.hpp @@ -19,8 +19,9 @@ #include #include +#include + #include "modelversion.hpp" -#include "logging.hpp" // note: think about using https://github.com/Neargye/magic_enum when compatible compiler is supported. diff --git a/src/precision_configuration.cpp b/src/precision_configuration.cpp index 58b0dbcd9d..83365f9130 100644 --- a/src/precision_configuration.cpp +++ b/src/precision_configuration.cpp @@ -18,6 +18,8 @@ #include #include +#include "logging.hpp" + namespace ovms { const char PrecisionConfiguration::PRECISION_DELIMITER = ':'; diff --git a/src/predict_request_validation_utils_impl.hpp b/src/predict_request_validation_utils_impl.hpp index 57707ffc7c..14eff89771 100644 --- a/src/predict_request_validation_utils_impl.hpp +++ b/src/predict_request_validation_utils_impl.hpp @@ -22,7 +22,6 @@ #include #include -#include "logging.hpp" #include "shape.hpp" #include "anonymous_input_name.hpp" #include "status.hpp" diff --git a/src/prediction_service.cpp b/src/prediction_service.cpp index 74e23583d1..e211ab1ccc 100644 --- a/src/prediction_service.cpp +++ b/src/prediction_service.cpp @@ -33,6 +33,7 @@ #include "tfs_frontend/tfs_request_utils.hpp" #include "dags/pipeline.hpp" +#include "dags/pipeline_factory.hpp" #include "execution_context.hpp" #include "get_model_metadata_impl.hpp" #include "grpc_utils.hpp" @@ -88,7 +89,7 @@ Status PredictionServiceImpl::getPipeline(const PredictRequest* request, PredictResponse* response, std::unique_ptr& pipelinePtr) { OVMS_PROFILE_FUNCTION(); - return this->modelManager.createPipeline(pipelinePtr, request->model_spec().name(), request, response); + return this->modelManager.getPipelineFactory().create(pipelinePtr, request->model_spec().name(), request, response, this->modelManager); } grpc::Status ovms::PredictionServiceImpl::Predict( diff --git a/src/prediction_service_utils.hpp b/src/prediction_service_utils.hpp index 788141d0a4..21769d08c1 100644 --- a/src/prediction_service_utils.hpp +++ b/src/prediction_service_utils.hpp @@ -27,7 +27,6 @@ #include "kfs_frontend/kfs_grpc_inference_service.hpp" #include "extractchoice.hpp" #include "requesttensorextractor.hpp" -#include "logging.hpp" #include "shape.hpp" #include "status.hpp" diff --git a/src/profilermodule.hpp b/src/profilermodule.hpp index 9927b75ab8..de4eaa86a8 100644 --- a/src/profilermodule.hpp +++ b/src/profilermodule.hpp @@ -17,7 +17,6 @@ #include #include -#include "logging.hpp" #include "module.hpp" namespace ovms { diff --git a/src/python/BUILD b/src/python/BUILD index f4fd4c571e..deed6b7449 100644 --- a/src/python/BUILD +++ b/src/python/BUILD @@ -41,7 +41,7 @@ mediapipe_proto_library( ovms_cc_library( name = "utils", hdrs = ["utils.hpp",], - srcs = [], + srcs = [], deps = PYBIND_DEPS + [ "//src:libovmslogging", ], diff --git a/src/python/utils.hpp b/src/python/utils.hpp index 13708be639..6e8b4f511d 100644 --- a/src/python/utils.hpp +++ b/src/python/utils.hpp @@ -18,6 +18,9 @@ #include #include + +#include "src/logging.hpp" + #pragma warning(push) #pragma warning(disable : 6326 28182 6011 28020) #include @@ -25,7 +28,6 @@ #include #pragma warning(pop) -#include "../logging.hpp" namespace py = pybind11; using namespace py::literals; diff --git a/src/servable_name_checker.hpp b/src/servable_name_checker.hpp new file mode 100644 index 0000000000..373dd940c2 --- /dev/null +++ b/src/servable_name_checker.hpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include + +namespace ovms { + +enum class ServableType : uint8_t { + Model = 1 << 0, + Pipeline = 1 << 1, + Mediapipe = 1 << 2, + All = Model | Pipeline | Mediapipe +}; + +inline ServableType operator|(ServableType a, ServableType b) { + using T = std::underlying_type_t; + return static_cast(static_cast(a) | static_cast(b)); +} + +inline bool hasFlag(ServableType value, ServableType flag) { + using T = std::underlying_type_t; + return (static_cast(value) & static_cast(flag)) != 0; +} + +class ServableNameChecker { +public: + virtual ~ServableNameChecker() = default; + virtual bool servableExists(const std::string& name, ServableType check = ServableType::All) const = 0; +}; + +} // namespace ovms diff --git a/src/status.hpp b/src/status.hpp index 18a2b093b5..b6447a9066 100644 --- a/src/status.hpp +++ b/src/status.hpp @@ -21,7 +21,7 @@ #include #include -#include "logging.hpp" +#include namespace ovms { diff --git a/src/tensor_conversion.hpp b/src/tensor_conversion.hpp index 4b23a734b1..c5eb537062 100644 --- a/src/tensor_conversion.hpp +++ b/src/tensor_conversion.hpp @@ -21,6 +21,7 @@ #include #include "deps/opencv.hpp" +#include "logging.hpp" #include "precision.hpp" #include "predict_request_validation_utils_impl.hpp" #include "profiler.hpp" diff --git a/src/tensor_conversion_common.cpp b/src/tensor_conversion_common.cpp index ad073b1740..328705baa4 100644 --- a/src/tensor_conversion_common.cpp +++ b/src/tensor_conversion_common.cpp @@ -20,6 +20,7 @@ #include #include "deps/opencv.hpp" +#include "logging.hpp" #include "precision.hpp" #include "predict_request_validation_utils_impl.hpp" #include "profiler.hpp" diff --git a/src/test/embeddingsnode_test.cpp b/src/test/embeddingsnode_test.cpp index 92a852a502..6443bba525 100644 --- a/src/test/embeddingsnode_test.cpp +++ b/src/test/embeddingsnode_test.cpp @@ -19,6 +19,7 @@ #include #include "../http_rest_api_handler.hpp" +#include "../mediapipe_internal/mediapipefactory.hpp" #include "../servablemanagermodule.hpp" #include "../server.hpp" #include "rapidjson/document.h" diff --git a/src/test/ensemble_flow_custom_node_tests.cpp b/src/test/ensemble_flow_custom_node_tests.cpp index 01a81d827e..b95241c066 100644 --- a/src/test/ensemble_flow_custom_node_tests.cpp +++ b/src/test/ensemble_flow_custom_node_tests.cpp @@ -45,9 +45,11 @@ #include "../dags/node_library_utils.hpp" #include "../dags/nodestreamidguard.hpp" #include "../dags/pipeline.hpp" +#include "../dags/pipeline_factory.hpp" #include "../dags/pipelinedefinition.hpp" #include "../execution_context.hpp" #include "../metric_registry.hpp" +#include "../model.hpp" #include "../model_metric_reporter.hpp" #include "../modelinstance.hpp" #include "../modelinstanceunloadguard.hpp" @@ -1520,7 +1522,7 @@ TEST_F(EnsembleFlowCustomNodeLoadConfigThenExecuteTest, AddSubCustomNode) { std::unique_ptr pipeline; this->prepareRequest(inputValues); this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); } @@ -1573,17 +1575,17 @@ TEST_F(EnsembleFlowCustomNodeLoadConfigThenExecuteTest, ReferenceMissingLibraryT // Loading correct configuration is required for test to pass. // This is due to fact that when OVMS loads pipeline definition for the first time and fails, its status is RETIRED. this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); response.Clear(); this->loadConfiguration(pipelineCustomNodeReferenceMissingLibraryConfig, StatusCode::PIPELINE_DEFINITION_INVALID_NODE_LIBRARY); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::PIPELINE_DEFINITION_NOT_LOADED_YET); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::PIPELINE_DEFINITION_NOT_LOADED_YET); response.Clear(); this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); } @@ -1634,18 +1636,18 @@ TEST_F(EnsembleFlowCustomNodeLoadConfigThenExecuteTest, ReferenceLibraryWithExec // Loading correct configuration is required for test to pass. // This is due to fact that when OVMS loads pipeline definition for the first time and fails, its status is RETIRED. this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); response.Clear(); this->loadConfiguration(pipelineCustomNodeReferenceLibraryWithExecutionErrorMissingParamsLibraryConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::NODE_LIBRARY_EXECUTION_FAILED); response.Clear(); this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); } @@ -1697,18 +1699,18 @@ TEST_F(EnsembleFlowCustomNodeLoadConfigThenExecuteTest, MissingRequiredNodeParam // Loading correct configuration is required for test to pass. // This is due to fact that when OVMS loads pipeline definition for the first time and fails, its status is RETIRED. this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); response.Clear(); this->loadConfiguration(pipelineCustomNodeMissingParametersConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::NODE_LIBRARY_EXECUTION_FAILED); response.Clear(); this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); } @@ -1761,17 +1763,17 @@ TEST_F(EnsembleFlowCustomNodeLoadConfigThenExecuteTest, ReferenceLibraryWithRest // Loading correct configuration is required for test to pass. // This is due to fact that when OVMS loads pipeline definition for the first time and fails, its status is RETIRED. this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); response.Clear(); this->loadConfiguration(pipelineCustomNodeLibraryNotEscapedPathConfig, StatusCode::PIPELINE_DEFINITION_INVALID_NODE_LIBRARY); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::PIPELINE_DEFINITION_NOT_LOADED_YET); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::PIPELINE_DEFINITION_NOT_LOADED_YET); response.Clear(); this->loadCorrectConfiguration(); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); this->checkResponseForCorrectConfiguration(); } @@ -1920,7 +1922,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, JustDiff this->prepareRequest(request, input, differentOpsInputName); this->prepareRequest(request, factors, differentOpsFactorsName); this->loadConfiguration(pipelineCustomNodeDifferentOperationsConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(4 * DUMMY_MODEL_OUTPUT_SIZE); @@ -2011,7 +2013,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, Differen this->prepareRequest(request, input, differentOpsInputName); this->prepareRequest(request, factors, differentOpsFactorsName); this->loadConfiguration(pipelineCustomNodeDifferentOperationsThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(4 * DUMMY_MODEL_OUTPUT_SIZE); prepareDifferentOpsExpectedOutput(expectedOutput, input, factors); @@ -2093,7 +2095,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, Differen this->prepareRequest(request, input, differentOpsInputName); this->prepareRequest(request, factors, differentOpsFactorsName); this->loadConfiguration(pipelineCustomNodeDifferentOperations2OutputsConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(4 * DUMMY_MODEL_OUTPUT_SIZE); @@ -2207,7 +2209,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, Differen this->prepareRequest(request, input, differentOpsInputName); this->prepareRequest(request, factors, differentOpsFactorsName); this->loadConfiguration(pipelineCustomNodeDifferentOperationsThenDummyThenChooseMaximumConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(4 * DUMMY_MODEL_OUTPUT_SIZE); @@ -2322,7 +2324,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, Differen this->prepareRequest(request, input, differentOpsInputName); this->prepareRequest(request, factors, differentOpsFactorsName); this->loadConfiguration(pipelineCustomNodeDifferentOperationsThenDummyThenChooseMaximumThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(4 * DUMMY_MODEL_OUTPUT_SIZE); @@ -2415,7 +2417,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, Demultip this->prepareRequest(request, input, differentOpsInputName, {4, 1, 10}); this->loadConfiguration(demultiplyThenDummyThenChooseMaximumConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); auto status = pipeline->execute(DEFAULT_TEST_CONTEXT); ASSERT_EQ(status, StatusCode::OK) << status.string(); @@ -4474,7 +4476,7 @@ TEST_F(EnsembleFlowCustomNodeAndDynamicDemultiplexerLoadConfigThenExecuteTest, J std::vector input{static_cast(dynamicDemultiplyCount), 1, 2, 3, 4, 5, 6, 7, 8, 9}; this->prepareRequest(request, input, differentOpsInputName); this->loadConfiguration(pipelineCustomNodeDynamicDemultiplexThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(dynamicDemultiplyCount * DUMMY_MODEL_OUTPUT_SIZE); @@ -4667,7 +4669,7 @@ TEST_F(EnsembleFlowCustomNodeAndDynamicDemultiplexerLoadConfigThenExecuteTest, D std::iota(input.begin(), input.end(), 42); this->prepareRequest(request, input, differentOpsInputName, {dynamicDemultiplyCount, 1, 10}); this->loadConfiguration(pipelineEntryNodeDynamicDemultiplexThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput = input; @@ -4741,7 +4743,7 @@ TEST_F(EnsembleFlowCustomNodeAndDynamicDemultiplexerLoadConfigThenExecuteTest, D std::iota(input.begin(), input.end(), 42); this->prepareRequest(request, input, pipelineInputName, {3, 5, DUMMY_MODEL_INPUT_SIZE}); this->loadConfiguration(pipelineEntryNodeDemultiplexThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput = input; @@ -4774,7 +4776,7 @@ TEST_F(EnsembleFlowCustomNodeAndDynamicDemultiplexerLoadConfigThenExecuteTest, D std::vector input{static_cast(dynamicDemultiplyCount), 1, 2, 3, 4, 5, 6, 7, 8, 9}; this->prepareRequest(request, input, differentOpsInputName); this->loadConfiguration(pipelineCustomNodeDynamicDemultiplexThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); auto status = pipeline->execute(DEFAULT_TEST_CONTEXT); ASSERT_EQ(status, StatusCode::PIPELINE_TOO_LARGE_DIMENSION_SIZE_TO_DEMULTIPLY) << status.string(); } @@ -4870,7 +4872,7 @@ TEST_F(EnsembleFlowCustomNodeAndDemultiplexerLoadConfigThenExecuteTest, Differen this->prepareRequest(request, input, differentOpsInputName); this->prepareRequest(request, factors, differentOpsFactorsName); this->loadConfiguration(pipelineCustomNodeDifferentOperationsThenDummyThenChooseMaximumNotInOrderConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(4 * DUMMY_MODEL_OUTPUT_SIZE); @@ -4887,7 +4889,7 @@ TEST_F(EnsembleFlowCustomNodeAndDynamicDemultiplexerLoadConfigThenExecuteTest, D std::vector input{static_cast(dynamicDemultiplyCount), 1, 2, 3, 4, 5, 6, 7, 8, 9}; this->prepareRequest(request, input, differentOpsInputName); this->loadConfiguration(pipelineCustomNodeDynamicDemultiplexThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::PIPELINE_DEMULTIPLEXER_NO_RESULTS); } @@ -4897,7 +4899,7 @@ TEST_F(EnsembleFlowCustomNodeAndDynamicDemultiplexerLoadConfigThenExecuteTest, D std::vector input{static_cast(dynamicDemultiplyCount), 1, 2, 3, 4, 5, 6, 7, 8, 9}; this->prepareRequest(request, input, differentOpsInputName); this->loadConfiguration(pipelineCustomNodeDynamicDemultiplexThenDummyConfig); - ASSERT_EQ(manager.createPipeline(pipeline, pipelineName, &request, &response), StatusCode::OK); + ASSERT_EQ(manager.getPipelineFactory().create(pipeline, pipelineName, &request, &response, manager), StatusCode::OK); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); std::vector expectedOutput(dynamicDemultiplyCount * DUMMY_MODEL_OUTPUT_SIZE); diff --git a/src/test/ensemble_mapping_config_tests.cpp b/src/test/ensemble_mapping_config_tests.cpp index 23e6a84a4c..5c34f11f0c 100644 --- a/src/test/ensemble_mapping_config_tests.cpp +++ b/src/test/ensemble_mapping_config_tests.cpp @@ -26,6 +26,7 @@ #include "../execution_context.hpp" #include "../get_model_metadata_impl.hpp" #include "../kfs_frontend/kfs_grpc_inference_service.hpp" +#include "../model.hpp" #include "../model_metric_reporter.hpp" #include "../modelconfig.hpp" #include "../modelmanager.hpp" diff --git a/src/test/ensemble_tests.cpp b/src/test/ensemble_tests.cpp index 1eb83595d4..c0d4794b8c 100644 --- a/src/test/ensemble_tests.cpp +++ b/src/test/ensemble_tests.cpp @@ -39,7 +39,9 @@ #include "../inference_executor.hpp" #include "../localfilesystem.hpp" #include "../logging.hpp" +#include "../mediapipe_internal/mediapipefactory.hpp" #include "../metric_registry.hpp" +#include "../model.hpp" #include "../model_metric_reporter.hpp" #include "../modelconfig.hpp" #include "../modelinstance.hpp" @@ -225,10 +227,11 @@ class EnsembleFlowTest : public TestWithTempDir { ConstructorEnabledModelManager managerWithDummyModel; managerWithDummyModel.loadConfig(fileToReload); std::unique_ptr pipeline; - auto status = managerWithDummyModel.createPipeline(pipeline, + auto status = managerWithDummyModel.getPipelineFactory().create(pipeline, "pipeline1Dummy", &request, - &response); + &response, + managerWithDummyModel); ASSERT_EQ(status, ovms::StatusCode::PIPELINE_DEFINITION_NAME_MISSING) << status.string(); } @@ -2886,10 +2889,11 @@ TEST_F(EnsembleFlowTest, PipelineFactoryCreationWithInputOutputsMappings) { ConstructorEnabledModelManager managerWithDummyModel; managerWithDummyModel.loadConfig(fileToReload); std::unique_ptr pipeline; - auto status = managerWithDummyModel.createPipeline(pipeline, + auto status = managerWithDummyModel.getPipelineFactory().create(pipeline, "pipeline1Dummy", &request, - &response); + &response, + managerWithDummyModel); ASSERT_EQ(status, ovms::StatusCode::OK) << status.string(); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); const int dummySeriallyConnectedCount = 1; @@ -2959,10 +2963,11 @@ TEST_F(EnsembleFlowTest, PipelineFactoryCreationWithInputOutputsMappings2Paralle ConstructorEnabledModelManager managerWithDummyModel; managerWithDummyModel.loadConfig(fileToReload); std::unique_ptr pipeline; - auto status = managerWithDummyModel.createPipeline(pipeline, + auto status = managerWithDummyModel.getPipelineFactory().create(pipeline, "pipeline1Dummy", &request, - &response); + &response, + managerWithDummyModel); ASSERT_EQ(status, ovms::StatusCode::OK) << status.string(); ASSERT_EQ(pipeline->execute(DEFAULT_TEST_CONTEXT), StatusCode::OK); ASSERT_EQ(response.outputs().count(customPipelineOutputName), 1); @@ -4362,7 +4367,7 @@ TEST_F(EnsembleFlowTest, MediapipeConfigModelWithSameNamePipeline) { ASSERT_FALSE(manager.getMediapipeFactory().definitionExists(MEDIAPIPE_DUMMY_NAME)); - ASSERT_TRUE(manager.pipelineDefinitionExists(MEDIAPIPE_DUMMY_NAME)); + ASSERT_TRUE(manager.servableExists(MEDIAPIPE_DUMMY_NAME, ServableType::Pipeline)); } #endif TEST_F(EnsembleFlowTest, PipelineConfigModelWithSameName) { diff --git a/src/test/llm/assisted_decoding_test.cpp b/src/test/llm/assisted_decoding_test.cpp index e265cd6b68..9c870a1f63 100644 --- a/src/test/llm/assisted_decoding_test.cpp +++ b/src/test/llm/assisted_decoding_test.cpp @@ -22,6 +22,8 @@ #include #include +#include + #include #include #include diff --git a/src/test/llm/llmnode_test.cpp b/src/test/llm/llmnode_test.cpp index 55f87fa720..0cbcfa2f2d 100644 --- a/src/test/llm/llmnode_test.cpp +++ b/src/test/llm/llmnode_test.cpp @@ -22,6 +22,8 @@ #include #include +#include + #include #include #include diff --git a/src/test/mediapipeflow_test.cpp b/src/test/mediapipeflow_test.cpp index 55b6ab96ed..dd3358858d 100644 --- a/src/test/mediapipeflow_test.cpp +++ b/src/test/mediapipeflow_test.cpp @@ -36,6 +36,7 @@ #pragma GCC diagnostic pop #include "../config.hpp" +#include "../dags/pipeline_factory.hpp" #include "../dags/pipelinedefinition.hpp" #include "../grpcservermodule.hpp" #include "../http_rest_api_handler.hpp" @@ -47,6 +48,7 @@ #include "../mediapipe_internal/mediapipegraphexecutor.hpp" #include "../metric_config.hpp" #include "../metric_module.hpp" +#include "../model.hpp" #include "../model_service.hpp" #include "../ovms_exit_codes.hpp" #include "../precision.hpp" @@ -1487,8 +1489,6 @@ TEST_F(MediapipeStreamFlowAddTest, Infer) { TEST_F(MediapipeStreamFlowAddTest, InferOnUnloadedGraph) { const ovms::Module* grpcModule = server.getModule(ovms::GRPC_SERVER_MODULE_NAME); KFSInferenceServiceImpl& impl = dynamic_cast(grpcModule)->getKFSGrpcImpl(); - const ServableManagerModule* smm = dynamic_cast(server.getModule(SERVABLE_MANAGER_MODULE_NAME)); - ModelManager& modelManager = smm->getServableManager(); auto* definition = this->getMPDefinitionByName(this->modelName); ASSERT_NE(definition, nullptr); @@ -1532,10 +1532,10 @@ TEST_F(MediapipeStreamFlowAddTest, InferOnUnloadedGraph) { checkAddResponse("out", this->requestData1[2], this->requestData1[2], this->request[2], msg.infer_response(), 1, 1, this->modelName); return true; }); - std::thread unloader([&startUnloading, &finishedUnloading, &definition, &modelManager]() { + std::thread unloader([&startUnloading, &finishedUnloading, &definition]() { // Wait till first response notifies that we should start unloading startUnloading.get_future().get(); - definition->retire(modelManager); + definition->retire(); // Notify second request to arrive because we unloaded the graph finishedUnloading.set_value(); }); @@ -1654,11 +1654,9 @@ TEST_F(MediapipeStreamFlowAddTest, InferOnReloadedGraph) { TEST_F(MediapipeStreamFlowAddTest, NegativeShouldNotReachInferDueToRetiredGraph) { const ovms::Module* grpcModule = server.getModule(ovms::GRPC_SERVER_MODULE_NAME); KFSInferenceServiceImpl& impl = dynamic_cast(grpcModule)->getKFSGrpcImpl(); - const ServableManagerModule* smm = dynamic_cast(server.getModule(SERVABLE_MANAGER_MODULE_NAME)); - ModelManager& modelManager = smm->getServableManager(); auto* definition = this->getMPDefinitionByName(this->modelName); ASSERT_NE(definition, nullptr); - definition->retire(modelManager); + definition->retire(); // Opening new stream, expect graph to be unavailable MockedServerReaderWriter<::inference::ModelStreamInferResponse, ::inference::ModelInferRequest> stream; diff --git a/src/test/streaming_test.cpp b/src/test/streaming_test.cpp index 02e7c4178a..5af193b6c3 100644 --- a/src/test/streaming_test.cpp +++ b/src/test/streaming_test.cpp @@ -29,6 +29,7 @@ #include "../status.hpp" #include "../stringutils.hpp" #include "mediapipe/framework/port/integral_types.h" +#include "../mediapipe_internal/mediapipefactory.hpp" #include "constructor_enabled_model_manager.hpp" #include "platform_utils.hpp" #include "test_utils.hpp" diff --git a/src/test/stress_test_utils.hpp b/src/test/stress_test_utils.hpp index ccbdd60758..0cf85e1f2f 100644 --- a/src/test/stress_test_utils.hpp +++ b/src/test/stress_test_utils.hpp @@ -1726,7 +1726,7 @@ class ConfigChangeStressTest : public TestWithTempDir { RequestType request = preparePipelinePredictRequest(request2); ovms::Status createPipelineStatus = StatusCode::UNKNOWN_ERROR; if (typeid(ServableType) == typeid(ovms::Pipeline)) { - createPipelineStatus = this->manager->createPipeline(pipelinePtr, pipelineName, &request, &response); + createPipelineStatus = this->manager->getPipelineFactory().create(pipelinePtr, pipelineName, &request, &response, *(this->manager)); #if (MEDIAPIPE_DISABLE == 0) } else if (typeid(ServableType) == typeid(ovms::MediapipeGraphExecutor)) { mediacreate(executorPtr, *(this->manager), request, response, createPipelineStatus); diff --git a/src/tfs_frontend/serialization.hpp b/src/tfs_frontend/serialization.hpp index f6925206d1..1d227bc976 100644 --- a/src/tfs_frontend/serialization.hpp +++ b/src/tfs_frontend/serialization.hpp @@ -29,7 +29,6 @@ #pragma GCC diagnostic pop #include "../profiler.hpp" -#include "../logging.hpp" #include "../status.hpp" #include "../serialization_common.hpp" #include "../tensorinfo.hpp" diff --git a/src/tfs_frontend/tfs_request_utils.hpp b/src/tfs_frontend/tfs_request_utils.hpp index d040e7a0a9..e6f55d451d 100644 --- a/src/tfs_frontend/tfs_request_utils.hpp +++ b/src/tfs_frontend/tfs_request_utils.hpp @@ -26,7 +26,6 @@ #include "../extractchoice.hpp" #include "../requesttensorextractor.hpp" #include "../statefulrequestprocessor.hpp" -#include "../logging.hpp" #include "../profiler.hpp" #include "../shape.hpp" #include "../status.hpp"