Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion src/llm_orchestration_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
INPUT_GUARDRAIL_VIOLATION_MESSAGES,
OUTPUT_GUARDRAIL_VIOLATION_MESSAGE,
OUTPUT_GUARDRAIL_VIOLATION_MESSAGES,
QUERY_VALIDATION_FAILED_MESSAGES,
get_localized_message,
GUARDRAILS_BLOCKED_PHRASES,
TEST_DEPLOYMENT_ENVIRONMENT,
Expand All @@ -52,6 +53,7 @@
from src.utils.production_store import get_production_store
from src.utils.language_detector import detect_language, get_language_name
from src.utils.prompt_config_loader import PromptConfigurationLoader
from src.utils.query_validator import validate_query_basic
from src.guardrails import NeMoRailsAdapter, GuardrailCheckResult
from src.contextual_retrieval import ContextualRetriever
from src.llm_orchestrator_config.exceptions import (
Expand Down Expand Up @@ -170,7 +172,36 @@ def process_orchestration_request(
# Using setattr for type safety - adds dynamic attribute to Pydantic model instance
setattr(request, "_detected_language", detected_language)

# Initialize all service components
# STEP 0.5: Basic Query Validation (before expensive component initialization)
validation_result = validate_query_basic(request.message)
if not validation_result.is_valid:
logger.info(
f"[{request.chatId}] Query validation failed: {validation_result.rejection_reason}"
)
# Get localized message
validation_msg = get_localized_message(
QUERY_VALIDATION_FAILED_MESSAGES, detected_language
)

# Return appropriate response type without initializing components
if request.environment == TEST_DEPLOYMENT_ENVIRONMENT:
return TestOrchestrationResponse(
llmServiceActive=True,
questionOutOfLLMScope=False,
inputGuardFailed=False,
content=validation_msg,
chunks=None,
)
else:
return OrchestrationResponse(
chatId=request.chatId,
llmServiceActive=True,
questionOutOfLLMScope=False,
inputGuardFailed=False,
content=validation_msg,
)

# Initialize all service components (only for valid queries)
components = self._initialize_service_components(request)

# Execute the orchestration pipeline
Expand Down Expand Up @@ -299,6 +330,22 @@ async def stream_orchestration_response(
# Using setattr for type safety - adds dynamic attribute to Pydantic model instance
setattr(request, "_detected_language", detected_language)

# Step 0.5: Basic Query Validation (before guardrails)
validation_result = validate_query_basic(request.message)
if not validation_result.is_valid:
logger.info(
f"[{request.chatId}] Streaming - Query validation failed: {validation_result.rejection_reason}"
)
# Get localized message
validation_msg = get_localized_message(
QUERY_VALIDATION_FAILED_MESSAGES, detected_language
)

# Yield SSE format error + END marker
yield self._format_sse(request.chatId, validation_msg)
yield self._format_sse(request.chatId, "END")
return # Stop processing

# Use StreamManager for centralized tracking and guaranteed cleanup
async with stream_manager.managed_stream(
chat_id=request.chatId, author_id=request.authorId
Expand Down Expand Up @@ -953,6 +1000,9 @@ def _execute_orchestration_pipeline(
timing_dict: Dict[str, float],
) -> Union[OrchestrationResponse, TestOrchestrationResponse]:
"""Execute the main orchestration pipeline with all components."""
# Note: Query validation now happens in process_orchestration_request()
# before component initialization for true early rejection

# Step 1: Input Guardrails Check
if components["guardrails_adapter"]:
start_time = time.time()
Expand Down
54 changes: 54 additions & 0 deletions src/llm_orchestration_service_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,6 +839,60 @@ def refresh_prompt_config(http_request: Request) -> Dict[str, Any]:
},
) from e

try:
success = orchestration_service.prompt_config_loader.force_refresh()

if success:
# Get prompt metadata without exposing content (security)
custom_instructions = (
orchestration_service.prompt_config_loader.get_custom_instructions()
)
prompt_length = len(custom_instructions)

# Generate hash for verification purposes (without exposing content)
import hashlib

prompt_hash = hashlib.sha256(custom_instructions.encode()).hexdigest()[:16]

logger.info(
f"Prompt configuration cache refreshed successfully ({prompt_length} chars)"
)

return {
"refreshed": True,
"message": "Prompt configuration refreshed successfully",
"prompt_length": prompt_length,
"content_hash": prompt_hash, # Safe: hash instead of preview
}
else:
# No fresh data loaded - could be fetch failure or truly not found
error_id = generate_error_id()
logger.warning(
f"[{error_id}] Prompt configuration refresh returned empty result"
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
"error": "No prompt configuration found in database",
"error_id": error_id,
},
)

except HTTPException:
# Re-raise HTTP exceptions as-is
raise
except Exception as e:
# Unexpected errors during refresh
error_id = generate_error_id()
logger.error(f"[{error_id}] Failed to refresh prompt configuration: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail={
"error": "Failed to refresh prompt configuration",
"error_id": error_id,
},
) from e


