diff --git a/agentrun/knowledgebase/__init__.py b/agentrun/knowledgebase/__init__.py index 6a01376..fd314f1 100644 --- a/agentrun/knowledgebase/__init__.py +++ b/agentrun/knowledgebase/__init__.py @@ -6,6 +6,7 @@ get_data_api, KnowledgeBaseControlAPI, KnowledgeBaseDataAPI, + OTSDataAPI, RagFlowDataAPI, ) from .client import KnowledgeBaseClient @@ -20,6 +21,16 @@ KnowledgeBaseListOutput, KnowledgeBaseProvider, KnowledgeBaseUpdateInput, + OTSDenseVectorSearchConfig, + OTSEmbeddingConfiguration, + OTSFullTextSearchConfig, + OTSMetadataField, + OTSModelConfig, + OTSProviderSettings, + OTSRerankingConfig, + OTSRetrieveSettings, + OTSRRFConfig, + OTSWeightConfig, ProviderSettings, RagFlowProviderSettings, RagFlowRetrieveSettings, @@ -37,6 +48,7 @@ "RagFlowDataAPI", "BailianDataAPI", "ADBDataAPI", + "OTSDataAPI", "get_data_api", # enums "KnowledgeBaseProvider", @@ -45,11 +57,21 @@ "RagFlowProviderSettings", "BailianProviderSettings", "ADBProviderSettings", + "OTSProviderSettings", + "OTSMetadataField", + "OTSEmbeddingConfiguration", # retrieve settings "RetrieveSettings", "RagFlowRetrieveSettings", "BailianRetrieveSettings", "ADBRetrieveSettings", + "OTSRetrieveSettings", + "OTSDenseVectorSearchConfig", + "OTSFullTextSearchConfig", + "OTSRerankingConfig", + "OTSRRFConfig", + "OTSWeightConfig", + "OTSModelConfig", # api model "KnowledgeBaseCreateInput", "KnowledgeBaseUpdateInput", diff --git a/agentrun/knowledgebase/__knowledgebase_async_template.py b/agentrun/knowledgebase/__knowledgebase_async_template.py index 501449e..d0042e2 100644 --- a/agentrun/knowledgebase/__knowledgebase_async_template.py +++ b/agentrun/knowledgebase/__knowledgebase_async_template.py @@ -26,6 +26,16 @@ KnowledgeBaseProvider, KnowledgeBaseSystemProps, KnowledgeBaseUpdateInput, + OTSDenseVectorSearchConfig, + OTSEmbeddingConfiguration, + OTSFullTextSearchConfig, + OTSMetadataField, + OTSModelConfig, + OTSProviderSettings, + OTSRerankingConfig, + OTSRetrieveSettings, + OTSRRFConfig, + OTSWeightConfig, RagFlowProviderSettings, RagFlowRetrieveSettings, RetrieveInput, @@ -302,8 +312,6 @@ def _get_data_api(self, config: Optional[Config] = None): if isinstance(self.provider_settings, ADBProviderSettings): converted_provider_settings = self.provider_settings elif isinstance(self.provider_settings, dict): - # ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case - # ADB provider_settings uses PascalCase keys, need to convert to snake_case converted_provider_settings = ADBProviderSettings( db_instance_id=self.provider_settings.get( "DBInstanceId", "" @@ -323,8 +331,6 @@ def _get_data_api(self, config: Optional[Config] = None): if isinstance(self.retrieve_settings, ADBRetrieveSettings): converted_retrieve_settings = self.retrieve_settings elif isinstance(self.retrieve_settings, dict): - # ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case - # ADB retrieve_settings uses PascalCase keys, need to convert to snake_case converted_retrieve_settings = ADBRetrieveSettings( top_k=self.retrieve_settings.get("TopK"), use_full_text_retrieval=self.retrieve_settings.get( @@ -344,6 +350,115 @@ def _get_data_api(self, config: Optional[Config] = None): ), ) + elif provider == KnowledgeBaseProvider.OTS: + # OTS 设置 / OTS settings (camelCase → snake_case) + if self.provider_settings: + if isinstance(self.provider_settings, OTSProviderSettings): + converted_provider_settings = self.provider_settings + elif isinstance(self.provider_settings, dict): + ps = self.provider_settings + + metadata = None + raw_metadata = ps.get("metadata") + if raw_metadata and isinstance(raw_metadata, list): + metadata = [ + OTSMetadataField( + name=m.get("name", ""), + type=m.get("type", ""), + ) + for m in raw_metadata + ] + + embedding_config = None + raw_ec = ps.get("embeddingConfiguration") + if raw_ec and isinstance(raw_ec, dict): + embedding_config = OTSEmbeddingConfiguration( + provider=raw_ec.get("provider", ""), + model=raw_ec.get("model", ""), + dimension=raw_ec.get("dimension", 0), + url=raw_ec.get("url"), + api_key=raw_ec.get("apiKey"), + ) + + converted_provider_settings = OTSProviderSettings( + ots_instance_name=ps.get("otsInstanceName", ""), + tags=ps.get("tags"), + metadata=metadata, + embedding_configuration=embedding_config, + ) + + if self.retrieve_settings: + if isinstance(self.retrieve_settings, OTSRetrieveSettings): + converted_retrieve_settings = self.retrieve_settings + elif isinstance(self.retrieve_settings, dict): + rs = self.retrieve_settings + + dvsc = None + raw_dvsc = rs.get("denseVectorSearchConfiguration") + if raw_dvsc and isinstance(raw_dvsc, dict): + dvsc = OTSDenseVectorSearchConfig( + number_of_results=raw_dvsc.get("numberOfResults"), + ) + + ftsc = None + raw_ftsc = rs.get("fullTextSearchConfiguration") + if raw_ftsc and isinstance(raw_ftsc, dict): + ftsc = OTSFullTextSearchConfig( + number_of_results=raw_ftsc.get("numberOfResults"), + ) + + reranking = None + raw_rr = rs.get("rerankingConfiguration") + if raw_rr and isinstance(raw_rr, dict): + rrf_config = None + raw_rrf = raw_rr.get("rrfConfiguration") + if raw_rrf and isinstance(raw_rrf, dict): + rrf_config = OTSRRFConfig( + dense_vector_search_weight=raw_rrf.get( + "denseVectorSearchWeight" + ), + full_text_search_weight=raw_rrf.get( + "fullTextSearchWeight" + ), + k=raw_rrf.get("k"), + ) + + weight_config = None + raw_wc = raw_rr.get("weightConfiguration") + if raw_wc and isinstance(raw_wc, dict): + weight_config = OTSWeightConfig( + dense_vector_search_weight=raw_wc.get( + "denseVectorSearchWeight" + ), + full_text_search_weight=raw_wc.get( + "fullTextSearchWeight" + ), + ) + + model_config = None + raw_mc = raw_rr.get("modelConfiguration") + if raw_mc and isinstance(raw_mc, dict): + model_config = OTSModelConfig( + provider=raw_mc.get("provider"), + model=raw_mc.get("model"), + ) + + reranking = OTSRerankingConfig( + type=raw_rr.get("type"), + number_of_results=raw_rr.get("numberOfResults"), + rrf_configuration=rrf_config, + weight_configuration=weight_config, + model_configuration=model_config, + ) + + converted_retrieve_settings = OTSRetrieveSettings( + search_type=rs.get("searchType"), + dense_vector_search_configuration=dvsc, + full_text_search_configuration=ftsc, + reranking_configuration=reranking, + filter=rs.get("filter"), + ) + return get_data_api( provider=provider, knowledge_base_name=self.knowledge_base_name or "", diff --git a/agentrun/knowledgebase/api/__data_async_template.py b/agentrun/knowledgebase/api/__data_async_template.py index de99b50..fce66a9 100644 --- a/agentrun/knowledgebase/api/__data_async_template.py +++ b/agentrun/knowledgebase/api/__data_async_template.py @@ -13,6 +13,7 @@ from alibabacloud_bailian20231229 import models as bailian_models from alibabacloud_gpdb20160503 import models as gpdb_models import httpx +from tablestore_agent_storage import AgentStorageClient from agentrun.utils.config import Config from agentrun.utils.control_api import ControlAPI @@ -25,6 +26,8 @@ BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseProvider, + OTSProviderSettings, + OTSRetrieveSettings, RagFlowProviderSettings, RagFlowRetrieveSettings, ) @@ -545,6 +548,227 @@ async def retrieve_async( } +class OTSDataAPI(KnowledgeBaseDataAPI): + """OTS (TableStore) 知识库数据链路 API / OTS KnowledgeBase Data API + + 实现 OTS 知识库的检索逻辑,通过 tablestore-agent-storage 包调用 retrieve 接口。 + Implements retrieval logic for OTS knowledge base via tablestore-agent-storage retrieve API. + """ + + def __init__( + self, + knowledge_base_name: str, + config: Optional[Config] = None, + provider_settings: Optional[OTSProviderSettings] = None, + retrieve_settings: Optional[OTSRetrieveSettings] = None, + ): + """初始化 OTS 知识库数据链路 API / Initialize OTS KnowledgeBase Data API + + Args: + knowledge_base_name: 知识库名称 / Knowledge base name + config: 配置 / Configuration + provider_settings: OTS 提供商设置 / OTS provider settings + retrieve_settings: OTS 检索设置 / OTS retrieve settings + """ + super().__init__(knowledge_base_name, config) + self.provider_settings = provider_settings + self.retrieve_settings = retrieve_settings + + def _build_agent_storage_client( + self, config: Optional[Config] = None + ) -> AgentStorageClient: + """构建 AgentStorageClient / Build AgentStorageClient + + Args: + config: 配置 / Configuration + + Returns: + AgentStorageClient: OTS 存储客户端 + """ + if self.provider_settings is None: + raise ValueError("provider_settings is required for OTS retrieval") + + cfg = Config.with_configs(self.config, config) + region_id = cfg.get_region_id() + ots_endpoint = f"http://ots-{region_id}.aliyuncs.com" + + return AgentStorageClient( + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + ots_endpoint=ots_endpoint, + ots_instance_name=self.provider_settings.ots_instance_name, + ) + + def _build_retrieval_configuration(self) -> Optional[Dict[str, Any]]: + """将 OTSRetrieveSettings 转换为 tablestore-agent-storage 的 dict 格式 + Convert OTSRetrieveSettings to tablestore-agent-storage dict format + + Returns: + Optional[Dict[str, Any]]: 检索配置字典 / Retrieval configuration dict + """ + if self.retrieve_settings is None: + return None + + config: Dict[str, Any] = {} + + if self.retrieve_settings.search_type is not None: + config["searchType"] = self.retrieve_settings.search_type + + if self.retrieve_settings.dense_vector_search_configuration is not None: + dvsc = self.retrieve_settings.dense_vector_search_configuration + dv_config: Dict[str, Any] = {} + if dvsc.number_of_results is not None: + dv_config["numberOfResults"] = dvsc.number_of_results + config["denseVectorSearchConfiguration"] = dv_config + + if self.retrieve_settings.full_text_search_configuration is not None: + ftsc = self.retrieve_settings.full_text_search_configuration + ft_config: Dict[str, Any] = {} + if ftsc.number_of_results is not None: + ft_config["numberOfResults"] = ftsc.number_of_results + config["fullTextSearchConfiguration"] = ft_config + + if self.retrieve_settings.reranking_configuration is not None: + rc = self.retrieve_settings.reranking_configuration + rr_config: Dict[str, Any] = {} + + if rc.type is not None: + rr_config["type"] = rc.type + if rc.number_of_results is not None: + rr_config["numberOfResults"] = rc.number_of_results + + if rc.rrf_configuration is not None: + rrf: Dict[str, Any] = {} + if rc.rrf_configuration.dense_vector_search_weight is not None: + rrf["denseVectorSearchWeight"] = ( + rc.rrf_configuration.dense_vector_search_weight + ) + if rc.rrf_configuration.full_text_search_weight is not None: + rrf["fullTextSearchWeight"] = ( + rc.rrf_configuration.full_text_search_weight + ) + if rc.rrf_configuration.k is not None: + rrf["k"] = rc.rrf_configuration.k + rr_config["rrfConfiguration"] = rrf + + if rc.weight_configuration is not None: + wc: Dict[str, Any] = {} + if ( + rc.weight_configuration.dense_vector_search_weight + is not None + ): + wc["denseVectorSearchWeight"] = ( + rc.weight_configuration.dense_vector_search_weight + ) + if rc.weight_configuration.full_text_search_weight is not None: + wc["fullTextSearchWeight"] = ( + rc.weight_configuration.full_text_search_weight + ) + rr_config["weightConfiguration"] = wc + + if rc.model_configuration is not None: + mc: Dict[str, Any] = {} + if rc.model_configuration.provider is not None: + mc["provider"] = rc.model_configuration.provider + if rc.model_configuration.model is not None: + mc["model"] = rc.model_configuration.model + rr_config["modelConfiguration"] = mc + + config["rerankingConfiguration"] = rr_config + + if self.retrieve_settings.filter is not None: + config["filter"] = self.retrieve_settings.filter + + return config if config else None + + def _parse_retrieve_response( + self, response: Dict[str, Any], query: str + ) -> Dict[str, Any]: + """解析 OTS 检索响应 / Parse OTS retrieve response + + Args: + response: AgentStorageClient.retrieve 的响应 / Response from retrieve + query: 原始查询文本 / Original query text + + Returns: + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results + """ + all_results: List[Dict[str, Any]] = [] + + data = response.get("data", {}) + retrieval_results = data.get("retrievalResults", []) + + for item in retrieval_results: + all_results.append({ + "content": item.get("content"), + "score": item.get("score"), + "doc_id": item.get("docId"), + "chunk_id": item.get("chunkId"), + "subspace": item.get("subspace"), + "oss_key": item.get("ossKey"), + "metadata": item.get("metadata"), + }) + + return { + "data": all_results, + "query": query, + "knowledge_base_name": self.knowledge_base_name, + } + + async def retrieve_async( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """OTS 检索(异步)/ OTS retrieval asynchronously + + 通过 tablestore-agent-storage 调用 retrieve 接口进行知识库检索。 + Retrieves from OTS knowledge base via tablestore-agent-storage retrieve API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for OTS retrieval" + ) + + client = self._build_agent_storage_client(config) + + retrieval_config = self._build_retrieval_configuration() + + request: Dict[str, Any] = { + "knowledgeBaseName": self.knowledge_base_name, + "retrievalQuery": {"text": query, "type": "TEXT"}, + } + + if retrieval_config: + request["retrievalConfiguration"] = retrieval_config + + logger.debug(f"OTS retrieve request: {request}") + response = client.retrieve(request) + logger.debug(f"OTS retrieve response: {response}") + + return self._parse_retrieve_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from OTS knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, @@ -554,6 +778,7 @@ def get_data_api( RagFlowProviderSettings, BailianProviderSettings, ADBProviderSettings, + OTSProviderSettings, ] ] = None, retrieve_settings: Optional[ @@ -561,6 +786,7 @@ def get_data_api( RagFlowRetrieveSettings, BailianRetrieveSettings, ADBRetrieveSettings, + OTSRetrieveSettings, ] ] = None, credential_name: Optional[str] = None, @@ -633,5 +859,22 @@ def get_data_api( provider_settings=adb_provider_settings, retrieve_settings=adb_retrieve_settings, ) + elif provider == KnowledgeBaseProvider.OTS or provider == "ots": + ots_provider_settings = ( + provider_settings + if isinstance(provider_settings, OTSProviderSettings) + else None + ) + ots_retrieve_settings = ( + retrieve_settings + if isinstance(retrieve_settings, OTSRetrieveSettings) + else None + ) + return OTSDataAPI( + knowledge_base_name, + config, + provider_settings=ots_provider_settings, + retrieve_settings=ots_retrieve_settings, + ) else: raise ValueError(f"Unsupported provider type: {provider}") diff --git a/agentrun/knowledgebase/api/__init__.py b/agentrun/knowledgebase/api/__init__.py index bcfc80c..3cb38be 100644 --- a/agentrun/knowledgebase/api/__init__.py +++ b/agentrun/knowledgebase/api/__init__.py @@ -6,6 +6,7 @@ BailianDataAPI, get_data_api, KnowledgeBaseDataAPI, + OTSDataAPI, RagFlowDataAPI, ) @@ -17,5 +18,6 @@ "RagFlowDataAPI", "BailianDataAPI", "ADBDataAPI", + "OTSDataAPI", "get_data_api", ] diff --git a/agentrun/knowledgebase/api/data.py b/agentrun/knowledgebase/api/data.py index d5d1a58..2dacece 100644 --- a/agentrun/knowledgebase/api/data.py +++ b/agentrun/knowledgebase/api/data.py @@ -23,6 +23,7 @@ from alibabacloud_bailian20231229 import models as bailian_models from alibabacloud_gpdb20160503 import models as gpdb_models import httpx +from tablestore_agent_storage import AgentStorageClient from agentrun.utils.config import Config from agentrun.utils.control_api import ControlAPI @@ -35,6 +36,8 @@ BailianProviderSettings, BailianRetrieveSettings, KnowledgeBaseProvider, + OTSProviderSettings, + OTSRetrieveSettings, RagFlowProviderSettings, RagFlowRetrieveSettings, ) @@ -804,6 +807,280 @@ def retrieve( } +class OTSDataAPI(KnowledgeBaseDataAPI): + """OTS (TableStore) 知识库数据链路 API / OTS KnowledgeBase Data API + + 实现 OTS 知识库的检索逻辑,通过 tablestore-agent-storage 包调用 retrieve 接口。 + Implements retrieval logic for OTS knowledge base via tablestore-agent-storage retrieve API. + """ + + def __init__( + self, + knowledge_base_name: str, + config: Optional[Config] = None, + provider_settings: Optional[OTSProviderSettings] = None, + retrieve_settings: Optional[OTSRetrieveSettings] = None, + ): + """初始化 OTS 知识库数据链路 API / Initialize OTS KnowledgeBase Data API + + Args: + knowledge_base_name: 知识库名称 / Knowledge base name + config: 配置 / Configuration + provider_settings: OTS 提供商设置 / OTS provider settings + retrieve_settings: OTS 检索设置 / OTS retrieve settings + """ + super().__init__(knowledge_base_name, config) + self.provider_settings = provider_settings + self.retrieve_settings = retrieve_settings + + def _build_agent_storage_client( + self, config: Optional[Config] = None + ) -> AgentStorageClient: + """构建 AgentStorageClient / Build AgentStorageClient + + Args: + config: 配置 / Configuration + + Returns: + AgentStorageClient: OTS 存储客户端 + """ + if self.provider_settings is None: + raise ValueError("provider_settings is required for OTS retrieval") + + cfg = Config.with_configs(self.config, config) + region_id = cfg.get_region_id() + ots_endpoint = f"http://ots-{region_id}.aliyuncs.com" + + return AgentStorageClient( + access_key_id=cfg.get_access_key_id(), + access_key_secret=cfg.get_access_key_secret(), + ots_endpoint=ots_endpoint, + ots_instance_name=self.provider_settings.ots_instance_name, + ) + + def _build_retrieval_configuration(self) -> Optional[Dict[str, Any]]: + """将 OTSRetrieveSettings 转换为 tablestore-agent-storage 的 dict 格式 + Convert OTSRetrieveSettings to tablestore-agent-storage dict format + + Returns: + Optional[Dict[str, Any]]: 检索配置字典 / Retrieval configuration dict + """ + if self.retrieve_settings is None: + return None + + config: Dict[str, Any] = {} + + if self.retrieve_settings.search_type is not None: + config["searchType"] = self.retrieve_settings.search_type + + if self.retrieve_settings.dense_vector_search_configuration is not None: + dvsc = self.retrieve_settings.dense_vector_search_configuration + dv_config: Dict[str, Any] = {} + if dvsc.number_of_results is not None: + dv_config["numberOfResults"] = dvsc.number_of_results + config["denseVectorSearchConfiguration"] = dv_config + + if self.retrieve_settings.full_text_search_configuration is not None: + ftsc = self.retrieve_settings.full_text_search_configuration + ft_config: Dict[str, Any] = {} + if ftsc.number_of_results is not None: + ft_config["numberOfResults"] = ftsc.number_of_results + config["fullTextSearchConfiguration"] = ft_config + + if self.retrieve_settings.reranking_configuration is not None: + rc = self.retrieve_settings.reranking_configuration + rr_config: Dict[str, Any] = {} + + if rc.type is not None: + rr_config["type"] = rc.type + if rc.number_of_results is not None: + rr_config["numberOfResults"] = rc.number_of_results + + if rc.rrf_configuration is not None: + rrf: Dict[str, Any] = {} + if rc.rrf_configuration.dense_vector_search_weight is not None: + rrf["denseVectorSearchWeight"] = ( + rc.rrf_configuration.dense_vector_search_weight + ) + if rc.rrf_configuration.full_text_search_weight is not None: + rrf["fullTextSearchWeight"] = ( + rc.rrf_configuration.full_text_search_weight + ) + if rc.rrf_configuration.k is not None: + rrf["k"] = rc.rrf_configuration.k + rr_config["rrfConfiguration"] = rrf + + if rc.weight_configuration is not None: + wc: Dict[str, Any] = {} + if ( + rc.weight_configuration.dense_vector_search_weight + is not None + ): + wc["denseVectorSearchWeight"] = ( + rc.weight_configuration.dense_vector_search_weight + ) + if rc.weight_configuration.full_text_search_weight is not None: + wc["fullTextSearchWeight"] = ( + rc.weight_configuration.full_text_search_weight + ) + rr_config["weightConfiguration"] = wc + + if rc.model_configuration is not None: + mc: Dict[str, Any] = {} + if rc.model_configuration.provider is not None: + mc["provider"] = rc.model_configuration.provider + if rc.model_configuration.model is not None: + mc["model"] = rc.model_configuration.model + rr_config["modelConfiguration"] = mc + + config["rerankingConfiguration"] = rr_config + + if self.retrieve_settings.filter is not None: + config["filter"] = self.retrieve_settings.filter + + return config if config else None + + def _parse_retrieve_response( + self, response: Dict[str, Any], query: str + ) -> Dict[str, Any]: + """解析 OTS 检索响应 / Parse OTS retrieve response + + Args: + response: AgentStorageClient.retrieve 的响应 / Response from retrieve + query: 原始查询文本 / Original query text + + Returns: + Dict[str, Any]: 格式化的检索结果 / Formatted retrieval results + """ + all_results: List[Dict[str, Any]] = [] + + data = response.get("data", {}) + retrieval_results = data.get("retrievalResults", []) + + for item in retrieval_results: + all_results.append({ + "content": item.get("content"), + "score": item.get("score"), + "doc_id": item.get("docId"), + "chunk_id": item.get("chunkId"), + "subspace": item.get("subspace"), + "oss_key": item.get("ossKey"), + "metadata": item.get("metadata"), + }) + + return { + "data": all_results, + "query": query, + "knowledge_base_name": self.knowledge_base_name, + } + + async def retrieve_async( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """OTS 检索(异步)/ OTS retrieval asynchronously + + 通过 tablestore-agent-storage 调用 retrieve 接口进行知识库检索。 + Retrieves from OTS knowledge base via tablestore-agent-storage retrieve API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for OTS retrieval" + ) + + client = self._build_agent_storage_client(config) + + retrieval_config = self._build_retrieval_configuration() + + request: Dict[str, Any] = { + "knowledgeBaseName": self.knowledge_base_name, + "retrievalQuery": {"text": query, "type": "TEXT"}, + } + + if retrieval_config: + request["retrievalConfiguration"] = retrieval_config + + logger.debug(f"OTS retrieve request: {request}") + response = client.retrieve(request) + logger.debug(f"OTS retrieve response: {response}") + + return self._parse_retrieve_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from OTS knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def retrieve( + self, + query: str, + config: Optional[Config] = None, + ) -> Dict[str, Any]: + """OTS 检索(同步)/ OTS retrieval synchronously + + 通过 tablestore-agent-storage 调用 retrieve 接口进行知识库检索。 + Retrieves from OTS knowledge base via tablestore-agent-storage retrieve API. + + Args: + query: 查询文本 / Query text + config: 配置 / Configuration + + Returns: + Dict[str, Any]: 检索结果 / Retrieval results + """ + try: + if self.provider_settings is None: + raise ValueError( + "provider_settings is required for OTS retrieval" + ) + + client = self._build_agent_storage_client(config) + + retrieval_config = self._build_retrieval_configuration() + + request: Dict[str, Any] = { + "knowledgeBaseName": self.knowledge_base_name, + "retrievalQuery": {"text": query, "type": "TEXT"}, + } + + if retrieval_config: + request["retrievalConfiguration"] = retrieval_config + + logger.debug(f"OTS retrieve request: {request}") + response = client.retrieve(request) + logger.debug(f"OTS retrieve response: {response}") + + return self._parse_retrieve_response(response, query) + + except Exception as e: + logger.warning( + "Failed to retrieve from OTS knowledge base " + f"'{self.knowledge_base_name}': {e}" + ) + return { + "data": f"Failed to retrieve: {e}", + "query": query, + "knowledge_base_name": self.knowledge_base_name, + "error": True, + } + + def get_data_api( provider: KnowledgeBaseProvider, knowledge_base_name: str, @@ -813,6 +1090,7 @@ def get_data_api( RagFlowProviderSettings, BailianProviderSettings, ADBProviderSettings, + OTSProviderSettings, ] ] = None, retrieve_settings: Optional[ @@ -820,6 +1098,7 @@ def get_data_api( RagFlowRetrieveSettings, BailianRetrieveSettings, ADBRetrieveSettings, + OTSRetrieveSettings, ] ] = None, credential_name: Optional[str] = None, @@ -892,5 +1171,22 @@ def get_data_api( provider_settings=adb_provider_settings, retrieve_settings=adb_retrieve_settings, ) + elif provider == KnowledgeBaseProvider.OTS or provider == "ots": + ots_provider_settings = ( + provider_settings + if isinstance(provider_settings, OTSProviderSettings) + else None + ) + ots_retrieve_settings = ( + retrieve_settings + if isinstance(retrieve_settings, OTSRetrieveSettings) + else None + ) + return OTSDataAPI( + knowledge_base_name, + config, + provider_settings=ots_provider_settings, + retrieve_settings=ots_retrieve_settings, + ) else: raise ValueError(f"Unsupported provider type: {provider}") diff --git a/agentrun/knowledgebase/knowledgebase.py b/agentrun/knowledgebase/knowledgebase.py index 74c5c50..e4901f0 100644 --- a/agentrun/knowledgebase/knowledgebase.py +++ b/agentrun/knowledgebase/knowledgebase.py @@ -36,6 +36,16 @@ KnowledgeBaseProvider, KnowledgeBaseSystemProps, KnowledgeBaseUpdateInput, + OTSDenseVectorSearchConfig, + OTSEmbeddingConfiguration, + OTSFullTextSearchConfig, + OTSMetadataField, + OTSModelConfig, + OTSProviderSettings, + OTSRerankingConfig, + OTSRetrieveSettings, + OTSRRFConfig, + OTSWeightConfig, RagFlowProviderSettings, RagFlowRetrieveSettings, RetrieveInput, @@ -480,8 +490,6 @@ def _get_data_api(self, config: Optional[Config] = None): if isinstance(self.provider_settings, ADBProviderSettings): converted_provider_settings = self.provider_settings elif isinstance(self.provider_settings, dict): - # ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case - # ADB provider_settings uses PascalCase keys, need to convert to snake_case converted_provider_settings = ADBProviderSettings( db_instance_id=self.provider_settings.get( "DBInstanceId", "" @@ -501,8 +509,6 @@ def _get_data_api(self, config: Optional[Config] = None): if isinstance(self.retrieve_settings, ADBRetrieveSettings): converted_retrieve_settings = self.retrieve_settings elif isinstance(self.retrieve_settings, dict): - # ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case - # ADB retrieve_settings uses PascalCase keys, need to convert to snake_case converted_retrieve_settings = ADBRetrieveSettings( top_k=self.retrieve_settings.get("TopK"), use_full_text_retrieval=self.retrieve_settings.get( @@ -522,6 +528,115 @@ def _get_data_api(self, config: Optional[Config] = None): ), ) + elif provider == KnowledgeBaseProvider.OTS: + # OTS 设置 / OTS settings (camelCase → snake_case) + if self.provider_settings: + if isinstance(self.provider_settings, OTSProviderSettings): + converted_provider_settings = self.provider_settings + elif isinstance(self.provider_settings, dict): + ps = self.provider_settings + + metadata = None + raw_metadata = ps.get("metadata") + if raw_metadata and isinstance(raw_metadata, list): + metadata = [ + OTSMetadataField( + name=m.get("name", ""), + type=m.get("type", ""), + ) + for m in raw_metadata + ] + + embedding_config = None + raw_ec = ps.get("embeddingConfiguration") + if raw_ec and isinstance(raw_ec, dict): + embedding_config = OTSEmbeddingConfiguration( + provider=raw_ec.get("provider", ""), + model=raw_ec.get("model", ""), + dimension=raw_ec.get("dimension", 0), + url=raw_ec.get("url"), + api_key=raw_ec.get("apiKey"), + ) + + converted_provider_settings = OTSProviderSettings( + ots_instance_name=ps.get("otsInstanceName", ""), + tags=ps.get("tags"), + metadata=metadata, + embedding_configuration=embedding_config, + ) + + if self.retrieve_settings: + if isinstance(self.retrieve_settings, OTSRetrieveSettings): + converted_retrieve_settings = self.retrieve_settings + elif isinstance(self.retrieve_settings, dict): + rs = self.retrieve_settings + + dvsc = None + raw_dvsc = rs.get("denseVectorSearchConfiguration") + if raw_dvsc and isinstance(raw_dvsc, dict): + dvsc = OTSDenseVectorSearchConfig( + number_of_results=raw_dvsc.get("numberOfResults"), + ) + + ftsc = None + raw_ftsc = rs.get("fullTextSearchConfiguration") + if raw_ftsc and isinstance(raw_ftsc, dict): + ftsc = OTSFullTextSearchConfig( + number_of_results=raw_ftsc.get("numberOfResults"), + ) + + reranking = None + raw_rr = rs.get("rerankingConfiguration") + if raw_rr and isinstance(raw_rr, dict): + rrf_config = None + raw_rrf = raw_rr.get("rrfConfiguration") + if raw_rrf and isinstance(raw_rrf, dict): + rrf_config = OTSRRFConfig( + dense_vector_search_weight=raw_rrf.get( + "denseVectorSearchWeight" + ), + full_text_search_weight=raw_rrf.get( + "fullTextSearchWeight" + ), + k=raw_rrf.get("k"), + ) + + weight_config = None + raw_wc = raw_rr.get("weightConfiguration") + if raw_wc and isinstance(raw_wc, dict): + weight_config = OTSWeightConfig( + dense_vector_search_weight=raw_wc.get( + "denseVectorSearchWeight" + ), + full_text_search_weight=raw_wc.get( + "fullTextSearchWeight" + ), + ) + + model_config = None + raw_mc = raw_rr.get("modelConfiguration") + if raw_mc and isinstance(raw_mc, dict): + model_config = OTSModelConfig( + provider=raw_mc.get("provider"), + model=raw_mc.get("model"), + ) + + reranking = OTSRerankingConfig( + type=raw_rr.get("type"), + number_of_results=raw_rr.get("numberOfResults"), + rrf_configuration=rrf_config, + weight_configuration=weight_config, + model_configuration=model_config, + ) + + converted_retrieve_settings = OTSRetrieveSettings( + search_type=rs.get("searchType"), + dense_vector_search_configuration=dvsc, + full_text_search_configuration=ftsc, + reranking_configuration=reranking, + filter=rs.get("filter"), + ) + return get_data_api( provider=provider, knowledge_base_name=self.knowledge_base_name or "", diff --git a/agentrun/knowledgebase/model.py b/agentrun/knowledgebase/model.py index 1c7c227..c3f6df8 100644 --- a/agentrun/knowledgebase/model.py +++ b/agentrun/knowledgebase/model.py @@ -20,6 +20,8 @@ class KnowledgeBaseProvider(str, Enum): """百炼知识库 / Bailian knowledge base""" ADB = "adb" """ADB (AnalyticDB for PostgreSQL) 知识库 / ADB knowledge base""" + OTS = "ots" + """OTS (TableStore) 知识库 / OTS (TableStore) knowledge base""" # ============================================================================= @@ -131,6 +133,137 @@ class ADBRetrieveSettings(BaseModel): Hybrid search algorithm parameters""" +# ============================================================================= +# OTS 配置模型 / OTS Configuration Models +# ============================================================================= + + +class OTSMetadataField(BaseModel): + """OTS 元数据字段定义 / OTS Metadata Field Definition + + 支持的类型:string / long / double / boolean / date / + string_list / long_list / double_list / boolean_list / date_list + """ + + name: str + """字段名 / Field name""" + type: str + """字段类型 / Field type""" + + +class OTSEmbeddingConfiguration(BaseModel): + """OTS 向量化配置 / OTS Embedding Configuration""" + + provider: str + """向量化服务提供商,如 "bailian" / Embedding provider""" + model: str + """向量化模型名称,如 "text-embedding-v3" / Embedding model name""" + dimension: int + """向量维度,如 1024 / Vector dimension""" + url: Optional[str] = None + """向量化服务地址(可选)/ Embedding service URL (optional)""" + api_key: Optional[str] = None + """向量化服务 API Key(可选)/ Embedding API key (optional)""" + + +class OTSProviderSettings(BaseModel): + """OTS (TableStore) 提供商设置 / OTS Provider Settings + + 配置 OTS 知识库的连接和访问参数。 + Configure OTS knowledge base connection and access parameters. + """ + + ots_instance_name: str + """OTS 实例名称 / OTS instance name""" + tags: Optional[List[str]] = None + """标签列表 / Tag list""" + metadata: Optional[List[OTSMetadataField]] = None + """元数据字段定义列表 / Metadata field definitions""" + embedding_configuration: Optional[OTSEmbeddingConfiguration] = None + """向量化配置 / Embedding configuration""" + + +class OTSDenseVectorSearchConfig(BaseModel): + """OTS 向量检索配置 / OTS Dense Vector Search Configuration""" + + number_of_results: Optional[int] = None + """向量检索返回结果数量 / Number of dense vector search results""" + + +class OTSFullTextSearchConfig(BaseModel): + """OTS 全文检索配置 / OTS Full Text Search Configuration""" + + number_of_results: Optional[int] = None + """全文检索返回结果数量 / Number of full text search results""" + + +class OTSRRFConfig(BaseModel): + """OTS RRF 重排序配置 / OTS RRF Reranking Configuration""" + + dense_vector_search_weight: Optional[float] = None + """向量检索权重,默认 1.0 / Dense vector search weight""" + full_text_search_weight: Optional[float] = None + """全文检索权重,默认 1.0 / Full text search weight""" + k: Optional[int] = None + """RRF 参数 k,默认 60 / RRF parameter k""" + + +class OTSWeightConfig(BaseModel): + """OTS Weight 重排序配置 / OTS Weight Reranking Configuration""" + + dense_vector_search_weight: Optional[float] = None + """向量检索权重 / Dense vector search weight""" + full_text_search_weight: Optional[float] = None + """全文检索权重 / Full text search weight""" + + +class OTSModelConfig(BaseModel): + """OTS Model 重排序配置 / OTS Model Reranking Configuration""" + + provider: Optional[str] = None + """重排序模型提供商 / Reranking model provider""" + model: Optional[str] = None + """重排序模型名称 / Reranking model name""" + + +class OTSRerankingConfig(BaseModel): + """OTS 重排序配置 / OTS Reranking Configuration""" + + type: Optional[str] = None + """重排序类型:RRF / WEIGHT / MODEL / Reranking type""" + number_of_results: Optional[int] = None + """重排序后返回结果数量 / Number of results after reranking""" + rrf_configuration: Optional[OTSRRFConfig] = None + """RRF 配置(当 type=RRF 时)/ RRF config (when type=RRF)""" + weight_configuration: Optional[OTSWeightConfig] = None + """Weight 配置(当 type=WEIGHT 时)/ Weight config (when type=WEIGHT)""" + model_configuration: Optional[OTSModelConfig] = None + """Model 配置(当 type=MODEL 时)/ Model config (when type=MODEL)""" + + +class OTSRetrieveSettings(BaseModel): + """OTS 检索设置 / OTS Retrieve Settings + + 配置 OTS 知识库的检索参数,支持向量检索、全文检索和混合检索。 + Configure OTS knowledge base retrieval parameters, supporting + dense vector, full-text, and hybrid search. + """ + + search_type: Optional[List[str]] = None + """检索类型列表,支持 DENSE_VECTOR 和 FULL_TEXT + Search type list, supports DENSE_VECTOR and FULL_TEXT""" + dense_vector_search_configuration: Optional[OTSDenseVectorSearchConfig] = ( + None + ) + """向量检索配置 / Dense vector search configuration""" + full_text_search_configuration: Optional[OTSFullTextSearchConfig] = None + """全文检索配置 / Full text search configuration""" + reranking_configuration: Optional[OTSRerankingConfig] = None + """重排序配置 / Reranking configuration""" + filter: Optional[Dict[str, Any]] = None + """元数据过滤条件 / Metadata filter""" + + # ============================================================================= # 联合类型定义 / Union Type Definitions # ============================================================================= @@ -139,6 +272,7 @@ class ADBRetrieveSettings(BaseModel): RagFlowProviderSettings, BailianProviderSettings, ADBProviderSettings, + OTSProviderSettings, Dict[str, Any], ] """提供商设置联合类型 / Provider settings union type""" @@ -147,6 +281,7 @@ class ADBRetrieveSettings(BaseModel): RagFlowRetrieveSettings, BailianRetrieveSettings, ADBRetrieveSettings, + OTSRetrieveSettings, Dict[str, Any], ] """检索设置联合类型 / Retrieve settings union type""" diff --git a/examples/knowledgebase.py b/examples/knowledgebase.py index 61b5722..fff3e9f 100644 --- a/examples/knowledgebase.py +++ b/examples/knowledgebase.py @@ -1,11 +1,11 @@ """ 知识库模块示例 / KnowledgeBase Module Example -本示例演示如何使用 AgentRun SDK 管理知识库,包括百炼、RagFlow 和 ADB 三种类型: +本示例演示如何使用 AgentRun SDK 管理知识库,包括百炼、RagFlow、ADB 和 OTS 四种类型: This example demonstrates how to use the AgentRun SDK to manage knowledge bases, -including Bailian, RagFlow and ADB types: +including Bailian, RagFlow, ADB and OTS types: -1. 创建知识库 / Create knowledge base (Bailian & RagFlow) +1. 创建知识库 / Create knowledge base (Bailian, RagFlow, ADB & OTS) 2. 获取知识库信息 / Get knowledge base info 3. 查询知识库 / Query knowledge base 4. 更新知识库配置 / Update knowledge base configuration @@ -31,6 +31,12 @@ - ADB_NAMESPACE: ADB 命名空间 - ADB_NAMESPACE_PASSWORD: ADB 命名空间密码 - ADB_COLLECTION: ADB 文档集合名称 + +OTS 知识库额外配置 / Additional config for OTS: +- OTS_INSTANCE_NAME: OTS (TableStore) 实例名称 +- OTS_EMBEDDING_PROVIDER: 向量化服务提供商(默认 bailian) +- OTS_EMBEDDING_MODEL: 向量化模型名称(默认 text-embedding-v3) +- OTS_EMBEDDING_DIMENSION: 向量维度(默认 1024) """ import json @@ -47,6 +53,15 @@ KnowledgeBaseCreateInput, KnowledgeBaseProvider, KnowledgeBaseUpdateInput, + OTSDenseVectorSearchConfig, + OTSEmbeddingConfiguration, + OTSFullTextSearchConfig, + OTSMetadataField, + OTSProviderSettings, + OTSRerankingConfig, + OTSRetrieveSettings, + OTSRRFConfig, + OTSWeightConfig, RagFlowProviderSettings, RagFlowRetrieveSettings, ) @@ -136,6 +151,30 @@ # ADB embedding model name (optional) ADB_EMBEDDING_MODEL = os.getenv("ADB_EMBEDDING_MODEL", "text-embedding-v3") +# ----------------------------------------------------------------------------- +# OTS 知识库配置 / OTS Knowledge Base Configuration +# ----------------------------------------------------------------------------- + +# OTS 知识库名称 +# OTS knowledge base name +OTS_KB_NAME = os.getenv("OTS_KB_NAME", f"sdk-test-ots-kb-{TIMESTAMP}") + +# OTS 实例名称,请替换为您的实际值 +# OTS instance name, please replace with your actual value +OTS_INSTANCE_NAME = os.getenv("OTS_INSTANCE_NAME", "your-ots-instance-name") + +# OTS 向量化服务提供商 +# OTS embedding provider +OTS_EMBEDDING_PROVIDER = os.getenv("OTS_EMBEDDING_PROVIDER", "bailian") + +# OTS 向量化模型名称 +# OTS embedding model name +OTS_EMBEDDING_MODEL = os.getenv("OTS_EMBEDDING_MODEL", "text-embedding-v3") + +# OTS 向量维度 +# OTS embedding dimension +OTS_EMBEDDING_DIMENSION = int(os.getenv("OTS_EMBEDDING_DIMENSION", "1024")) + # ============================================================================ # 客户端初始化 / Client Initialization # ============================================================================ @@ -545,6 +584,182 @@ def delete_adb_kb(kb: KnowledgeBase): logger.info("✅ ADB 知识库已成功删除 / ADB KB deleted successfully") +# ============================================================================ +# OTS 知识库示例函数 / OTS Knowledge Base Example Functions +# ============================================================================ + + +def create_or_get_ots_kb() -> KnowledgeBase: + """创建或获取已有的 OTS 知识库 / Create or get existing OTS knowledge base + + Returns: + KnowledgeBase: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("创建或获取 OTS 知识库") + logger.info("Create or get OTS knowledge base") + logger.info("=" * 60) + + try: + kb = KnowledgeBase.create( + KnowledgeBaseCreateInput( + knowledge_base_name=OTS_KB_NAME, + description=( + "通过 SDK 创建的 OTS 知识库示例 / OTS KB example" + " created via SDK" + ), + provider=KnowledgeBaseProvider.OTS, + provider_settings=OTSProviderSettings( + ots_instance_name=OTS_INSTANCE_NAME, + tags=["sdk-test", "example"], + metadata=[ + OTSMetadataField(name="source", type="string"), + OTSMetadataField(name="category", type="string"), + OTSMetadataField(name="score", type="double"), + ], + embedding_configuration=OTSEmbeddingConfiguration( + provider=OTS_EMBEDDING_PROVIDER, + model=OTS_EMBEDDING_MODEL, + dimension=OTS_EMBEDDING_DIMENSION, + ), + ), + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR", "FULL_TEXT"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=10, + ), + full_text_search_configuration=OTSFullTextSearchConfig( + number_of_results=10, + ), + reranking_configuration=OTSRerankingConfig( + type="RRF", + number_of_results=5, + rrf_configuration=OTSRRFConfig( + dense_vector_search_weight=1.0, + full_text_search_weight=0.5, + k=60, + ), + ), + ), + ) + ) + logger.info("✅ OTS 知识库创建成功 / OTS KB created successfully") + + except ResourceAlreadyExistError: + logger.info( + "ℹ️ OTS 知识库已存在,获取已有资源 / OTS KB exists, getting" + " existing" + ) + kb = client.get(OTS_KB_NAME) + + _log_kb_info(kb) + return kb + + +def query_ots_kb(kb: KnowledgeBase): + """查询 OTS 知识库 / Query OTS knowledge base + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("查询 OTS 知识库") + logger.info("Query OTS knowledge base") + logger.info("=" * 60) + + query_text = "什么是云原生" + logger.info("查询文本 / Query text: %s", query_text) + + try: + results = kb.retrieve(query=query_text) + logger.info("✅ 查询成功 / Query successful") + logger.info("检索结果 / Retrieval results: %s", results) + logger.info( + " - 结果数量 / Result count: %s", len(results.get("data", [])) + ) + except Exception as e: + logger.warning("⚠️ 查询失败(可能是配置或连接问题): %s", e) + + +def query_ots_kb_by_name(knowledgebase_name: str): + """根据名称查询 OTS 知识库 / Query OTS knowledge base by name + + Args: + knowledgebase_name: 知识库名称 / Knowledge base name + """ + try: + kb = KnowledgeBase.get_by_name(knowledgebase_name) + results = kb.retrieve(query="什么是云原生") + logger.info("✅ 查询成功 / Query successful") + logger.info("检索结果 / Retrieval results: %s", results) + logger.info( + " - 结果数量 / Result count: %s", len(results.get("data", [])) + ) + except Exception as e: + logger.warning("⚠️ 查询失败(可能是配置或连接问题): %s", e) + + +def update_ots_kb(kb: KnowledgeBase): + """更新 OTS 知识库配置 / Update OTS knowledge base configuration + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("更新 OTS 知识库配置") + logger.info("Update OTS knowledge base configuration") + logger.info("=" * 60) + + new_description = f"[OTS] 更新于 {time.strftime('%Y-%m-%d %H:%M:%S')}" + + kb.update( + KnowledgeBaseUpdateInput( + description=new_description, + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR", "FULL_TEXT"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=20, + ), + full_text_search_configuration=OTSFullTextSearchConfig( + number_of_results=20, + ), + reranking_configuration=OTSRerankingConfig( + type="WEIGHT", + number_of_results=10, + weight_configuration=OTSWeightConfig( + dense_vector_search_weight=0.7, + full_text_search_weight=0.3, + ), + ), + ), + ) + ) + + logger.info("✅ OTS 知识库更新成功 / OTS KB updated successfully") + logger.info(" - 新描述 / New description: %s", kb.description) + + +def delete_ots_kb(kb: KnowledgeBase): + """删除 OTS 知识库 / Delete OTS knowledge base + + Args: + kb: 知识库对象 / Knowledge base object + """ + logger.info("=" * 60) + logger.info("删除 OTS 知识库") + logger.info("Delete OTS knowledge base") + logger.info("=" * 60) + + kb.delete() + logger.info("✅ OTS 知识库删除请求已发送 / OTS KB delete request sent") + + try: + client.get(OTS_KB_NAME) + logger.warning("⚠️ OTS 知识库仍然存在 / OTS KB still exists") + except ResourceNotExistError: + logger.info("✅ OTS 知识库已成功删除 / OTS KB deleted successfully") + + # ============================================================================ # 通用工具函数 / Common Utility Functions # ============================================================================ @@ -588,9 +803,11 @@ def list_knowledge_bases(): provider=KnowledgeBaseProvider.RAGFLOW.value ) adb_list = KnowledgeBase.list_all(provider=KnowledgeBaseProvider.ADB.value) + ots_list = KnowledgeBase.list_all(provider=KnowledgeBaseProvider.OTS.value) logger.info(" - 百炼知识库 / Bailian KBs: %d 个", len(bailian_list)) logger.info(" - RagFlow 知识库 / RagFlow KBs: %d 个", len(ragflow_list)) logger.info(" - ADB 知识库 / ADB KBs: %d 个", len(adb_list)) + logger.info(" - OTS 知识库 / OTS KBs: %d 个", len(ots_list)) # ============================================================================ @@ -664,11 +881,33 @@ def adb_example(): logger.info("") +def ots_example(): + """OTS 知识库完整示例 / Complete OTS knowledge base example""" + logger.info("") + logger.info("🔸 OTS 知识库示例 / OTS Knowledge Base Example") + logger.info("=" * 60) + + # 创建 OTS 知识库 / Create OTS KB + kb = create_or_get_ots_kb() + + # 查询 OTS 知识库 / Query OTS KB + query_ots_kb(kb) + + # 更新 OTS 知识库 / Update OTS KB + update_ots_kb(kb) + + # # 删除 OTS 知识库 / Delete OTS KB + delete_ots_kb(kb) + + logger.info("🔸 OTS 知识库示例完成 / OTS KB Example Complete") + logger.info("") + + def knowledgebase_example(): """知识库模块完整示例 / Complete knowledge base module example - 演示百炼和 RagFlow 两种知识库的完整操作流程。 - Demonstrates complete operation flow for both Bailian and RagFlow knowledge bases. + 演示百炼、RagFlow、ADB 和 OTS 四种知识库的完整操作流程。 + Demonstrates complete operation flow for Bailian, RagFlow, ADB and OTS knowledge bases. """ logger.info("") logger.info("🚀 知识库模块示例开始 / KnowledgeBase Module Example Start") @@ -683,6 +922,12 @@ def knowledgebase_example(): # RagFlow 知识库示例 / RagFlow KB example ragflow_example() + # ADB 知识库示例 / ADB KB example + adb_example() + + # OTS 知识库示例 / OTS KB example + ots_example() + # 最终列出知识库 / Final list list_knowledge_bases() @@ -717,6 +962,15 @@ def adb_only_example(): logger.info("🎉 完成 / Complete") +def ots_only_example(): + """仅运行 OTS 知识库示例 / Run OTS knowledge base example only""" + logger.info("🚀 OTS 知识库示例 / OTS KB Example") + list_knowledge_bases() + ots_example() + list_knowledge_bases() + logger.info("🎉 完成 / Complete") + + def multiple_knowledgebase_query(): """多知识库检索 / Multi knowledge base retrieval 根据知识库名称列表进行检索,自动获取各知识库的配置并执行检索。 @@ -753,6 +1007,8 @@ def update_ragflow_kb_config(): # bailian_only_example() # ragflow_only_example() # adb_only_example() + # ots_only_example() multiple_knowledgebase_query() # query_adb_kb_by_name("") + # query_ots_kb_by_name("") # update_ragflow_kb_config() diff --git a/pyproject.toml b/pyproject.toml index 31e1d6b..c49f164 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,8 @@ dependencies = [ "alibabacloud_tea_openapi>=0.4.2", "alibabacloud_bailian20231229>=2.6.2", "agentrun-mem0ai>=0.0.10", - "alibabacloud_gpdb20160503>=5.0.1" + "alibabacloud_gpdb20160503>=5.0.1", + "tablestore-agent-storage>=1.0.4" ] [project.optional-dependencies] diff --git a/tests/unittests/knowledgebase/test_ots_knowledgebase.py b/tests/unittests/knowledgebase/test_ots_knowledgebase.py new file mode 100644 index 0000000..ad7ea09 --- /dev/null +++ b/tests/unittests/knowledgebase/test_ots_knowledgebase.py @@ -0,0 +1,1120 @@ +"""测试 OTS 知识库相关功能 / Test OTS KnowledgeBase functionality""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from agentrun.knowledgebase.api.data import get_data_api, OTSDataAPI +from agentrun.knowledgebase.knowledgebase import KnowledgeBase +from agentrun.knowledgebase.model import ( + KnowledgeBaseProvider, + OTSDenseVectorSearchConfig, + OTSEmbeddingConfiguration, + OTSFullTextSearchConfig, + OTSMetadataField, + OTSModelConfig, + OTSProviderSettings, + OTSRerankingConfig, + OTSRetrieveSettings, + OTSRRFConfig, + OTSWeightConfig, +) +from agentrun.utils.config import Config + +# ============================================================================= +# OTS 模型测试 / OTS Model Tests +# ============================================================================= + + +class TestOTSModels: + """测试 OTS 模型定义""" + + def test_ots_provider_enum(self): + """测试 OTS provider 枚举值""" + assert KnowledgeBaseProvider.OTS == "ots" + assert KnowledgeBaseProvider("ots") == KnowledgeBaseProvider.OTS + + def test_ots_metadata_field(self): + """测试 OTSMetadataField 模型""" + field = OTSMetadataField(name="author", type="string") + assert field.name == "author" + assert field.type == "string" + + def test_ots_embedding_configuration(self): + """测试 OTSEmbeddingConfiguration 模型""" + ec = OTSEmbeddingConfiguration( + provider="bailian", + model="text-embedding-v3", + dimension=1024, + ) + assert ec.provider == "bailian" + assert ec.model == "text-embedding-v3" + assert ec.dimension == 1024 + assert ec.url is None + assert ec.api_key is None + + def test_ots_embedding_configuration_full(self): + """测试 OTSEmbeddingConfiguration 完整字段""" + ec = OTSEmbeddingConfiguration( + provider="bailian", + model="text-embedding-v3", + dimension=1024, + url="https://embedding.example.com", + api_key="test-key", + ) + assert ec.url == "https://embedding.example.com" + assert ec.api_key == "test-key" + + def test_ots_provider_settings_required_only(self): + """测试 OTSProviderSettings 仅必填字段""" + ps = OTSProviderSettings(ots_instance_name="test-instance") + assert ps.ots_instance_name == "test-instance" + assert ps.tags is None + assert ps.metadata is None + assert ps.embedding_configuration is None + + def test_ots_provider_settings_full(self): + """测试 OTSProviderSettings 完整字段""" + ps = OTSProviderSettings( + ots_instance_name="test-instance", + tags=["demo", "test"], + metadata=[ + OTSMetadataField(name="author", type="string"), + OTSMetadataField(name="score", type="double"), + ], + embedding_configuration=OTSEmbeddingConfiguration( + provider="bailian", + model="text-embedding-v3", + dimension=1024, + ), + ) + assert ps.ots_instance_name == "test-instance" + assert ps.tags == ["demo", "test"] + assert len(ps.metadata) == 2 + assert ps.metadata[0].name == "author" + assert ps.embedding_configuration.provider == "bailian" + + def test_ots_dense_vector_search_config(self): + """测试 OTSDenseVectorSearchConfig 模型""" + c = OTSDenseVectorSearchConfig(number_of_results=10) + assert c.number_of_results == 10 + + def test_ots_full_text_search_config(self): + """测试 OTSFullTextSearchConfig 模型""" + c = OTSFullTextSearchConfig(number_of_results=20) + assert c.number_of_results == 20 + + def test_ots_rrf_config(self): + """测试 OTSRRFConfig 模型""" + c = OTSRRFConfig( + dense_vector_search_weight=1.0, + full_text_search_weight=1.0, + k=60, + ) + assert c.dense_vector_search_weight == 1.0 + assert c.full_text_search_weight == 1.0 + assert c.k == 60 + + def test_ots_weight_config(self): + """测试 OTSWeightConfig 模型""" + c = OTSWeightConfig( + dense_vector_search_weight=0.7, + full_text_search_weight=0.3, + ) + assert c.dense_vector_search_weight == 0.7 + assert c.full_text_search_weight == 0.3 + + def test_ots_model_config(self): + """测试 OTSModelConfig 模型""" + c = OTSModelConfig(provider="bailian", model="gte-rerank-v2") + assert c.provider == "bailian" + assert c.model == "gte-rerank-v2" + + def test_ots_reranking_config_rrf(self): + """测试 OTSRerankingConfig RRF 类型""" + c = OTSRerankingConfig( + type="RRF", + number_of_results=10, + rrf_configuration=OTSRRFConfig( + dense_vector_search_weight=1.0, + full_text_search_weight=1.0, + k=60, + ), + ) + assert c.type == "RRF" + assert c.rrf_configuration.k == 60 + + def test_ots_reranking_config_weight(self): + """测试 OTSRerankingConfig WEIGHT 类型""" + c = OTSRerankingConfig( + type="WEIGHT", + number_of_results=10, + weight_configuration=OTSWeightConfig( + dense_vector_search_weight=0.7, + full_text_search_weight=0.3, + ), + ) + assert c.type == "WEIGHT" + assert c.weight_configuration.dense_vector_search_weight == 0.7 + + def test_ots_reranking_config_model(self): + """测试 OTSRerankingConfig MODEL 类型""" + c = OTSRerankingConfig( + type="MODEL", + number_of_results=10, + model_configuration=OTSModelConfig( + provider="bailian", model="gte-rerank-v2" + ), + ) + assert c.type == "MODEL" + assert c.model_configuration.model == "gte-rerank-v2" + + def test_ots_retrieve_settings_dense_only(self): + """测试 OTSRetrieveSettings 仅向量检索""" + rs = OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=10 + ), + reranking_configuration=OTSRerankingConfig( + type="RRF", number_of_results=10 + ), + ) + assert rs.search_type == ["DENSE_VECTOR"] + assert rs.dense_vector_search_configuration.number_of_results == 10 + assert rs.full_text_search_configuration is None + + def test_ots_retrieve_settings_hybrid(self): + """测试 OTSRetrieveSettings 混合检索""" + rs = OTSRetrieveSettings( + search_type=["DENSE_VECTOR", "FULL_TEXT"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=20 + ), + full_text_search_configuration=OTSFullTextSearchConfig( + number_of_results=20 + ), + ) + assert len(rs.search_type) == 2 + assert rs.full_text_search_configuration.number_of_results == 20 + + def test_ots_retrieve_settings_with_filter(self): + """测试 OTSRetrieveSettings 带元数据过滤""" + rs = OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + filter={ + "andAll": [ + {"equals": {"key": "author", "value": "test"}}, + {"greaterThan": {"key": "score", "value": 0.5}}, + ] + }, + ) + assert rs.filter is not None + assert "andAll" in rs.filter + + def test_ots_retrieve_settings_all_none(self): + """测试 OTSRetrieveSettings 所有字段为 None""" + rs = OTSRetrieveSettings() + assert rs.search_type is None + assert rs.dense_vector_search_configuration is None + assert rs.full_text_search_configuration is None + assert rs.reranking_configuration is None + assert rs.filter is None + + +# ============================================================================= +# OTSDataAPI 测试 / OTSDataAPI Tests +# ============================================================================= + + +class TestOTSDataAPIBuildRetrievalConfiguration: + """测试 OTSDataAPI._build_retrieval_configuration""" + + def test_none_retrieve_settings(self): + """测试无 retrieve_settings""" + api = OTSDataAPI("test-kb", provider_settings=None) + assert api._build_retrieval_configuration() is None + + def test_empty_retrieve_settings(self): + """测试空 retrieve_settings""" + api = OTSDataAPI( + "test-kb", + retrieve_settings=OTSRetrieveSettings(), + ) + assert api._build_retrieval_configuration() is None + + def test_dense_vector_only(self): + """测试仅向量检索配置""" + api = OTSDataAPI( + "test-kb", + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=10 + ), + ), + ) + config = api._build_retrieval_configuration() + assert config["searchType"] == ["DENSE_VECTOR"] + assert config["denseVectorSearchConfiguration"]["numberOfResults"] == 10 + + def test_hybrid_search_with_rrf(self): + """测试混合检索 + RRF 重排序""" + api = OTSDataAPI( + "test-kb", + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR", "FULL_TEXT"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=20 + ), + full_text_search_configuration=OTSFullTextSearchConfig( + number_of_results=20 + ), + reranking_configuration=OTSRerankingConfig( + type="RRF", + number_of_results=10, + rrf_configuration=OTSRRFConfig( + dense_vector_search_weight=1.0, + full_text_search_weight=1.0, + k=60, + ), + ), + ), + ) + config = api._build_retrieval_configuration() + assert config["searchType"] == ["DENSE_VECTOR", "FULL_TEXT"] + assert config["rerankingConfiguration"]["type"] == "RRF" + assert config["rerankingConfiguration"]["rrfConfiguration"]["k"] == 60 + + def test_weight_reranking(self): + """测试 Weight 重排序""" + api = OTSDataAPI( + "test-kb", + retrieve_settings=OTSRetrieveSettings( + reranking_configuration=OTSRerankingConfig( + type="WEIGHT", + number_of_results=10, + weight_configuration=OTSWeightConfig( + dense_vector_search_weight=0.7, + full_text_search_weight=0.3, + ), + ), + ), + ) + config = api._build_retrieval_configuration() + rr = config["rerankingConfiguration"] + assert rr["type"] == "WEIGHT" + assert rr["weightConfiguration"]["denseVectorSearchWeight"] == 0.7 + assert rr["weightConfiguration"]["fullTextSearchWeight"] == 0.3 + + def test_model_reranking(self): + """测试 Model 重排序""" + api = OTSDataAPI( + "test-kb", + retrieve_settings=OTSRetrieveSettings( + reranking_configuration=OTSRerankingConfig( + type="MODEL", + number_of_results=10, + model_configuration=OTSModelConfig( + provider="bailian", model="gte-rerank-v2" + ), + ), + ), + ) + config = api._build_retrieval_configuration() + rr = config["rerankingConfiguration"] + assert rr["type"] == "MODEL" + assert rr["modelConfiguration"]["provider"] == "bailian" + assert rr["modelConfiguration"]["model"] == "gte-rerank-v2" + + def test_with_filter(self): + """测试带过滤条件""" + api = OTSDataAPI( + "test-kb", + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + filter={ + "andAll": [{"equals": {"key": "author", "value": "test"}}] + }, + ), + ) + config = api._build_retrieval_configuration() + assert config["filter"]["andAll"][0]["equals"]["key"] == "author" + + +class TestOTSDataAPIParseResponse: + """测试 OTSDataAPI._parse_retrieve_response""" + + def test_parse_normal_response(self): + """测试解析正常响应""" + api = OTSDataAPI("test-kb") + response = { + "code": "SUCCESS", + "data": { + "retrievalResults": [{ + "ossKey": "oss://testbucket/xxx.pdf", + "docId": "96fb386e-44d5-40aa-aa4d-edc0762f867c", + "chunkId": 3, + "subspace": "test", + "score": 0.1, + "content": "test content", + "metadata": {"date": "2026-01-22"}, + }] + }, + "message": "success", + } + result = api._parse_retrieve_response(response, "test query") + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + assert len(result["data"]) == 1 + assert result["data"][0]["content"] == "test content" + assert result["data"][0]["score"] == 0.1 + assert ( + result["data"][0]["doc_id"] + == "96fb386e-44d5-40aa-aa4d-edc0762f867c" + ) + assert result["data"][0]["chunk_id"] == 3 + assert result["data"][0]["subspace"] == "test" + assert result["data"][0]["oss_key"] == "oss://testbucket/xxx.pdf" + assert result["data"][0]["metadata"]["date"] == "2026-01-22" + + def test_parse_empty_response(self): + """测试解析空响应""" + api = OTSDataAPI("test-kb") + response = {"code": "SUCCESS", "data": {"retrievalResults": []}} + result = api._parse_retrieve_response(response, "test query") + assert result["data"] == [] + + def test_parse_no_data(self): + """测试解析无 data 字段的响应""" + api = OTSDataAPI("test-kb") + response = {"code": "SUCCESS"} + result = api._parse_retrieve_response(response, "test query") + assert result["data"] == [] + + def test_parse_multiple_results(self): + """测试解析多条结果""" + api = OTSDataAPI("test-kb") + response = { + "code": "SUCCESS", + "data": { + "retrievalResults": [ + {"content": "result 1", "score": 0.9}, + {"content": "result 2", "score": 0.8}, + {"content": "result 3", "score": 0.7}, + ] + }, + } + result = api._parse_retrieve_response(response, "query") + assert len(result["data"]) == 3 + assert result["data"][0]["score"] == 0.9 + assert result["data"][2]["content"] == "result 3" + + +class TestOTSDataAPIRetrieve: + """测试 OTSDataAPI.retrieve 方法""" + + @patch( + "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" + ) + def test_retrieve_sync_success(self, mock_build_client): + """测试同步检索成功""" + mock_client = MagicMock() + mock_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [{"content": "test result", "score": 0.9}] + }, + } + mock_build_client.return_value = mock_client + + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=10 + ), + ), + ) + + result = api.retrieve("test query") + assert result["query"] == "test query" + assert len(result["data"]) == 1 + assert result["data"][0]["content"] == "test result" + mock_client.retrieve.assert_called_once() + + @patch( + "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" + ) + @pytest.mark.asyncio + async def test_retrieve_async_success(self, mock_build_client): + """测试异步检索成功""" + mock_client = MagicMock() + mock_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [{"content": "async result", "score": 0.85}] + }, + } + mock_build_client.return_value = mock_client + + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + + result = await api.retrieve_async("async query") + assert result["query"] == "async query" + assert result["data"][0]["content"] == "async result" + + @patch( + "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" + ) + def test_retrieve_error_handling(self, mock_build_client): + """测试检索错误处理""" + mock_build_client.side_effect = Exception("Connection failed") + + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + + result = api.retrieve("test query") + assert result["error"] is True + assert "Failed to retrieve" in result["data"] + assert result["query"] == "test query" + assert result["knowledge_base_name"] == "test-kb" + + @patch( + "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" + ) + @pytest.mark.asyncio + async def test_retrieve_async_error_handling(self, mock_build_client): + """测试异步检索错误处理""" + mock_build_client.side_effect = Exception("Network error") + + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "Failed to retrieve" in result["data"] + + def test_retrieve_without_provider_settings(self): + """测试无 provider_settings 时检索""" + api = OTSDataAPI("test-kb") + result = api.retrieve("test query") + assert result["error"] is True + assert "provider_settings is required" in result["data"] + + @pytest.mark.asyncio + async def test_retrieve_async_without_provider_settings(self): + """测试异步无 provider_settings 时检索""" + api = OTSDataAPI("test-kb") + result = await api.retrieve_async("test query") + assert result["error"] is True + assert "provider_settings is required" in result["data"] + + @patch( + "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" + ) + def test_retrieve_without_retrieve_settings(self, mock_build_client): + """测试无 retrieve_settings 时检索(使用默认配置)""" + mock_client = MagicMock() + mock_client.retrieve.return_value = { + "code": "SUCCESS", + "data": { + "retrievalResults": [{"content": "default", "score": 0.5}] + }, + } + mock_build_client.return_value = mock_client + + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + + result = api.retrieve("test query") + assert len(result["data"]) == 1 + + call_args = mock_client.retrieve.call_args[0][0] + assert "retrievalConfiguration" not in call_args + + @patch( + "agentrun.knowledgebase.api.data.OTSDataAPI._build_agent_storage_client" + ) + def test_retrieve_request_structure(self, mock_build_client): + """测试检索请求结构是否正确""" + mock_client = MagicMock() + mock_client.retrieve.return_value = { + "code": "SUCCESS", + "data": {"retrievalResults": []}, + } + mock_build_client.return_value = mock_client + + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + dense_vector_search_configuration=OTSDenseVectorSearchConfig( + number_of_results=10 + ), + ), + ) + + api.retrieve("test query") + + call_args = mock_client.retrieve.call_args[0][0] + assert call_args["knowledgeBaseName"] == "test-kb" + assert call_args["retrievalQuery"]["text"] == "test query" + assert call_args["retrievalQuery"]["type"] == "TEXT" + assert call_args["retrievalConfiguration"]["searchType"] == [ + "DENSE_VECTOR" + ] + + +class TestOTSDataAPIBuildClient: + """测试 OTSDataAPI._build_agent_storage_client""" + + @patch("agentrun.knowledgebase.api.data.AgentStorageClient") + def test_build_client(self, mock_client_class): + """测试构建客户端""" + mock_config = MagicMock(spec=Config) + mock_config.get_region_id.return_value = "cn-hangzhou" + mock_config.get_access_key_id.return_value = "test-ak" + mock_config.get_access_key_secret.return_value = "test-sk" + + with patch.object(Config, "with_configs", return_value=mock_config): + api = OTSDataAPI( + "test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + api._build_agent_storage_client() + + mock_client_class.assert_called_once_with( + access_key_id="test-ak", + access_key_secret="test-sk", + ots_endpoint="http://ots-cn-hangzhou.aliyuncs.com", + ots_instance_name="test-instance", + ) + + def test_build_client_without_provider_settings(self): + """测试无 provider_settings 时构建客户端""" + api = OTSDataAPI("test-kb") + with pytest.raises(ValueError, match="provider_settings is required"): + api._build_agent_storage_client() + + +# ============================================================================= +# get_data_api 工厂函数测试 / get_data_api Factory Tests +# ============================================================================= + + +class TestGetDataAPIOTS: + """测试 get_data_api 的 OTS 分支""" + + def test_get_data_api_ots_with_enum(self): + """测试使用枚举获取 OTS DataAPI""" + api = get_data_api( + provider=KnowledgeBaseProvider.OTS, + knowledge_base_name="test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + assert isinstance(api, OTSDataAPI) + + def test_get_data_api_ots_with_string(self): + """测试使用字符串获取 OTS DataAPI""" + api = get_data_api( + provider="ots", + knowledge_base_name="test-kb", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + assert isinstance(api, OTSDataAPI) + + def test_get_data_api_ots_with_settings(self): + """测试获取带设置的 OTS DataAPI""" + ps = OTSProviderSettings(ots_instance_name="test-instance") + rs = OTSRetrieveSettings(search_type=["DENSE_VECTOR"]) + + api = get_data_api( + provider=KnowledgeBaseProvider.OTS, + knowledge_base_name="test-kb", + provider_settings=ps, + retrieve_settings=rs, + ) + assert isinstance(api, OTSDataAPI) + assert api.provider_settings is ps + assert api.retrieve_settings is rs + + def test_get_data_api_ots_without_settings(self): + """测试获取无设置的 OTS DataAPI""" + api = get_data_api( + provider=KnowledgeBaseProvider.OTS, + knowledge_base_name="test-kb", + ) + assert isinstance(api, OTSDataAPI) + assert api.provider_settings is None + assert api.retrieve_settings is None + + def test_get_data_api_ots_wrong_settings_type(self): + """测试传入非 OTS 类型的 settings""" + from agentrun.knowledgebase.model import RagFlowProviderSettings + + api = get_data_api( + provider=KnowledgeBaseProvider.OTS, + knowledge_base_name="test-kb", + provider_settings=RagFlowProviderSettings( + base_url="http://example.com", + dataset_ids=["ds-1"], + ), + ) + assert isinstance(api, OTSDataAPI) + assert api.provider_settings is None + + +# ============================================================================= +# KnowledgeBase._get_data_api OTS 分支测试 +# ============================================================================= + + +class TestKnowledgeBaseGetDataAPIOTS: + """测试 KnowledgeBase._get_data_api 的 OTS 分支""" + + def test_get_data_api_ots_with_typed_settings(self): + """测试 OTS 使用类型化设置""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance", + ), + retrieve_settings=OTSRetrieveSettings( + search_type=["DENSE_VECTOR"], + ), + ) + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + assert data_api.provider_settings.ots_instance_name == "test-instance" + assert data_api.retrieve_settings.search_type == ["DENSE_VECTOR"] + + def test_get_data_api_ots_with_camelcase_dict(self): + """测试 OTS 使用 camelCase dict 设置""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__( + kb, + "provider_settings", + { + "otsInstanceName": "jingsu-ots-test", + "tags": ["demo", "test"], + "metadata": [ + {"name": "author", "type": "string"}, + {"name": "date", "type": "date"}, + ], + "embeddingConfiguration": { + "provider": "bailian", + "model": "text-embedding-v3", + "dimension": 1024, + }, + }, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "searchType": ["DENSE_VECTOR", "FULL_TEXT"], + "denseVectorSearchConfiguration": {"numberOfResults": 20}, + "fullTextSearchConfiguration": {"numberOfResults": 20}, + "rerankingConfiguration": { + "type": "RRF", + "numberOfResults": 10, + "rrfConfiguration": { + "denseVectorSearchWeight": 1.0, + "fullTextSearchWeight": 1.0, + "k": 60, + }, + }, + }, + ) + + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + + ps = data_api.provider_settings + assert ps.ots_instance_name == "jingsu-ots-test" + assert ps.tags == ["demo", "test"] + assert len(ps.metadata) == 2 + assert ps.metadata[0].name == "author" + assert ps.embedding_configuration.provider == "bailian" + assert ps.embedding_configuration.dimension == 1024 + + rs = data_api.retrieve_settings + assert rs.search_type == ["DENSE_VECTOR", "FULL_TEXT"] + assert rs.dense_vector_search_configuration.number_of_results == 20 + assert rs.full_text_search_configuration.number_of_results == 20 + assert rs.reranking_configuration.type == "RRF" + assert rs.reranking_configuration.rrf_configuration.k == 60 + + def test_get_data_api_ots_with_weight_reranking_dict(self): + """测试 OTS camelCase dict 带 Weight 重排序""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__( + kb, + "provider_settings", + {"otsInstanceName": "test-instance"}, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "searchType": ["DENSE_VECTOR", "FULL_TEXT"], + "rerankingConfiguration": { + "type": "WEIGHT", + "numberOfResults": 10, + "weightConfiguration": { + "denseVectorSearchWeight": 0.7, + "fullTextSearchWeight": 0.3, + }, + }, + }, + ) + + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + rs = data_api.retrieve_settings + assert rs.reranking_configuration.type == "WEIGHT" + assert ( + rs.reranking_configuration.weight_configuration.dense_vector_search_weight + == 0.7 + ) + + def test_get_data_api_ots_with_model_reranking_dict(self): + """测试 OTS camelCase dict 带 Model 重排序""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__( + kb, + "provider_settings", + {"otsInstanceName": "test-instance"}, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "rerankingConfiguration": { + "type": "MODEL", + "numberOfResults": 10, + "modelConfiguration": { + "provider": "bailian", + "model": "gte-rerank-v2", + }, + }, + }, + ) + + data_api = kb._get_data_api() + rs = data_api.retrieve_settings + assert rs.reranking_configuration.type == "MODEL" + assert ( + rs.reranking_configuration.model_configuration.provider == "bailian" + ) + + def test_get_data_api_ots_with_filter_dict(self): + """测试 OTS camelCase dict 带 filter""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__( + kb, + "provider_settings", + {"otsInstanceName": "test-instance"}, + ) + object.__setattr__( + kb, + "retrieve_settings", + { + "searchType": ["DENSE_VECTOR"], + "filter": { + "andAll": [{"equals": {"key": "author", "value": "test"}}] + }, + }, + ) + + data_api = kb._get_data_api() + rs = data_api.retrieve_settings + assert rs.filter is not None + assert rs.filter["andAll"][0]["equals"]["key"] == "author" + + def test_get_data_api_ots_minimal_dict(self): + """测试 OTS 最小化 camelCase dict""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__( + kb, + "provider_settings", + {"otsInstanceName": "test-instance"}, + ) + + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + assert data_api.provider_settings.ots_instance_name == "test-instance" + assert data_api.retrieve_settings is None + + def test_get_data_api_ots_without_settings(self): + """测试 OTS 无设置""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + assert data_api.provider_settings is None + assert data_api.retrieve_settings is None + + def test_get_data_api_ots_with_invalid_provider_settings_type(self): + """测试 OTS 使用无效类型的 provider_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__(kb, "provider_settings", "invalid") + + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + assert data_api.provider_settings is None + + def test_get_data_api_ots_with_invalid_retrieve_settings_type(self): + """测试 OTS 使用无效类型的 retrieve_settings""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + object.__setattr__(kb, "retrieve_settings", "invalid") + + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + assert data_api.provider_settings is not None + assert data_api.retrieve_settings is None + + def test_get_data_api_ots_with_string_provider(self): + """测试使用字符串 provider='ots'""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider="ots", + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + data_api = kb._get_data_api() + assert isinstance(data_api, OTSDataAPI) + + def test_get_data_api_ots_embedding_with_api_key(self): + """测试 camelCase dict 带 apiKey 的 embeddingConfiguration""" + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + ) + object.__setattr__( + kb, + "provider_settings", + { + "otsInstanceName": "test-instance", + "embeddingConfiguration": { + "provider": "custom", + "model": "my-model", + "dimension": 512, + "url": "https://custom.embedding.com", + "apiKey": "secret-key", + }, + }, + ) + + data_api = kb._get_data_api() + ec = data_api.provider_settings.embedding_configuration + assert ec.provider == "custom" + assert ec.url == "https://custom.embedding.com" + assert ec.api_key == "secret-key" + + +# ============================================================================= +# KnowledgeBase.retrieve OTS 测试 +# ============================================================================= + + +class TestKnowledgeBaseRetrieveOTS: + """测试 KnowledgeBase.retrieve 的 OTS 分支""" + + @patch("agentrun.knowledgebase.api.data.OTSDataAPI.retrieve") + def test_retrieve_ots_sync(self, mock_retrieve): + """测试 OTS 同步检索""" + mock_retrieve.return_value = { + "data": [{"content": "ots result", "score": 0.9}], + "query": "test query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + + result = kb.retrieve("test query") + assert result["query"] == "test query" + assert result["data"][0]["content"] == "ots result" + + @patch("agentrun.knowledgebase.api.data.OTSDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_retrieve_ots_async(self, mock_retrieve_async): + """测试 OTS 异步检索""" + mock_retrieve_async.return_value = { + "data": [{"content": "ots async result", "score": 0.85}], + "query": "async query", + "knowledge_base_name": "test-kb", + } + + kb = KnowledgeBase( + knowledge_base_name="test-kb", + provider=KnowledgeBaseProvider.OTS, + provider_settings=OTSProviderSettings( + ots_instance_name="test-instance" + ), + ) + + result = await kb.retrieve_async("async query") + assert result["query"] == "async query" + assert result["data"][0]["content"] == "ots async result" + + +# ============================================================================= +# OTS multi_retrieve 测试 +# ============================================================================= + + +class MockOTSKnowledgeBaseData: + """模拟 OTS 知识库数据""" + + def to_map(self): + return { + "knowledgeBaseId": "kb-ots-001", + "knowledgeBaseName": "test-ots-kb", + "provider": "ots", + "description": "Test OTS knowledge base", + "providerSettings": { + "otsInstanceName": "test-instance", + }, + "retrieveSettings": { + "searchType": ["DENSE_VECTOR"], + }, + "createdAt": "2024-01-01T00:00:00Z", + "lastUpdatedAt": "2024-01-01T00:00:00Z", + } + + +class TestKnowledgeBaseMultiRetrieveOTS: + """测试 OTS 参与 multi_retrieve""" + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.OTSDataAPI.retrieve") + def test_multi_retrieve_with_ots( + self, mock_retrieve, mock_control_api_class + ): + """测试 OTS 参与同步多知识库检索""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base.return_value = ( + MockOTSKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_retrieve.return_value = { + "data": [{"content": "ots content"}], + "query": "test query", + "knowledge_base_name": "test-ots-kb", + } + + result = KnowledgeBase.multi_retrieve( + query="test query", + knowledge_base_names=["test-ots-kb"], + ) + + assert "results" in result + assert "test-ots-kb" in result["results"] + assert ( + result["results"]["test-ots-kb"]["data"][0]["content"] + == "ots content" + ) + + @patch("agentrun.knowledgebase.client.KnowledgeBaseControlAPI") + @patch("agentrun.knowledgebase.api.data.OTSDataAPI.retrieve_async") + @pytest.mark.asyncio + async def test_multi_retrieve_async_with_ots( + self, mock_retrieve_async, mock_control_api_class + ): + """测试 OTS 参与异步多知识库检索""" + mock_control_api = MagicMock() + mock_control_api.get_knowledge_base_async = AsyncMock( + return_value=MockOTSKnowledgeBaseData() + ) + mock_control_api_class.return_value = mock_control_api + + mock_retrieve_async.return_value = { + "data": [{"content": "ots async content"}], + "query": "test query", + "knowledge_base_name": "test-ots-kb", + } + + result = await KnowledgeBase.multi_retrieve_async( + query="test query", + knowledge_base_names=["test-ots-kb"], + ) + + assert "results" in result + assert "test-ots-kb" in result["results"] + + def test_from_inner_object_ots(self): + """测试从内部对象创建 OTS 知识库""" + mock_data = MockOTSKnowledgeBaseData() + kb = KnowledgeBase.from_inner_object(mock_data) + + assert kb.knowledge_base_id == "kb-ots-001" + assert kb.knowledge_base_name == "test-ots-kb" + assert kb.provider == "ots"