Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions agentrun/knowledgebase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
get_data_api,
KnowledgeBaseControlAPI,
KnowledgeBaseDataAPI,
OTSDataAPI,
RagFlowDataAPI,
)
from .client import KnowledgeBaseClient
Expand All @@ -20,6 +21,16 @@
KnowledgeBaseListOutput,
KnowledgeBaseProvider,
KnowledgeBaseUpdateInput,
OTSDenseVectorSearchConfig,
OTSEmbeddingConfiguration,
OTSFullTextSearchConfig,
OTSMetadataField,
OTSModelConfig,
OTSProviderSettings,
OTSRerankingConfig,
OTSRetrieveSettings,
OTSRRFConfig,
OTSWeightConfig,
ProviderSettings,
RagFlowProviderSettings,
RagFlowRetrieveSettings,
Expand All @@ -37,6 +48,7 @@
"RagFlowDataAPI",
"BailianDataAPI",
"ADBDataAPI",
"OTSDataAPI",
"get_data_api",
# enums
"KnowledgeBaseProvider",
Expand All @@ -45,11 +57,21 @@
"RagFlowProviderSettings",
"BailianProviderSettings",
"ADBProviderSettings",
"OTSProviderSettings",
"OTSMetadataField",
"OTSEmbeddingConfiguration",
# retrieve settings
"RetrieveSettings",
"RagFlowRetrieveSettings",
"BailianRetrieveSettings",
"ADBRetrieveSettings",
"OTSRetrieveSettings",
"OTSDenseVectorSearchConfig",
"OTSFullTextSearchConfig",
"OTSRerankingConfig",
"OTSRRFConfig",
"OTSWeightConfig",
"OTSModelConfig",
# api model
"KnowledgeBaseCreateInput",
"KnowledgeBaseUpdateInput",
Expand Down
123 changes: 119 additions & 4 deletions agentrun/knowledgebase/__knowledgebase_async_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@
KnowledgeBaseProvider,
KnowledgeBaseSystemProps,
KnowledgeBaseUpdateInput,
OTSDenseVectorSearchConfig,
OTSEmbeddingConfiguration,
OTSFullTextSearchConfig,
OTSMetadataField,
OTSModelConfig,
OTSProviderSettings,
OTSRerankingConfig,
OTSRetrieveSettings,
OTSRRFConfig,
OTSWeightConfig,
RagFlowProviderSettings,
RagFlowRetrieveSettings,
RetrieveInput,
Expand Down Expand Up @@ -302,8 +312,6 @@ def _get_data_api(self, config: Optional[Config] = None):
if isinstance(self.provider_settings, ADBProviderSettings):
converted_provider_settings = self.provider_settings
elif isinstance(self.provider_settings, dict):
# ADB provider_settings 使用 PascalCase 键名,需要转换为 snake_case
# ADB provider_settings uses PascalCase keys, need to convert to snake_case
converted_provider_settings = ADBProviderSettings(
db_instance_id=self.provider_settings.get(
"DBInstanceId", ""
Expand All @@ -323,8 +331,6 @@ def _get_data_api(self, config: Optional[Config] = None):
if isinstance(self.retrieve_settings, ADBRetrieveSettings):
converted_retrieve_settings = self.retrieve_settings
elif isinstance(self.retrieve_settings, dict):
# ADB retrieve_settings 使用 PascalCase 键名,需要转换为 snake_case
# ADB retrieve_settings uses PascalCase keys, need to convert to snake_case
converted_retrieve_settings = ADBRetrieveSettings(
top_k=self.retrieve_settings.get("TopK"),
use_full_text_retrieval=self.retrieve_settings.get(
Expand All @@ -344,6 +350,115 @@ def _get_data_api(self, config: Optional[Config] = None):
),
)

elif provider == KnowledgeBaseProvider.OTS:
# OTS 设置 / OTS settings (camelCase → snake_case)
if self.provider_settings:
if isinstance(self.provider_settings, OTSProviderSettings):
converted_provider_settings = self.provider_settings
elif isinstance(self.provider_settings, dict):
ps = self.provider_settings

metadata = None
raw_metadata = ps.get("metadata")
if raw_metadata and isinstance(raw_metadata, list):
metadata = [
OTSMetadataField(
name=m.get("name", ""),
type=m.get("type", ""),
)
for m in raw_metadata
]

embedding_config = None
raw_ec = ps.get("embeddingConfiguration")
if raw_ec and isinstance(raw_ec, dict):
embedding_config = OTSEmbeddingConfiguration(
provider=raw_ec.get("provider", ""),
model=raw_ec.get("model", ""),
dimension=raw_ec.get("dimension", 0),
url=raw_ec.get("url"),
api_key=raw_ec.get("apiKey"),
)

converted_provider_settings = OTSProviderSettings(
ots_instance_name=ps.get("otsInstanceName", ""),
tags=ps.get("tags"),
metadata=metadata,
embedding_configuration=embedding_config,
)

if self.retrieve_settings:
if isinstance(self.retrieve_settings, OTSRetrieveSettings):
converted_retrieve_settings = self.retrieve_settings
elif isinstance(self.retrieve_settings, dict):
rs = self.retrieve_settings

dvsc = None
raw_dvsc = rs.get("denseVectorSearchConfiguration")
if raw_dvsc and isinstance(raw_dvsc, dict):
dvsc = OTSDenseVectorSearchConfig(
number_of_results=raw_dvsc.get("numberOfResults"),
)

ftsc = None
raw_ftsc = rs.get("fullTextSearchConfiguration")
if raw_ftsc and isinstance(raw_ftsc, dict):
ftsc = OTSFullTextSearchConfig(
number_of_results=raw_ftsc.get("numberOfResults"),
)

reranking = None
raw_rr = rs.get("rerankingConfiguration")
if raw_rr and isinstance(raw_rr, dict):
rrf_config = None
raw_rrf = raw_rr.get("rrfConfiguration")
if raw_rrf and isinstance(raw_rrf, dict):
rrf_config = OTSRRFConfig(
dense_vector_search_weight=raw_rrf.get(
"denseVectorSearchWeight"
),
full_text_search_weight=raw_rrf.get(
"fullTextSearchWeight"
),
k=raw_rrf.get("k"),
)

weight_config = None
raw_wc = raw_rr.get("weightConfiguration")
if raw_wc and isinstance(raw_wc, dict):
weight_config = OTSWeightConfig(
dense_vector_search_weight=raw_wc.get(
"denseVectorSearchWeight"
),
full_text_search_weight=raw_wc.get(
"fullTextSearchWeight"
),
)

model_config = None
raw_mc = raw_rr.get("modelConfiguration")
if raw_mc and isinstance(raw_mc, dict):
model_config = OTSModelConfig(
provider=raw_mc.get("provider"),
model=raw_mc.get("model"),
)

reranking = OTSRerankingConfig(
type=raw_rr.get("type"),
number_of_results=raw_rr.get("numberOfResults"),
rrf_configuration=rrf_config,
weight_configuration=weight_config,
model_configuration=model_config,
)

converted_retrieve_settings = OTSRetrieveSettings(
search_type=rs.get("searchType"),
dense_vector_search_configuration=dvsc,
full_text_search_configuration=ftsc,
reranking_configuration=reranking,
filter=rs.get("filter"),
)

return get_data_api(
provider=provider,
knowledge_base_name=self.knowledge_base_name or "",
Expand Down
Loading
Loading