diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..bc07a72 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,92 @@ +# Changelog + +## [1.3.0] + +### Added +- **Admin configuration UI** with role-based access for DB, LLM, and GraphRAG settings + - Separate pages for DB config, LLM provider config, and GraphRAG config + - Graph admin role restriction via `ConfigScopeToggle` + - `apiToken` auth option added to GraphDB config with conditional UI +- **Per-graph chatbot LLM override** (`chat_service` in `llm_config`) with inheritance from `completion_service` + - Missing keys fall back to `completion_service` automatically + - Graph admins can configure per graph via the UI +- **Secret masking** in configuration API responses + - GET responses return masked values; backend substitutes on save/test + - Credentials never reach the frontend +- **Session idle timeout** (1 hour) that auto-clears the session on inactivity + - Session data moved from `localStorage` to `sessionStorage`; theme stays in `localStorage` + - Timer pauses during long-running operations (ingest, rebuild) +- **Auth guard** on all UI routes + - `RequireAuth` wrapper redirects unauthenticated users to login + - SPA routing with `serve -s` and catch-all route +- **GraphRAG config UI fields** + - Search parameters: `top_k`, `num_hops`, `num_seen_min`, `community_level`, `doc_only` + - Advanced ingestion settings: `load_batch_size`, `upsert_delay`, `default_concurrency` + - All chunker settings (chunk_size, overlap_size, method, threshold, pattern) shown and saved regardless of selected chunker +- **Multimodal inherit checkbox** in LLM config UI + - "Use same model as completion service" option in both single and multi-provider modes + - Amber warning when inheriting: "Ensure your completion model supports vision input" +- **`get_embedding_config()`** getter in `common/config.py` for parity with other service getters +- **Greeting detection** in agent router + - Regex-based pattern matching for common greetings, farewells, and thanks + - Responds directly without invoking query generation or search +- **Centralized LLM token usage tracking** + - All LLM call sites (15+) migrated to `invoke_with_parser` / `ainvoke_with_parser` + - Supports both structured (JSON) and plain text LLM responses +- **JSON parsing fallback** for LLM responses + - Handles responses wrapped in preamble text or markdown code fences + - Entity extraction uses a 3-tier fallback: direct parse, code fence extraction, regex extraction +- **Cypher/GSQL output validation** before query execution + - Checks for required query keywords before wrapping in `INTERPRET OPENCYPHER QUERY` + - Invalid output raises an error and retries instead of executing garbage queries +- **Retriever scoring** for all retriever types when `combine=False` + - Scoring logic lifted from `CommunityRetriever` into `BaseRetriever` + - Similarity, Hybrid, and Sibling retrievers now score and rank context chunks +- **User-customized prompts** persisted under `configs/` across container restarts +- **Unit tests** for LLM invocation and JSON parsing (13 test cases) + +### Changed +- **All config consumers use `get_xxx_config(graphname)` getters** instead of direct `llm_config` access + - `root.py`, `report-service/root.py`, `ecc/main.py`, `ui.py` migrated + - Test connection and save endpoints use `_build_test_config()` overlay pattern + - `_unmask_auth` resolves credentials via getters for correct per-graph resolution +- **Multimodal service inherits completion model directly** when not explicitly configured + - Removed hardcoded `DEFAULT_MULTIMODAL_MODELS` that silently substituted different models +- **LLM config UI improvements** + - Red asterisk markers on mandatory model name fields + - Shared `LLM_PROVIDERS` constant replaces duplicate provider lists + - State synced when toggling between single/multi-provider modes + - Reordered sections: Completion → Chatbot → Multimodal → Embedding +- Config file writes are now atomic with file locking to prevent race conditions + - `_config_file_lock` prevents concurrent overwrites + - In-memory config updates use atomic dict replacement instead of clear-and-update +- Chat history messages display instantly without typewriter animation + - History messages tagged with `response_type: "history"` to skip CSS animation +- Chatbot model selection uses `chat_service` config with `completion_service` fallback + - Community summarization prompt loaded at call time instead of import time +- README config documentation updated for clarity and consistency + - Parameter descriptions focus on purpose, not implementation details + - `token_limit`, `default_concurrency`, and other parameters reworded + - `multimodal_service` defaults corrected to show inheritance from `completion_service` +- `default_concurrency` replaces `tg_concurrency` in `graphrag_config` + - Configurable per graph +- Wired up `default_mem_threshold` and `default_thread_limit` in database connection proxy + +### Fixed +- **Bedrock multimodal connection test** — 1x1 test PNG rejected by Bedrock image validation; replaced with 20x20 PNG +- **Provider-aware image format** in multimodal test and `image_data_extractor` + - GenAI/VertexAI require `image_url` format; Bedrock/Anthropic use `type:"image"` with source block +- **report-service/root.py** — `llm_config` used but never imported (NameError on health endpoint) +- **Null service values** stripped before config reload (null = inherit, key should be absent) +- Login page shows proper error messages based on HTTP status + - 401/403: "Invalid credentials"; other errors: "Server error (N)"; network failure: "Unable to connect" +- SPA routing fixed with catch-all route to login page +- Rebuild dialog button no longer flickers between status labels + - Polling stops once rebuild completes; final status message preserved +- Idle timer pauses during long-running operations (ingest, rebuild) + - Uses pause/resume instead of repeated signal activity calls +- Bedrock model names no longer trigger token calculator warnings + - Provider prefix and version suffix stripped before tiktoken lookup +- Config reload no longer clears in-memory state during concurrent requests +- Startup validation restored for `llm_service` and `llm_model` +- `HTTPException` properly re-raised in config and DB test endpoints diff --git a/README.md b/README.md index a6261f5..e49317b 100644 --- a/README.md +++ b/README.md @@ -469,14 +469,19 @@ Copy the below code into `configs/server_config.json`. You shouldn’t need to c | `chat_history_api` | string | `"http://chat-history:8002"` | URL of the chat history service. No change needed when using the provided Docker Compose file. | | `chunker` | string | `"semantic"` | Default document chunker. Options: `semantic`, `character`, `regex`, `markdown`, `html`, `recursive`. | | `extractor` | string | `"llm"` | Entity extraction method. Options: `llm`, `graphrag`. | -| `chunker_config` | object | `{}` | Chunker-specific settings. For `character`/`markdown`/`recursive`: `chunk_size`, `overlap_size`. For `semantic`: `method`, `threshold`. For `regex`: `pattern`. | -| `top_k` | int | `5` | Number of top similar results to retrieve during search. | -| `num_hops` | int | `2` | Number of graph hops to traverse when expanding retrieved results. | -| `num_seen_min` | int | `2` | Minimum occurrence threshold for a node to be included in search results. | -| `community_level` | int | `2` | Community hierarchy level used for community search. | -| `chunk_only` | bool | `true` | If true, hybrid search only retrieves document chunks (not entities). | -| `doc_only` | bool | `false` | If true, hybrid search retrieves whole documents instead of chunks. | -| `with_chunk` | bool | `true` | If true, community search also includes document chunks in results. | +| `chunker_config` | object | `{}` | Chunker-specific settings (see sub-parameters below). All settings are saved regardless of which chunker is selected as default. | +| ↳ `chunk_size` | int | `2048` | Maximum number of characters per chunk. Used by `character`, `markdown`, `html`, and `recursive` chunkers. Larger values produce fewer, bigger chunks; smaller values produce more, finer-grained chunks. | +| ↳ `overlap_size` | int | 1/8 of `chunk_size` | Number of overlapping characters between consecutive chunks. Used by `character`, `markdown`, `html`, and `recursive` chunkers. More overlap preserves cross-chunk context but increases total chunk count. Set to `0` for no overlap. | +| ↳ `method` | string | `"percentile"` | Breakpoint detection method for the `semantic` chunker. Options: `percentile`, `standard_deviation`, `interquartile`, `gradient`. Controls how the chunker decides where to split based on embedding similarity. | +| ↳ `threshold` | float | `0.95` | Similarity threshold for the `semantic` chunker. Higher values produce more splits (smaller chunks); lower values produce fewer splits (larger chunks). | +| ↳ `pattern` | string | `""` | Regular expression pattern for the `regex` chunker. The document is split at each match of this pattern. | +| `top_k` | int | `5` | Number of initial seed results to retrieve per search. Also caps the final scored results. Increasing `top_k` increases the overall context size sent to the LLM. | +| `num_hops` | int | `2` | Number of graph hops to traverse from seed nodes during hybrid search. More hops expand the result set with related context. | +| `num_seen_min` | int | `2` | Minimum occurrence count for a node to be included during hybrid search traversal. Higher values filter out loosely connected nodes, reducing context size. | +| `community_level` | int | `2` | Community hierarchy level for community search. Higher levels retrieve broader, higher-order community summaries. | +| `chunk_only` | bool | `true` | If true, hybrid search only retrieves document chunks, excluding entity data. | +| `doc_only` | bool | `false` | If true, hybrid search retrieves whole documents instead of chunks. Significantly increases context size. | +| `with_chunk` | bool | `true` | If true, community search also includes document chunks alongside community summaries. Increases context size. | | `doc_process_switch` | bool | `true` | Enable/disable document processing during knowledge graph build. | | `entity_extraction_switch` | bool | same as `doc_process_switch` | Enable/disable entity extraction during knowledge graph build. | | `community_detection_switch` | bool | same as `entity_extraction_switch` | Enable/disable community detection during knowledge graph build. | @@ -552,7 +557,7 @@ In the `llm_config` section of `configs/server_config.json` file, copy JSON conf | Parameter | Type | Default | Description | | --- | --- | --- | --- | | `authentication_configuration` | object | — | Shared authentication credentials for all services. Service-level values take precedence. | -| `token_limit` | int | — | Maximum token count for retrieved context. Inherited by all services if not set at service level. `0` or omitted means unlimited. | +| `token_limit` | int | — | Hard cap on token count for retrieved context sent to the LLM. Context exceeding this limit is truncated. Inherited by all services if not set at service level. `0` or omitted means unlimited. | **`completion_service` parameters:** @@ -564,7 +569,7 @@ In the `llm_config` section of `configs/server_config.json` file, copy JSON conf | `model_kwargs` | object | No | `{}` | Additional model parameters (e.g., `{"temperature": 0}`). | | `prompt_path` | string | No | `"./common/prompts/openai_gpt4/"` | Path to prompt template files. | | `base_url` | string | No | — | Custom API endpoint URL. | -| `token_limit` | int | No | inherited from top-level | Max token count for retrieved context sent to the LLM. `0` or omitted means unlimited. | +| `token_limit` | int | No | inherited from top-level | Hard cap on token count for retrieved context sent to the LLM. Context exceeding this limit is truncated. `0` or omitted means unlimited. | **`embedding_service` parameters:** @@ -587,16 +592,16 @@ Chatbot LLM override. If not configured, inherits from `completion_service`. Con | `model_kwargs` | object | No | inherited from completion | Additional model parameters (e.g., `{"temperature": 0}`). | | `prompt_path` | string | No | inherited from completion | Path to prompt template files. | | `base_url` | string | No | inherited from completion | Custom API endpoint URL. | -| `token_limit` | int | No | inherited from completion | Max token count for retrieved context sent to the chatbot LLM. `0` or omitted means unlimited. | +| `token_limit` | int | No | inherited from completion | Hard cap on token count for retrieved context sent to the chatbot LLM. Context exceeding this limit is truncated. `0` or omitted means unlimited. | **`multimodal_service` parameters (optional):** -Vision model for image processing during document ingestion. If not configured, inherits from `completion_service` with a default vision model derived per provider. +Vision model for image processing during document ingestion. If not configured, inherits from `completion_service` — ensure the completion model supports vision input. | Parameter | Type | Required | Default | Description | | --- | --- | --- | --- | --- | | `llm_service` | string | No | inherited from completion | Multimodal LLM provider. | -| `llm_model` | string | No | auto-derived per provider | Vision model name (e.g., `gpt-4o`). | +| `llm_model` | string | No | inherited from completion | Vision model name (e.g., `gpt-4o`). | | `authentication_configuration` | object | No | inherited from completion | Service-specific auth credentials. Overrides top-level values. | | `model_kwargs` | object | No | inherited from completion | Additional model parameters. | | `prompt_path` | string | No | inherited from completion | Path to prompt template files. | diff --git a/common/chunkers/character_chunker.py b/common/chunkers/character_chunker.py index 6d4138a..abf2480 100644 --- a/common/chunkers/character_chunker.py +++ b/common/chunkers/character_chunker.py @@ -1,12 +1,12 @@ from common.chunkers.base_chunker import BaseChunker -_DEFAULT_FALLBACK_SIZE = 4096 +_DEFAULT_CHUNK_SIZE = 2048 class CharacterChunker(BaseChunker): - def __init__(self, chunk_size=0, overlap_size=0): - self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE - self.overlap_size = overlap_size + def __init__(self, chunk_size=0, overlap_size=-1): + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_CHUNK_SIZE + self.overlap_size = overlap_size if overlap_size >= 0 else self.chunk_size // 8 def chunk(self, input_string): if self.chunk_size <= self.overlap_size: diff --git a/common/chunkers/html_chunker.py b/common/chunkers/html_chunker.py index 326dff8..83b3477 100644 --- a/common/chunkers/html_chunker.py +++ b/common/chunkers/html_chunker.py @@ -20,7 +20,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter -_DEFAULT_FALLBACK_SIZE = 4096 +_DEFAULT_CHUNK_SIZE = 2048 class HTMLChunker(BaseChunker): @@ -30,7 +30,7 @@ class HTMLChunker(BaseChunker): - Automatically detects which headers (h1-h6) are present in the HTML - Uses only the headers that exist in the document for optimal chunking - If custom headers are provided, uses those instead of auto-detection - - Supports chunk_size / chunk_overlap: when chunk_size > 0, oversized + - Supports chunk_size / overlap_size: when chunk_size > 0, oversized header-based chunks are further split with RecursiveCharacterTextSplitter - When chunk_size is 0 (default), a fallback of 4096 is used so that headerless HTML documents are still split into reasonable chunks @@ -39,11 +39,11 @@ class HTMLChunker(BaseChunker): def __init__( self, chunk_size: int = 0, - chunk_overlap: int = 0, + overlap_size: int = -1, headers: Optional[List[Tuple[str, str]]] = None, ): - self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE - self.chunk_overlap = chunk_overlap + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_CHUNK_SIZE + self.overlap_size = overlap_size if overlap_size >= 0 else self.chunk_size // 8 self.headers = headers def _detect_headers(self, html_content: str) -> List[Tuple[str, str]]: @@ -96,7 +96,7 @@ def chunk(self, input_string: str) -> List[str]: recursive_splitter = RecursiveCharacterTextSplitter( separators=TEXT_SEPARATORS, chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, + chunk_overlap=self.overlap_size, ) final_chunks = [] for chunk in initial_chunks: diff --git a/common/chunkers/markdown_chunker.py b/common/chunkers/markdown_chunker.py index 2d4c4ce..85c1a82 100644 --- a/common/chunkers/markdown_chunker.py +++ b/common/chunkers/markdown_chunker.py @@ -20,18 +20,18 @@ # When chunk_size is not configured, cap any heading-section that exceeds this # so that form-based PDFs (tables/bold but no # headings) are not left as a # single multi-thousand-character chunk. -_DEFAULT_FALLBACK_SIZE = 4096 +_DEFAULT_CHUNK_SIZE = 2048 class MarkdownChunker(BaseChunker): - + def __init__( self, chunk_size: int = 0, - chunk_overlap: int = 0 + overlap_size: int = -1 ): - self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE - self.chunk_overlap = chunk_overlap + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_CHUNK_SIZE + self.overlap_size = overlap_size if overlap_size >= 0 else self.chunk_size // 8 def chunk(self, input_string): md_splitter = ExperimentalMarkdownSyntaxTextSplitter() @@ -46,7 +46,7 @@ def chunk(self, input_string): recursive_splitter = RecursiveCharacterTextSplitter( separators=TEXT_SEPARATORS, chunk_size=self.chunk_size, - chunk_overlap=self.chunk_overlap, + chunk_overlap=self.overlap_size, ) md_chunks = [] for chunk in initial_chunks: diff --git a/common/chunkers/recursive_chunker.py b/common/chunkers/recursive_chunker.py index 4c8a324..69ee83a 100644 --- a/common/chunkers/recursive_chunker.py +++ b/common/chunkers/recursive_chunker.py @@ -16,13 +16,13 @@ from common.chunkers.separators import TEXT_SEPARATORS from langchain.text_splitter import RecursiveCharacterTextSplitter -_DEFAULT_FALLBACK_SIZE = 4096 +_DEFAULT_CHUNK_SIZE = 2048 class RecursiveChunker(BaseChunker): - def __init__(self, chunk_size=0, overlap_size=0): - self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_FALLBACK_SIZE - self.overlap_size = overlap_size + def __init__(self, chunk_size=0, overlap_size=-1): + self.chunk_size = chunk_size if chunk_size > 0 else _DEFAULT_CHUNK_SIZE + self.overlap_size = overlap_size if overlap_size >= 0 else self.chunk_size // 8 def chunk(self, input_string): text_splitter = RecursiveCharacterTextSplitter( diff --git a/common/config.py b/common/config.py index 371e303..3dc3be1 100644 --- a/common/config.py +++ b/common/config.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import json import logging import os @@ -122,6 +123,73 @@ def _resolve_service_config(base_config, override=None): return result +def resolve_llm_services(llm_cfg: dict) -> dict: + """ + Resolve per-service configs from an llm_config dict. + + Applies the same resolution chain as the get_xxx_config() getters but + operates on the provided dict instead of the global llm_config. This + allows both the on-disk config and a candidate config (from UI payload) + to be resolved with the same logic. + + Resolution: + 1. Inject top-level authentication_configuration into each service + 2. completion_service / embedding_service: used as-is + 3. chat_service / multimodal_service: completion_service base + overrides + + When chat_service or multimodal_service is absent, the resolved config + falls back to completion_service (inherit). + + Returns dict with keys: completion_service, embedding_service, + chat_service, multimodal_service — each a fully resolved config. + """ + # Work on deep copies to avoid mutating the input + cfg = copy.deepcopy(llm_cfg) + + # Inject top-level auth into service configs (same as reload_llm_config) + top_auth = cfg.get("authentication_configuration", {}) + if top_auth: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + if svc_key in cfg: + svc = cfg[svc_key] + if "authentication_configuration" not in svc: + svc["authentication_configuration"] = top_auth.copy() + else: + merged = top_auth.copy() + merged.update(svc["authentication_configuration"]) + svc["authentication_configuration"] = merged + + # Inject top-level region_name into service configs if missing + top_region = cfg.get("region_name") + if top_region: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + if svc_key in cfg and "region_name" not in cfg[svc_key]: + cfg[svc_key]["region_name"] = top_region + + completion = cfg.get("completion_service", {}) + + # Resolve embedding: inherit provider-level config from completion + # when the embedding provider matches the completion provider. + # (embedding has a different schema — model_name vs llm_model — + # so we only inherit shared provider fields like region_name.) + embedding = cfg.get("embedding_service", {}).copy() + embedding_provider = embedding.get("embedding_model_service", "").lower() + completion_provider = completion.get("llm_service", "").lower() + if embedding_provider and embedding_provider == completion_provider: + # Identity/schema keys that belong to the embedding service itself + embedding_own_keys = {"embedding_model_service", "model_name", "authentication_configuration", "token_limit"} + for k, v in completion.items(): + if k not in embedding_own_keys and k not in embedding: + embedding[k] = v + + return { + "completion_service": completion.copy(), + "embedding_service": embedding, + "chat_service": _resolve_service_config(completion, cfg.get("chat_service")), + "multimodal_service": _resolve_service_config(completion, cfg.get("multimodal_service")), + } + + def get_completion_config(graphname=None): """ Return completion_service config for the given graph. @@ -142,13 +210,24 @@ def get_completion_config(graphname=None): return result -DEFAULT_MULTIMODAL_MODELS = { - "openai": "gpt-4o-mini", - "azure": "gpt-4o-mini", - "genai": "gemini-3.5-flash", - "vertexai": "gemini-3.5-flash", - "bedrock": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", -} +def get_embedding_config(graphname=None): + """ + Return embedding_service config for the given graph. + + Resolution: merge graph-specific embedding_service overrides on top of + global embedding_service. Graph configs only store overrides, so unchanged + fields always inherit the latest global values. + """ + graph_llm = _load_graph_llm_config(graphname) + override = graph_llm.get("embedding_service") + if override: + logger.debug(f"[get_embedding_config] graph={graphname} using graph-specific overrides") + result = _resolve_service_config(llm_config["embedding_service"], override) + + if graphname: + result["graphname"] = graphname + + return result def get_chat_config(graphname=None): @@ -189,21 +268,6 @@ def get_chat_config(graphname=None): return result -def _apply_default_multimodal_model(override, provider): - """Apply default vision model if llm_model is not explicitly set.""" - if override and "llm_model" not in override: - default_model = DEFAULT_MULTIMODAL_MODELS.get(provider) - if default_model: - return {**override, "llm_model": default_model} - return override - if not override: - default_model = DEFAULT_MULTIMODAL_MODELS.get(provider) - if default_model: - return {"llm_model": default_model} - return None - return override - - def get_multimodal_config(graphname=None): """ Return the multimodal/vision config for the given graph. @@ -211,9 +275,10 @@ def get_multimodal_config(graphname=None): Resolution chain: 1. Start with global completion_service 2. Merge graph-specific completion_service overrides (shared base) - 3. Merge multimodal_service overrides (graph-specific > global > default model) + 3. Merge multimodal_service overrides (graph-specific > global) - Returns the merged config, or None if the provider doesn't support vision. + When no multimodal_service override exists ("inherit"), the completion + config is returned as-is — the completion model is used for vision. """ graph_llm = _load_graph_llm_config(graphname) @@ -223,17 +288,11 @@ def get_multimodal_config(graphname=None): graph_llm.get("completion_service"), ) - # Find multimodal override: graph-specific > global > None + # Find multimodal override: graph-specific > global > None (inherit) mm_override = graph_llm.get("multimodal_service") if mm_override is None and "multimodal_service" in llm_config: mm_override = llm_config["multimodal_service"] - provider = (mm_override or {}).get("llm_service", base.get("llm_service", "")).lower() - mm_override = _apply_default_multimodal_model(mm_override, provider) - - if mm_override is None: - return None - return _resolve_service_config(base, mm_override) @@ -301,6 +360,12 @@ def get_graphrag_config(graphname=None): merged.update(svc["authentication_configuration"]) svc["authentication_configuration"] = merged +# Inject top-level region_name into service configs if missing +if "region_name" in llm_config: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + if svc_key in llm_config and "region_name" not in llm_config[svc_key]: + llm_config[svc_key]["region_name"] = llm_config["region_name"] + _comp = llm_config.get("completion_service") if _comp is None: raise Exception("completion_service is not found in llm_config") @@ -479,6 +544,12 @@ def reload_llm_config(new_llm_config: dict = None): merged.update(svc["authentication_configuration"]) svc["authentication_configuration"] = merged + # Inject top-level region_name into service configs if missing + if "region_name" in new_llm_config: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + if svc_key in new_llm_config and "region_name" not in new_llm_config[svc_key]: + new_llm_config[svc_key]["region_name"] = new_llm_config["region_name"] + new_completion_config = new_llm_config.get("completion_service") new_embedding_config = new_llm_config.get("embedding_service") diff --git a/common/utils/image_data_extractor.py b/common/utils/image_data_extractor.py index 48f9b65..711c562 100644 --- a/common/utils/image_data_extractor.py +++ b/common/utils/image_data_extractor.py @@ -7,16 +7,31 @@ logger = logging.getLogger(__name__) _multimodal_client = None +_multimodal_provider = None def _get_client(): - global _multimodal_client + global _multimodal_client, _multimodal_provider if _multimodal_client is None and get_multimodal_config(): try: - _multimodal_client = get_llm_service(get_multimodal_config()) + config = get_multimodal_config() + _multimodal_provider = config.get("llm_service", "").lower() + _multimodal_client = get_llm_service(config) except Exception: logger.warning("Failed to create multimodal LLM client") return _multimodal_client +def _build_image_content_block(image_base64: str, media_type: str) -> dict: + """Build a LangChain image content block appropriate for the configured provider.""" + if _multimodal_provider in ("genai", "vertexai"): + return { + "type": "image_url", + "image_url": {"url": f"data:{media_type};base64,{image_base64}"}, + } + return { + "type": "image", + "source": {"type": "base64", "media_type": media_type, "data": image_base64}, + } + def describe_image_with_llm(file_path): """ Read image file and convert to base64 to send to LLM. @@ -49,10 +64,7 @@ def describe_image_with_llm(file_path): "If the image has any logo, identify and describe the logo." ), }, - { - "type": "image_url", - "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, - }, + _build_image_content_block(image_base64, "image/jpeg"), ], ), ] diff --git a/ecc/app/ecc_util.py b/ecc/app/ecc_util.py index e17ce9f..a28567a 100644 --- a/ecc/app/ecc_util.py +++ b/ecc/app/ecc_util.py @@ -18,24 +18,24 @@ def get_chunker(chunker_type: str = "", graphname: str = None): ) elif chunker_type == "character": chunker = character_chunker.CharacterChunker( - chunk_size=chunker_config.get("chunk_size", 4096), - overlap_size=chunker_config.get("overlap_size", 0), + chunk_size=chunker_config.get("chunk_size", 0), + overlap_size=chunker_config.get("overlap_size", -1), ) elif chunker_type == "markdown": chunker = markdown_chunker.MarkdownChunker( chunk_size=chunker_config.get("chunk_size", 0), - chunk_overlap=chunker_config.get("overlap_size", 0), + overlap_size=chunker_config.get("overlap_size", -1), ) elif chunker_type == "html": chunker = html_chunker.HTMLChunker( chunk_size=chunker_config.get("chunk_size", 0), - chunk_overlap=chunker_config.get("overlap_size", 0), + overlap_size=chunker_config.get("overlap_size", -1), headers=chunker_config.get("headers", None), ) elif chunker_type == "recursive": chunker = recursive_chunker.RecursiveChunker( - chunk_size=chunker_config.get("chunk_size", 4096), - overlap_size=chunker_config.get("overlap_size", 0), + chunk_size=chunker_config.get("chunk_size", 0), + overlap_size=chunker_config.get("overlap_size", -1), ) elif chunker_type == "single" or chunker_type == "image": # Single chunker: NEVER splits, always returns 1 chunk diff --git a/ecc/app/main.py b/ecc/app/main.py index 0db691b..d15ac75 100644 --- a/ecc/app/main.py +++ b/ecc/app/main.py @@ -35,7 +35,6 @@ graphrag_config, embedding_service, get_llm_service, - llm_config, get_completion_config, get_graphrag_config, reload_db_config, @@ -225,7 +224,7 @@ async def run_with_tracking(task_key: str, run_func, graphname: str, conn): llm_result = reload_llm_config() if llm_result["status"] == "success": LogWriter.info(f"LLM config reloaded: {llm_result['message']}") - completion_service = llm_config.get("completion_service", {}) + completion_service = get_completion_config(graphname) ecc_model = completion_service.get("llm_model", "unknown") ecc_provider = completion_service.get("llm_service", "unknown") LogWriter.info( diff --git a/graphrag-ui/src/pages/setup/GraphRAGConfig.tsx b/graphrag-ui/src/pages/setup/GraphRAGConfig.tsx index 0e05e4e..dc33689 100644 --- a/graphrag-ui/src/pages/setup/GraphRAGConfig.tsx +++ b/graphrag-ui/src/pages/setup/GraphRAGConfig.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from "react"; +import React, { useState, useEffect, useRef } from "react"; import { Settings, Save, Loader2 } from "lucide-react"; import { Input } from "@/components/ui/input"; import { Button } from "@/components/ui/button"; @@ -14,7 +14,7 @@ import ConfigScopeToggle from "@/components/ConfigScopeToggle"; const GraphRAGConfig = () => { const [selectedGraph, setSelectedGraph] = useState(sessionStorage.getItem("selectedGraph") || ""); const [availableGraphs, setAvailableGraphs] = useState([]); - const [reuseEmbedding, setReuseEmbedding] = useState(false); + const [reuseEmbedding, setReuseEmbedding] = useState(true); const [eccUrl, setEccUrl] = useState("http://graphrag-ecc:8001"); const [chatHistoryUrl, setChatHistoryUrl] = useState("http://chat-history:8002"); @@ -35,11 +35,11 @@ const GraphRAGConfig = () => { const [maxConcurrency, setMaxConcurrency] = useState("10"); // Chunker-specific settings - const [chunkSize, setChunkSize] = useState("1024"); - const [overlapSize, setOverlapSize] = useState("0"); - const [semanticMethod, setSemanticMethod] = useState("percentile"); - const [semanticThreshold, setSemanticThreshold] = useState("0.95"); - const [regexPattern, setRegexPattern] = useState("\\r?\\n"); + const [chunkSize, setChunkSize] = useState(""); + const [overlapSize, setOverlapSize] = useState(""); + const [semanticMethod, setSemanticMethod] = useState(""); + const [semanticThreshold, setSemanticThreshold] = useState(""); + const [regexPattern, setRegexPattern] = useState(""); const [isLoading, setIsLoading] = useState(false); const [isSaving, setIsSaving] = useState(false); @@ -50,6 +50,10 @@ const GraphRAGConfig = () => { const [configScope, setConfigScope] = useState<"global" | "graph">("global"); const [graphOverrides, setGraphOverrides] = useState>({}); + // Track configs as loaded from API so we only save what's needed + const loadedGlobalConfig = useRef>({}); + const loadedGraphOverrides = useRef>({}); + useEffect(() => { const site = JSON.parse(sessionStorage.getItem("site") || "{}"); setAvailableGraphs(site.graphs || []); @@ -59,7 +63,7 @@ const GraphRAGConfig = () => { const applyGraphragConfig = (graphragConfig: any) => { if (!graphragConfig) return; - setReuseEmbedding(graphragConfig.reuse_embedding || false); + setReuseEmbedding(graphragConfig.reuse_embedding ?? true); setEccUrl(graphragConfig.ecc || "http://graphrag-ecc:8001"); setChatHistoryUrl(graphragConfig.chat_history_api || "http://chat-history:8002"); setDefaultChunker(graphragConfig.chunker || "semantic"); @@ -67,17 +71,17 @@ const GraphRAGConfig = () => { setNumHops(String(graphragConfig.num_hops ?? 2)); setNumSeenMin(String(graphragConfig.num_seen_min ?? 2)); setCommunityLevel(String(graphragConfig.community_level ?? 2)); - setDocOnly(graphragConfig.doc_only || false); + setDocOnly(graphragConfig.doc_only ?? false); setLoadBatchSize(String(graphragConfig.load_batch_size ?? 500)); setUpsertDelay(String(graphragConfig.upsert_delay ?? 0)); setMaxConcurrency(String(graphragConfig.default_concurrency ?? 10)); const chunkerConfig = graphragConfig.chunker_config || {}; - setChunkSize(String(chunkerConfig.chunk_size || 1024)); - setOverlapSize(String(chunkerConfig.overlap_size || 0)); - setSemanticMethod(chunkerConfig.method || "percentile"); - setSemanticThreshold(String(chunkerConfig.threshold || 0.95)); - setRegexPattern(chunkerConfig.pattern || "\\r?\\n"); + setChunkSize(String(chunkerConfig.chunk_size ?? "")); + setOverlapSize(chunkerConfig.overlap_size != null ? String(chunkerConfig.overlap_size) : ""); + setSemanticMethod(chunkerConfig.method || ""); + setSemanticThreshold(chunkerConfig.threshold != null ? String(chunkerConfig.threshold) : ""); + setRegexPattern(chunkerConfig.pattern != null ? chunkerConfig.pattern : ""); }; const fetchConfig = async (scope?: "global" | "graph", graphname?: string) => { @@ -100,12 +104,17 @@ const GraphRAGConfig = () => { const data = await response.json(); + const deepCopy = (obj: any) => JSON.parse(JSON.stringify(obj || {})); + loadedGlobalConfig.current = deepCopy(data.graphrag_config); + if (effectiveScope === "graph" && data.graphrag_overrides) { + loadedGraphOverrides.current = deepCopy(data.graphrag_overrides); setGraphOverrides(data.graphrag_overrides); // Show per-graph values: merge global + overrides for display const merged = { ...data.graphrag_config, ...data.graphrag_overrides }; applyGraphragConfig(merged); } else { + loadedGraphOverrides.current = {}; setGraphOverrides({}); applyGraphragConfig(data.graphrag_config); } @@ -126,28 +135,20 @@ const GraphRAGConfig = () => { try { const creds = sessionStorage.getItem("creds"); - // Prepare chunker config based on selected chunker type - const chunkerConfig: any = {}; - - if (defaultChunker === "character" || defaultChunker === "markdown" || defaultChunker === "recursive") { - chunkerConfig.chunk_size = parseInt(chunkSize); - chunkerConfig.overlap_size = parseInt(overlapSize); - } else if (defaultChunker === "semantic") { - chunkerConfig.method = semanticMethod; - chunkerConfig.threshold = parseFloat(semanticThreshold); - } else if (defaultChunker === "regex") { - chunkerConfig.pattern = regexPattern; - } else if (defaultChunker === "html") { - // HTML chunker doesn't require specific config in the current implementation - // but we keep it consistent - } - - const graphragConfigData: any = { + // Build current UI state — only include non-empty fields + const currentChunkerConfig: any = {}; + if (chunkSize !== "") currentChunkerConfig.chunk_size = parseInt(chunkSize); + if (overlapSize !== "") currentChunkerConfig.overlap_size = parseInt(overlapSize); + if (semanticMethod !== "") currentChunkerConfig.method = semanticMethod; + if (semanticThreshold !== "") currentChunkerConfig.threshold = parseFloat(semanticThreshold); + if (regexPattern !== "") currentChunkerConfig.pattern = regexPattern; + + const currentConfig: any = { reuse_embedding: reuseEmbedding, ecc: eccUrl, chat_history_api: chatHistoryUrl, chunker: defaultChunker, - chunker_config: chunkerConfig, + chunker_config: currentChunkerConfig, top_k: parseInt(topK), num_hops: parseInt(numHops), num_seen_min: parseInt(numSeenMin), @@ -158,6 +159,87 @@ const GraphRAGConfig = () => { default_concurrency: parseInt(maxConcurrency), }; + // Display defaults — used to avoid saving values the user never changed + const displayDefaults: Record = { + reuse_embedding: true, + ecc: "http://graphrag-ecc:8001", + chat_history_api: "http://chat-history:8002", + chunker: "semantic", + top_k: 5, + num_hops: 2, + num_seen_min: 2, + community_level: 2, + doc_only: false, + load_batch_size: 500, + upsert_delay: 0, + default_concurrency: 10, + }; + + // Determine which config to diff against based on scope + const globalCfg = loadedGlobalConfig.current; + const globalChunker = globalCfg.chunker_config || {}; + const graphragConfigData: any = {}; + + // Helper: should a key be saved? + // - If it was in the reference config (loaded/overrides), always save + // - If it differs from the reference config AND differs from display default, save + const shouldSave = (key: string, current: any, reference: Record, wasInRef: boolean) => { + if (wasInRef) return true; + const diffFromRef = JSON.stringify(current) !== JSON.stringify(reference[key]); + const matchesDefault = JSON.stringify(current) === JSON.stringify(displayDefaults[key]); + return diffFromRef && !matchesDefault; + }; + + if (configScope === "graph") { + // Graph scope: save values that differ from effective global (loaded or display default) + // or were already overridden per-graph + const overrides = loadedGraphOverrides.current; + const overridesChunker = overrides.chunker_config || {}; + + for (const key of Object.keys(currentConfig)) { + if (key === "chunker_config") continue; + const effectiveGlobal = key in globalCfg ? globalCfg[key] : displayDefaults[key]; + const diffFromGlobal = JSON.stringify(currentConfig[key]) !== JSON.stringify(effectiveGlobal); + const wasOverridden = key in overrides; + if (wasOverridden || diffFromGlobal) { + graphragConfigData[key] = currentConfig[key]; + } + } + + const chunkerConfig: any = {}; + for (const key of Object.keys(currentChunkerConfig)) { + const effectiveGlobal = key in globalChunker ? globalChunker[key] : undefined; + const diffFromGlobal = JSON.stringify(currentChunkerConfig[key]) !== JSON.stringify(effectiveGlobal); + const wasOverridden = key in overridesChunker; + if (wasOverridden || diffFromGlobal) { + chunkerConfig[key] = currentChunkerConfig[key]; + } + } + if (Object.keys(chunkerConfig).length > 0 || "chunker_config" in overrides) { + graphragConfigData.chunker_config = chunkerConfig; + } + } else { + // Global scope: save loaded keys + user changes (skip display defaults for unloaded keys) + for (const key of Object.keys(currentConfig)) { + if (key === "chunker_config") continue; + if (shouldSave(key, currentConfig[key], globalCfg, key in globalCfg)) { + graphragConfigData[key] = currentConfig[key]; + } + } + + const chunkerConfig: any = {}; + for (const key of Object.keys(currentChunkerConfig)) { + const wasLoaded = key in globalChunker; + const changed = JSON.stringify(currentChunkerConfig[key]) !== JSON.stringify(globalChunker[key]); + if (wasLoaded || changed) { + chunkerConfig[key] = currentChunkerConfig[key]; + } + } + if (Object.keys(chunkerConfig).length > 0 || "chunker_config" in globalCfg) { + graphragConfigData.chunker_config = chunkerConfig; + } + } + if (configScope === "graph") { graphragConfigData.scope = "graph"; graphragConfigData.graphname = selectedGraph; @@ -408,8 +490,11 @@ const GraphRAGConfig = () => {