if __name__ == "__main__":
logger.info("Starting LLM Orchestration Service API server on port 8100")
Expand Down
12 changes: 9 additions & 3 deletions src/llm_orchestrator_config/llm_ochestrator_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@
"en": "I apologize, but I'm unable to provide a response as it may violate our usage policies.",
}

# Query validation messages - single generic message for all rejection types
# (empty queries, special characters only, too short, repetitive characters)
QUERY_VALIDATION_FAILED_MESSAGES = {
"et": "Palun esitage kehtiv küsimus või sõnum, et ma saaksin teid aidata."
}

# Legacy constants for backward compatibility (English defaults)
OUT_OF_SCOPE_MESSAGE = OUT_OF_SCOPE_MESSAGES["en"]
TECHNICAL_ISSUE_MESSAGE = TECHNICAL_ISSUE_MESSAGES["en"]
Expand Down Expand Up @@ -106,9 +112,9 @@


# Helper function to get localized messages
def get_localized_message(message_dict: dict, language_code: str = "en") -> str:
def get_localized_message(message_dict: dict, language_code: str = "et") -> str:
"""
Get message in the specified language, fallback to English.
Get message in the specified language, fallback to Estonian.

Args:
message_dict: Dictionary with language codes as keys
Expand All @@ -117,7 +123,7 @@ def get_localized_message(message_dict: dict, language_code: str = "en") -> str:
Returns:
Localized message string
"""
return message_dict.get(language_code, message_dict.get("en", ""))
return message_dict.get(language_code, message_dict.get("et", ""))


# Service endpoints
Expand Down
17 changes: 9 additions & 8 deletions src/models/request_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,20 @@ class OrchestrationRequest(BaseModel):
def validate_and_sanitize_message(cls, v: str) -> str:
"""Sanitize and validate user message.

Note: Content safety checks (prompt injection, PII, harmful content)
Note: This validator only handles security/format concerns:
- XSS/HTML sanitization
- Maximum length enforcement

Query quality validation (empty messages, special chars, etc.) is handled
by the business logic layer (query_validator) with localized error messages.

Content safety checks (prompt injection, PII, harmful content)
are handled by NeMo Guardrails after this validation layer.
"""
# Sanitize HTML/XSS and normalize whitespace
v = InputSanitizer.sanitize_message(v)

# Check if message is empty after sanitization
if not v or len(v.strip()) < 3:
raise ValueError(
"Message must contain at least 3 characters after sanitization"
)

