diff --git a/README.md b/README.md index dce6f98..a6261f5 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,10 @@ - [More Detailed Configurations](#more-detailed-configurations) - [DB configuration](#db-configuration) - [GraphRAG configuration](#graphrag-configuration) - - [Chat configuration](#chat-configuration) + - [Chat History Configuration](#chat-history-configuration) - [LLM provider configuration](#llm-provider-configuration) + - [Supported parameters](#supported-parameters) + - [Provider examples](#provider-examples) - [OpenAI](#openai) - [Google GenAI](#google-genai) - [GCP VertexAI](#gcp-vertexai) @@ -53,7 +55,7 @@ --- ## Releases -* **2/28/2025**: GraphRAG v1.2.0 released. Added Admin UI for graph initialization, document ingestion, and knowledge graph rebuild, along with many other improvements and bug fixes. See [release notes](https://github.com/tigergraph/graphrag/releases/tag/v1.2.0) for details. +* **2/28/2026**: GraphRAG v1.2.0 released. Added Admin UI for graph initialization, document ingestion, and knowledge graph rebuild, along with many other improvements and bug fixes. See [release notes](https://github.com/tigergraph/graphrag/releases/tag/v1.2.0) for details. * **9/22/2025**: GraphRAG is available now officially v1.1 (v1.1.0). AWS Bedrock support is completed with BDA integration for multimodal document ingestion. See [release notes](https://github.com/tigergraph/graphrag/releases/tag/v1.1.0) for details. * **6/18/2025**: GraphRAG is available now officially v1.0 (v1.0.0). TigerGraph database is the only graph and vector storagge supported. Please see [Release Notes](https://docs.tigergraph.com/tg-graphrag/current/release-notes/) for details. @@ -103,7 +105,7 @@ Organizing the data as a knowledge graph allows a chatbot to access accurate, fa ### Quick Start #### Use TigerGraph Docker-Based Instance -Set your LLM Provider (supported `openai` or `gemini`) api key as environment varabiel LLM_API_KEY and use the following command for a one-step quick deployment with TigerGraph Community Edition and default configurations: +Set your LLM Provider (supported `openai` or `gemini`) api key as environment variable LLM_API_KEY and use the following command for a one-step quick deployment with TigerGraph Community Edition and default configurations: ``` curl -k https://raw.githubusercontent.com/tigergraph/graphrag/refs/heads/main/docs/tutorials/setup_graphrag.sh | bash ``` @@ -198,10 +200,10 @@ Run command `docker compose down` and wait for all the service containers to sto If you prefer to start a TigerGraph Community Edition instance without a license key, please make sure the container can be accessed from the GraphRAG containers by add `--network graphrag_default`: ``` -docker run -d -p 14240:14240 --name tigergraph --ulimit nofile=1000000:1000000 --init --network graphrag_default -t tigergraph/community:4.2.1 +docker run -d -p 14240:14240 --name tigergraph --ulimit nofile=1000000:1000000 --init --network graphrag_default -t tigergraph/community:4.2.2 ``` -> Use **tigergraph/tigergraph:4.2.1** if Enterprise Edition is preferred. +> Use **tigergraph/tigergraph:4.2.2** if Enterprise Edition is preferred. > Setting up **DNS** or `/etc/hosts` properly is an alternative solution to ensure contains can connect to each other. > Or modify`hostname` in `db_config` section of `configs/server_config.json` and replace `http://tigergraph` to your tigergraph container IP address, e.g., `http://172.19.0.2`. @@ -419,6 +421,8 @@ Copy the below into `configs/server_config.json` and edit the `hostname` and `ge "hostname": "http://tigergraph", "restppPort": "9000", "gsPort": "14240", + "username": "tigergraph", + "password": "tigergraph", "getToken": false, "default_timeout": 300, "default_mem_threshold": 5000, @@ -427,23 +431,65 @@ Copy the below into `configs/server_config.json` and edit the `hostname` and `ge } ``` +| Parameter | Type | Default | Description | +| --- | --- | --- | --- | +| `hostname` | string | `"http://tigergraph"` | TigerGraph server URL. | +| `restppPort` | string | `"9000"` | RESTPP port for TigerGraph API requests. | +| `gsPort` | string | `"14240"` | GSQL port for TigerGraph admin operations. | +| `username` | string | `"tigergraph"` | TigerGraph database username. | +| `password` | string | `"tigergraph"` | TigerGraph database password. | +| `getToken` | bool | `false` | Set to `true` if token authentication is enabled on TigerGraph. | +| `graphname` | string | `""` | Default graph name. Usually left empty (selected at runtime). | +| `apiToken` | string | `""` | Pre-generated API token. If set, token-based auth is used instead of username/password. | +| `default_timeout` | int | `300` | Default query timeout in seconds. | +| `default_mem_threshold` | int | `5000` | Memory threshold (MB) for query execution. | +| `default_thread_limit` | int | `8` | Max threads for query execution. | + ### GraphRAG configuration Copy the below code into `configs/server_config.json`. You shouldn’t need to change anything unless you change the port of the chat history service in the Docker Compose file. -`reuse_embedding` to `true` will skip re-generating the embedding if it already exists. -`ecc` and `chat_history_api` are the addresses of internal components of GraphRAG.If you use the Docker Compose file as is, you don’t need to change them. - ```json { "graphrag_config": { "reuse_embedding": false, - "ecc": "http://eventual-consistency-service:8001", - "chat_history_api": "http://chat-history:8002" + "ecc": "http://graphrag-ecc:8001", + "chat_history_api": "http://chat-history:8002", + "chunker": "semantic", + "extractor": "llm", + "top_k": 5, + "num_hops": 2 } } ``` -### Chat configuration +| Parameter | Type | Default | Description | +| --- | --- | --- | --- | +| `reuse_embedding` | bool | `true` | Reuse existing embeddings instead of regenerating them. | +| `ecc` | string | `"http://graphrag-ecc:8001"` | URL of the knowledge graph build service. No change needed when using the provided Docker Compose file. | +| `chat_history_api` | string | `"http://chat-history:8002"` | URL of the chat history service. No change needed when using the provided Docker Compose file. | +| `chunker` | string | `"semantic"` | Default document chunker. Options: `semantic`, `character`, `regex`, `markdown`, `html`, `recursive`. | +| `extractor` | string | `"llm"` | Entity extraction method. Options: `llm`, `graphrag`. | +| `chunker_config` | object | `{}` | Chunker-specific settings. For `character`/`markdown`/`recursive`: `chunk_size`, `overlap_size`. For `semantic`: `method`, `threshold`. For `regex`: `pattern`. | +| `top_k` | int | `5` | Number of top similar results to retrieve during search. | +| `num_hops` | int | `2` | Number of graph hops to traverse when expanding retrieved results. | +| `num_seen_min` | int | `2` | Minimum occurrence threshold for a node to be included in search results. | +| `community_level` | int | `2` | Community hierarchy level used for community search. | +| `chunk_only` | bool | `true` | If true, hybrid search only retrieves document chunks (not entities). | +| `doc_only` | bool | `false` | If true, hybrid search retrieves whole documents instead of chunks. | +| `with_chunk` | bool | `true` | If true, community search also includes document chunks in results. | +| `doc_process_switch` | bool | `true` | Enable/disable document processing during knowledge graph build. | +| `entity_extraction_switch` | bool | same as `doc_process_switch` | Enable/disable entity extraction during knowledge graph build. | +| `community_detection_switch` | bool | same as `entity_extraction_switch` | Enable/disable community detection during knowledge graph build. | +| `load_batch_size` | int | `500` | Batch size for document loading. | +| `upsert_delay` | int | `0` | Delay in seconds between loading batches. | +| `default_concurrency` | int | `10` | Base concurrency level for parallel processing. Configurable per graph. | +| `process_interval_seconds` | int | `300` | Interval (seconds) for background consistency processing. | +| `cleanup_interval_seconds` | int | `300` | Interval (seconds) for background cleanup. | +| `checker_batch_size` | int | `100` | Batch size for background consistency checking. | +| `enable_consistency_checker` | bool | `false` | Enable the background consistency checker. | +| `graph_names` | list | `[]` | Graphs to monitor when consistency checker is enabled. | + +### Chat History Configuration Copy the below code into `configs/server_config.json`. You shouldn’t need to change anything unless you change the port of the chat history service in the Docker Compose file. ```json @@ -464,6 +510,99 @@ Copy the below code into `configs/server_config.json`. You shouldn’t need to c ### LLM provider configuration In the `llm_config` section of `configs/server_config.json` file, copy JSON config template from below for your LLM provider, and fill out the appropriate fields. Only one provider is needed. +#### Structure overview + +```json +{ + "llm_config": { + "authentication_configuration": { + "OPENAI_API_KEY": "sk-..." + }, + "completion_service": { + "llm_service": "openai", + "llm_model": "gpt-4.1-mini", + "model_kwargs": { "temperature": 0 }, + "prompt_path": "./common/prompts/openai_gpt4/" + }, + "embedding_service": { + "embedding_model_service": "openai", + "model_name": "text-embedding-3-small" + }, + "chat_service": { + "llm_model": "gpt-4.1" + }, + "multimodal_service": { + "llm_service": "openai", + "llm_model": "gpt-4o" + } + } +} +``` + +- `authentication_configuration`: Shared credentials for all services. Service-level keys take precedence over top-level keys. +- `completion_service` **(required)**: LLM for knowledge graph building and query generation. +- `embedding_service` **(required)**: Text embedding model for document indexing. +- `chat_service` *(optional)*: Chatbot LLM override. Missing keys are inherited from `completion_service`. Configurable per graph. +- `multimodal_service` *(optional)*: Vision/image model for document ingestion. + +#### Supported parameters + +**Top-level `llm_config` parameters:** + +| Parameter | Type | Default | Description | +| --- | --- | --- | --- | +| `authentication_configuration` | object | — | Shared authentication credentials for all services. Service-level values take precedence. | +| `token_limit` | int | — | Maximum token count for retrieved context. Inherited by all services if not set at service level. `0` or omitted means unlimited. | + +**`completion_service` parameters:** + +| Parameter | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `llm_service` | string | **Yes** | — | LLM provider. Options: `openai`, `azure`, `vertexai`, `genai`, `bedrock`, `sagemaker`, `groq`, `ollama`, `huggingface`, `watsonx`. | +| `llm_model` | string | **Yes** | — | Model name for knowledge graph building and query generation (e.g., `gpt-4.1-mini`). | +| `authentication_configuration` | object | No | inherited from top-level | Service-specific auth credentials. Overrides top-level values. | +| `model_kwargs` | object | No | `{}` | Additional model parameters (e.g., `{"temperature": 0}`). | +| `prompt_path` | string | No | `"./common/prompts/openai_gpt4/"` | Path to prompt template files. | +| `base_url` | string | No | — | Custom API endpoint URL. | +| `token_limit` | int | No | inherited from top-level | Max token count for retrieved context sent to the LLM. `0` or omitted means unlimited. | + +**`embedding_service` parameters:** + +| Parameter | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `embedding_model_service` | string | **Yes** | — | Embedding provider. Options: `openai`, `azure`, `vertexai`, `genai`, `bedrock`, `ollama`. | +| `model_name` | string | **Yes** | — | Embedding model name (e.g., `text-embedding-3-small`). | +| `dimensions` | int | No | `1536` | Embedding vector dimensions. | +| `authentication_configuration` | object | No | inherited from top-level | Service-specific auth credentials. Overrides top-level values. | + +**`chat_service` parameters (optional):** + +Chatbot LLM override. If not configured, inherits from `completion_service`. Configurable per graph via the UI. + +| Parameter | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `llm_service` | string | No | same as completion | LLM provider for the chatbot. | +| `llm_model` | string | No | same as completion | Model name for the chatbot. | +| `authentication_configuration` | object | No | inherited from completion | Auth credentials. Service-level values take precedence. | +| `model_kwargs` | object | No | inherited from completion | Additional model parameters (e.g., `{"temperature": 0}`). | +| `prompt_path` | string | No | inherited from completion | Path to prompt template files. | +| `base_url` | string | No | inherited from completion | Custom API endpoint URL. | +| `token_limit` | int | No | inherited from completion | Max token count for retrieved context sent to the chatbot LLM. `0` or omitted means unlimited. | + +**`multimodal_service` parameters (optional):** + +Vision model for image processing during document ingestion. If not configured, inherits from `completion_service` with a default vision model derived per provider. + +| Parameter | Type | Required | Default | Description | +| --- | --- | --- | --- | --- | +| `llm_service` | string | No | inherited from completion | Multimodal LLM provider. | +| `llm_model` | string | No | auto-derived per provider | Vision model name (e.g., `gpt-4o`). | +| `authentication_configuration` | object | No | inherited from completion | Service-specific auth credentials. Overrides top-level values. | +| `model_kwargs` | object | No | inherited from completion | Additional model parameters. | +| `prompt_path` | string | No | inherited from completion | Path to prompt template files. | + +#### Provider examples + #### OpenAI In addition to the `OPENAI_API_KEY`, `llm_model` and `model_name` can be edited to match your specific configuration details. diff --git a/common/config.py b/common/config.py index 2b58581..371e303 100644 --- a/common/config.py +++ b/common/config.py @@ -13,9 +13,18 @@ # limitations under the License. import json +import logging import os +import re +import threading from fastapi.security import HTTPBasic + +logger = logging.getLogger(__name__) + +# Lock for all reads/writes to SERVER_CONFIG to prevent concurrent modifications +# from different endpoints (LLM, DB, GraphRAG config saves) from overwriting each other. +_config_file_lock = threading.Lock() from pyTigerGraph import TigerGraphConnection from common.embeddings.embedding_services import ( @@ -40,7 +49,6 @@ OpenAI, IBMWatsonX ) -from common.logs.logwriter import LogWriter from common.session import SessionHandler from common.status import StatusManager @@ -51,6 +59,202 @@ # Configs SERVER_CONFIG = os.getenv("SERVER_CONFIG", "configs/server_config.json") + + +_VALID_GRAPHNAME_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") + + +def validate_graphname(graphname: str) -> str: + """Validate graphname to prevent path traversal. + + Raises ValueError if graphname contains path separators or other unsafe characters. + Returns the graphname unchanged if valid. + """ + if not graphname: + return graphname + if not _VALID_GRAPHNAME_RE.match(graphname): + raise ValueError(f"Invalid graph name: {graphname!r}") + return graphname + + +def _load_graph_config(graphname): + """Load entire graph-specific server config overrides, or empty dict if none exist.""" + if not graphname: + return {} + validate_graphname(graphname) + graph_path = f"configs/graph_configs/{graphname}/server_config.json" + if not os.path.exists(graph_path): + return {} + with open(graph_path, "r") as f: + return json.load(f) + + +def _load_graph_llm_config(graphname): + """Load graph-specific llm_config overrides, or empty dict if none exist.""" + return _load_graph_config(graphname).get("llm_config", {}) + + +def _resolve_service_config(base_config, override=None): + """ + Merge a service override on top of a base config (typically completion_service). + + - Starts with base_config as the foundation + - Overlays override keys on top (if provided) + - authentication_configuration: override keys take precedence, + missing keys fall back to base auth + """ + result = base_config.copy() + + if not override: + return result + + for key, value in override.items(): + if key == "authentication_configuration": + continue # Handle separately below + result[key] = value + + if "authentication_configuration" in override: + merged_auth = result.get("authentication_configuration", {}).copy() + merged_auth.update(override["authentication_configuration"]) + result["authentication_configuration"] = merged_auth + # else: keep base's auth + + return result + + +def get_completion_config(graphname=None): + """ + Return completion_service config for the given graph. + + Resolution: merge graph-specific completion_service overrides on top of + global completion_service. Graph configs only store overrides, so unchanged + fields always inherit the latest global values. + """ + graph_llm = _load_graph_llm_config(graphname) + override = graph_llm.get("completion_service") + if override: + logger.debug(f"[get_completion_config] graph={graphname} using graph-specific overrides") + result = _resolve_service_config(llm_config["completion_service"], override) + + if graphname: + result["graphname"] = graphname + + return result + + +DEFAULT_MULTIMODAL_MODELS = { + "openai": "gpt-4o-mini", + "azure": "gpt-4o-mini", + "genai": "gemini-3.5-flash", + "vertexai": "gemini-3.5-flash", + "bedrock": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", +} + + +def get_chat_config(graphname=None): + """ + Return the chatbot LLM config for the given graph. + + Resolution chain: + 1. Start with global completion_service + 2. Merge graph-specific completion_service overrides (shared base for all services) + 3. Merge chat_service overrides (graph-specific > global > none) + + This ensures graph-level completion_service changes (e.g. prompt_path) + propagate to the chatbot config as well. + """ + graph_llm = _load_graph_llm_config(graphname) + + # Build per-graph base: global completion + graph completion overrides + base = _resolve_service_config( + llm_config["completion_service"], + graph_llm.get("completion_service"), + ) + + # Find chat override: graph-specific > global > None + chat_override = graph_llm.get("chat_service") + if chat_override: + logger.debug(f"[get_chat_config] graph={graphname} using graph-specific chat_service") + elif "chat_service" in llm_config: + chat_override = llm_config["chat_service"] + logger.debug(f"[get_chat_config] graph={graphname} using global chat_service") + else: + logger.debug(f"[get_chat_config] graph={graphname} falling back to completion_service") + + result = _resolve_service_config(base, chat_override) + + if graphname: + result["graphname"] = graphname + + return result + + +def _apply_default_multimodal_model(override, provider): + """Apply default vision model if llm_model is not explicitly set.""" + if override and "llm_model" not in override: + default_model = DEFAULT_MULTIMODAL_MODELS.get(provider) + if default_model: + return {**override, "llm_model": default_model} + return override + if not override: + default_model = DEFAULT_MULTIMODAL_MODELS.get(provider) + if default_model: + return {"llm_model": default_model} + return None + return override + + +def get_multimodal_config(graphname=None): + """ + Return the multimodal/vision config for the given graph. + + Resolution chain: + 1. Start with global completion_service + 2. Merge graph-specific completion_service overrides (shared base) + 3. Merge multimodal_service overrides (graph-specific > global > default model) + + Returns the merged config, or None if the provider doesn't support vision. + """ + graph_llm = _load_graph_llm_config(graphname) + + # Build per-graph base: global completion + graph completion overrides + base = _resolve_service_config( + llm_config["completion_service"], + graph_llm.get("completion_service"), + ) + + # Find multimodal override: graph-specific > global > None + mm_override = graph_llm.get("multimodal_service") + if mm_override is None and "multimodal_service" in llm_config: + mm_override = llm_config["multimodal_service"] + + provider = (mm_override or {}).get("llm_service", base.get("llm_service", "")).lower() + mm_override = _apply_default_multimodal_model(mm_override, provider) + + if mm_override is None: + return None + + return _resolve_service_config(base, mm_override) + + +def get_graphrag_config(graphname=None): + """ + Return graphrag_config for the given graph. + + Resolution: merge graph-specific graphrag_config overrides on top of + global graphrag_config. Graph configs only store overrides, so unchanged + fields always inherit the latest global values. + """ + graph_cfg = _load_graph_config(graphname) + override = graph_cfg.get("graphrag_config") + if not override: + return graphrag_config + # Merge: global as base, graph overrides on top (simple dict merge, no auth logic) + result = graphrag_config.copy() + result.update(override) + return result + + PATH_PREFIX = os.getenv("PATH_PREFIX", "") PRODUCTION = os.getenv("PRODUCTION", "false").lower() == "true" @@ -83,64 +287,61 @@ if llm_config is None: raise Exception("llm_config is not found in SERVER_CONFIG") -completion_config = llm_config.get("completion_service") -if completion_config is None: +# Inject authentication_configuration into service configs so they have everything they need. +# Rule: service-level (lower) auth keys take precedence; missing keys fall back to top-level (upper). +if "authentication_configuration" in llm_config: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + if svc_key in llm_config: + svc = llm_config[svc_key] + if "authentication_configuration" not in svc: + svc["authentication_configuration"] = llm_config["authentication_configuration"].copy() + else: + # Merge: top-level as base, service-level on top (service-level wins) + merged = llm_config["authentication_configuration"].copy() + merged.update(svc["authentication_configuration"]) + svc["authentication_configuration"] = merged + +_comp = llm_config.get("completion_service") +if _comp is None: raise Exception("completion_service is not found in llm_config") -if "llm_service" not in completion_config: +if "llm_service" not in _comp: raise Exception("llm_service is not found in completion_service") -if "llm_model" not in completion_config: +if "llm_model" not in _comp: raise Exception("llm_model is not found in completion_service") -embedding_config = llm_config.get("embedding_service") -if embedding_config is None: + +# Log which model will be used for chatbot and ECC/GraphRAG +if "chat_service" in llm_config: + chat_svc = llm_config["chat_service"] + logger.info(f"[CHATBOT] Using chat_service: {chat_svc.get('llm_model', 'N/A')} (Provider: {chat_svc.get('llm_service', _comp['llm_service'])})") + logger.info(f"[ECC] Using completion_service: {_comp['llm_model']} (Provider: {_comp['llm_service']})") +else: + logger.info(f"[CHATBOT] Using completion_service llm_model: {_comp['llm_model']} (Provider: {_comp['llm_service']})") + logger.info(f"[ECC] Using completion_service: {_comp['llm_model']} (Provider: {_comp['llm_service']})") + +_emb = llm_config.get("embedding_service") +if _emb is None: raise Exception("embedding_service is not found in llm_config") -if "embedding_model_service" not in embedding_config: +if "embedding_model_service" not in _emb: raise Exception("embedding_model_service is not found in embedding_service") -if "model_name" not in embedding_config: +if "model_name" not in _emb: raise Exception("model_name is not found in embedding_service") -embedding_dimension = embedding_config.get("dimensions", 1536) +embedding_dimension = _emb.get("dimensions", 1536) + +# Log which embedding model will be used +logger.info(f"[EMBEDDING] Using model: {_emb.get('model_name', 'N/A')} (Provider: {_emb.get('embedding_model_service', 'N/A')})") # Get context window size from llm_config # <=0 means unlimited tokens (no truncation), otherwise use the specified limit if "token_limit" in llm_config: - if "token_limit" not in completion_config: - completion_config["token_limit"] = llm_config["token_limit"] - if "token_limit" not in embedding_config: - embedding_config["token_limit"] = llm_config["token_limit"] - -# Get multimodal_service config (optional, for vision/image tasks) -multimodal_config = llm_config.get("multimodal_service") - -# Merge shared authentication configuration from llm_config level into service configs -# Services can still override by defining their own authentication_configuration -shared_auth = llm_config.get("authentication_configuration", {}) -if shared_auth: - # Merge into embedding_config (service-specific auth takes precedence) - if "authentication_configuration" not in embedding_config: - embedding_config["authentication_configuration"] = shared_auth.copy() - else: - # Merge shared auth with service-specific auth (service-specific takes precedence) - merged_embedding_auth = shared_auth.copy() - merged_embedding_auth.update(embedding_config["authentication_configuration"]) - embedding_config["authentication_configuration"] = merged_embedding_auth - - # Merge into completion_config (service-specific auth takes precedence) - if "authentication_configuration" not in completion_config: - completion_config["authentication_configuration"] = shared_auth.copy() - else: - # Merge shared auth with service-specific auth (service-specific takes precedence) - merged_completion_auth = shared_auth.copy() - merged_completion_auth.update(completion_config["authentication_configuration"]) - completion_config["authentication_configuration"] = merged_completion_auth - - # Merge into multimodal_config if it exists (service-specific auth takes precedence) - if multimodal_config: - if "authentication_configuration" not in multimodal_config: - multimodal_config["authentication_configuration"] = shared_auth.copy() - else: - # Merge shared auth with service-specific auth (service-specific takes precedence) - merged_multimodal_auth = shared_auth.copy() - merged_multimodal_auth.update(multimodal_config["authentication_configuration"]) - multimodal_config["authentication_configuration"] = merged_multimodal_auth + if "token_limit" not in _comp: + _comp["token_limit"] = llm_config["token_limit"] + if "token_limit" not in _emb: + _emb["token_limit"] = llm_config["token_limit"] + +# Log multimodal_service config (optional, for vision/image tasks). +_mm_config = get_multimodal_config() +if _mm_config: + logger.info(f"[MULTIMODAL] Using model: {_mm_config.get('llm_model', 'N/A')} (Provider: {_mm_config.get('llm_service', 'N/A')})") if graphrag_config is None: graphrag_config = {"reuse_embedding": True} @@ -175,81 +376,38 @@ else: raise Exception("Embedding service not implemented") -def get_llm_service(llm_config) -> LLM_Model: - if llm_config["completion_service"]["llm_service"].lower() == "openai": - return OpenAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "azure": - return AzureOpenAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "sagemaker": - return AWS_SageMaker_Endpoint(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "vertexai": - return GoogleVertexAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "genai": - return GoogleGenAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "bedrock": - return AWSBedrock(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "groq": - return Groq(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "ollama": - return Ollama(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "huggingface": - return HuggingFaceEndpoint(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "watsonx": - return IBMWatsonX(llm_config["completion_service"]) - else: - raise Exception("LLM Completion Service Not Supported") - -DEFAULT_MULTIMODAL_MODELS = { - "openai": "gpt-4o-mini", - "azure": "gpt-4o-mini", - "genai": "gemini-3.5-flash", - "vertexai": "gemini-3.5-flash", - "bedrock": "us.anthropic.claude-sonnet-4-5-20250929-v1:0", -} - -def get_multimodal_service() -> LLM_Model: - """ - Get the multimodal/vision LLM service for image description tasks. - Priority: - 1. Explicit multimodal_service config - 2. Auto-derived from completion_service with a default vision model - Currently supports: OpenAI, Azure, GenAI, VertexAI, Bedrock +def get_llm_service(service_config: dict) -> LLM_Model: """ - config_copy = completion_config.copy() - - if multimodal_config: - config_copy.update(multimodal_config) + Instantiate an LLM provider from a flat service config dict. - service_type = config_copy.get("llm_service", "").lower() - - if not multimodal_config or "llm_model" not in multimodal_config: - default_model = DEFAULT_MULTIMODAL_MODELS.get(service_type) - if default_model: - config_copy["llm_model"] = default_model - LogWriter.info( - f"Using default vision model '{default_model}' " - f"for provider '{service_type}'" - ) - - if "prompt_path" not in config_copy: - config_copy["prompt_path"] = "./common/prompts/openai_gpt4/" - - if service_type == "openai": - return OpenAI(config_copy) - elif service_type == "azure": - return AzureOpenAI(config_copy) - elif service_type == "genai": - return GoogleGenAI(config_copy) - elif service_type == "vertexai": - return GoogleVertexAI(config_copy) - elif service_type == "bedrock": - return AWSBedrock(config_copy) + The config must contain ``llm_service`` at the top level. + Use ``get_completion_config()`` or ``get_chat_config()`` to obtain + the appropriate config for ECC or chatbot callers respectively. + """ + service_name = service_config["llm_service"].lower() + if service_name == "openai": + return OpenAI(service_config) + elif service_name == "azure": + return AzureOpenAI(service_config) + elif service_name == "sagemaker": + return AWS_SageMaker_Endpoint(service_config) + elif service_name == "vertexai": + return GoogleVertexAI(service_config) + elif service_name == "genai": + return GoogleGenAI(service_config) + elif service_name == "bedrock": + return AWSBedrock(service_config) + elif service_name == "groq": + return Groq(service_config) + elif service_name == "ollama": + return Ollama(service_config) + elif service_name == "huggingface": + return HuggingFaceEndpoint(service_config) + elif service_name == "watsonx": + return IBMWatsonX(service_config) else: - LogWriter.warning( - f"Multimodal/vision not supported for provider '{service_type}'. " - "Image descriptions will be skipped." - ) - return None + raise Exception(f"LLM service '{service_name}' not supported") + if os.getenv("INIT_EMBED_STORE", "true") == "true": conn = TigerGraphConnection( @@ -270,3 +428,203 @@ def get_multimodal_service() -> LLM_Model: support_ai_instance=True, ) service_status["embedding_store"] = {"status": "ok", "error": None} + + +def reload_llm_config(new_llm_config: dict = None): + """ + Reload LLM configuration and reinitialize services. + + Args: + new_llm_config: If provided, saves this config to file first. + If None, just reloads from existing file. + + Returns: + dict: Status of reload operation + """ + global llm_config, embedding_service + + try: + with _config_file_lock: + # If new config provided, save it first + if new_llm_config is not None: + with open(SERVER_CONFIG, "r") as f: + server_config = json.load(f) + + server_config["llm_config"] = new_llm_config + + temp_file = f"{SERVER_CONFIG}.tmp" + with open(temp_file, "w") as f: + json.dump(server_config, f, indent=2) + os.replace(temp_file, SERVER_CONFIG) + + # Read/reload from file + with open(SERVER_CONFIG, "r") as f: + server_config = json.load(f) + + # Validate before updating + new_llm_config = server_config.get("llm_config") + if new_llm_config is None: + raise Exception("llm_config is not found in SERVER_CONFIG") + + # Inject authentication_configuration into service configs BEFORE updating globals. + # Rule: service-level (lower) auth keys take precedence; missing keys fall back to top-level (upper). + if "authentication_configuration" in new_llm_config: + for svc_key in ["completion_service", "embedding_service", "multimodal_service", "chat_service"]: + if svc_key in new_llm_config: + svc = new_llm_config[svc_key] + if "authentication_configuration" not in svc: + svc["authentication_configuration"] = new_llm_config["authentication_configuration"].copy() + else: + merged = new_llm_config["authentication_configuration"].copy() + merged.update(svc["authentication_configuration"]) + svc["authentication_configuration"] = merged + + new_completion_config = new_llm_config.get("completion_service") + new_embedding_config = new_llm_config.get("embedding_service") + + if new_completion_config is None: + raise Exception("completion_service is not found in llm_config") + if new_embedding_config is None: + raise Exception("embedding_service is not found in llm_config") + + # Validate required fields before touching globals + if "llm_service" not in new_completion_config: + raise Exception("llm_service is not found in completion_service") + if "llm_model" not in new_completion_config: + raise Exception("llm_model is not found in completion_service") + + # Propagate top-level token_limit into service configs (same as startup) + if "token_limit" in new_llm_config: + if "token_limit" not in new_completion_config: + new_completion_config["token_limit"] = new_llm_config["token_limit"] + if "token_limit" not in new_embedding_config: + new_embedding_config["token_limit"] = new_llm_config["token_limit"] + + # Update globals atomically: build complete new state, then swap in one step. + # Using dict slice assignment avoids the clear()+update() window where readers + # would see an empty dict. + old_llm_keys = set(llm_config.keys()) + for k in old_llm_keys - set(new_llm_config.keys()): + del llm_config[k] + llm_config.update(new_llm_config) + + # Re-initialize embedding service + if new_embedding_config["embedding_model_service"].lower() == "openai": + embedding_service = OpenAI_Embedding(new_embedding_config) + elif new_embedding_config["embedding_model_service"].lower() == "azure": + embedding_service = AzureOpenAI_Ada002(new_embedding_config) + elif new_embedding_config["embedding_model_service"].lower() == "vertexai": + embedding_service = VertexAI_PaLM_Embedding(new_embedding_config) + elif new_embedding_config["embedding_model_service"].lower() == "genai": + embedding_service = GenAI_Embedding(new_embedding_config) + elif new_embedding_config["embedding_model_service"].lower() == "bedrock": + embedding_service = AWS_Bedrock_Embedding(new_embedding_config) + elif new_embedding_config["embedding_model_service"].lower() == "ollama": + embedding_service = Ollama_Embedding(new_embedding_config) + else: + raise Exception("Embedding service not implemented") + + return { + "status": "success", + "message": "LLM configuration reloaded successfully" + } + + except Exception as e: + return { + "status": "error", + "message": f"Failed to reload LLM config: {str(e)}" + } + + +def reload_db_config(new_db_config: dict = None): + """ + Reload DB configuration from server_config.json and update in-memory config. + + Args: + new_db_config: If provided, saves this config to file first. + If None, just reloads from existing file. + + Returns: + dict: Status of reload operation + """ + global db_config + + try: + with _config_file_lock: + if new_db_config is not None: + with open(SERVER_CONFIG, "r") as f: + server_config = json.load(f) + + server_config["db_config"] = new_db_config + + temp_file = f"{SERVER_CONFIG}.tmp" + with open(temp_file, "w") as f: + json.dump(server_config, f, indent=2) + os.replace(temp_file, SERVER_CONFIG) + + with open(SERVER_CONFIG, "r") as f: + server_config = json.load(f) + + new_db_config = server_config.get("db_config") + if new_db_config is None: + raise Exception("db_config is not found in SERVER_CONFIG") + + old_db_keys = set(db_config.keys()) + for k in old_db_keys - set(new_db_config.keys()): + del db_config[k] + db_config.update(new_db_config) + + return { + "status": "success", + "message": "DB configuration reloaded successfully" + } + except Exception as e: + return { + "status": "error", + "message": f"Failed to reload DB config: {str(e)}" + } + + +def reload_graphrag_config(): + """ + Reload GraphRAG configuration from server_config.json. + Updates the in-memory graphrag_config dict to reflect changes immediately. + + Returns: + dict: Status of reload operation + """ + global graphrag_config + + try: + with _config_file_lock: + with open(SERVER_CONFIG, "r") as f: + server_config = json.load(f) + + new_graphrag_config = server_config.get("graphrag_config") + if new_graphrag_config is None: + new_graphrag_config = {"reuse_embedding": True} + + # Set defaults (same as startup logic) + if "chunker" not in new_graphrag_config: + new_graphrag_config["chunker"] = "semantic" + if "extractor" not in new_graphrag_config: + new_graphrag_config["extractor"] = "llm" + + # Update graphrag_config in-place to preserve references in other modules + old_graphrag_keys = set(graphrag_config.keys()) + for k in old_graphrag_keys - set(new_graphrag_config.keys()): + del graphrag_config[k] + graphrag_config.update(new_graphrag_config) + + logger.info(f"GraphRAG config reloaded: extractor={graphrag_config.get('extractor')}, chunker={graphrag_config.get('chunker')}, reuse_embedding={graphrag_config.get('reuse_embedding')}") + + return { + "status": "success", + "message": "GraphRAG configuration reloaded successfully" + } + + except Exception as e: + return { + "status": "error", + "message": f"Failed to reload GraphRAG config: {str(e)}" + } \ No newline at end of file diff --git a/common/embeddings/embedding_services.py b/common/embeddings/embedding_services.py index 1597cd2..6f170d0 100644 --- a/common/embeddings/embedding_services.py +++ b/common/embeddings/embedding_services.py @@ -184,9 +184,9 @@ class VertexAI_PaLM_Embedding(EmbeddingModel): def __init__(self, config): super().__init__(config, model_name=config.get("model_name", "VertexAI PaLM")) - from langchain.embeddings import VertexAIEmbeddings + from langchain_google_vertexai import VertexAIEmbeddings - self.embeddings = VertexAIEmbeddings(model_name=self.model_name) + self.embeddings = VertexAIEmbeddings(model=self.model_name) class GenAI_Embedding(EmbeddingModel): @@ -243,3 +243,4 @@ def __init__(self, config): model=self.model_name, base_url=base_url ) + diff --git a/common/extractors/GraphExtractor.py b/common/extractors/GraphExtractor.py index 2a7ba50..9cf44cc 100644 --- a/common/extractors/GraphExtractor.py +++ b/common/extractors/GraphExtractor.py @@ -2,13 +2,13 @@ from langchain_core.documents import Document from langchain_experimental.graph_transformers import LLMGraphTransformer -from common.config import get_llm_service, llm_config +from common.config import get_llm_service, get_completion_config from common.extractors.BaseExtractor import BaseExtractor class GraphExtractor(BaseExtractor): def __init__(self): - llm = get_llm_service(llm_config).llm + llm = get_llm_service(get_completion_config()).llm self.transformer = LLMGraphTransformer( llm=llm, node_properties=["description"], diff --git a/common/extractors/LLMEntityRelationshipExtractor.py b/common/extractors/LLMEntityRelationshipExtractor.py index b81a769..dec1753 100644 --- a/common/extractors/LLMEntityRelationshipExtractor.py +++ b/common/extractors/LLMEntityRelationshipExtractor.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import re from typing import List import logging @@ -37,6 +38,36 @@ def __init__( self.allowed_edge_types = allowed_relationship_types self.strict_mode = strict_mode + def _parse_json_output(self, content: str) -> dict: + """Parse JSON from LLM output with multiple fallback strategies. + + Tries in order: + 1. Direct json.loads + 2. Extract from ```json code fences + 3. Regex extraction of first JSON object + """ + # Try direct parse + try: + return json.loads(content.strip("content=")) + except (json.JSONDecodeError, ValueError): + pass + + # Try ```json code fence + if "```json" in content: + try: + return json.loads( + content.split("```")[1].strip("```").strip("json").strip() + ) + except (json.JSONDecodeError, ValueError, IndexError): + pass + + # Regex fallback: extract first JSON object + match = re.search(r'\{[\s\S]*\}', content) + if match: + return json.loads(match.group()) + + raise ValueError(f"Could not extract JSON from LLM output: {content[:200]}") + async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: try: logger.debug(str(doc)) @@ -47,12 +78,7 @@ async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument] except Exception as e: return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] try: - if "```json" not in out.content: - json_out = json.loads(out.content.strip("content=")) - else: - json_out = json.loads( - out.content.split("```")[1].strip("```").strip("json").strip() - ) + json_out = self._parse_json_output(out.content) formatted_rels = [] for rels in json_out["rels"]: @@ -124,7 +150,7 @@ async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument] for rel in formatted_rels if rel["type"] in self.allowed_edge_types ] - + nodes = [] for node in formatted_nodes: nodes.append(Node(id=node["id"], @@ -141,7 +167,7 @@ async def _aextract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument] except: return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] - + def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: try: out = chain.invoke( @@ -150,12 +176,7 @@ def _extract_kg_from_doc(self, doc, chain, parser) -> list[GraphDocument]: except Exception as e: return [GraphDocument(nodes=[], relationships=[], source=Document(page_content=doc))] try: - if "```json" not in out.content: - json_out = json.loads(out.content.strip("content=")) - else: - json_out = json.loads( - out.content.split("```")[1].strip("```").strip("json").strip() - ) + json_out = self._parse_json_output(out.content) formatted_rels = [] for rels in json_out["rels"]: @@ -278,7 +299,7 @@ async def adocument_er_extraction(self, document): if self.allowed_edge_types: prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}")) prompt = ChatPromptTemplate.from_messages(prompt) - chain = prompt | self.llm_service.model # | parser + chain = prompt | self.llm_service.llm # | parser er = await self._aextract_kg_from_doc(document, chain, parser) return er @@ -316,7 +337,7 @@ def document_er_extraction(self, document): if self.allowed_edge_types: prompt.append(("human", f"Allowed Edge Types: {self.allowed_edge_types}")) prompt = ChatPromptTemplate.from_messages(prompt) - chain = prompt | self.llm_service.model # | parser + chain = prompt | self.llm_service.llm # | parser er = self._extract_kg_from_doc(document, chain, parser) return er diff --git a/common/llm_services/aws_bedrock_service.py b/common/llm_services/aws_bedrock_service.py index ba1b114..de6143a 100644 --- a/common/llm_services/aws_bedrock_service.py +++ b/common/llm_services/aws_bedrock_service.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import boto3, botocore from langchain_aws import ChatBedrock import logging @@ -57,53 +56,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated AWSBedrock model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def generate_cypher_prompt(self): - filepath = self.prompt_path + "generate_cypher.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().generate_cypher_prompt - - @property - def generate_gsql_prompt(self): - filepath = self.prompt_path + "generate_gsql.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().generate_gsql_prompt - - @property - def chatbot_response_prompt(self): - filepath = self.prompt_path + "chatbot_response.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().chatbot_response_prompt - - @property - def graphrag_scoring_prompt(self): - filepath = self.prompt_path + "graphrag_scoring.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().graphrag_scoring_prompt - - @property - def model(self): - return self.llm diff --git a/common/llm_services/aws_sagemaker_endpoint.py b/common/llm_services/aws_sagemaker_endpoint.py index fcc1cf3..5134497 100644 --- a/common/llm_services/aws_sagemaker_endpoint.py +++ b/common/llm_services/aws_sagemaker_endpoint.py @@ -54,15 +54,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated AWS_SageMaker_Endpoint model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def model(self): - return self.llm diff --git a/common/llm_services/azure_openai_service.py b/common/llm_services/azure_openai_service.py index e4dc6f5..bfb9279 100644 --- a/common/llm_services/azure_openai_service.py +++ b/common/llm_services/azure_openai_service.py @@ -21,7 +21,6 @@ def __init__(self, config): azure_deployment=config["azure_deployment"], openai_api_version=config["openai_api_version"], model_name=config["llm_model"], - max_tokens=config.get("token_limit"), temperature=config["model_kwargs"]["temperature"], ) @@ -29,21 +28,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated AzureOpenAI model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def model(self): - return self.llm diff --git a/common/llm_services/base_llm.py b/common/llm_services/base_llm.py index 1dafd3d..ba1c770 100644 --- a/common/llm_services/base_llm.py +++ b/common/llm_services/base_llm.py @@ -12,6 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import re +import logging +from langchain_core.output_parsers import BaseOutputParser, PydanticOutputParser +from langchain_core.exceptions import OutputParserException +from langchain_core.prompts import BasePromptTemplate +from langchain_community.callbacks.manager import get_openai_callback + +logger = logging.getLogger(__name__) + + class LLM_Model: """Base LLM_Model Class @@ -20,26 +31,142 @@ class LLM_Model: def __init__(self, config): self.llm = None + self.config = config + from common.config import validate_graphname + self._graphname = validate_graphname(config.get("graphname")) + self.prompt_path = config.get("prompt_path", "") def _read_prompt_file(self, path): - with open(path) as f: - prompt = f.read() - return prompt + """Read a prompt file with per-graph override support. + + Resolution order: + 1. configs/graph_configs//prompts/ (if graphname is set) + 2. Original path (from prompt_path config) + + Returns the file content, or None if the file doesn't exist anywhere. + """ + filename = os.path.basename(path) + if self._graphname: + graph_override = os.path.join( + "configs", "graph_configs", self._graphname, "prompts", filename + ) + if os.path.exists(graph_override): + with open(graph_override) as f: + return f.read() + if os.path.exists(path): + with open(path) as f: + return f.read() + return None + + def invoke_with_parser( + self, + prompt: BasePromptTemplate, + parser: BaseOutputParser, + input_variables: dict, + caller_name: str = "unknown", + ): + """Invoke the LLM with a prompt and parse the output using the given parser. + + Builds a chain (prompt | llm), invokes it, and parses the output. + Supports PydanticOutputParser (with JSON extraction fallback) + and StrOutputParser (returns raw text). + + Args: + prompt: The prompt template. + parser: The output parser (PydanticOutputParser, StrOutputParser, etc.). + input_variables: Dict of variables to pass to the prompt. + caller_name: Name of the calling function (for logging). + + Returns: + Parsed Pydantic model instance. + + Raises: + OutputParserException: If all parsing attempts fail. + """ + + chain = prompt | self.llm + + usage_data = {} + with get_openai_callback() as cb: + raw_output = chain.invoke(input_variables) + + usage_data["input_tokens"] = cb.prompt_tokens + usage_data["output_tokens"] = cb.completion_tokens + usage_data["total_tokens"] = cb.total_tokens + usage_data["cost"] = cb.total_cost + logger.info(f"{caller_name} usage: {usage_data}") + + raw_text = raw_output.content if hasattr(raw_output, "content") else str(raw_output) + + try: + return parser.parse(raw_text) + except OutputParserException: + logger.warning(f"{caller_name}: parser failed, attempting JSON extraction") + json_match = re.search(r'\{[\s\S]*\}', raw_text) + if json_match: + return parser.parse(json_match.group()) + raise + + async def ainvoke_with_parser( + self, + prompt: BasePromptTemplate, + parser: BaseOutputParser, + input_variables: dict, + caller_name: str = "unknown", + ): + """Async version of invoke_with_parser. + + Uses chain.ainvoke() to avoid blocking the event loop, + suitable for async callers (e.g., ECC workers). + """ + + chain = prompt | self.llm + + usage_data = {} + with get_openai_callback() as cb: + raw_output = await chain.ainvoke(input_variables) + + usage_data["input_tokens"] = cb.prompt_tokens + usage_data["output_tokens"] = cb.completion_tokens + usage_data["total_tokens"] = cb.total_tokens + usage_data["cost"] = cb.total_cost + logger.info(f"{caller_name} usage: {usage_data}") + + raw_text = raw_output.content if hasattr(raw_output, "content") else str(raw_output) + + try: + return parser.parse(raw_text) + except OutputParserException: + logger.warning(f"{caller_name}: parser failed, attempting JSON extraction") + json_match = re.search(r'\{[\s\S]*\}', raw_text) + if json_match: + return parser.parse(json_match.group()) + raise @property def map_question_schema_prompt(self): """Property to get the prompt for the MapQuestionToSchema tool.""" - raise ("map_question_schema_prompt not supported in base class") + return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") @property def generate_function_prompt(self): """Property to get the prompt for the GenerateFunction tool.""" - raise ("generate_function_prompt not supported in base class") + return self._read_prompt_file(self.prompt_path + "generate_function.txt") + + @property + def entity_relationship_extraction_prompt(self): + """Property to get the prompt for the EntityRelationshipExtraction tool.""" + return self._read_prompt_file( + self.prompt_path + "entity_relationship_extraction.txt" + ) @property def generate_cypher_prompt(self): """Property to get the prompt for the GenerateCypher tool.""" - prompt = """You're an expert in OpenCypher programming. Given the following schema and history, what is the OpenCypher query that retrieves the {question} + result = self._read_prompt_file(self.prompt_path + "generate_cypher.txt") + if result is not None: + return result + return """You're an expert in OpenCypher programming. Given the following schema and history, what is the OpenCypher query that retrieves the {question} Only include attributes that are found in the schema. Never include any attributes that are not found in the schema. Use attributes instead of primary id if attribute name is closer to the keyword type in the question. Use as less vertex type, edge type and attributes as possible. If an attribute is not found in the schema, please exclude it from the query. @@ -65,12 +192,14 @@ def generate_cypher_prompt(self): Make sure to have correct attribute names in the OpenCypher query and not to name result aliases that are vertex or edge types. ONLY write the OpenCypher query in the response. Do not include any other information in the response.""" - return prompt @property def generate_gsql_prompt(self): """Property to get the prompt for the GenerateGSQL tool.""" - prompt = """You're an expert in GSQL (Graph SQL) programming for TigerGraph. Given the following schema: {schema}, what is the GSQL query that retrieves the answer for question: {question} + result = self._read_prompt_file(self.prompt_path + "generate_gsql.txt") + if result is not None: + return result + return """You're an expert in GSQL (Graph SQL) programming for TigerGraph. Given the following schema: {schema}, what is the GSQL query that retrieves the answer for question: {question} Only include attributes that are found in the schema. Never include any attributes that are not found in the schema. Use attributes instead of primary id if attribute name is more similar to the keyword type in the question. Use as few vertex types, edge types and attributes as possible. If an attribute is not found in the schema, please exclude it from the query. @@ -101,12 +230,14 @@ def generate_gsql_prompt(self): Make sure to have correct attribute names in the GSQL query and not to name result aliases that are vertex or edge types, operator or function names, and other reserved keywords, always construct alias with multiple words connected with underscore. ONLY write the GSQL query in the response. Do not include any other information in the response.""" - return prompt @property def route_response_prompt(self): """Property to get the prompt for the RouteResponse tool.""" - prompt = """\ + result = self._read_prompt_file(self.prompt_path + "route_response.txt") + if result is not None: + return result + return """\ You are an expert at routing a user question to a vectorstore, function calls, or conversation history. Use the conversation history for questions that are similar to previous ones or that reference earlier answers or responses. Use the vectorstore for questions that would be best suited by text documents. @@ -126,50 +257,74 @@ def route_response_prompt(self): Conversation history: {conversation} Format: {format_instructions}\ """ - return prompt @property def hyde_prompt(self): """Property to get the prompt for the HyDE tool.""" + result = self._read_prompt_file(self.prompt_path + "hyde.txt") + if result is not None: + return result return """You are a helpful agent that is writing an example of a document that might answer this question: {question} Answer:""" - @property - def entity_relationship_extraction_prompt(self): - """Property to get the prompt for the EntityRelationshipExtraction tool.""" - raise ("entity_relationship_extraction_prompt not supported in base class") - @property def chatbot_response_prompt(self): """Property to get the prompt for the SupportAI response.""" - prompt ="""Given the answer context in JSON format, rephrase it to answer the question. \n + result = self._read_prompt_file(self.prompt_path + "chatbot_response.txt") + if result is not None: + return result + return """Given the answer context in JSON format, rephrase it to answer the question. \n Use only the provided information in context without adding any reasoning or additional logic. \n Make sure all information in the answer are covered in the generated answer.\n Question: {question} \n Answer: {context} \n Format: {format_instructions}""" - return prompt @property def keyword_extraction_prompt(self): - """Property to get the prompt for the Question Expension response.""" + """Property to get the prompt for the Question Expansion response.""" + result = self._read_prompt_file(self.prompt_path + "keyword_extraction.txt") + if result is not None: + return result return """You are a helpful assistant responsible for extracting key terms (glossary) from all the questions below to represent their original meaning as much as possible. Each term should only contain a couple of words. Include a quality score for the each extracted glossary, based on how important and frequent it's in the given questions. The quality score should range from 0 (poor) to 100 (excellent), with higher scores indicating terms that are both significant and frequent in the context of the questions.\nThe output should only contain the extracted terms and their quality scores using the required format.\n\nQuestion: {question}\n\n{format_instructions}\n""" @property def question_expansion_prompt(self): - """Property to get the prompt for the Question Expension response.""" + """Property to get the prompt for the Question Expansion response.""" + result = self._read_prompt_file(self.prompt_path + "question_expansion.txt") + if result is not None: + return result return """You are a helpful assistant responsible for generating 10 new questions similar to the original question below to represent its meaning in a more clear way.\nInclude a quality score for the answer, based on how well it represents the meaning of the original question. The quality score should be between 0 (poor) and 100 (excellent).\n\nQuestion: {question}\n\n{format_instructions}\n""" @property def graphrag_scoring_prompt(self): """Property to get the prompt for the GraphRAG Scoring response.""" + result = self._read_prompt_file(self.prompt_path + "graphrag_scoring.txt") + if result is not None: + return result return """You are a helpful assistant responsible for generating an answer to the question below using the data provided.\nInclude a quality score for the answer, based on how well it answers the question. The quality score should be between 0 (poor) and 100 (excellent).\n\nQuestion: {question}\nContext: {context}\n\n{format_instructions}\n""" + @property + def community_summarize_prompt(self): + """Property to get the prompt for community summarization.""" + result = self._read_prompt_file(self.prompt_path + "community_summarization.txt") + if result is not None: + return result + raise FileNotFoundError( + f"Community summarization prompt file not found in {self.prompt_path}. " + "Please ensure community_summarization.txt exists in the configured prompt path." + ) + @property def contextualize_question_prompt(self): """Property to get the prompt for contextualizing a follow-up question into a standalone search query using conversation history.""" + result = self._read_prompt_file( + self.prompt_path + "contextualize_question.txt" + ) + if result is not None: + return result return ( "Given the following conversation history and a follow-up " "question, rewrite the follow-up question into a standalone, " @@ -180,7 +335,3 @@ def contextualize_question_prompt(self): "Standalone question:" ) - @property - def model(self): - """Property to get the external LLM model.""" - raise ("model not supported in base class") diff --git a/common/llm_services/google_genai_service.py b/common/llm_services/google_genai_service.py index c544978..54d3a20 100644 --- a/common/llm_services/google_genai_service.py +++ b/common/llm_services/google_genai_service.py @@ -36,7 +36,6 @@ def __init__(self, config): self.llm = ChatGoogleGenerativeAI( temperature=config["model_kwargs"]["temperature"], model=model_name, - max_tokens=config.get("token_limit"), timeout=None, max_retries=2, ) @@ -44,85 +43,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated OpenAI model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def generate_cypher_prompt(self): - filepath = self.prompt_path + "generate_cypher.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().generate_cypher_prompt - - @property - def generate_gsql_prompt(self): - filepath = self.prompt_path + "generate_gsql.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().generate_gsql_prompt - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def route_response_prompt(self): - filepath = self.prompt_path + "route_response.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().route_response_prompt - - @property - def graphrag_scoring_prompt(self): - filepath = self.prompt_path + "graphrag_scoring.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().graphrag_scoring_prompt - - @property - def keyword_extraction_prompt(self): - filepath = self.prompt_path + "keyword_extraction.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().keyword_extraction_prompt - - @property - def question_expansion_prompt(self): - filepath = self.prompt_path + "question_expansion.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().question_expansion_prompt - - @property - def chatbot_response_prompt(self): - filepath = self.prompt_path + "chatbot_response.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().chatbot_response_prompt - - @property - def hyde_prompt(self): - filepath = self.prompt_path + "hyde.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().hyde_prompt - - @property - def model(self): - return self.llm diff --git a/common/llm_services/google_vertexai_service.py b/common/llm_services/google_vertexai_service.py index 22679f5..2bc9847 100644 --- a/common/llm_services/google_vertexai_service.py +++ b/common/llm_services/google_vertexai_service.py @@ -9,11 +9,11 @@ class GoogleVertexAI(LLM_Model): def __init__(self, config): super().__init__(config) - from langchain_community.llms import VertexAI + from langchain_google_vertexai import VertexAI model_name = config["llm_model"] self.llm = VertexAI( - model_name=model_name, max_output_tokens=1000, **config["model_kwargs"] + model=model_name, max_output_tokens=1000, **config["model_kwargs"] ) self.prompt_path = config["prompt_path"] @@ -21,20 +21,3 @@ def __init__(self, config): f"request_id={req_id_cv.get()} instantiated GoogleVertexAI model_name={model_name}" ) - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def model(self): - return self.llm diff --git a/common/llm_services/groq_llm_service.py b/common/llm_services/groq_llm_service.py index afa6f89..b1e58ee 100644 --- a/common/llm_services/groq_llm_service.py +++ b/common/llm_services/groq_llm_service.py @@ -22,21 +22,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated OpenAI model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def model(self): - return self.llm diff --git a/common/llm_services/huggingface_endpoint.py b/common/llm_services/huggingface_endpoint.py index 2151966..5b83916 100644 --- a/common/llm_services/huggingface_endpoint.py +++ b/common/llm_services/huggingface_endpoint.py @@ -31,21 +31,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated HuggingFace model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def model(self): - return self.llm diff --git a/common/llm_services/ibm_watsonx_service.py b/common/llm_services/ibm_watsonx_service.py index b2504da..e4c9d99 100644 --- a/common/llm_services/ibm_watsonx_service.py +++ b/common/llm_services/ibm_watsonx_service.py @@ -30,21 +30,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated WatsonX model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def model(self): - return self.llm diff --git a/common/llm_services/ollama.py b/common/llm_services/ollama.py index bdb0b44..40d5c97 100644 --- a/common/llm_services/ollama.py +++ b/common/llm_services/ollama.py @@ -17,21 +17,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated Ollama model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def model(self): - return self.llm diff --git a/common/llm_services/openai_service.py b/common/llm_services/openai_service.py index f23e81b..e5f1c6d 100644 --- a/common/llm_services/openai_service.py +++ b/common/llm_services/openai_service.py @@ -43,85 +43,3 @@ def __init__(self, config): LogWriter.info( f"request_id={req_id_cv.get()} instantiated OpenAI model_name={model_name}" ) - - @property - def map_question_schema_prompt(self): - return self._read_prompt_file(self.prompt_path + "map_question_to_schema.txt") - - @property - def generate_function_prompt(self): - return self._read_prompt_file(self.prompt_path + "generate_function.txt") - - @property - def generate_cypher_prompt(self): - filepath = self.prompt_path + "generate_cypher.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().generate_cypher_prompt - - @property - def generate_gsql_prompt(self): - filepath = self.prompt_path + "generate_gsql.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().generate_gsql_prompt - - @property - def entity_relationship_extraction_prompt(self): - return self._read_prompt_file( - self.prompt_path + "entity_relationship_extraction.txt" - ) - - @property - def route_response_prompt(self): - filepath = self.prompt_path + "route_response.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().route_response_prompt - - @property - def graphrag_scoring_prompt(self): - filepath = self.prompt_path + "graphrag_scoring.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().graphrag_scoring_prompt - - @property - def keyword_extraction_prompt(self): - filepath = self.prompt_path + "keyword_extraction.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().keyword_extraction_prompt - - @property - def question_expansion_prompt(self): - filepath = self.prompt_path + "question_expansion.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().question_expansion_prompt - - @property - def chatbot_response_prompt(self): - filepath = self.prompt_path + "chatbot_response.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().chatbot_response_prompt - - @property - def hyde_prompt(self): - filepath = self.prompt_path + "hyde.txt" - if os.path.exists(filepath): - return self._read_prompt_file(filepath) - else: - return super().hyde_prompt - - @property - def model(self): - return self.llm diff --git a/common/metrics/tg_proxy.py b/common/metrics/tg_proxy.py index 804d66f..a9a325f 100644 --- a/common/metrics/tg_proxy.py +++ b/common/metrics/tg_proxy.py @@ -5,6 +5,7 @@ from common.logs.logwriter import LogWriter import logging from common.logs.log import req_id_cv +from common.config import db_config logger = logging.getLogger(__name__) @@ -47,7 +48,9 @@ def _runInstalledQuery(self, query_name, params, sizeLimit=None, usePost=False): metrics.tg_inprogress_requests.labels(query_name=query_name).inc() try: restppid = self._tg_connection.runInstalledQuery( - query_name, params, runAsync=True, usePost=usePost, sizeLimit=sizeLimit + query_name, params, runAsync=True, usePost=usePost, sizeLimit=sizeLimit, + threadLimit=db_config.get("default_thread_limit", 8), + memoryLimit=db_config.get("default_mem_threshold", 5000), ) LogWriter.info( f"request_id={req_id_cv.get()} query {query_name} started with RESTPP ID {restppid}" diff --git a/common/prompts/aws_bedrock_claude3haiku/community_summarization.txt b/common/prompts/aws_bedrock_claude3haiku/community_summarization.txt new file mode 100644 index 0000000..50e4619 --- /dev/null +++ b/common/prompts/aws_bedrock_claude3haiku/community_summarization.txt @@ -0,0 +1,11 @@ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Commuinty Title: {entity_name} +Description List: {description_list} + diff --git a/common/prompts/custom/aml/community_summarization.txt b/common/prompts/custom/aml/community_summarization.txt new file mode 100644 index 0000000..50e4619 --- /dev/null +++ b/common/prompts/custom/aml/community_summarization.txt @@ -0,0 +1,11 @@ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Commuinty Title: {entity_name} +Description List: {description_list} + diff --git a/common/prompts/gcp_vertexai_palm/community_summarization.txt b/common/prompts/gcp_vertexai_palm/community_summarization.txt new file mode 100644 index 0000000..50e4619 --- /dev/null +++ b/common/prompts/gcp_vertexai_palm/community_summarization.txt @@ -0,0 +1,11 @@ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Commuinty Title: {entity_name} +Description List: {description_list} + diff --git a/common/prompts/google_gemini/community_summarization.txt b/common/prompts/google_gemini/community_summarization.txt new file mode 100644 index 0000000..50e4619 --- /dev/null +++ b/common/prompts/google_gemini/community_summarization.txt @@ -0,0 +1,11 @@ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Commuinty Title: {entity_name} +Description List: {description_list} + diff --git a/common/prompts/openai_gpt4/community_summarization.txt b/common/prompts/openai_gpt4/community_summarization.txt new file mode 100644 index 0000000..50e4619 --- /dev/null +++ b/common/prompts/openai_gpt4/community_summarization.txt @@ -0,0 +1,11 @@ +You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. +Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. +Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. +If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. +Make sure it is written in third person, and include the entity names so we the have full context. + +####### +-Data- +Commuinty Title: {entity_name} +Description List: {description_list} + diff --git a/common/requirements.txt b/common/requirements.txt index d5a2d5b..12c9fcf 100644 --- a/common/requirements.txt +++ b/common/requirements.txt @@ -80,6 +80,7 @@ kiwisolver==1.4.8 langchain>=0.3.26 langchain-core>=0.3.26 langchain_google_genai==2.1.8 +langchain-google-vertexai==2.1.2 langchain-community==0.3.26 langchain-experimental==0.3.5rc1 langchain-groq==0.3.4 diff --git a/common/utils/image_data_extractor.py b/common/utils/image_data_extractor.py index 19da86e..48f9b65 100644 --- a/common/utils/image_data_extractor.py +++ b/common/utils/image_data_extractor.py @@ -2,7 +2,7 @@ import io import logging from langchain_core.messages import HumanMessage, SystemMessage -from common.config import get_multimodal_service +from common.config import get_llm_service, get_multimodal_config logger = logging.getLogger(__name__) @@ -10,8 +10,11 @@ def _get_client(): global _multimodal_client - if _multimodal_client is None: - _multimodal_client = get_multimodal_service() + if _multimodal_client is None and get_multimodal_config(): + try: + _multimodal_client = get_llm_service(get_multimodal_config()) + except Exception: + logger.warning("Failed to create multimodal LLM client") return _multimodal_client def describe_image_with_llm(file_path): diff --git a/common/utils/token_calculator.py b/common/utils/token_calculator.py index 762e824..dfe4a76 100644 --- a/common/utils/token_calculator.py +++ b/common/utils/token_calculator.py @@ -61,12 +61,31 @@ def __init__(self, token_limit: int = 0, model_name: str = None): self.max_context_tokens = token_limit if token_limit else 0 self.model_name = model_name if model_name else "gpt-4" try: - self.token_encoding = tiktoken.encoding_for_model(self.model_name) + self.token_encoding = tiktoken.encoding_for_model(self._normalize_model_name(self.model_name)) except Exception as e: self.token_encoding = tiktoken.get_encoding("cl100k_base") - logger.warning(f"Error getting encoding for model {self.model_name}, using cl100k_base: {e}") + logger.info(f"No tiktoken mapping for model {self.model_name}, using cl100k_base") logger.info(f"Initialized TokenCalculator with max_context_tokens: {self.max_context_tokens} and encoding: {self.token_encoding}") + @staticmethod + def _normalize_model_name(model_name: str) -> str: + """Normalize provider-specific model names for tiktoken lookup. + + Examples: + anthropic.claude-3-5-haiku-20241022-v1:0 → claude-3-5-haiku + us.anthropic.claude-3-5-haiku-20241022-v1:0 → claude-3-5-haiku + gpt-4o-mini → gpt-4o-mini (unchanged) + """ + name = model_name + # Strip Bedrock provider prefix (e.g., "anthropic." or "us.anthropic.") + if "." in name: + name = name.rsplit(".", 1)[-1] + # Strip version suffix (e.g., "-20241022-v1:0") + # Pattern: date stamp followed by version + import re + name = re.sub(r'-\d{8}-v\d+.*$', '', name) + return name + def set_max_context_tokens(self, max_tokens: int): """Set the maximum number of tokens allowed for retrieved context.""" self.max_context_tokens = max_tokens diff --git a/docker-compose.yml b/docker-compose.yml index b228151..97a0952 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -74,7 +74,7 @@ services: - graphrag # tigergraph: -# image: tigergraph/community:4.2.1 +# image: tigergraph/community:4.2.2 # container_name: tigergraph # platform: linux/amd64 # ports: diff --git a/docs/tutorials/configs/nginx.conf b/docs/tutorials/configs/nginx.conf index dc09929..975d8a0 100644 --- a/docs/tutorials/configs/nginx.conf +++ b/docs/tutorials/configs/nginx.conf @@ -14,6 +14,16 @@ server { proxy_pass http://graphrag-ui:3000/; } + location /setup { + rewrite ^/setup$ / break; + proxy_pass http://graphrag-ui:3000; + } + + location /setup/ { + rewrite ^/setup/.*$ / break; + proxy_pass http://graphrag-ui:3000; + } + location /chat-dialog { proxy_pass http://graphrag-ui:3000/; diff --git a/docs/tutorials/docker-compose.yml b/docs/tutorials/docker-compose.yml index 8be754b..2d5734c 100644 --- a/docs/tutorials/docker-compose.yml +++ b/docs/tutorials/docker-compose.yml @@ -74,7 +74,7 @@ services: - graphrag tigergraph: - image: tigergraph/community:4.2.1 + image: tigergraph/community:4.2.2 container_name: tigergraph platform: linux/amd64 ports: diff --git a/docs/tutorials/setup_graphrag.sh b/docs/tutorials/setup_graphrag.sh index cb818b8..a3540a6 100755 --- a/docs/tutorials/setup_graphrag.sh +++ b/docs/tutorials/setup_graphrag.sh @@ -40,7 +40,7 @@ cd $root_dir || { echo "Cannot switch to $root_dir!"; exit 5; } echo "Downloading GraphRAG service config..." mkdir -p configs || true -curl -sk https://raw.githubusercontent.com/tigergraph/graphrag/refs/heads/main/docs/tutorials/docker-compose.yml | sed "s/community:4.2.1/community:${tg_version}/g" > docker-compose.yml +curl -sk https://raw.githubusercontent.com/tigergraph/graphrag/refs/heads/main/docs/tutorials/docker-compose.yml | sed "s/community:4.2.2/community:${tg_version}/g" > docker-compose.yml curl -sk https://raw.githubusercontent.com/tigergraph/graphrag/refs/heads/main/docs/tutorials/configs/nginx.conf -o configs/nginx.conf curl -sk "https://raw.githubusercontent.com/tigergraph/graphrag/refs/heads/main/docs/tutorials/configs/server_config.json.${llm_provider}" | sed '/"gsPort": "14240"/a\ "username": "'${tg_username}'",\ diff --git a/ecc/app/ecc_util.py b/ecc/app/ecc_util.py index 35bbcaa..e17ce9f 100644 --- a/ecc/app/ecc_util.py +++ b/ecc/app/ecc_util.py @@ -1,21 +1,11 @@ from common.chunkers import character_chunker, regex_chunker, semantic_chunker, markdown_chunker, recursive_chunker, html_chunker, single_chunker -from common.config import graphrag_config, embedding_service, llm_config -from common.llm_services import ( - AWS_SageMaker_Endpoint, - AWSBedrock, - AzureOpenAI, - GoogleVertexAI, - GoogleGenAI, - Groq, - HuggingFaceEndpoint, - Ollama, - OpenAI, -) +from common.config import get_graphrag_config, embedding_service -def get_chunker(chunker_type: str = ""): +def get_chunker(chunker_type: str = "", graphname: str = None): + cfg = get_graphrag_config(graphname) if not chunker_type: - chunker_type = graphrag_config.get("chunker", "semantic") - chunker_config = graphrag_config.get("chunker_config", {}) + chunker_type = cfg.get("chunker", "semantic") + chunker_config = cfg.get("chunker_config", {}) if chunker_type == "semantic": chunker = semantic_chunker.SemanticChunker( embedding_service, @@ -55,26 +45,3 @@ def get_chunker(chunker_type: str = ""): raise ValueError(f"Invalid chunker type: {chunker_type}") return chunker - - -def get_llm_service(): - if llm_config["completion_service"]["llm_service"].lower() == "openai": - llm_provider = OpenAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "azure": - llm_provider = AzureOpenAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "sagemaker": - llm_provider = AWS_SageMaker_Endpoint(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "vertexai": - llm_provider = GoogleVertexAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "genai": - llm_provider = GoogleGenAI(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "bedrock": - llm_provider = AWSBedrock(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "groq": - llm_provider = Groq(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "ollama": - llm_provider = Ollama(llm_config["completion_service"]) - elif llm_config["completion_service"]["llm_service"].lower() == "huggingface": - llm_provider = HuggingFaceEndpoint(llm_config["completion_service"]) - - return llm_provider diff --git a/ecc/app/eventual_consistency_checker.py b/ecc/app/eventual_consistency_checker.py index 499bdc7..1c28b53 100644 --- a/ecc/app/eventual_consistency_checker.py +++ b/ecc/app/eventual_consistency_checker.py @@ -91,7 +91,7 @@ def _check_query_install(self, query_name): return True def _chunk_document(self, content): - chunker = ecc_util.get_chunker(content["ctype"]) + chunker = ecc_util.get_chunker(content["ctype"], graphname=self.graphname) return chunker.chunk(content["text"]) def _extract_entities(self, content): diff --git a/ecc/app/graphrag/community_summarizer.py b/ecc/app/graphrag/community_summarizer.py index 0bab35b..532b94f 100644 --- a/ecc/app/graphrag/community_summarizer.py +++ b/ecc/app/graphrag/community_summarizer.py @@ -13,25 +13,18 @@ # limitations under the License. import re +import logging from langchain_core.prompts import PromptTemplate +from langchain_core.output_parsers import PydanticOutputParser from common.llm_services import LLM_Model from common.py_schemas import CommunitySummary +logger = logging.getLogger(__name__) + + # src: https://github.com/microsoft/graphrag/blob/main/graphrag/index/graph/extractors/summarize/prompts.py -SUMMARIZE_PROMPT = PromptTemplate.from_template(""" -You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. -Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. -Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. -If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary, but do not add any information that is not in the description. -Make sure it is written in third person, and include the entity names so we the have full context. - -####### --Data- -Commuinty Title: {entity_name} -Description List: {description_list} -""") id_pat = re.compile(r"[_\d]*") @@ -43,19 +36,22 @@ def __init__( ): self.llm_service = llm_service - async def summarize(self, name: str, text: list[str]) -> CommunitySummary: - structured_llm = self.llm_service.model.with_structured_output(CommunitySummary) - chain = SUMMARIZE_PROMPT | structured_llm + async def summarize(self, name: str, text: list[str]) -> dict: + summary_parser = PydanticOutputParser(pydantic_object=CommunitySummary) + prompt = PromptTemplate( + template=self.llm_service.community_summarize_prompt + "\n{format_instructions}", + input_variables=["entity_name", "description_list"], + partial_variables={"format_instructions": summary_parser.get_format_instructions()}, + ) # remove iteration tags from name name = id_pat.sub("", name) try: - summary = await chain.ainvoke( - { - "entity_name": name, - "description_list": text, - } + summary = await self.llm_service.ainvoke_with_parser( + prompt, summary_parser, + {"entity_name": name, "description_list": text}, + caller_name="community_summarize", ) except Exception as e: return {"error": True, "summary": "", "message": str(e)} - return {"error": False, "summary": summary.summary} + return {"error": False, "summary": summary.summary} \ No newline at end of file diff --git a/ecc/app/graphrag/graph_rag.py b/ecc/app/graphrag/graph_rag.py index 5544789..49a5760 100644 --- a/ecc/app/graphrag/graph_rag.py +++ b/ecc/app/graphrag/graph_rag.py @@ -36,7 +36,7 @@ ) from pyTigerGraph import AsyncTigerGraphConnection -from common.config import embedding_service, graphrag_config, entity_extraction_switch, community_detection_switch, doc_process_switch +from common.config import embedding_service, entity_extraction_switch, community_detection_switch, doc_process_switch, get_graphrag_config from common.embeddings.base_embedding_store import EmbeddingStore from common.extractors.BaseExtractor import BaseExtractor @@ -179,8 +179,9 @@ async def upsert(upsert_chan: Channel): async def load(conn: AsyncTigerGraphConnection): logger.info("Data Loading Start") dd = lambda: defaultdict(dd) # infinite default dict - batch_size = graphrag_config.get("load_batch_size", 500) - upsert_delay = graphrag_config.get("upsert_delay", 0) + graph_cfg = get_graphrag_config(conn.graphname) + batch_size = graph_cfg.get("load_batch_size", 500) + upsert_delay = graph_cfg.get("upsert_delay", 0) # while the load q is still open or has contents while not load_q.closed() or not load_q.empty(): if load_q.closed(): @@ -259,7 +260,7 @@ async def embed( (v_id, content, index_name) = await embed_chan.get() v_id = (v_id, index_name) logger.info(f"Embed to {graphname}_{index_name}: {v_id}") - if graphrag_config.get("reuse_embedding", True) and embedding_store.has_embeddings([v_id]): + if get_graphrag_config(graphname).get("reuse_embedding", True) and embedding_store.has_embeddings([v_id]): logger.info(f"Embeddings for {v_id} already exists, skipping to save cost") continue grp.create_task( diff --git a/ecc/app/graphrag/util.py b/ecc/app/graphrag/util.py index f581057..f12157f 100644 --- a/ecc/app/graphrag/util.py +++ b/ecc/app/graphrag/util.py @@ -28,7 +28,8 @@ graphrag_config, embedding_service, get_llm_service, - llm_config, + get_completion_config, + get_graphrag_config, ) from common.embeddings.base_embedding_store import EmbeddingStore from common.embeddings.tigergraph_embedding_store import TigerGraphEmbeddingStore @@ -40,7 +41,11 @@ http_timeout = httpx.Timeout(15.0) -tg_sem = asyncio.Semaphore(graphrag_config.get("tg_concurrency", 10)) +_default_concurrency = graphrag_config.get("default_concurrency", 10) +# Worker amplifier: processing workers (chunk, embed, extract, community) run at 2x +# the base concurrency since each worker is mostly waiting on I/O (LLM/embedding API calls). +_worker_concurrency = _default_concurrency * 2 +tg_sem = asyncio.Semaphore(_default_concurrency) load_q = reusable_channel.ReuseableChannel() # will pause workers until the event is false @@ -132,10 +137,11 @@ async def init( await install_queries(requried_queries, conn) # extractor - if graphrag_config.get("extractor") == "graphrag": + graph_cfg = get_graphrag_config(conn.graphname) + if graph_cfg.get("extractor") == "graphrag": extractor = GraphExtractor() - elif graphrag_config.get("extractor") == "llm": - extractor = LLMEntityRelationshipExtractor(get_llm_service(llm_config)) + elif graph_cfg.get("extractor") == "llm": + extractor = LLMEntityRelationshipExtractor(get_llm_service(get_completion_config())) else: raise ValueError("Invalid extractor type") diff --git a/ecc/app/graphrag/workers.py b/ecc/app/graphrag/workers.py index 78f38be..c0b35cc 100644 --- a/ecc/app/graphrag/workers.py +++ b/ecc/app/graphrag/workers.py @@ -64,7 +64,7 @@ async def install_query( return {"result": res, "error": False} -chunk_sem = asyncio.Semaphore(20) +chunk_sem = asyncio.Semaphore(util._worker_concurrency) async def chunk_doc( @@ -98,7 +98,7 @@ async def chunk_doc( # Use get_chunker for all types (including images) # For images, get_chunker returns SingleChunker which preserves markdown image references - chunker = ecc_util.get_chunker(chunker_type) + chunker = ecc_util.get_chunker(chunker_type, graphname=conn.graphname) # decode the text return from tigergraph as it was encoded when written into jsonl file for uploading chunks = chunker.chunk(doc["attributes"]["text"].encode('raw_unicode_escape').decode('unicode_escape')) @@ -172,7 +172,7 @@ async def upsert_chunk(conn: AsyncTigerGraphConnection, doc_id, chunk_id, chunk) ) -embed_sem = asyncio.Semaphore(20) +embed_sem = asyncio.Semaphore(util._worker_concurrency) async def embed( @@ -220,7 +220,7 @@ async def get_vert_desc(conn, v_id, node: Node): return desc -extract_sem = asyncio.Semaphore(20) +extract_sem = asyncio.Semaphore(util._worker_concurrency) async def extract( @@ -406,7 +406,7 @@ async def extract( # right now, we're not embedding relationships in graphrag -comm_sem = asyncio.Semaphore(20) +comm_sem = asyncio.Semaphore(util._worker_concurrency) async def process_community( @@ -440,7 +440,8 @@ async def process_community( if len(children) == 1: summary = children[0] else: - llm = ecc_util.get_llm_service() + from common.config import get_llm_service, get_completion_config + llm = get_llm_service(get_completion_config(conn.graphname)) summarizer = community_summarizer.CommunitySummarizer(llm) summary = await summarizer.summarize(comm_id, children) if summary["error"]: diff --git a/ecc/app/main.py b/ecc/app/main.py index 5468391..0db691b 100644 --- a/ecc/app/main.py +++ b/ecc/app/main.py @@ -36,6 +36,9 @@ embedding_service, get_llm_service, llm_config, + get_completion_config, + get_graphrag_config, + reload_db_config, ) from common.db.connections import elevate_db_connection_to_token, get_db_connection_id_token from common.embeddings.base_embedding_store import EmbeddingStore @@ -97,28 +100,29 @@ def initialize_eventual_consistency_checker( embedding_service, support_ai_instance=False, ) - index_names = graphrag_config.get( + graph_cfg = get_graphrag_config(graphname) + index_names = graph_cfg.get( "indexes", ["DocumentChunk", "Community"], ) - if graphrag_config.get("extractor") == "llm": + if graph_cfg.get("extractor") == "llm": from common.extractors import LLMEntityRelationshipExtractor - extractor = LLMEntityRelationshipExtractor(get_llm_service(llm_config)) + extractor = LLMEntityRelationshipExtractor(get_llm_service(get_completion_config())) else: raise ValueError("Invalid extractor type") checker = EventualConsistencyChecker( - graphrag_config.get("process_interval_seconds", 300), - graphrag_config.get("cleanup_interval_seconds", 300), + graph_cfg.get("process_interval_seconds", 300), + graph_cfg.get("cleanup_interval_seconds", 300), graphname, embedding_service, embedding_store, index_names, conn, extractor, - graphrag_config.get("batch_size", 100), + graph_cfg.get("checker_batch_size", graph_cfg.get("batch_size", 100)), ) consistency_checkers[graphname] = checker @@ -213,6 +217,41 @@ async def run_with_tracking(task_key: str, run_func, graphname: str, conn): try: running_tasks[task_key] = {"status": "running", "started_at": time.time()} LogWriter.info(f"Starting ECC task: {task_key}") + + # Reload config at the start of each job to ensure latest settings are used + LogWriter.info("Reloading configuration for new job...") + from common.config import reload_llm_config, reload_graphrag_config, reload_db_config + + llm_result = reload_llm_config() + if llm_result["status"] == "success": + LogWriter.info(f"LLM config reloaded: {llm_result['message']}") + completion_service = llm_config.get("completion_service", {}) + ecc_model = completion_service.get("llm_model", "unknown") + ecc_provider = completion_service.get("llm_service", "unknown") + LogWriter.info( + f"[ECC] Using completion model={ecc_model} (provider={ecc_provider})" + ) + else: + LogWriter.warning(f"LLM config reload had issues: {llm_result['message']}") + + db_result = reload_db_config() + if db_result["status"] == "success": + LogWriter.info( + f"DB config reloaded: {db_result['message']} " + f"(host={db_config.get('hostname')}, " + f"restppPort={db_config.get('restppPort')}, " + f"gsPort={db_config.get('gsPort')})" + ) + else: + LogWriter.warning(f"DB config reload had issues: {db_result['message']}") + + graphrag_result = reload_graphrag_config() + if graphrag_result["status"] == "success": + LogWriter.info(f"GraphRAG config reloaded: {graphrag_result['message']}") + else: + LogWriter.warning(f"GraphRAG config reload had issues: {graphrag_result['message']}") + + # Now run the actual job with fresh config await run_func(graphname, conn) running_tasks[task_key] = {"status": "completed", "completed_at": time.time()} LogWriter.info(f"Completed ECC task: {task_key}") @@ -242,6 +281,17 @@ def consistency_update( response: Response, credentials = Depends(auth_credentials), ): + db_result = reload_db_config() + if db_result["status"] == "success": + LogWriter.info( + f"DB config reloaded: {db_result['message']} " + f"(host={db_config.get('hostname')}, " + f"restppPort={db_config.get('restppPort')}, " + f"gsPort={db_config.get('gsPort')})" + ) + else: + LogWriter.warning(f"DB config reload had issues: {db_result['message']}") + if isinstance(credentials, HTTPBasicCredentials): conn = elevate_db_connection_to_token( db_config.get("hostname"), diff --git a/ecc/app/supportai/util.py b/ecc/app/supportai/util.py index d3906ca..6269624 100644 --- a/ecc/app/supportai/util.py +++ b/ecc/app/supportai/util.py @@ -12,10 +12,11 @@ from pyTigerGraph import TigerGraphConnection from common.config import ( - graphrag_config, embedding_service, + graphrag_config, get_llm_service, - llm_config, + get_completion_config, + get_graphrag_config, ) from common.embeddings.base_embedding_store import EmbeddingStore from common.embeddings.tigergraph_embedding_store import TigerGraphEmbeddingStore @@ -26,7 +27,8 @@ logger = logging.getLogger(__name__) http_timeout = httpx.Timeout(15.0) -tg_sem = asyncio.Semaphore(100) +_default_concurrency = graphrag_config.get("default_concurrency", 10) +tg_sem = asyncio.Semaphore(_default_concurrency * 2) async def install_queries( requried_queries: list[str], @@ -109,10 +111,11 @@ async def init( await install_queries(requried_queries, conn) # extractor - if graphrag_config.get("extractor") == "graphrag": + graph_cfg = get_graphrag_config(conn.graphname) + if graph_cfg.get("extractor") == "graphrag": extractor = GraphExtractor() - elif graphrag_config.get("extractor") == "llm": - extractor = LLMEntityRelationshipExtractor(get_llm_service(llm_config)) + elif graph_cfg.get("extractor") == "llm": + extractor = LLMEntityRelationshipExtractor(get_llm_service(get_completion_config())) else: raise ValueError("Invalid extractor type") diff --git a/ecc/app/supportai/workers.py b/ecc/app/supportai/workers.py index 30fc9ab..2c62169 100644 --- a/ecc/app/supportai/workers.py +++ b/ecc/app/supportai/workers.py @@ -85,7 +85,7 @@ async def chunk_doc( # Use markdown chunker for all documents # Image descriptions wrapped in headers will naturally become single chunks - chunker = ecc_util.get_chunker(chunker_type) + chunker = ecc_util.get_chunker(chunker_type, graphname=conn.graphname) chunks = chunker.chunk(doc["attributes"]["text"]) logger.info(f"Chunking {v_id} into {len(chunks)} chunk(s)") diff --git a/graphrag-ui/Dockerfile b/graphrag-ui/Dockerfile index 9dc17e6..aec0713 100644 --- a/graphrag-ui/Dockerfile +++ b/graphrag-ui/Dockerfile @@ -12,4 +12,4 @@ RUN pnpm run build RUN pnpm i -g serve -CMD [ "serve", "dist" ] +CMD [ "serve", "-s", "dist" ] diff --git a/graphrag-ui/src/actions/ActionProvider.tsx b/graphrag-ui/src/actions/ActionProvider.tsx index d12c9f7..58fc7aa 100644 --- a/graphrag-ui/src/actions/ActionProvider.tsx +++ b/graphrag-ui/src/actions/ActionProvider.tsx @@ -55,8 +55,8 @@ const conversationManager = { if (onNewConversationCallback) { onNewConversationCallback(); } - // Clear conversation data from localStorage - localStorage.removeItem('selectedConversationData'); + // Clear conversation data from sessionStorage + sessionStorage.removeItem('selectedConversationData'); // Don't reload the page - just clear the chat state }, @@ -88,7 +88,7 @@ const ActionProvider: React.FC = ({ const { sendMessage, lastMessage, readyState } = useWebSocket(WS_URL, { onOpen: () => { // Send authentication credentials - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); console.log("Sending credentials, length:", creds ? creds.length : 0); queryGraphragWs2(creds!); @@ -115,7 +115,7 @@ const ActionProvider: React.FC = ({ // Initialize conversation manager and load conversation messages useEffect(() => { - const selectedConversationData = localStorage.getItem('selectedConversationData'); + const selectedConversationData = sessionStorage.getItem('selectedConversationData'); if (selectedConversationData) { try { const data = JSON.parse(selectedConversationData); @@ -163,7 +163,7 @@ const ActionProvider: React.FC = ({ // Create bot message const botMessage = createChatBotMessage({ content: msg.content || "", - response_type: msg.response_type || "text", + response_type: "history", query_sources: msg.query_sources, answered_question: msg.answered_question, }); diff --git a/graphrag-ui/src/components/Bot.tsx b/graphrag-ui/src/components/Bot.tsx index 1f4e4e6..6386dec 100644 --- a/graphrag-ui/src/components/Bot.tsx +++ b/graphrag-ui/src/components/Bot.tsx @@ -22,15 +22,15 @@ import { const Bot = ({ layout, getConversationId }: { layout?: string | undefined, getConversationId?:any }) => { const [store, setStore] = useState(); const [currentDate, setCurrentDate] = useState(''); - const [selectedGraph, setSelectedGraph] = useState(localStorage.getItem("selectedGraph") || ''); - const [ragPattern, setRagPattern] = useState(localStorage.getItem("ragPattern") || ''); + const [selectedGraph, setSelectedGraph] = useState(sessionStorage.getItem("selectedGraph") || ''); + const [ragPattern, setRagPattern] = useState(sessionStorage.getItem("ragPattern") || ''); const navigate = useNavigate(); const location = useLocation(); useEffect(() => { - // Function to load store from localStorage + // Function to load store from sessionStorage const loadStore = () => { - const parseStore = JSON.parse(localStorage.getItem("site") || "{}"); + const parseStore = JSON.parse(sessionStorage.getItem("site") || "{}"); setStore(parseStore); return parseStore; }; @@ -39,23 +39,23 @@ const Bot = ({ layout, getConversationId }: { layout?: string | undefined, getCo const parseStore = loadStore(); // Validate selectedGraph against the current graph list - const storedGraph = localStorage.getItem("selectedGraph"); + const storedGraph = sessionStorage.getItem("selectedGraph"); const availableGraphs = parseStore?.graphs || []; if (!storedGraph || !availableGraphs.includes(storedGraph)) { if (availableGraphs.length > 0) { const firstGraph = availableGraphs[0]; setSelectedGraph(firstGraph); - localStorage.setItem("selectedGraph", firstGraph); + sessionStorage.setItem("selectedGraph", firstGraph); } else { setSelectedGraph(''); - localStorage.removeItem("selectedGraph"); + sessionStorage.removeItem("selectedGraph"); } } - // Set default ragPattern if no value in localStorage - if (!localStorage.getItem("ragPattern")) { + // Set default ragPattern if no value in sessionStorage + if (!sessionStorage.getItem("ragPattern")) { setRagPattern("Hybrid Search"); - localStorage.setItem("ragPattern", "Hybrid Search"); + sessionStorage.setItem("ragPattern", "Hybrid Search"); } const date = new Date(); @@ -78,20 +78,21 @@ const Bot = ({ layout, getConversationId }: { layout?: string | undefined, getCo // Reload graph list when navigating back to chat (location change) useEffect(() => { - const parseStore = JSON.parse(localStorage.getItem("site") || "{}"); + const parseStore = JSON.parse(sessionStorage.getItem("site") || "{}"); setStore(parseStore); }, [location]); const handleSelect = (value) => { setSelectedGraph(value); - localStorage.setItem("selectedGraph", value); + sessionStorage.setItem("selectedGraph", value); + window.dispatchEvent(new Event("graphrag:selectedGraph")); navigate("/chat"); //window.location.reload(); }; const handleSelectRag = (value) => { setRagPattern(value); - localStorage.setItem("ragPattern", value); + sessionStorage.setItem("ragPattern", value); navigate("/chat"); //window.location.reload(); }; diff --git a/graphrag-ui/src/components/ConfigScopeToggle.tsx b/graphrag-ui/src/components/ConfigScopeToggle.tsx new file mode 100644 index 0000000..b311631 --- /dev/null +++ b/graphrag-ui/src/components/ConfigScopeToggle.tsx @@ -0,0 +1,100 @@ +import React from "react"; +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from "@/components/ui/select"; + +interface ConfigScopeToggleProps { + configScope: "global" | "graph"; + selectedGraph: string; + availableGraphs: string[]; + onScopeChange: (scope: "global" | "graph") => void; + onGraphChange: (graph: string) => void; + /** Optional hint rendered below the toggle when graph scope is active and a graph is selected */ + graphSelectedHint?: React.ReactNode; + /** CSS class for the outer wrapper (e.g. "mb-6") */ + className?: string; + /** When true, hides the "Edit global defaults" option and forces graph-specific scope */ + graphOnly?: boolean; +} + +const ConfigScopeToggle: React.FC = ({ + configScope, + selectedGraph, + availableGraphs, + onScopeChange, + onGraphChange, + graphSelectedHint, + className = "mb-6", + graphOnly = false, +}) => { + if (availableGraphs.length === 0) return null; + + return ( +
+ +
+ {!graphOnly && ( + + )} + {!graphOnly && ( + + )} + {graphOnly && ( + Edit graph-specific config for + )} + +
+ {configScope === "graph" && !selectedGraph && ( +

+ Please select a graph to edit its configuration. +

+ )} + {configScope === "graph" && selectedGraph && graphSelectedHint && ( +
+ {graphSelectedHint} +
+ )} +
+ ); +}; + +export default ConfigScopeToggle; diff --git a/graphrag-ui/src/components/CustomChatMessage.tsx b/graphrag-ui/src/components/CustomChatMessage.tsx index 9c2c5ee..0aef2ea 100755 --- a/graphrag-ui/src/components/CustomChatMessage.tsx +++ b/graphrag-ui/src/components/CustomChatMessage.tsx @@ -60,10 +60,10 @@ const AuthenticatedImage: FC<{ src: string; alt: string }> = ({ src, alt }) => { useEffect(() => { const fetchImage = async () => { try { - // Get credentials from localStorage (same pattern as Interact.tsx and SideMenu.tsx) - const creds = localStorage.getItem("creds"); + // Get credentials from sessionStorage (same pattern as Interact.tsx and SideMenu.tsx) + const creds = sessionStorage.getItem("creds"); if (!creds) { - console.error("No credentials found in localStorage"); + console.error("No credentials found in sessionStorage"); setError(true); setLoading(false); return; @@ -173,7 +173,7 @@ export const CustomChatMessage: FC = ({ <> {typeof message === "string" ? (
- {message} + {message}
) : message.key === null ? ( message @@ -181,9 +181,9 @@ export const CustomChatMessage: FC = ({
{message.response_type === "progress" ? ( -

{message.content}

+

{message.content}

) : ( - {message.content} + {message.content} )} = ({ const [feedback, setFeedback] = useState(Feedback.NoFeedback); const sendFeedback = async (action: Feedback, message: Message) => { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); setFeedback(action); message.feedback = action; await fetch(`${GRAPHRAG_URL}/ui/feedback`, { diff --git a/graphrag-ui/src/components/Login.tsx b/graphrag-ui/src/components/Login.tsx index 5455f08..d753124 100644 --- a/graphrag-ui/src/components/Login.tsx +++ b/graphrag-ui/src/components/Login.tsx @@ -36,12 +36,12 @@ const WS_URL = "/ui/ui-login"; export function Login() { const { i18n, t } = useTranslation(); const [user, setUser] = useState(""); - const [token, setToken] = useState(localStorage.getItem("site") || ""); + const [token, setToken] = useState(sessionStorage.getItem("site") || ""); const [hint, setHint] = useState(""); const navigate = useNavigate(); useEffect(() => { - const parseStore = JSON.parse(localStorage.getItem("site") || "{}"); + const parseStore = JSON.parse(sessionStorage.getItem("site") || "{}"); setToken(parseStore); }, []); @@ -49,23 +49,30 @@ export function Login() { const creds = btoa(`${data.email}:${data.password}`); const username = data.email; - const res = await fetch("/ui/ui-login", { + try { + const res = await fetch("/ui/ui-login", { method: "POST", headers: { - Authorization: `Basic ${creds}`, + Authorization: `Basic ${creds}`, }, }); - if (res.ok) { - const data = await res.json(); - localStorage.setItem("creds", creds); - localStorage.setItem("site", JSON.stringify(data)); - setUser(username); - localStorage.setItem("username", username); - navigate("/chat"); - } else { - // setError("Invalid credentials"); // This line was removed from the new_code, so it's removed here. - setHint("Invalid credentials"); + if (res.ok) { + const data = await res.json(); + sessionStorage.setItem("creds", creds); + sessionStorage.setItem("site", JSON.stringify(data)); + setUser(username); + sessionStorage.setItem("username", username); + navigate("/chat"); + } else if (res.status === 401 || res.status === 403) { + setHint("Invalid credentials"); + navigate("/"); + } else { + setHint(`Server error (${res.status}). Please try again later.`); + navigate("/"); + } + } catch { + setHint("Unable to connect to the server. Please try again later."); navigate("/"); } }; @@ -73,7 +80,7 @@ export function Login() { const logOut = () => { setUser(""); setToken(""); - localStorage.removeItem("site"); + sessionStorage.removeItem("site"); navigate("/"); }; diff --git a/graphrag-ui/src/components/ModeToggle.tsx b/graphrag-ui/src/components/ModeToggle.tsx index a8109bc..053ac9f 100644 --- a/graphrag-ui/src/components/ModeToggle.tsx +++ b/graphrag-ui/src/components/ModeToggle.tsx @@ -10,6 +10,7 @@ import { } from "@/components/ui/dropdown-menu"; import { useTheme } from "@/components/ThemeProvider"; import { useConfirm } from "@/hooks/useConfirm"; +import { useRoles } from "@/hooks/useRoles"; export function ModeToggle() { const { setTheme } = useTheme(); @@ -17,6 +18,7 @@ export function ModeToggle() { const location = useLocation(); const isLoginRoute = location.pathname === "/"; const [confirm, confirmDialog] = useConfirm(); + const { rolesLoaded, canAccessSetup } = useRoles(location.pathname); const handleLogout = async () => { // Show confirmation dialog @@ -46,7 +48,7 @@ export function ModeToggle() { return (
- {!isLoginRoute && ( + {!isLoginRoute && rolesLoaded && canAccessSetup && ( + {!isLoginRoute && ( + + )} diff --git a/graphrag-ui/src/components/SideMenu.tsx b/graphrag-ui/src/components/SideMenu.tsx index a1980b5..c4072db 100644 --- a/graphrag-ui/src/components/SideMenu.tsx +++ b/graphrag-ui/src/components/SideMenu.tsx @@ -72,8 +72,8 @@ const SideMenu = ({ height, setGetConversationId }: { height?: string, setGetCon const fetchHistory2 = useCallback(async () => { setConversationId([]); - const creds = localStorage.getItem("creds"); - const username = localStorage.getItem("username"); + const creds = sessionStorage.getItem("creds"); + const username = sessionStorage.getItem("username"); if (!username) { return; @@ -165,7 +165,7 @@ const SideMenu = ({ height, setGetConversationId }: { height?: string, setGetCon const handleNewChat = () => { conversationManager.startNewConversation(); // Clear any selected conversation data - localStorage.removeItem('selectedConversationData'); + sessionStorage.removeItem('selectedConversationData'); // Force navigation by reloading if already on chat page if (window.location.pathname === "/chat") { window.location.reload(); @@ -186,7 +186,7 @@ const SideMenu = ({ height, setGetConversationId }: { height?: string, setGetCon setExpandedConversations(prev => new Set([...prev, id])); // Store conversation data for the chat component - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); if (!creds) { return; } @@ -207,8 +207,8 @@ const SideMenu = ({ height, setGetConversationId }: { height?: string, setGetCon const data = await response.json(); setConversationId2(data); - // Store the conversation data in localStorage for the chat component - localStorage.setItem('selectedConversationData', JSON.stringify(data)); + // Store the conversation data in sessionStorage for the chat component + sessionStorage.setItem('selectedConversationData', JSON.stringify(data)); // Force reload to restart the WebSocket connection with the conversation ID // This ensures the Bot component re-initializes and loads the conversation messages diff --git a/graphrag-ui/src/components/Start.tsx b/graphrag-ui/src/components/Start.tsx index 6ddd927..2746932 100644 --- a/graphrag-ui/src/components/Start.tsx +++ b/graphrag-ui/src/components/Start.tsx @@ -4,7 +4,7 @@ import { useTheme } from "@/components/ThemeProvider"; const questions = (() => { - const selectedGraph = localStorage.getItem('selectedGraph'); + const selectedGraph = sessionStorage.getItem('selectedGraph'); if (selectedGraph?.includes('pyTigerGraphRAG') || selectedGraph?.includes('pyTG')) { return [ diff --git a/graphrag-ui/src/hooks/AuthProvider.tsx b/graphrag-ui/src/hooks/AuthProvider.tsx index e2a32ea..c64006e 100644 --- a/graphrag-ui/src/hooks/AuthProvider.tsx +++ b/graphrag-ui/src/hooks/AuthProvider.tsx @@ -8,7 +8,7 @@ const AuthContext = createContext(); const AuthProvider = ({ children }) => { const [user, setUser] = useState(""); - const [token, setToken] = useState(localStorage.getItem("site") || ""); + const [token, setToken] = useState(sessionStorage.getItem("site") || ""); const navigate = useNavigate(); const loginAction = async (data) => { try { @@ -23,7 +23,7 @@ const AuthProvider = ({ children }) => { if (res.data) { setUser(res.data.user); setToken(res.token); - localStorage.setItem("site", res.token); + sessionStorage.setItem("site", res.token); navigate("/dashboard"); return; } @@ -36,7 +36,7 @@ const AuthProvider = ({ children }) => { const logOut = () => { setUser(""); setToken(""); - localStorage.removeItem("site"); + sessionStorage.removeItem("site"); navigate("/login"); }; diff --git a/graphrag-ui/src/hooks/useIdleTimeout.ts b/graphrag-ui/src/hooks/useIdleTimeout.ts new file mode 100644 index 0000000..07f0486 --- /dev/null +++ b/graphrag-ui/src/hooks/useIdleTimeout.ts @@ -0,0 +1,72 @@ +import { useEffect, useRef, useCallback } from "react"; + +const DEFAULT_TIMEOUT_MS = 60 * 60 * 1000; // 1 hour + +/** + * Monitors user activity and clears the session after a period of inactivity. + * Resets the timer on mouse, keyboard, scroll, and touch events. + * + * Components with long-running operations can pause/resume the timer: + * pauseIdleTimer() — stops the countdown (e.g. before a long ingest call) + * resumeIdleTimer() — restarts the countdown (e.g. when the call finishes) + */ +export function useIdleTimeout(timeoutMs: number = DEFAULT_TIMEOUT_MS) { + const timerRef = useRef | null>(null); + + const handleTimeout = useCallback(() => { + const creds = sessionStorage.getItem("creds"); + if (!creds) return; // Not logged in, nothing to do + + sessionStorage.clear(); + alert("Session expired due to inactivity. Please log in again."); + window.location.href = "/"; + }, []); + + const resetTimer = useCallback(() => { + if (timerRef.current) { + clearTimeout(timerRef.current); + } + // Only set timer if user is logged in + if (sessionStorage.getItem("creds")) { + timerRef.current = setTimeout(handleTimeout, timeoutMs); + } + }, [handleTimeout, timeoutMs]); + + const pause = useCallback(() => { + if (timerRef.current) { + clearTimeout(timerRef.current); + timerRef.current = null; + } + }, []); + + useEffect(() => { + const events = ["mousemove", "mousedown", "keydown", "scroll", "touchstart"]; + + const onPause = () => pause(); + const onResume = () => resetTimer(); + + events.forEach((event) => window.addEventListener(event, resetTimer)); + window.addEventListener("idle-timer-pause", onPause); + window.addEventListener("idle-timer-resume", onResume); + resetTimer(); // Start the timer + + return () => { + events.forEach((event) => window.removeEventListener(event, resetTimer)); + window.removeEventListener("idle-timer-pause", onPause); + window.removeEventListener("idle-timer-resume", onResume); + if (timerRef.current) { + clearTimeout(timerRef.current); + } + }; + }, [resetTimer, pause]); +} + +/** Pause the idle timer (e.g. during long-running backend operations). */ +export function pauseIdleTimer() { + window.dispatchEvent(new Event("idle-timer-pause")); +} + +/** Resume the idle timer (e.g. when a long-running operation completes). */ +export function resumeIdleTimer() { + window.dispatchEvent(new Event("idle-timer-resume")); +} diff --git a/graphrag-ui/src/hooks/useRoles.ts b/graphrag-ui/src/hooks/useRoles.ts new file mode 100644 index 0000000..b9fd578 --- /dev/null +++ b/graphrag-ui/src/hooks/useRoles.ts @@ -0,0 +1,122 @@ +import { useState, useEffect, useCallback } from "react"; + +export interface RolesState { + userRoles: string[]; + graphRoles: Record; + rolesLoaded: boolean; + hasCreds: boolean; + selectedGraph: string; + isSuperuser: boolean; + isGlobalDesigner: boolean; + isGraphAdmin: boolean; + canAccessSetup: boolean; +} + +function parseGraphRoles(raw: unknown): Record { + if (!raw || typeof raw !== "object") return {}; + return Object.fromEntries( + Object.entries(raw as Record).map(([graph, roles]) => [ + graph, + Array.isArray(roles) + ? roles.map((role: string) => role.toLowerCase()) + : [], + ]) + ); +} + +export function useRoles(refreshKey?: unknown): RolesState { + const [userRoles, setUserRoles] = useState([]); + const [graphRoles, setGraphRoles] = useState>({}); + const [rolesLoaded, setRolesLoaded] = useState(false); + const [hasCreds, setHasCreds] = useState(false); + const [selectedGraph, setSelectedGraph] = useState( + sessionStorage.getItem("selectedGraph") || "" + ); + + const loadRoles = useCallback(async () => { + const creds = sessionStorage.getItem("creds"); + if (!creds) { + setUserRoles([]); + setGraphRoles({}); + setHasCreds(false); + setRolesLoaded(true); + return; + } + + // Try loading from sessionStorage first (populated at login) + const site = JSON.parse(sessionStorage.getItem("site") || "{}"); + if (Array.isArray(site.roles)) { + const roles = site.roles.map((role: string) => role.toLowerCase()); + setUserRoles(roles); + setGraphRoles(parseGraphRoles(site.graph_roles)); + setSelectedGraph(sessionStorage.getItem("selectedGraph") || ""); + setHasCreds(true); + setRolesLoaded(true); + return; + } + + // Fallback: fetch from backend (for sessions created before login returned roles) + try { + const response = await fetch("/ui/roles", { + headers: { Authorization: `Basic ${creds}` }, + }); + if (!response.ok) { + setUserRoles([]); + setGraphRoles({}); + setHasCreds(false); + return; + } + const data = await response.json(); + const roles = Array.isArray(data.roles) ? data.roles : []; + setUserRoles(roles.map((role: string) => role.toLowerCase())); + setGraphRoles(parseGraphRoles(data.graph_roles)); + setSelectedGraph(sessionStorage.getItem("selectedGraph") || ""); + setHasCreds(true); + + // Persist to site so subsequent reads don't need a fetch + site.roles = data.roles; + site.graph_roles = data.graph_roles; + sessionStorage.setItem("site", JSON.stringify(site)); + } catch (err) { + console.error("Failed to fetch user roles:", err); + setUserRoles([]); + setGraphRoles({}); + setHasCreds(false); + } finally { + setRolesLoaded(true); + } + }, []); + + useEffect(() => { + loadRoles(); + }, [loadRoles, refreshKey]); + + useEffect(() => { + const handleGraphChange = () => { + setSelectedGraph(sessionStorage.getItem("selectedGraph") || ""); + }; + window.addEventListener("graphrag:selectedGraph", handleGraphChange); + return () => { + window.removeEventListener("graphrag:selectedGraph", handleGraphChange); + }; + }, []); + + const selectedGraphRoles = graphRoles[selectedGraph] || []; + const isSuperuser = userRoles.includes("superuser"); + const isGlobalDesigner = userRoles.includes("globaldesigner"); + const isGraphAdmin = selectedGraphRoles.includes("admin"); + const isAdminOnAnyGraph = (Object.values(graphRoles) as string[][]).some(roles => roles.includes("admin")); + const canAccessSetup = isSuperuser || isGlobalDesigner || isAdminOnAnyGraph; + + return { + userRoles, + graphRoles, + rolesLoaded, + hasCreds, + selectedGraph, + isSuperuser, + isGlobalDesigner, + isGraphAdmin, + canAccessSetup, + }; +} diff --git a/graphrag-ui/src/main.tsx b/graphrag-ui/src/main.tsx index 69cfa82..70a14d3 100755 --- a/graphrag-ui/src/main.tsx +++ b/graphrag-ui/src/main.tsx @@ -1,16 +1,32 @@ import ReactDOM from "react-dom/client"; import App from "./App.tsx"; import "./index.css"; -import { Outlet, RouterProvider, createBrowserRouter } from "react-router-dom"; +import { Outlet, RouterProvider, createBrowserRouter, Navigate } from "react-router-dom"; import Chat from "./pages/Chat"; import ChatDialog from "./pages/ChatDialog.tsx"; -import Setup from "./pages/Setup.tsx"; +import SetupLayout from "./pages/setup/SetupLayout.tsx"; +import KGAdmin from "./pages/setup/KGAdmin.tsx"; +import IngestGraph from "./pages/setup/IngestGraph.tsx"; +import LLMConfig from "./pages/setup/LLMConfig.tsx"; +import GraphDBConfig from "./pages/setup/GraphDBConfig.tsx"; +import GraphRAGConfig from "./pages/setup/GraphRAGConfig.tsx"; +import CustomizePrompts from "./pages/setup/CustomizePrompts.tsx"; import { ThemeProvider } from "./components/ThemeProvider.tsx"; import { ModeToggle } from "@/components/ModeToggle.tsx"; +import { useIdleTimeout } from "./hooks/useIdleTimeout.ts"; import "./components/i18n"; +/** Redirect to login if no credentials in session. */ +const RequireAuth = ({ children }: { children: any }) => { + if (!sessionStorage.getItem("creds")) { + return ; + } + return children; +}; + const Layout = () => { + useIdleTimeout(); return ( @@ -30,19 +46,57 @@ const router = createBrowserRouter([ }, { path: "/chat", - element: , + element: , }, { path: "/chat-dialog", - element: , + element: , }, { path: "/preferences", - element: , + element: , }, { path: "/setup", - element: , + element: , + children: [ + { + path: "", + element: , + }, + { + path: "kg-admin", + element: , + }, + { + path: "kg-admin/ingest", + element: , + }, + { + path: "server-config", + element: , + }, + { + path: "server-config/llm", + element: , + }, + { + path: "server-config/graphdb", + element: , + }, + { + path: "server-config/graphrag", + element: , + }, + { + path: "prompts", + element: , + }, + ], + }, + { + path: "*", + element: , }, ], }, diff --git a/graphrag-ui/src/pages/Setup.tsx b/graphrag-ui/src/pages/Setup.tsx index 3ec977d..d5674ac 100644 --- a/graphrag-ui/src/pages/Setup.tsx +++ b/graphrag-ui/src/pages/Setup.tsx @@ -102,7 +102,7 @@ const [activeTab, setActiveTab] = useState("upload"); if (!ingestGraphName) return; try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const response = await fetch(`/ui/${ingestGraphName}/uploads/list`, { headers: { Authorization: `Basic ${creds}` }, }); @@ -151,7 +151,7 @@ const [activeTab, setActiveTab] = useState("upload"); setUploadMessage("Uploading files..."); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const formData = new FormData(); filesArray.forEach((file) => formData.append("files", file)); @@ -200,7 +200,7 @@ const [activeTab, setActiveTab] = useState("upload"); setUploadMessage("Total size exceeds limit. Uploading files one by one..."); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); let uploadedCount = 0; let failedCount = 0; const totalFiles = filesArray.length; @@ -273,7 +273,7 @@ const [activeTab, setActiveTab] = useState("upload"); console.log("Deleting file:", filename); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); // Delete original file const url = `/ui/${ingestGraphName}/uploads?filename=${encodeURIComponent(filename)}`; @@ -301,7 +301,7 @@ const [activeTab, setActiveTab] = useState("upload"); if (!shouldDelete) return; try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const response = await fetch(`/ui/${ingestGraphName}/uploads`, { method: "DELETE", headers: { Authorization: `Basic ${creds}` }, @@ -323,7 +323,7 @@ const [activeTab, setActiveTab] = useState("upload"); if (!ingestGraphName) return; try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const response = await fetch(`/ui/${ingestGraphName}/cloud/list`, { headers: { Authorization: `Basic ${creds}` }, }); @@ -345,7 +345,7 @@ const [activeTab, setActiveTab] = useState("upload"); setDownloadMessage("Downloading files from cloud storage..."); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); // Prepare request body based on provider let requestBody: any = { provider: cloudProvider }; @@ -437,7 +437,7 @@ const [activeTab, setActiveTab] = useState("upload"); if (!ingestGraphName) return; try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); // Delete original file const url = `/ui/${ingestGraphName}/cloud/delete?filename=${encodeURIComponent(filename)}`; @@ -462,7 +462,7 @@ const [activeTab, setActiveTab] = useState("upload"); if (!shouldDelete) return; try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const response = await fetch(`/ui/${ingestGraphName}/cloud/delete`, { method: "DELETE", headers: { Authorization: `Basic ${creds}` }, @@ -485,7 +485,7 @@ const [activeTab, setActiveTab] = useState("upload"); setIsIngesting(true); setIngestMessage("Ingesting documents into knowledge graph..."); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const folderPath = sourceType === "uploaded" ? `uploads/${ingestGraphName}` : `downloaded_files_cloud/${ingestGraphName}`; // Use existing ingestJobData if available, otherwise construct from folder path @@ -547,7 +547,7 @@ const [activeTab, setActiveTab] = useState("upload"); setIngestMessage("Step 1/2: Creating ingest job..."); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); // Step 1: Create ingest job const createIngestConfig = { @@ -643,7 +643,7 @@ const [activeTab, setActiveTab] = useState("upload"); console.log("fileCount:", fileCount); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); // Call create_ingest to process files const createIngestConfig = { @@ -741,7 +741,7 @@ const [activeTab, setActiveTab] = useState("upload"); setIsIngesting(true); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); let loadingInfo: any = {}; if (skipBDAProcessing) { @@ -815,7 +815,7 @@ const [activeTab, setActiveTab] = useState("upload"); file_path: outputBucket, }; - const filesToIngest = createData.data_source_id.bda_jobs.map((job: any) => job.jobId.split("/")[-1]); + const filesToIngest = createData.data_source_id.bda_jobs.map((job: any) => job.jobId.split("/").at(-1)); setIngestMessage(`Step 2/2: Running document ingest for ${filesToIngest.length} files in ${outputBucket}...`); } @@ -859,7 +859,7 @@ const [activeTab, setActiveTab] = useState("upload"); } try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); const statusResponse = await fetch(`/ui/${graphName}/rebuild_status`, { method: "GET", headers: { @@ -930,7 +930,7 @@ const [activeTab, setActiveTab] = useState("upload"); setRefreshMessage("Verifying rebuild status..."); try { - const creds = localStorage.getItem("creds"); + const creds = sessionStorage.getItem("creds"); // Final status check to prevent race conditions const statusCheckResponse = await fetch(`/ui/${refreshGraphName}/rebuild_status`, { @@ -1000,9 +1000,9 @@ const [activeTab, setActiveTab] = useState("upload"); } }, [refreshOpen, refreshGraphName]); - // Load available graphs from localStorage on mount + // Load available graphs from sessionStorage on mount useEffect(() => { - const store = JSON.parse(localStorage.getItem("site") || "{}"); + const store = JSON.parse(sessionStorage.getItem("site") || "{}"); if (store.graphs && Array.isArray(store.graphs)) { setAvailableGraphs(store.graphs); // Auto-select first graph if available @@ -1036,8 +1036,8 @@ const [activeTab, setActiveTab] = useState("upload"); setStatusType(""); try { - // Get credentials from localStorage - const creds = localStorage.getItem("creds"); + // Get credentials from sessionStorage + const creds = sessionStorage.getItem("creds"); if (!creds) { throw new Error("Not authenticated. Please login first."); } @@ -1104,10 +1104,10 @@ const [activeTab, setActiveTab] = useState("upload"); setAvailableGraphs(prev => { if (!prev.includes(newGraph)) { const updated = [...prev, newGraph]; - // Update localStorage as well - const store = JSON.parse(localStorage.getItem("site") || "{}"); + // Update sessionStorage as well + const store = JSON.parse(sessionStorage.getItem("site") || "{}"); store.graphs = updated; - localStorage.setItem("site", JSON.stringify(store)); + sessionStorage.setItem("site", JSON.stringify(store)); return updated; } return prev; @@ -1139,7 +1139,7 @@ const [activeTab, setActiveTab] = useState("upload"); Back to Chat

- Knowledge Graph Administration + Knowledge Graph Setup

Configure and manage your knowledge graphs diff --git a/graphrag-ui/src/pages/setup/CustomizePrompts.tsx b/graphrag-ui/src/pages/setup/CustomizePrompts.tsx new file mode 100644 index 0000000..d16fe59 --- /dev/null +++ b/graphrag-ui/src/pages/setup/CustomizePrompts.tsx @@ -0,0 +1,353 @@ +import React, { useState, useEffect } from "react"; +import { FileText, Save, Loader2 } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import ConfigScopeToggle from "@/components/ConfigScopeToggle"; +import { useRoles } from "@/hooks/useRoles"; +import { useLocation } from "react-router-dom"; + +const ALL_PROMPT_TYPES = [ + { id: "chatbot_response", name: "Chatbot Responses", description: "Customize how the chatbot responds to user questions" }, + { id: "entity_relationship", name: "Entity Relationships", description: "Configure entity and relationship extraction from document chunks" }, + { id: "community_summarization", name: "Community Summarization", description: "Define how community summaries are generated" }, + { id: "query_generation", name: "Schema Instructions", description: "Configure instructions for schema filtering and schema generation" }, +]; + +const CustomizePrompts = () => { + const location = useLocation(); + const { isSuperuser, isGlobalDesigner } = useRoles(location.pathname); + const graphOnly = !isSuperuser && !isGlobalDesigner; + const [configuredProvider, setConfiguredProvider] = useState(""); + const [isLoading, setIsLoading] = useState(true); + const [expandedPrompt, setExpandedPrompt] = useState(null); + // Only the prompt types returned by the backend (filtered by access level) + const [availablePromptIds, setAvailablePromptIds] = useState([]); + + // Prompts loaded from backend (editable content only) + const [prompts, setPrompts] = useState({ + chatbot_response: "", + entity_relationship: "", + community_summarization: "", + query_generation: "", + }); + + // Template variables that should not be edited (stored separately) + const [promptTemplates, setPromptTemplates] = useState({ + chatbot_response: "", + entity_relationship: "", + community_summarization: "", + query_generation: "", + }); + + // Only render prompt types the backend returned for this user + const promptTypes = ALL_PROMPT_TYPES.filter(p => availablePromptIds.includes(p.id)); + + const [isSaving, setIsSaving] = useState(false); + const [saveMessage, setSaveMessage] = useState(""); + const [saveMessageType, setSaveMessageType] = useState<"success" | "error" | "">(""); + const [configScope, setConfigScope] = useState<"global" | "graph">("global"); + const [selectedGraph, setSelectedGraph] = useState(sessionStorage.getItem("selectedGraph") || ""); + const [availableGraphs, setAvailableGraphs] = useState([]); + + const handleSavePrompt = async (promptId: string) => { + setIsSaving(true); + setSaveMessage(""); + setSaveMessageType(""); + + try { + const creds = sessionStorage.getItem("creds"); + const query = selectedGraph ? `?graphname=${encodeURIComponent(selectedGraph)}` : ""; + const response = await fetch(`/ui/prompts${query}`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Authorization: `Basic ${creds}`, + }, + body: JSON.stringify({ + prompt_type: promptId, + editable_content: prompts[promptId as keyof typeof prompts], + template_variables: promptTemplates[promptId as keyof typeof promptTemplates], + graphname: selectedGraph || undefined, + }), + }); + + if (!response.ok) { + const errorData = await response.json(); + throw new Error(errorData.detail || "Failed to save prompt"); + } + + const result = await response.json(); + setSaveMessage(`✅ ${result.message}`); + setSaveMessageType("success"); + setExpandedPrompt(null); // Collapse after successful save + } catch (error: any) { + console.error("Error saving prompt:", error); + setSaveMessage(`❌ Error: ${error.message}`); + setSaveMessageType("error"); + } finally { + setIsSaving(false); + } + }; + + const handlePromptChange = (promptId: string, value: string) => { + setPrompts(prev => ({ ...prev, [promptId]: value })); + }; + + const fetchPrompts = async (graphname?: string) => { + setIsLoading(true); + const effectiveGraph = graphname ?? selectedGraph; + try { + const creds = sessionStorage.getItem("creds"); + const query = effectiveGraph ? `?graphname=${encodeURIComponent(effectiveGraph)}` : ""; + const response = await fetch(`/ui/prompts${query}`, { + headers: { Authorization: `Basic ${creds}` }, + }); + + if (!response.ok) { + throw new Error("Failed to fetch prompts"); + } + + const data = await response.json(); + + // Track which prompts this user is allowed to see (backend filters by role) + setAvailablePromptIds(Object.keys(data.prompts)); + + // Update prompts with fetched data (editable content only) + setPrompts({ + chatbot_response: data.prompts.chatbot_response?.editable_content !== undefined + ? data.prompts.chatbot_response.editable_content + : (typeof data.prompts.chatbot_response === 'string' ? data.prompts.chatbot_response : ""), + entity_relationship: data.prompts.entity_relationship?.editable_content !== undefined + ? data.prompts.entity_relationship.editable_content + : (typeof data.prompts.entity_relationship === 'string' ? data.prompts.entity_relationship : ""), + community_summarization: data.prompts.community_summarization?.editable_content !== undefined + ? data.prompts.community_summarization.editable_content + : (typeof data.prompts.community_summarization === 'string' ? data.prompts.community_summarization : ""), + query_generation: data.prompts.query_generation?.editable_content !== undefined + ? data.prompts.query_generation.editable_content + : (typeof data.prompts.query_generation === 'string' ? data.prompts.query_generation : ""), + }); + + // Store template variables separately + setPromptTemplates({ + chatbot_response: data.prompts.chatbot_response?.template_variables || "", + entity_relationship: data.prompts.entity_relationship?.template_variables || "", + community_summarization: data.prompts.community_summarization?.template_variables || "", + query_generation: data.prompts.query_generation?.template_variables || "", + }); + + // Set configured provider + const providerMap: Record = { + openai: "OpenAI", + azure: "Azure OpenAI", + genai: "Google GenAI (Gemini)", + vertexai: "Google Vertex AI", + bedrock: "AWS Bedrock", + ollama: "Ollama", + }; + const provider = data.configured_provider?.toLowerCase() || "openai"; + setConfiguredProvider(providerMap[provider] || data.configured_provider || "OpenAI"); + } catch (error) { + console.error("Error loading prompts:", error); + setConfiguredProvider("OpenAI"); + } finally { + setIsLoading(false); + } + }; + + // Fetch prompts and graph list on mount + useEffect(() => { + const site = JSON.parse(sessionStorage.getItem("site") || "{}"); + const graphs = site.graphs || []; + setAvailableGraphs(graphs); + const storedGraph = sessionStorage.getItem("selectedGraph") || ""; + if (graphOnly) { + // Graph admins must use graph-specific scope + setConfigScope("graph"); + const graph = storedGraph || (graphs.length > 0 ? graphs[0] : ""); + if (graph) { + setSelectedGraph(graph); + sessionStorage.setItem("selectedGraph", graph); + window.dispatchEvent(new Event("graphrag:selectedGraph")); + fetchPrompts(graph); + } + } else if (storedGraph) { + setConfigScope("graph"); + setSelectedGraph(storedGraph); + fetchPrompts(storedGraph); + } else { + fetchPrompts(""); + } + }, [graphOnly]); + + return ( +

+
+
+
+
+ +
+
+

+ Customize Prompts +

+

+ Customize the core prompts used by GraphRAG +

+
+
+
+ + {/* Config Scope Toggle */} + { + setConfigScope(scope); + setSaveMessage(""); + setSaveMessageType(""); + if (scope === "global") { + setSelectedGraph(""); + sessionStorage.removeItem("selectedGraph"); + window.dispatchEvent(new Event("graphrag:selectedGraph")); + fetchPrompts(""); + } else if (selectedGraph) { + fetchPrompts(selectedGraph); + } + }} + onGraphChange={(value) => { + setConfigScope("graph"); + setSelectedGraph(value); + sessionStorage.setItem("selectedGraph", value); + window.dispatchEvent(new Event("graphrag:selectedGraph")); + setSaveMessage(""); + setSaveMessageType(""); + fetchPrompts(value); + }} + graphSelectedHint="Only customized prompts are stored per graph. Others fall back to global defaults." + /> + +
+
+ {/* Configured Provider - Read Only */} +
+ +
+ + {isLoading && ( +
+ +
+ )} +
+

+ Prompts are configured for your currently active LLM provider. Change provider in Server Configuration. +

+
+ + {/* Save Message */} + {saveMessage && ( +
+ {saveMessage} +
+ )} + + {/* Prompt Templates */} +
+

+ Prompt Templates +

+ +
+ {promptTypes.map((prompt) => ( +
+
+
+

+ {prompt.name} +

+

+ {prompt.description} +

+
+ +
+ + {expandedPrompt === prompt.id && ( +
+