- {/* Settings for character/markdown/recursive chunkers */} - {(defaultChunker === "character" || defaultChunker === "markdown" || defaultChunker === "recursive") && ( + {/* Character/Markdown/Recursive chunker settings */} +
+

+ Character / Markdown / Recursive Chunker +

-
- )} +
- {/* Settings for semantic chunker */} - {defaultChunker === "semantic" && ( + {/* Semantic chunker settings */} +
+

+ Semantic Chunker +

- setSemanticMethod(v)}> - + Percentile Standard Deviation Interquartile + Gradient

Breakpoint detection method

-
- )} +
- {/* Settings for regex chunker */} - {defaultChunker === "regex" && ( + {/* Regex chunker settings */} +
+

+ Regex Chunker +

- )} - - {/* Info for HTML chunker */} - {defaultChunker === "html" && ( -
-

- HTML chunker uses the document structure to split content. No additional configuration needed. -

-
- )} +
diff --git a/graphrag-ui/src/pages/setup/LLMConfig.tsx b/graphrag-ui/src/pages/setup/LLMConfig.tsx index aa5596f..836e17c 100644 --- a/graphrag-ui/src/pages/setup/LLMConfig.tsx +++ b/graphrag-ui/src/pages/setup/LLMConfig.tsx @@ -64,7 +64,7 @@ const PROVIDER_FIELDS: Record = { { key: "AWS_SECRET_ACCESS_KEY", label: "AWS Secret Access Key", type: "password", required: true } ], configFields: [ - { key: "region_name", label: "AWS Region", type: "text", required: true, placeholder: "us-east-1" } + { key: "region_name", label: "AWS Region", type: "text", required: false, placeholder: "us-east-1" } ] }, groq: { @@ -107,6 +107,20 @@ const PROVIDER_FIELDS: Record = { } }; +// Single provider list shared across all service Select dropdowns +const LLM_PROVIDERS = [ + { value: "openai", label: "OpenAI" }, + { value: "azure", label: "Azure OpenAI" }, + { value: "genai", label: "Google GenAI (Gemini)" }, + { value: "vertexai", label: "Google Vertex AI" }, + { value: "bedrock", label: "AWS Bedrock" }, + { value: "groq", label: "Groq" }, + { value: "ollama", label: "Ollama" }, + { value: "sagemaker", label: "AWS SageMaker" }, + { value: "huggingface", label: "HuggingFace" }, + { value: "watsonx", label: "IBM WatsonX" }, +] as const; + const LLMConfig = () => { const [selectedGraph, setSelectedGraph] = useState(sessionStorage.getItem("selectedGraph") || ""); const [availableGraphs, setAvailableGraphs] = useState([]); @@ -119,15 +133,10 @@ const LLMConfig = () => { const [messageType, setMessageType] = useState<"success" | "error" | "">(""); const [testResults, setTestResults] = useState(null); const [connectionTested, setConnectionTested] = useState(false); - - // Single provider state - const [singleProvider, setSingleProvider] = useState("openai"); - const [singleConfig, setSingleConfig] = useState>({}); - const [singleDefaultModel, setSingleDefaultModel] = useState(""); - const [singleEmbeddingModel, setSingleEmbeddingModel] = useState(""); - const [multimodalModel, setMultimodalModel] = useState(""); - - // Multi-provider state + + const [useCustomMultimodal, setUseCustomMultimodal] = useState(false); + + // Canonical per-service state — both single and multi-provider UIs read/write these const [completionProvider, setCompletionProvider] = useState("openai"); const [completionConfig, setCompletionConfig] = useState>({}); const [completionDefaultModel, setCompletionDefaultModel] = useState(""); @@ -183,124 +192,118 @@ const LLMConfig = () => { const llmConfig = data.llm_config; setLlmConfigAccess(data.llm_config_access === "chatbot_only" ? "chatbot_only" : "full"); + // Store graph overrides when in per-graph scope + if (data.graph_overrides) { + setGraphOverrides(data.graph_overrides); + } else { + setGraphOverrides({}); + } + + // Detect providers (needed by chat/multimodal fallback below) + const completionProv = llmConfig.completion_service?.llm_service?.toLowerCase(); + const embeddingProv = llmConfig.embedding_service?.embedding_model_service?.toLowerCase(); + const multimodalProv = llmConfig.multimodal_service?.llm_service?.toLowerCase(); + const chatProv = llmConfig.chat_service?.llm_service?.toLowerCase(); + const defaultProv = completionProv || "openai"; + + // All config field keys that any provider might use + const allConfigKeys = ["base_url", "azure_deployment", "region_name", "project", "location", "endpoint_name", "endpoint_url"]; + + // Build the base config: top-level auth + completion_service fields. + // Every service inherits missing keys from this base. + const baseConfig: Record = {}; + // Layer 1: top-level auth + if (llmConfig.authentication_configuration) { + for (const [key, value] of Object.entries(llmConfig.authentication_configuration)) { + if (typeof value === "string") baseConfig[key] = value; + } + } + // Layer 2: completion_service config fields + auth + if (llmConfig.completion_service) { + for (const key of allConfigKeys) { + if (llmConfig.completion_service[key]) baseConfig[key] = llmConfig.completion_service[key]; + } + if (llmConfig.completion_service.authentication_configuration) { + for (const [key, value] of Object.entries(llmConfig.completion_service.authentication_configuration)) { + if (typeof value === "string") baseConfig[key] = value; + } + } + } + + // Helper: load a service config, inheriting all missing keys from baseConfig + const loadServiceConfigResolved = (svc: any) => { + // Start with base config as defaults + const cfg: Record = { ...baseConfig }; + // Override with service-specific config fields + if (svc) { + for (const key of allConfigKeys) { + if (svc[key]) cfg[key] = svc[key]; + } + // Override with service-specific auth + if (svc.authentication_configuration) { + for (const [key, value] of Object.entries(svc.authentication_configuration)) { + if (typeof value === "string") cfg[key] = value; + } + } + } + return cfg; + }; + // Parse per-graph chatbot config (chatbot_only mode) if (data.global_chat_info) { setGlobalChatInfo(data.global_chat_info); } if (data.chatbot_config) { setUseCustomChatbot(true); - setChatbotProvider(data.chatbot_config.llm_service?.toLowerCase() || "openai"); + setChatbotProvider(data.chatbot_config.llm_service?.toLowerCase() || defaultProv); setChatbotModelName(data.chatbot_config.llm_model || ""); setChatbotTemperature(String(data.chatbot_config.model_kwargs?.temperature ?? "0")); - // Load provider-specific config fields + masked auth - const cfg: Record = {}; - for (const key of ["base_url", "azure_deployment", "region_name", "project", "location", "endpoint_name", "endpoint_url"]) { - if (data.chatbot_config[key]) cfg[key] = data.chatbot_config[key]; - } - if (data.chatbot_config.authentication_configuration) { - for (const [key, value] of Object.entries(data.chatbot_config.authentication_configuration)) { - if (typeof value === "string") cfg[key] = value; - } - } - setChatbotProviderConfig(cfg); + // Resolve chatbot config: base config + chatbot overrides + setChatbotProviderConfig(loadServiceConfigResolved(data.chatbot_config)); } else { setUseCustomChatbot(false); } - // Store graph overrides when in per-graph scope - if (data.graph_overrides) { - setGraphOverrides(data.graph_overrides); - } else { - setGraphOverrides({}); - } - const currentDefaultModel = llmConfig.completion_service?.llm_model || ""; - setSingleDefaultModel(currentDefaultModel); + setCompletionDefaultModel(currentDefaultModel); + + const allSameProvider = + completionProv === embeddingProv && + (!multimodalProv || completionProv === multimodalProv) && + (!chatProv || completionProv === chatProv); + + setUseMultipleProviders(!allSameProvider); // Load chat_service config for full mode (superadmin) + // Chat inherits from base (completion) when not explicitly set if (llmConfig.chat_service) { setUseCustomChatbot(true); - setChatbotProvider(llmConfig.chat_service.llm_service?.toLowerCase() || "openai"); + setChatbotProvider(chatProv || defaultProv); setChatbotModelName(llmConfig.chat_service.llm_model || ""); setChatbotTemperature(String(llmConfig.chat_service.model_kwargs?.temperature ?? "0")); - const chatCfg: Record = {}; - for (const key of ["base_url", "azure_deployment", "region_name", "project", "location", "endpoint_name", "endpoint_url"]) { - if (llmConfig.chat_service[key]) chatCfg[key] = llmConfig.chat_service[key]; - } - if (llmConfig.chat_service.authentication_configuration) { - for (const [key, value] of Object.entries(llmConfig.chat_service.authentication_configuration)) { - if (typeof value === "string") chatCfg[key] = value; - } - } - setChatbotProviderConfig(chatCfg); + setChatbotProviderConfig(loadServiceConfigResolved(llmConfig.chat_service)); } else { setUseCustomChatbot(false); - setChatbotProvider("openai"); + setChatbotProvider(defaultProv); setChatbotModelName(""); setChatbotTemperature("0"); - setChatbotProviderConfig({}); + setChatbotProviderConfig({ ...baseConfig }); } - // Detect if using multiple providers - const completionProv = llmConfig.completion_service?.llm_service?.toLowerCase(); - const embeddingProv = llmConfig.embedding_service?.embedding_model_service?.toLowerCase(); - const multimodalProv = llmConfig.multimodal_service?.llm_service?.toLowerCase(); - const chatProv = llmConfig.chat_service?.llm_service?.toLowerCase(); + // Canonical per-service state — both single and multi-provider UIs read these + setCompletionProvider(completionProv || "openai"); + setCompletionDefaultModel(llmConfig.completion_service?.llm_model || ""); + setCompletionConfig(loadServiceConfigResolved(llmConfig.completion_service)); - const allSameProvider = - completionProv === embeddingProv && - (!multimodalProv || completionProv === multimodalProv) && - (!chatProv || completionProv === chatProv); - - setUseMultipleProviders(!allSameProvider); + setEmbeddingProvider(embeddingProv || completionProv || "openai"); + setEmbeddingModel(llmConfig.embedding_service?.model_name || ""); + setEmbeddingConfig(loadServiceConfigResolved(llmConfig.embedding_service)); - // Helper: load config fields + masked auth fields from a service config - const loadServiceConfig = (svc: any, configKeys: string[]) => { - const cfg: Record = {}; - for (const key of configKeys) { - if (svc?.[key]) cfg[key] = svc[key]; - } - // Load masked auth fields from authentication_configuration - if (svc?.authentication_configuration) { - for (const [key, value] of Object.entries(svc.authentication_configuration)) { - if (typeof value === "string") cfg[key] = value; - } - } - return cfg; - }; - - const completionConfigKeys = ["base_url", "azure_deployment", "region_name", "project", "location", "endpoint_name", "endpoint_url"]; - const embeddingConfigKeys = ["base_url", "azure_deployment", "region_name"]; - - if (!allSameProvider) { - // Multi-provider mode - Load from backend - setCompletionProvider(completionProv || "openai"); - setCompletionDefaultModel(llmConfig.completion_service?.llm_model || ""); - setCompletionConfig(loadServiceConfig(llmConfig.completion_service, completionConfigKeys)); - - setEmbeddingProvider(embeddingProv || "openai"); - setEmbeddingModel(llmConfig.embedding_service?.model_name || ""); - setEmbeddingConfig(loadServiceConfig(llmConfig.embedding_service, embeddingConfigKeys)); - - setMultimodalProvider(multimodalProv || "openai"); - setMultimodalModelName(llmConfig.multimodal_service?.llm_model || ""); - setMultimodalConfig(loadServiceConfig(llmConfig.multimodal_service, ["azure_deployment"])); - } else { - // Single provider mode - Load from backend - setSingleProvider(completionProv || "openai"); - setSingleDefaultModel(llmConfig.completion_service?.llm_model || ""); - setSingleEmbeddingModel(llmConfig.embedding_service?.model_name || ""); - setMultimodalModel(llmConfig.multimodal_service?.llm_model || ""); - // Load config + auth from completion_service (single provider shares auth) - const singleCfg = loadServiceConfig(llmConfig.completion_service, completionConfigKeys); - // Also load top-level authentication_configuration (used in single-provider mode) - if (llmConfig.authentication_configuration) { - for (const [key, value] of Object.entries(llmConfig.authentication_configuration)) { - if (typeof value === "string" && !singleCfg[key]) singleCfg[key] = value; - } - } - setSingleConfig(singleCfg); - } + setMultimodalProvider(multimodalProv || completionProv || "openai"); + const mmModel = llmConfig.multimodal_service?.llm_model || ""; + setMultimodalModelName(mmModel); + setMultimodalConfig(loadServiceConfigResolved(llmConfig.multimodal_service)); + setUseCustomMultimodal(!!mmModel || !!multimodalProv); } catch (error: any) { console.error("Error fetching config:", error); setMessage(`Failed to load configuration: ${error.message}`); @@ -318,19 +321,20 @@ const LLMConfig = () => { }; // Update config when provider changes - CLEAR ALL FIELDS - const handleProviderChange = (newProvider: string, target: 'single' | 'completion' | 'embedding' | 'multimodal') => { - if (target === 'single') { - setSingleProvider(newProvider); - setSingleConfig({}); - // Clear model names when switching provider - setSingleDefaultModel(""); - setSingleEmbeddingModel(""); - setMultimodalModel(""); - } else if (target === 'completion') { + const handleProviderChange = (newProvider: string, target: 'completion' | 'embedding' | 'multimodal') => { + if (target === 'completion') { setCompletionProvider(newProvider); setCompletionConfig({}); - // Clear model names when switching provider setCompletionDefaultModel(""); + // In single-provider mode, all services share the same provider + if (!useMultipleProviders) { + setEmbeddingProvider(newProvider); + setEmbeddingConfig({}); + setEmbeddingModel(""); + setMultimodalProvider(newProvider); + setMultimodalConfig({}); + setMultimodalModelName(""); + } } else if (target === 'embedding') { setEmbeddingProvider(newProvider); setEmbeddingConfig({}); @@ -373,6 +377,103 @@ const LLMConfig = () => { return serviceConfig; }; + /** + * Build the candidate LLM config payload. + * Used by both test-connection and save — same structure, single source of truth. + * Inherited services (multimodal, chatbot) are set to null when not customized. + */ + const buildLLMConfigPayload = (): any => { + let llmConfigData: any; + + if (useMultipleProviders) { + const completionServiceConfig: any = { + llm_service: completionProvider, + llm_model: completionDefaultModel, + authentication_configuration: buildAuthConfig(completionProvider, completionConfig), + model_kwargs: { temperature: 0 }, + prompt_path: `./common/prompts/${getPromptPath(completionProvider)}/`, + ...buildServiceConfig(completionProvider, completionConfig) + }; + + llmConfigData = { + graphname: selectedGraph || undefined, + completion_service: completionServiceConfig, + embedding_service: { + embedding_model_service: embeddingProvider, + model_name: embeddingModel, + authentication_configuration: buildAuthConfig(embeddingProvider, embeddingConfig), + ...buildServiceConfig(embeddingProvider, embeddingConfig) + }, + }; + + if (useCustomMultimodal && multimodalModelName) { + llmConfigData.multimodal_service = { + llm_service: multimodalProvider, + llm_model: multimodalModelName, + authentication_configuration: buildAuthConfig(multimodalProvider, multimodalConfig), + model_kwargs: { temperature: 0 }, + ...buildServiceConfig(multimodalProvider, multimodalConfig) + }; + } else { + llmConfigData.multimodal_service = null; + } + + if (useCustomChatbot) { + llmConfigData.chat_service = { + llm_service: chatbotProvider, + llm_model: chatbotModelName, + authentication_configuration: buildAuthConfig(chatbotProvider, chatbotProviderConfig), + model_kwargs: { temperature: parseFloat(chatbotTemperature) || 0 }, + ...buildServiceConfig(chatbotProvider, chatbotProviderConfig), + }; + } else { + llmConfigData.chat_service = null; + } + } else { + const completionServiceConfig: any = { + llm_service: completionProvider, + llm_model: completionDefaultModel, + model_kwargs: { temperature: 0 }, + prompt_path: `./common/prompts/${getPromptPath(completionProvider)}/`, + ...buildServiceConfig(completionProvider, completionConfig) + }; + + llmConfigData = { + graphname: selectedGraph || undefined, + authentication_configuration: buildAuthConfig(completionProvider, completionConfig), + completion_service: completionServiceConfig, + embedding_service: { + embedding_model_service: completionProvider, + model_name: embeddingModel, + }, + }; + + if (useCustomMultimodal && multimodalModelName.trim()) { + llmConfigData.multimodal_service = { + llm_model: multimodalModelName, + }; + } else { + llmConfigData.multimodal_service = null; + } + + if (useCustomChatbot) { + const chatTemp = parseFloat(chatbotTemperature) || 0; + llmConfigData.chat_service = { + ...(chatbotModelName.trim() ? { llm_model: chatbotModelName } : {}), + model_kwargs: { temperature: chatTemp }, + }; + } else { + llmConfigData.chat_service = null; + } + } + + if (configScope === "graph") { + llmConfigData.scope = "graph"; + } + + return llmConfigData; + }; + const handleSave = async () => { setIsSaving(true); setMessage(""); @@ -420,85 +521,7 @@ const LLMConfig = () => { return; } - if (useMultipleProviders) { - const completionServiceConfig: any = { - llm_service: completionProvider, - llm_model: completionDefaultModel, - authentication_configuration: buildAuthConfig(completionProvider, completionConfig), - model_kwargs: { temperature: 0 }, - prompt_path: `./common/prompts/${getPromptPath(completionProvider)}/`, - ...buildServiceConfig(completionProvider, completionConfig) - }; - - llmConfigData = { - graphname: selectedGraph || undefined, - completion_service: completionServiceConfig, - embedding_service: { - embedding_model_service: embeddingProvider, - model_name: embeddingModel, - authentication_configuration: buildAuthConfig(embeddingProvider, embeddingConfig), - ...buildServiceConfig(embeddingProvider, embeddingConfig) - }, - multimodal_service: { - llm_service: multimodalProvider, - llm_model: multimodalModelName, - authentication_configuration: buildAuthConfig(multimodalProvider, multimodalConfig), - model_kwargs: { temperature: 0 }, - ...buildServiceConfig(multimodalProvider, multimodalConfig) - }, - }; - - // Save chat_service if not inheriting from completion service - if (useCustomChatbot) { - llmConfigData.chat_service = { - llm_service: chatbotProvider, - llm_model: chatbotModelName, - authentication_configuration: buildAuthConfig(chatbotProvider, chatbotProviderConfig), - model_kwargs: { temperature: parseFloat(chatbotTemperature) || 0 }, - ...buildServiceConfig(chatbotProvider, chatbotProviderConfig), - }; - } else { - llmConfigData.chat_service = null; - } - } else { - const completionServiceConfig: any = { - llm_service: singleProvider, - llm_model: singleDefaultModel, - model_kwargs: { temperature: 0 }, - prompt_path: `./common/prompts/${getPromptPath(singleProvider)}/`, - ...buildServiceConfig(singleProvider, singleConfig) - }; - - llmConfigData = { - graphname: selectedGraph || undefined, - authentication_configuration: buildAuthConfig(singleProvider, singleConfig), - completion_service: completionServiceConfig, - embedding_service: { - embedding_model_service: singleProvider, - model_name: singleEmbeddingModel, - }, - multimodal_service: { - llm_service: singleProvider, - llm_model: multimodalModel, - model_kwargs: { temperature: 0 }, - ...buildServiceConfig(singleProvider, singleConfig) - }, - }; - - // Save chat_service with just the model name (same provider as completion) - if (chatbotModelName.trim()) { - llmConfigData.chat_service = { - llm_model: chatbotModelName, - }; - } else { - llmConfigData.chat_service = null; - } - } - - // Add scope for superadmin per-graph saves - if (configScope === "graph") { - llmConfigData.scope = "graph"; - } + llmConfigData = buildLLMConfigPayload(); const response = await fetch("/ui/config/llm", { method: "POST", @@ -519,6 +542,9 @@ const LLMConfig = () => { setMessageType("success"); setTestResults(null); setConnectionTested(false); + + // Refetch to sync all state with the saved config + fetchConfig(configScope === "graph" ? "graph" : "global", selectedGraph || undefined); } catch (error: any) { console.error("Error saving config:", error); setMessage(`❌ Error: ${error.message}`); @@ -553,112 +579,43 @@ const LLMConfig = () => { return null; }; + const failValidation = (msg: string) => { + setMessage(`❌ ${msg}`); + setMessageType("error"); + setIsTesting(false); + }; + if (useMultipleProviders) { const completionError = validateProvider(completionProvider, completionConfig, "Completion Service"); - if (completionError) { - setMessage(`❌ ${completionError}`); - setMessageType("error"); - setIsTesting(false); - return; - } - + if (completionError) { failValidation(completionError); return; } + if (!completionDefaultModel.trim()) { failValidation("Model Name is required for Completion Service"); return; } + const embeddingError = validateProvider(embeddingProvider, embeddingConfig, "Embedding Service"); - if (embeddingError) { - setMessage(`❌ ${embeddingError}`); - setMessageType("error"); - setIsTesting(false); - return; + if (embeddingError) { failValidation(embeddingError); return; } + if (!embeddingModel.trim()) { failValidation("Model Name is required for Embedding Service"); return; } + + if (useCustomMultimodal) { + const multimodalError = validateProvider(multimodalProvider, multimodalConfig, "Multimodal Service"); + if (multimodalError) { failValidation(multimodalError); return; } + if (!multimodalModelName.trim()) { failValidation("Model Name is required for Multimodal Service"); return; } } - const multimodalError = validateProvider(multimodalProvider, multimodalConfig, "Multimodal Service"); - if (multimodalError) { - setMessage(`❌ ${multimodalError}`); - setMessageType("error"); - setIsTesting(false); - return; + if (useCustomChatbot) { + const chatbotError = validateProvider(chatbotProvider, chatbotProviderConfig, "Chatbot Service"); + if (chatbotError) { failValidation(chatbotError); return; } + if (!chatbotModelName.trim()) { failValidation("Model Name is required for Chatbot Service"); return; } } } else { - const singleError = validateProvider(singleProvider, singleConfig, singleProvider); - if (singleError) { - setMessage(`❌ ${singleError}`); - setMessageType("error"); - setIsTesting(false); - return; - } + const singleError = validateProvider(completionProvider, completionConfig, completionProvider); + if (singleError) { failValidation(singleError); return; } + if (!completionDefaultModel.trim()) { failValidation("Completion Model is required"); return; } + if (!embeddingModel.trim()) { failValidation("Embedding Model is required"); return; } + if (useCustomMultimodal && !multimodalModelName.trim()) { failValidation("Multimodal Model is required when not inheriting from completion"); return; } + if (useCustomChatbot && !chatbotModelName.trim()) { failValidation("Chatbot Model is required when not inheriting from completion"); return; } } const creds = sessionStorage.getItem("creds"); - let llmConfigData: any; - - if (useMultipleProviders) { - llmConfigData = { - graphname: selectedGraph || undefined, - completion_service: { - llm_service: completionProvider, - llm_model: completionDefaultModel, - authentication_configuration: buildAuthConfig(completionProvider, completionConfig), - ...buildServiceConfig(completionProvider, completionConfig) - }, - embedding_service: { - embedding_model_service: embeddingProvider, - model_name: embeddingModel, - authentication_configuration: buildAuthConfig(embeddingProvider, embeddingConfig), - ...buildServiceConfig(embeddingProvider, embeddingConfig) - }, - }; - - llmConfigData.multimodal_service = { - llm_service: multimodalProvider, - llm_model: multimodalModelName, - authentication_configuration: buildAuthConfig(multimodalProvider, multimodalConfig), - ...buildServiceConfig(multimodalProvider, multimodalConfig) - }; - } else { - llmConfigData = { - graphname: selectedGraph || undefined, - authentication_configuration: buildAuthConfig(singleProvider, singleConfig), - completion_service: { - llm_service: singleProvider, - llm_model: singleDefaultModel, - ...buildServiceConfig(singleProvider, singleConfig) - }, - embedding_service: { - embedding_model_service: singleProvider, - model_name: singleEmbeddingModel, - }, - multimodal_service: { - llm_service: singleProvider, - llm_model: multimodalModel, - ...buildServiceConfig(singleProvider, singleConfig) - }, - }; - - } - - // Add chat_service to test config if custom chatbot is configured - // Add chat_service to test config if not inheriting - if (useCustomChatbot) { - if (useMultipleProviders) { - const chatbotError = validateProvider(chatbotProvider, chatbotProviderConfig, "Chatbot Service"); - if (chatbotError) { - setMessage(`❌ ${chatbotError}`); - setMessageType("error"); - setIsTesting(false); - return; - } - llmConfigData.chat_service = { - llm_service: chatbotProvider, - llm_model: chatbotModelName, - authentication_configuration: buildAuthConfig(chatbotProvider, chatbotProviderConfig), - model_kwargs: { temperature: parseFloat(chatbotTemperature) || 0 }, - ...buildServiceConfig(chatbotProvider, chatbotProviderConfig), - }; - } else if (chatbotModelName.trim()) { - llmConfigData.chat_service = { - llm_model: chatbotModelName, - }; - } - } + const llmConfigData = buildLLMConfigPayload(); const response = await fetch("/ui/config/llm/test", { method: "POST", @@ -919,16 +876,9 @@ const LLMConfig = () => { - OpenAI - Azure OpenAI - Google GenAI (Gemini) - Google Vertex AI - AWS Bedrock - Groq - Ollama - AWS SageMaker - HuggingFace - IBM WatsonX + {LLM_PROVIDERS.map((p) => ( + {p.label} + ))} @@ -1102,6 +1052,13 @@ const LLMConfig = () => { checked={useMultipleProviders} onChange={(e) => { setUseMultipleProviders(e.target.checked); + if (!e.target.checked) { + // Switching to single-provider: unify providers/configs to completion + setEmbeddingProvider(completionProvider); + setEmbeddingConfig({ ...completionConfig }); + setMultimodalProvider(completionProvider); + setMultimodalConfig({ ...completionConfig }); + } clearTestResults(); }} className="h-4 w-4 rounded border-gray-300 dark:border-[#3D3D3D]" @@ -1132,37 +1089,34 @@ const LLMConfig = () => { - handleProviderChange(value, 'completion')}> - OpenAI - Azure OpenAI - Google GenAI (Gemini) - Google Vertex AI - AWS Bedrock - Ollama + {LLM_PROVIDERS.map((p) => ( + {p.label} + ))}