# Check length after sanitization
# Check length after sanitization (resource protection)
if len(v) > StreamConfig.MAX_MESSAGE_LENGTH:
raise ValueError(
f"Message exceeds maximum length of {StreamConfig.MAX_MESSAGE_LENGTH} characters"
Expand Down
112 changes: 112 additions & 0 deletions src/utils/query_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
"""Basic query validation for empty/meaningless inputs.

This module provides lightweight, rule-based validation to reject syntactically
invalid queries before they reach expensive LLM-based processing stages.

Validation checks (all syntactic, NO semantic):
- Empty or whitespace-only messages
- Messages containing only special characters/punctuation (including unicode)
- Messages with too few meaningful characters (< 2)
- Messages with only repetitive characters (e.g., "aaaa", "????")
- Emoji-only messages

Out of scope for this module:
- Semantic validation (greetings, chitchat, intent detection)
- Language quality checks
- Content policy checks (handled by guardrails)

Design decisions:
- Numbers are considered valid (e.g., "123" passes validation)
- Mixed alphanumeric with punctuation is valid (e.g., "ab!" passes)
- Unicode punctuation is treated same as ASCII punctuation
- Emojis are not considered meaningful characters
"""

import re
from typing import Optional
from pydantic import BaseModel


class QueryValidationResult(BaseModel):
"""Result of basic query validation.

Attributes:
is_valid: True if query passes all validation checks
rejection_reason: Optional reason code if validation fails
(empty, special_chars_only, too_short, repetitive)
"""

is_valid: bool
rejection_reason: Optional[str] = None


def validate_query_basic(query: str) -> QueryValidationResult:
"""
Validate query for basic syntactic issues (NOT semantic).

This is a fast, rule-based check that runs before expensive operations
like guardrails or prompt refinement. It only catches obvious syntactic
issues, not semantic problems.

Args:
query: User's input message to validate

Returns:
QueryValidationResult with is_valid flag and optional rejection_reason

Examples:
Valid queries:
>>> validate_query_basic("How to apply for benefits?")
QueryValidationResult(is_valid=True, rejection_reason=None)
>>> validate_query_basic("hi")
QueryValidationResult(is_valid=True, rejection_reason=None)
>>> validate_query_basic("123")
QueryValidationResult(is_valid=True, rejection_reason=None)
>>> validate_query_basic("ab!")
QueryValidationResult(is_valid=True, rejection_reason=None)

Invalid queries:
>>> validate_query_basic("...")
QueryValidationResult(is_valid=False, rejection_reason='special_chars_only')
>>> validate_query_basic("")
QueryValidationResult(is_valid=False, rejection_reason='empty')
>>> validate_query_basic("????")
QueryValidationResult(is_valid=False, rejection_reason='repetitive')
>>> validate_query_basic("a")
QueryValidationResult(is_valid=False, rejection_reason='too_short')
>>> validate_query_basic("😀😀😀")
QueryValidationResult(is_valid=False, rejection_reason='special_chars_only')
"""
# Trim whitespace
query = query.strip()

# Check 1: Empty query
if not query:
return QueryValidationResult(is_valid=False, rejection_reason="empty")

# Check 2: Only special characters/punctuation (including unicode and emojis)
# Remove all alphanumeric characters (letters and numbers in any language)
# If nothing remains or only punctuation/symbols/emojis, reject
alphanumeric_pattern = re.compile(r"[\w]", re.UNICODE)
has_alphanumeric = bool(alphanumeric_pattern.search(query))

if not has_alphanumeric:
# No letters or numbers found - only punctuation/symbols/emojis
return QueryValidationResult(
is_valid=False, rejection_reason="special_chars_only"
)

# Check 3: Too short (< 2 meaningful characters)
# Extract alphanumeric characters (letters + numbers, unicode-aware)
meaningful_chars = alphanumeric_pattern.findall(query)
if len(meaningful_chars) < 2:
return QueryValidationResult(is_valid=False, rejection_reason="too_short")

# Check 4: Only repetitive characters (e.g., "aaaa", "????", "111")
# If all meaningful characters are the same (case-insensitive), likely spam
unique_chars = {c.lower() for c in meaningful_chars}
if len(unique_chars) == 1:
return QueryValidationResult(is_valid=False, rejection_reason="repetitive")

# Passed all checks - query is syntactically valid
return QueryValidationResult(is_valid=True)
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""Pytest configuration for test discovery and imports."""

import sys
from pathlib import Path

# Add the project root to Python path so tests can import from src
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
Loading
Loading