- Only providers supporting both completion and embedding services are shown + This provider will be used for all services (completion, embedding, multimodal)

- {renderProviderFields(singleProvider, singleConfig, setSingleConfig)} + {renderProviderFields(completionProvider, completionConfig, setCompletionConfig)}
{ - setSingleDefaultModel(e.target.value); + setCompletionDefaultModel(e.target.value); clearTestResults(); }} /> @@ -1171,6 +1125,8 @@ const LLMConfig = () => {

+
+
@@ -1275,16 +1273,9 @@ const LLMConfig = () => { - OpenAI - Azure OpenAI - Google GenAI (Gemini) - Google Vertex AI - AWS Bedrock - AWS SageMaker - Groq - Ollama - HuggingFace - IBM WatsonX + {LLM_PROVIDERS.map((p) => ( + {p.label} + ))} @@ -1293,7 +1284,7 @@ const LLMConfig = () => {
{ checked={!useCustomChatbot} onChange={(e) => { setUseCustomChatbot(!e.target.checked); - if (e.target.checked) { - setChatbotModelName(""); - setChatbotProviderConfig({}); - } clearTestResults(); }} /> @@ -1361,16 +1348,9 @@ const LLMConfig = () => { - OpenAI - Azure OpenAI - Google GenAI (Gemini) - Google Vertex AI - AWS Bedrock - Groq - Ollama - AWS SageMaker - HuggingFace - IBM WatsonX + {LLM_PROVIDERS.map((p) => ( + {p.label} + ))}
@@ -1410,62 +1390,84 @@ const LLMConfig = () => { - {/* Embedding Service Provider */} + {/* Multimodal Service Provider */}

- Embedding Service + Multimodal Service

-

- Configure the provider for generating embeddings. +

+ Configure the provider for processing images and multimodal content (vision tasks).

-
- - -
- - {renderProviderFields(embeddingProvider, embeddingConfig, setEmbeddingConfig)} - -
- - + { - setEmbeddingModel(e.target.value); + setUseCustomMultimodal(!e.target.checked); clearTestResults(); }} /> +
+ {!useCustomMultimodal && ( +

+ Ensure your completion model supports vision input. Use "Test Connection" to verify. +

+ )} + + {useCustomMultimodal && ( + <> +
+ + +
+ + {renderProviderFields(multimodalProvider, multimodalConfig, setMultimodalConfig)} + +
+ + { + setMultimodalModelName(e.target.value); + clearTestResults(); + }} + /> +
+ + )}
- {/* Multimodal Service Provider */} + {/* Embedding Service Provider */}

- Multimodal Service + Embedding Service

- Configure the provider for processing images and multimodal content (vision tasks). + Configure the provider for generating embeddings.

@@ -1473,35 +1475,31 @@ const LLMConfig = () => { - handleProviderChange(value, 'embedding')}> - OpenAI - Azure OpenAI - Google GenAI (Gemini) - Google Vertex AI + {LLM_PROVIDERS.map((p) => ( + {p.label} + ))} -

- Only OpenAI, Azure, GenAI, VertexAI support vision -

- {renderProviderFields(multimodalProvider, multimodalConfig, setMultimodalConfig)} + {renderProviderFields(embeddingProvider, embeddingConfig, setEmbeddingConfig)}
{ - setMultimodalModelName(e.target.value); + setEmbeddingModel(e.target.value); clearTestResults(); }} /> @@ -1535,7 +1533,7 @@ const LLMConfig = () => { ? "bg-green-50 dark:bg-green-900/20 text-green-700 dark:text-green-300" : "bg-red-50 dark:bg-red-900/20 text-red-700 dark:text-red-300" }`}> - Default LLM Model: {testResults.completion.message} + Completion Model: {testResults.completion.message}
)} @@ -1545,27 +1543,27 @@ const LLMConfig = () => { ? "bg-green-50 dark:bg-green-900/20 text-green-700 dark:text-green-300" : "bg-red-50 dark:bg-red-900/20 text-red-700 dark:text-red-300" }`}> - Chatbot LLM Model: {testResults.chatbot.message} + Chatbot Model: {testResults.chatbot.message}
)} - - {testResults.embedding && testResults.embedding.status !== "not_tested" && ( + + {testResults.multimodal && testResults.multimodal.status !== "not_tested" && (
- Embedding Model: {testResults.embedding.message} + Multimodal Model: {testResults.multimodal.message}
)} - - {testResults.multimodal && testResults.multimodal.status !== "not_tested" && ( + + {testResults.embedding && testResults.embedding.status !== "not_tested" && (
- Multimodal Model: {testResults.multimodal.message} + Embedding Model: {testResults.embedding.message}
)} diff --git a/graphrag/app/agent/agent_generation.py b/graphrag/app/agent/agent_generation.py index d6b3461..22d10d4 100644 --- a/graphrag/app/agent/agent_generation.py +++ b/graphrag/app/agent/agent_generation.py @@ -26,10 +26,10 @@ logger = logging.getLogger(__name__) class TigerGraphAgentGenerator: - def __init__(self, llm_model): - self.llm = llm_model - llm_config = getattr(llm_model, "config", {}) - self.token_calculator = get_token_calculator(token_limit=llm_config.get("token_limit"), model_name=llm_config.get("llm_model")) + def __init__(self, llm_service): + self.llm = llm_service + svc_config = getattr(llm_service, "config", {}) + self.token_calculator = get_token_calculator(token_limit=svc_config.get("token_limit"), model_name=svc_config.get("llm_model")) def generate_answer(self, question: str, context: str | dict, query: str = "") -> dict: """Generate an answer based on the question and context. diff --git a/graphrag/app/routers/root.py b/graphrag/app/routers/root.py index f96bb40..e986194 100644 --- a/graphrag/app/routers/root.py +++ b/graphrag/app/routers/root.py @@ -5,7 +5,7 @@ from fastapi.responses import FileResponse, HTMLResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials -from common.config import llm_config, service_status +from common.config import get_completion_config, service_status logger = logging.getLogger(__name__) router = APIRouter() @@ -13,7 +13,7 @@ @router.get("/") def read_root(): - return {"config": llm_config["model_name"]} + return {"config": get_completion_config().get("llm_model", "unknown")} @router.get("/health") diff --git a/graphrag/app/routers/ui.py b/graphrag/app/routers/ui.py index 9bd22b8..30971cd 100644 --- a/graphrag/app/routers/ui.py +++ b/graphrag/app/routers/ui.py @@ -51,7 +51,7 @@ from pyTigerGraph import TigerGraphConnection from tools.validation_utils import MapQuestionToSchemaException -from common.config import db_config, graphrag_config, embedding_service, llm_config, service_status, SERVER_CONFIG, get_chat_config, validate_graphname +from common.config import db_config, graphrag_config, embedding_service, llm_config, service_status, get_chat_config, get_completion_config, get_embedding_config, get_multimodal_config, validate_graphname, get_llm_service, resolve_llm_services from common.db.connections import get_db_connection_pwd_manual from common.logs.log import req_id_cv from common.logs.logwriter import LogWriter @@ -181,14 +181,6 @@ def _require_roles(credentials: HTTPBasicCredentials, allowed_roles: set[str]) - return roles -def _create_llm_service(provider: str, config: dict): - """Instantiate an LLM provider, returning None for unsupported providers.""" - try: - return get_llm_service(config) - except Exception: - return None - - def _create_embedding_service(provider: str, config: dict): from common.embeddings.embedding_services import ( OpenAI_Embedding, AzureOpenAI_Ada002, GenAI_Embedding, @@ -1118,7 +1110,7 @@ async def chat( status_code=503, detail=service_status["embedding_store"]["error"] ) - + await websocket.accept() # AUTH with proper error handling and timeout @@ -1817,7 +1809,7 @@ async def save_llm_config( Save LLM configuration and reload services. """ try: - graphname = llm_config_data.pop("graphname", None) + graphname = llm_config_data.get("graphname") llm_access_mode = _resolve_llm_config_access(credentials, graphname) graphs = auth(credentials.username, credentials.password)[0] auth_header = "Basic " + base64.b64encode( @@ -1837,10 +1829,7 @@ async def save_llm_config( # Save and reload in graphrag service from common.config import reload_llm_config - scope = llm_config_data.pop("scope", None) - - # Substitute masked sentinel values with real stored values - _unmask_auth(llm_config_data, llm_config) + candidate, graphname, scope = _prepare_llm_config(llm_config_data) if llm_access_mode == "chatbot_only" or (llm_access_mode == "full" and scope == "graph"): # Per-graph save: write only overrides to graph config file. @@ -1864,20 +1853,34 @@ async def save_llm_config( graph_llm = graph_server_config.setdefault("llm_config", {}) - # Also unmask against the graph's own stored config - _unmask_auth(llm_config_data, graph_llm) - if llm_access_mode == "chatbot_only": # Graph admin: only chat_service svc_keys = ["chat_service"] else: # Superadmin per-graph: all services - svc_keys = ["completion_service", "chat_service", "multimodal_service"] + svc_keys = ["completion_service", "embedding_service", "chat_service", "multimodal_service"] + + # Resolve both candidate and global to get fully expanded configs, + # then store only the delta as the graph override. + resolved_candidate = resolve_llm_services(candidate) + resolved_global = resolve_llm_services(llm_config) for svc_key in svc_keys: - incoming = llm_config_data.get(svc_key) + incoming = candidate.get(svc_key) if incoming: - graph_llm[svc_key] = incoming + rc = resolved_candidate.get(svc_key, {}) + rg = resolved_global.get(svc_key, {}) + # Compute delta: keys whose resolved values differ + delta = {} + for k, v in rc.items(): + if k == "authentication_configuration": + continue + if rg.get(k) != v: + delta[k] = v + if delta: + graph_llm[svc_key] = delta + else: + graph_llm.pop(svc_key, None) else: # Revert to inherit: remove override graph_llm.pop(svc_key, None) @@ -1890,7 +1893,7 @@ async def save_llm_config( result = {"status": "success"} else: # Superadmin global save - result = reload_llm_config(llm_config_data) + result = reload_llm_config(candidate) if result["status"] != "success": raise HTTPException(status_code=500, detail=result["message"]) @@ -1917,247 +1920,167 @@ async def test_llm_config( Test LLM configuration by making actual API calls to the provider. Tests completion, embedding, and multimodal services. """ + test_results = { + "completion": {"status": "not_tested", "message": ""}, + "chatbot": {"status": "not_tested", "message": ""}, + "embedding": {"status": "not_tested", "message": ""}, + "multimodal": {"status": "not_tested", "message": ""} + } try: - graphname = llm_test_config.pop("graphname", None) + graphname = llm_test_config.get("graphname") llm_access_mode = _resolve_llm_config_access(credentials, graphname) - # Substitute masked sentinel values with real stored values - _unmask_auth(llm_test_config, llm_config) - from common import config as cfg - - test_results = { - "completion": {"status": "not_tested", "message": ""}, - "chatbot": {"status": "not_tested", "message": ""}, - "embedding": {"status": "not_tested", "message": ""}, - "multimodal": {"status": "not_tested", "message": ""} - } + + # Build candidate config — same preparation as save + candidate, graphname, scope = _prepare_llm_config(llm_test_config) + # Resolve partial service configs into full configs for testing + # (same resolution logic used when parsing config from disk) + test_configs = resolve_llm_services(candidate) # Graph admins (chatbot_only) can only test chat_service if llm_access_mode == "chatbot_only": - if "chat_service" in llm_test_config: + if "chat_service" in candidate: try: - test_chat_config = llm_test_config["chat_service"].copy() - provider = test_chat_config.get("llm_service", "openai").lower() - model = test_chat_config.get("llm_model", "gpt-4o-mini") - - if "authentication_configuration" not in test_chat_config: - test_chat_config["authentication_configuration"] = {} - - if hasattr(cfg, 'completion_config') and cfg.completion_config: - for key in ["model_kwargs", "prompt_path", "base_url", "token_limit"]: - if key not in test_chat_config and key in cfg.completion_config: - test_chat_config[key] = cfg.completion_config[key] - - if "model_kwargs" not in test_chat_config: - test_chat_config["model_kwargs"] = {"temperature": 0} - if "prompt_path" not in test_chat_config: - test_chat_config["prompt_path"] = "common/prompts/openai_gpt4/" - - llm_service = _create_llm_service(provider, test_chat_config) - if llm_service: - response = llm_service.llm.invoke("Say 'Connection successful' in 2 words") - if not response or not str(response).strip(): - raise ValueError("LLM returned an empty response") - test_results["chatbot"]["status"] = "success" - test_results["chatbot"]["message"] = f"Chatbot LLM ({model}) connected successfully" - else: - test_results["chatbot"]["status"] = "error" - test_results["chatbot"]["message"] = f"Provider '{provider}' not supported" + test_config = test_configs["chat_service"] + model = test_config.get("llm_model", "") + llm_service = get_llm_service(test_config) + response = llm_service.llm.invoke("Say 'Connection successful' in 2 words") + if not response or not str(response).strip(): + raise ValueError("LLM returned an empty response") + test_results["chatbot"]["status"] = "success" + test_results["chatbot"]["message"] = f"Chatbot LLM ({model}) connected successfully" except Exception as e: test_results["chatbot"]["status"] = "error" test_results["chatbot"]["message"] = f"Chatbot test failed: {str(e)}" logger.error(f"Chatbot test failed for graph {graphname}: {str(e)}") - overall_status = "success" if test_results["chatbot"]["status"] == "success" else "error" + chatbot_status = test_results["chatbot"]["status"] + overall_status = "success" if chatbot_status == "success" else ("error" if chatbot_status == "error" else "not_tested") return { "status": overall_status, "message": "Connection test completed", "results": {"chatbot": test_results["chatbot"]} } - # Full access: test all services - # Test Completion Service (Default LLM Model) - if "completion_service" in llm_test_config or "llm_service" in llm_test_config: + # Full access: test all services from the resolved test configs + + # Test Completion Service + if "completion_service" in test_configs: try: - if "completion_service" in llm_test_config: - test_completion_config = llm_test_config["completion_service"].copy() - provider = test_completion_config.get("llm_service", "openai").lower() - model = test_completion_config.get("llm_model", "gpt-4o-mini") - else: - test_completion_config = { - "llm_service": llm_test_config.get("llm_service", "openai"), - "llm_model": llm_test_config.get("llm_model", "gpt-4o-mini"), - "authentication_configuration": llm_test_config.get("authentication_configuration", {}) - } - provider = test_completion_config["llm_service"].lower() - model = test_completion_config["llm_model"] - - # Ensure authentication_configuration exists (may be at top level in single-provider mode) - if "authentication_configuration" not in test_completion_config: - test_completion_config["authentication_configuration"] = llm_test_config.get("authentication_configuration", {}) - - # Merge with existing config to get model_kwargs and prompt_path - if hasattr(cfg, 'completion_config') and cfg.completion_config: - for key in ["model_kwargs", "prompt_path", "base_url", "token_limit"]: - if key not in test_completion_config and key in cfg.completion_config: - test_completion_config[key] = cfg.completion_config[key] - - # Ensure required fields exist - if "model_kwargs" not in test_completion_config: - test_completion_config["model_kwargs"] = {"temperature": 0} - if "prompt_path" not in test_completion_config: - test_completion_config["prompt_path"] = "common/prompts/openai_gpt4/" - - llm_service = _create_llm_service(provider, test_completion_config) - - if llm_service: - response = llm_service.llm.invoke("Say 'Connection successful' in 2 words") - if not response or not str(response).strip(): - raise ValueError("LLM returned an empty response") - test_results["completion"]["status"] = "success" - test_results["completion"]["message"] = f"✅ Default LLM model ({model}) connected successfully" - else: - test_results["completion"]["status"] = "error" - test_results["completion"]["message"] = f"Provider '{provider}' not supported for completion" - + test_config = test_configs["completion_service"] + model = test_config.get("llm_model", "") + llm_service = get_llm_service(test_config) + response = llm_service.llm.invoke("Say 'Connection successful' in 2 words") + if not response or not str(response).strip(): + raise ValueError("LLM returned an empty response") + test_results["completion"]["status"] = "success" + test_results["completion"]["message"] = f"Completion model ({model}) connected successfully" except Exception as e: test_results["completion"]["status"] = "error" - test_results["completion"]["message"] = f"❌ Completion test failed: {str(e)}" + test_results["completion"]["message"] = f"Completion test failed: {str(e)}" logger.error(f"Completion test failed: {str(e)}") - - # Test Chatbot Service (if different model is provided) - if "chatbot_service" in llm_test_config: + + # Test Chatbot Service (only if custom config provided in candidate; + # when inheriting from completion, the completion test already covers it) + if "chat_service" in candidate: try: - test_chatbot_config = llm_test_config["chatbot_service"].copy() - provider = test_chatbot_config.get("llm_service", "openai").lower() - model = test_chatbot_config.get("llm_model", "gpt-4o-mini") - - # Ensure authentication_configuration exists - if "authentication_configuration" not in test_chatbot_config: - test_chatbot_config["authentication_configuration"] = llm_test_config.get("authentication_configuration", {}) - - # Merge with existing config to get model_kwargs and prompt_path - if hasattr(cfg, 'completion_config') and cfg.completion_config: - for key in ["model_kwargs", "prompt_path", "base_url", "token_limit"]: - if key not in test_chatbot_config and key in cfg.completion_config: - test_chatbot_config[key] = cfg.completion_config[key] - - # Ensure required fields exist - if "model_kwargs" not in test_chatbot_config: - test_chatbot_config["model_kwargs"] = {"temperature": 0} - if "prompt_path" not in test_chatbot_config: - test_chatbot_config["prompt_path"] = "common/prompts/openai_gpt4/" - - llm_service = _create_llm_service(provider, test_chatbot_config) - - if llm_service: - response = llm_service.llm.invoke("Say 'Connection successful' in 2 words") - if not response or not str(response).strip(): - raise ValueError("LLM returned an empty response") - test_results["chatbot"]["status"] = "success" - test_results["chatbot"]["message"] = f"✅ Chatbot LLM model ({model}) connected successfully" - else: - test_results["chatbot"]["status"] = "error" - test_results["chatbot"]["message"] = f"Provider '{provider}' not supported for chatbot" - + test_config = test_configs["chat_service"] + model = test_config.get("llm_model", "") + llm_service = get_llm_service(test_config) + response = llm_service.llm.invoke("Say 'Connection successful' in 2 words") + if not response or not str(response).strip(): + raise ValueError("LLM returned an empty response") + test_results["chatbot"]["status"] = "success" + test_results["chatbot"]["message"] = f"Chatbot LLM model ({model}) connected successfully" except Exception as e: test_results["chatbot"]["status"] = "error" - test_results["chatbot"]["message"] = f"❌ Chatbot test failed: {str(e)}" + test_results["chatbot"]["message"] = f"Chatbot test failed: {str(e)}" logger.error(f"Chatbot test failed: {str(e)}") - + # Test Embedding Service - if "embedding_service" in llm_test_config: + if "embedding_service" in test_configs: try: - test_embedding_config = llm_test_config["embedding_service"].copy() - provider = test_embedding_config.get("embedding_model_service", "openai").lower() - model = test_embedding_config.get("model_name", "text-embedding-3-small") - - # Ensure authentication_configuration exists - if "authentication_configuration" not in test_embedding_config: - test_embedding_config["authentication_configuration"] = llm_test_config.get("authentication_configuration", {}) - - # Merge with existing config - if hasattr(cfg, 'embedding_config') and cfg.embedding_config: - for key in ["dimensions", "token_limit"]: - if key not in test_embedding_config and key in cfg.embedding_config: - test_embedding_config[key] = cfg.embedding_config[key] - - embedding_service_test = _create_embedding_service(provider, test_embedding_config) - - if embedding_service_test: - # Test with a simple text - embeddings = embedding_service_test.embed_query("test connection") - if embeddings and len(embeddings) > 0: - test_results["embedding"]["status"] = "success" - test_results["embedding"]["message"] = f"✅ Embedding model ({model}) connected successfully" - else: - test_results["embedding"]["status"] = "error" - test_results["embedding"]["message"] = "❌ Embedding returned empty result" - else: - test_results["embedding"]["status"] = "error" - test_results["embedding"]["message"] = f"Provider '{provider}' not supported for embeddings" - + test_config = test_configs["embedding_service"] + provider = test_config.get("embedding_model_service", "openai").lower() + model = test_config.get("model_name", "") + embedding_service_test = _create_embedding_service(provider, test_config) + if not embedding_service_test: + raise ValueError(f"Provider '{provider}' not supported for embeddings") + embeddings = embedding_service_test.embed_query("test connection") + if not embeddings or len(embeddings) == 0: + raise ValueError("Embedding returned empty result") + test_results["embedding"]["status"] = "success" + test_results["embedding"]["message"] = f"Embedding model ({model}) connected successfully" except Exception as e: test_results["embedding"]["status"] = "error" - test_results["embedding"]["message"] = f"❌ Embedding test failed: {str(e)}" + test_results["embedding"]["message"] = f"Embedding test failed: {str(e)}" logger.error(f"Embedding test failed: {str(e)}") - - # Test Multimodal Service - if "multimodal_service" in llm_test_config: + + # Test Multimodal Service — verifies the model supports vision + # When multimodal_service is absent (inheriting), use completion_service + # config — that's what will be used at runtime after save. + multimodal_config = test_configs.get("multimodal_service") or test_configs.get("completion_service") + if multimodal_config: + model = "" try: - test_multimodal_config = llm_test_config["multimodal_service"].copy() - provider = test_multimodal_config.get("llm_service", "openai").lower() - model = test_multimodal_config.get("llm_model", "gpt-4o") - - # Ensure authentication_configuration exists - if "authentication_configuration" not in test_multimodal_config: - test_multimodal_config["authentication_configuration"] = llm_test_config.get("authentication_configuration", {}) - - # Merge with existing config to get model_kwargs and prompt_path - if hasattr(cfg, 'multimodal_config') and cfg.multimodal_config: - for key in ["model_kwargs", "prompt_path", "base_url", "token_limit"]: - if key not in test_multimodal_config and key in cfg.multimodal_config: - test_multimodal_config[key] = cfg.multimodal_config[key] - elif hasattr(cfg, 'completion_config') and cfg.completion_config: - # Fallback to completion config - for key in ["model_kwargs", "prompt_path", "base_url", "token_limit"]: - if key not in test_multimodal_config and key in cfg.completion_config: - test_multimodal_config[key] = cfg.completion_config[key] - - # Ensure required fields exist - if "model_kwargs" not in test_multimodal_config: - test_multimodal_config["model_kwargs"] = {"temperature": 0} - if "prompt_path" not in test_multimodal_config: - test_multimodal_config["prompt_path"] = "common/prompts/openai_gpt4/" - - multimodal_service = _create_llm_service(provider, test_multimodal_config) - - if multimodal_service: - response = multimodal_service.llm.invoke("Say 'Connection successful' in 2 words") - if not response or not str(response).strip(): - raise ValueError("Multimodal LLM returned an empty response") - test_results["multimodal"]["status"] = "success" - test_results["multimodal"]["message"] = f"✅ Multimodal model ({model}) connected successfully" + from langchain_core.messages import HumanMessage + test_config = multimodal_config + model = test_config.get("llm_model", "") + llm_service = get_llm_service(test_config) + # Send a small 20x20 red PNG to verify the model accepts image input. + # Some providers (e.g. Bedrock) reject 1x1 images. + TEST_IMAGE_B64 = ( + "iVBORw0KGgoAAAANSUhEUgAAABQAAAAUCAIAAAAC64paAAAAKUlEQVR4" + "nGP8z0A+YKJAL8OoZhIBE6kakMGoZhIBE6kakMGoZhIBRZoBIpwBJy3" + "phGMAAAAASUVORK5CYII=" + ) + provider = test_config.get("llm_service", "").lower() + # Google GenAI/VertexAI only accept image_url format; + # Bedrock/Anthropic-native providers prefer type:"image" with source. + if provider in ("genai", "vertexai"): + image_block = { + "type": "image_url", + "image_url": {"url": f"data:image/png;base64,{TEST_IMAGE_B64}"}, + } else: - test_results["multimodal"]["status"] = "error" - test_results["multimodal"]["message"] = f"Provider '{provider}' not supported for multimodal" - + image_block = { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": TEST_IMAGE_B64, + }, + } + vision_message = HumanMessage( + content=[ + {"type": "text", "text": "Describe this image in one word."}, + image_block, + ] + ) + response = llm_service.llm.invoke([vision_message]) + if not response or not str(response).strip(): + raise ValueError("Multimodal LLM returned an empty response") + test_results["multimodal"]["status"] = "success" + test_results["multimodal"]["message"] = f"Multimodal model ({model}) connected and supports vision" except Exception as e: test_results["multimodal"]["status"] = "error" - test_results["multimodal"]["message"] = f"❌ Multimodal test failed: {str(e)}" + test_results["multimodal"]["message"] = ( + f"Multimodal test failed for model ({model}): {str(e)}. " + f"Please ensure the model supports vision input (e.g., GPT-4o, Claude 3.5+, Gemini)." + ) logger.error(f"Multimodal test failed: {str(e)}") - + # Determine overall status all_success = all(result["status"] == "success" for result in test_results.values() if result["status"] != "not_tested") any_error = any(result["status"] == "error" for result in test_results.values()) - + overall_status = "success" if all_success and not any_error else "error" if any_error else "partial" - + return { "status": overall_status, "message": "Connection test completed", "results": test_results } - + except HTTPException: raise except Exception as e: @@ -2172,37 +2095,131 @@ async def test_llm_config( MASKED_SECRET = "********" +def _prepare_llm_config(llm_config_data: dict): + """ + Shared preparation for both save and test endpoints. + + 1. Pop metadata keys (graphname, scope) + 2. Unmask MASKED_SECRET values using current config from disk + 3. Strip null service values (null = inherit, key should be absent) + + Returns (candidate_config, graphname, scope). + The candidate_config is save-ready. Top-level parameters (authentication_configuration, + region_name) are promoted from completion_service if missing and redundant per-service + copies are stripped. reload_llm_config() and resolve_llm_services() handle injecting + them back into service configs at runtime. + """ + graphname = llm_config_data.pop("graphname", None) + scope = llm_config_data.pop("scope", None) + + # Resolve masked secrets from disk before modifying the payload + _unmask_auth(llm_config_data, graphname) + + # Strip null values — null means "inherit from base", key should be absent + for key in list(llm_config_data.keys()): + if llm_config_data[key] is None: + del llm_config_data[key] + + # Normalize auth: ensure top-level authentication_configuration exists. + # If missing, promote from completion_service so future config files + # always have auth at the top level. + if "authentication_configuration" not in llm_config_data: + completion_svc = llm_config_data.get("completion_service") + if isinstance(completion_svc, dict) and "authentication_configuration" in completion_svc: + llm_config_data["authentication_configuration"] = completion_svc["authentication_configuration"] + + # Strip per-service auth if identical to top-level (redundant on disk; + # reload_llm_config injects top-level auth into services on load) + top_auth = llm_config_data.get("authentication_configuration") + if top_auth: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + svc = llm_config_data.get(svc_key) + if isinstance(svc, dict) and svc.get("authentication_configuration") == top_auth: + del svc["authentication_configuration"] + + # Normalize region_name: promote from completion_service to top level, + # strip per-service copies if identical (same pattern as auth). + if "region_name" not in llm_config_data: + completion_svc = llm_config_data.get("completion_service") + if isinstance(completion_svc, dict) and "region_name" in completion_svc: + llm_config_data["region_name"] = completion_svc["region_name"] + + top_region = llm_config_data.get("region_name") + if top_region: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + svc = llm_config_data.get(svc_key) + if isinstance(svc, dict) and svc.get("region_name") == top_region: + del svc["region_name"] + + return llm_config_data, graphname, scope + + + def _mask_secret_values(auth_config: dict) -> dict: """Replace all values in an authentication_configuration dict with the masked sentinel.""" return {k: MASKED_SECRET for k in auth_config} -def _unmask_auth(incoming: dict, stored_config: dict): +def _unmask_auth(incoming: dict, graphname: str = None): """ In-place: replace MASKED_SECRET values in incoming authentication_configuration - with the real values from stored_config. + with real values resolved through the full config chain via getters. - Works on both top-level and per-service authentication_configuration. + Uses get_xxx_config(graphname) which resolves: + Layer 1 (base) → Layer 2 (global service) → Layer 3 (graph base) → Layer 4 (graph service) """ - def _unmask_dict(incoming_auth, stored_auth): - if not isinstance(incoming_auth, dict) or not isinstance(stored_auth, dict): - return - for k, v in incoming_auth.items(): - if v == MASKED_SECRET: - incoming_auth[k] = stored_auth.get(k, "") + # Use completion_service as the primary source for top-level auth resolution + # (backward compat: base bootstraps from completion_service) + resolved_completion = get_completion_config(graphname) + + # Resolved configs for each service (lazy — only built if needed) + _resolved_cache = {} + def _get_resolved(svc_key): + if svc_key not in _resolved_cache: + getter = { + "completion_service": get_completion_config, + "embedding_service": get_embedding_config, + "chat_service": get_chat_config, + "multimodal_service": get_multimodal_config, + }.get(svc_key) + if getter: + result = getter(graphname) + _resolved_cache[svc_key] = result if result else {} + else: + _resolved_cache[svc_key] = {} + return _resolved_cache[svc_key] + + def _resolve_real_value(key, svc_key=None): + """Find real value for an auth key using the resolved config chain.""" + # Check the specific service first + if svc_key: + resolved = _get_resolved(svc_key) + val = resolved.get("authentication_configuration", {}).get(key, "") + if val and val != MASKED_SECRET: + return val + # Fallback to completion (which has full base resolution) + val = resolved_completion.get("authentication_configuration", {}).get(key, "") + if val and val != MASKED_SECRET: + return val + return "" # Top-level authentication_configuration if "authentication_configuration" in incoming: - stored_top = stored_config.get("authentication_configuration", {}) - _unmask_dict(incoming["authentication_configuration"], stored_top) + auth = incoming["authentication_configuration"] + if isinstance(auth, dict): + for k, v in auth.items(): + if v == MASKED_SECRET: + auth[k] = _resolve_real_value(k) # Per-service authentication_configuration for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: svc = incoming.get(svc_key) - if svc and "authentication_configuration" in svc: - stored_svc = stored_config.get(svc_key, {}) - stored_svc_auth = stored_svc.get("authentication_configuration", {}) - _unmask_dict(svc["authentication_configuration"], stored_svc_auth) + if isinstance(svc, dict) and "authentication_configuration" in svc: + auth = svc["authentication_configuration"] + if isinstance(auth, dict): + for k, v in auth.items(): + if v == MASKED_SECRET: + auth[k] = _resolve_real_value(k, svc_key) def _strip_auth(config: dict) -> dict: @@ -2248,7 +2265,7 @@ async def get_config( graph_chat_service["authentication_configuration"] = _mask_secret_values(graph_chat_service["authentication_configuration"]) # Global chat info for "Inherited from" display - global_chat = llm_config.get("chat_service", llm_config.get("completion_service", {})) + global_chat = get_chat_config() global_chat_info = { "llm_service": global_chat.get("llm_service", ""), "llm_model": global_chat.get("llm_model", ""), diff --git a/graphrag/app/supportai/supportai_ingest.py b/graphrag/app/supportai/supportai_ingest.py index 4ba69f1..e312f25 100644 --- a/graphrag/app/supportai/supportai_ingest.py +++ b/graphrag/app/supportai/supportai_ingest.py @@ -39,7 +39,7 @@ def chunk_document(self, document, chunker, chunker_params): from common.chunkers.character_chunker import CharacterChunker chunker = CharacterChunker( - chunker_params["chunk_size"], chunker_params.get("overlap", 0) + chunker_params.get("chunk_size", 0), chunker_params.get("overlap_size", -1) ) elif chunker.lower() == "semantic": from common.chunkers.semantic_chunker import SemanticChunker @@ -54,7 +54,7 @@ def chunk_document(self, document, chunker, chunker_params): chunker = HTMLChunker( chunk_size=chunker_params.get("chunk_size", 0), - chunk_overlap=chunker_params.get("overlap_size", 0), + overlap_size=chunker_params.get("overlap_size", -1), headers=chunker_params.get("headers", None), ) elif chunker.lower() == "markdown": @@ -62,7 +62,7 @@ def chunk_document(self, document, chunker, chunker_params): chunker = MarkdownChunker( chunk_size=chunker_params.get("chunk_size", 0), - chunk_overlap=chunker_params.get("overlap_size", 0) + overlap_size=chunker_params.get("overlap_size", -1) ) else: raise ValueError(f"Chunker {chunker} not supported") diff --git a/graphrag/tests/test_character_chunker.py b/graphrag/tests/test_character_chunker.py index f132ce7..8b60b06 100644 --- a/graphrag/tests/test_character_chunker.py +++ b/graphrag/tests/test_character_chunker.py @@ -5,7 +5,7 @@ class TestCharacterChunker(unittest.TestCase): def test_chunk_without_overlap(self): """Test chunking without overlap.""" - chunker = CharacterChunker(chunk_size=4) + chunker = CharacterChunker(chunk_size=4, overlap_size=0) input_string = "abcdefghijkl" expected_chunks = ["abcd", "efgh", "ijkl"] self.assertEqual(chunker.chunk(input_string), expected_chunks) @@ -33,7 +33,7 @@ def test_empty_input_string(self): def test_input_shorter_than_chunk_size(self): """Test input string shorter than chunk size.""" - chunker = CharacterChunker(chunk_size=10) + chunker = CharacterChunker(chunk_size=10, overlap_size=0) input_string = "abc" expected_chunks = ["abc"] self.assertEqual(chunker.chunk(input_string), expected_chunks) @@ -46,24 +46,27 @@ def test_last_chunk_shorter_than_chunk_size(self): self.assertEqual(chunker.chunk(input_string), expected_chunks) def test_chunk_size_equals_overlap_size(self): - """Test when chunk size equals overlap size.""" + """Test when chunk size equals overlap size raises on chunk().""" + chunker = CharacterChunker(chunk_size=4, overlap_size=4) with self.assertRaises(ValueError): - CharacterChunker(chunk_size=4, overlap_size=4) + chunker.chunk("abcdefgh") def test_overlap_larger_than_chunk_should_raise_error(self): - """Test initialization with overlap size larger than chunk size should raise an error.""" + """Test overlap size larger than chunk size raises on chunk().""" + chunker = CharacterChunker(chunk_size=3, overlap_size=4) with self.assertRaises(ValueError): - CharacterChunker(chunk_size=3, overlap_size=4) + chunker.chunk("abcdefgh") - def test_chunk_size_zero_should_raise_error(self): - """Test initialization with a chunk size of zero should raise an error.""" - with self.assertRaises(ValueError): - CharacterChunker(chunk_size=0, overlap_size=0) + def test_chunk_size_zero_uses_default(self): + """Test that chunk_size=0 falls back to default values.""" + chunker = CharacterChunker(chunk_size=0) + self.assertEqual(chunker.chunk_size, 2048) + self.assertEqual(chunker.overlap_size, 256) - def test_chunk_size_negative_should_raise_error(self): - """Test initialization with a negative chunk size.""" - with self.assertRaises(ValueError): - CharacterChunker(chunk_size=-1) + def test_chunk_size_negative_uses_default(self): + """Test that negative chunk_size falls back to default values.""" + chunker = CharacterChunker(chunk_size=-1) + self.assertEqual(chunker.chunk_size, 2048) if __name__ == "__main__": diff --git a/report-service/app/routers/root.py b/report-service/app/routers/root.py index 2e618a2..01a567a 100644 --- a/report-service/app/routers/root.py +++ b/report-service/app/routers/root.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Request, Depends, Response from typing import Annotated -from common.config import get_completion_config, get_llm_service +from common.config import get_completion_config, get_embedding_config, get_llm_service from common.py_schemas import ReportCreationRequest from report_agent.agent import TigerGraphReportAgent @@ -19,17 +19,15 @@ @router.get("/") def read_root(): - return {"config": llm_config["model_name"]} + return {"config": get_completion_config().get("llm_model", "unknown")} @router.get("/health") async def health(): return { "status": "healthy", - "llm_completion_model": llm_config["completion_service"]["llm_model"], - "embedding_service": llm_config["embedding_service"][ - "embedding_model_service" - ], + "llm_completion_model": get_completion_config().get("llm_model", "unknown"), + "embedding_service": get_embedding_config().get("embedding_model_service", "unknown"), } def retrieve_template(template_name: str):