From 0d11da2f3e5680aa8c6e48e75e1b5d413865af86 Mon Sep 17 00:00:00 2001 From: Charith Nuwan Bimsara <59943919+nuwangeek@users.noreply.github.com> Date: Tue, 10 Feb 2026 13:23:06 +0530 Subject: [PATCH] Added pre validation step (#298) * updated docker compose ec2 * integrate streaming endpoint with test prodction connection page * formatted response with markdown * fe logic for the encryption * vault secret update after fixing issues * fixed formatting issue * integration with be * update cron manager vault script * tested integration of vault security update * fix security issues * creation success model changes * clean vite config generated files * fixed issue references are not sending with streming tokens * complete #192 and #206 bug fixes * production inference display logic change * change production inference display logic * fixed requested issue * Refactor Docker Compose configuration for vault agents and update CSP settings * Remove obsolete Vite configuration files and associated plugins * prompt coniguration backend to be testing * custom prompt configuration update and fixed Pyright issues * fixed copilot reviews * pre validation step added when user query is inserted * added more validation cases * fixed review comments --------- Co-authored-by: Thiru Dinesh <56014038+Thirunayan22@users.noreply.github.com> Co-authored-by: Thiru Dinesh Co-authored-by: erangi-ar Co-authored-by: erangi-ar <111747955+erangi-ar@users.noreply.github.com> --- src/llm_orchestration_service.py | 52 ++- src/llm_orchestration_service_api.py | 54 ++++ .../llm_ochestrator_constants.py | 12 +- src/models/request_models.py | 17 +- src/utils/query_validator.py | 112 +++++++ tests/conftest.py | 8 + tests/test_query_validator.py | 306 ++++++++++++++++++ 7 files changed, 549 insertions(+), 12 deletions(-) create mode 100644 src/utils/query_validator.py create mode 100644 tests/conftest.py create mode 100644 tests/test_query_validator.py diff --git a/src/llm_orchestration_service.py b/src/llm_orchestration_service.py index 05303c2..92dd7b0 100644 --- a/src/llm_orchestration_service.py +++ b/src/llm_orchestration_service.py @@ -34,6 +34,7 @@ INPUT_GUARDRAIL_VIOLATION_MESSAGES, OUTPUT_GUARDRAIL_VIOLATION_MESSAGE, OUTPUT_GUARDRAIL_VIOLATION_MESSAGES, + QUERY_VALIDATION_FAILED_MESSAGES, get_localized_message, GUARDRAILS_BLOCKED_PHRASES, TEST_DEPLOYMENT_ENVIRONMENT, @@ -52,6 +53,7 @@ from src.utils.production_store import get_production_store from src.utils.language_detector import detect_language, get_language_name from src.utils.prompt_config_loader import PromptConfigurationLoader +from src.utils.query_validator import validate_query_basic from src.guardrails import NeMoRailsAdapter, GuardrailCheckResult from src.contextual_retrieval import ContextualRetriever from src.llm_orchestrator_config.exceptions import ( @@ -170,7 +172,36 @@ def process_orchestration_request( # Using setattr for type safety - adds dynamic attribute to Pydantic model instance setattr(request, "_detected_language", detected_language) - # Initialize all service components + # STEP 0.5: Basic Query Validation (before expensive component initialization) + validation_result = validate_query_basic(request.message) + if not validation_result.is_valid: + logger.info( + f"[{request.chatId}] Query validation failed: {validation_result.rejection_reason}" + ) + # Get localized message + validation_msg = get_localized_message( + QUERY_VALIDATION_FAILED_MESSAGES, detected_language + ) + + # Return appropriate response type without initializing components + if request.environment == TEST_DEPLOYMENT_ENVIRONMENT: + return TestOrchestrationResponse( + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=validation_msg, + chunks=None, + ) + else: + return OrchestrationResponse( + chatId=request.chatId, + llmServiceActive=True, + questionOutOfLLMScope=False, + inputGuardFailed=False, + content=validation_msg, + ) + + # Initialize all service components (only for valid queries) components = self._initialize_service_components(request) # Execute the orchestration pipeline @@ -299,6 +330,22 @@ async def stream_orchestration_response( # Using setattr for type safety - adds dynamic attribute to Pydantic model instance setattr(request, "_detected_language", detected_language) + # Step 0.5: Basic Query Validation (before guardrails) + validation_result = validate_query_basic(request.message) + if not validation_result.is_valid: + logger.info( + f"[{request.chatId}] Streaming - Query validation failed: {validation_result.rejection_reason}" + ) + # Get localized message + validation_msg = get_localized_message( + QUERY_VALIDATION_FAILED_MESSAGES, detected_language + ) + + # Yield SSE format error + END marker + yield self._format_sse(request.chatId, validation_msg) + yield self._format_sse(request.chatId, "END") + return # Stop processing + # Use StreamManager for centralized tracking and guaranteed cleanup async with stream_manager.managed_stream( chat_id=request.chatId, author_id=request.authorId @@ -953,6 +1000,9 @@ def _execute_orchestration_pipeline( timing_dict: Dict[str, float], ) -> Union[OrchestrationResponse, TestOrchestrationResponse]: """Execute the main orchestration pipeline with all components.""" + # Note: Query validation now happens in process_orchestration_request() + # before component initialization for true early rejection + # Step 1: Input Guardrails Check if components["guardrails_adapter"]: start_time = time.time() diff --git a/src/llm_orchestration_service_api.py b/src/llm_orchestration_service_api.py index 3ed24ce..8bdc80c 100644 --- a/src/llm_orchestration_service_api.py +++ b/src/llm_orchestration_service_api.py @@ -839,6 +839,60 @@ def refresh_prompt_config(http_request: Request) -> Dict[str, Any]: }, ) from e + try: + success = orchestration_service.prompt_config_loader.force_refresh() + + if success: + # Get prompt metadata without exposing content (security) + custom_instructions = ( + orchestration_service.prompt_config_loader.get_custom_instructions() + ) + prompt_length = len(custom_instructions) + + # Generate hash for verification purposes (without exposing content) + import hashlib + + prompt_hash = hashlib.sha256(custom_instructions.encode()).hexdigest()[:16] + + logger.info( + f"Prompt configuration cache refreshed successfully ({prompt_length} chars)" + ) + + return { + "refreshed": True, + "message": "Prompt configuration refreshed successfully", + "prompt_length": prompt_length, + "content_hash": prompt_hash, # Safe: hash instead of preview + } + else: + # No fresh data loaded - could be fetch failure or truly not found + error_id = generate_error_id() + logger.warning( + f"[{error_id}] Prompt configuration refresh returned empty result" + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={ + "error": "No prompt configuration found in database", + "error_id": error_id, + }, + ) + + except HTTPException: + # Re-raise HTTP exceptions as-is + raise + except Exception as e: + # Unexpected errors during refresh + error_id = generate_error_id() + logger.error(f"[{error_id}] Failed to refresh prompt configuration: {e}") + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": "Failed to refresh prompt configuration", + "error_id": error_id, + }, + ) from e + if __name__ == "__main__": logger.info("Starting LLM Orchestration Service API server on port 8100") diff --git a/src/llm_orchestrator_config/llm_ochestrator_constants.py b/src/llm_orchestrator_config/llm_ochestrator_constants.py index 90d01ed..789ef62 100644 --- a/src/llm_orchestrator_config/llm_ochestrator_constants.py +++ b/src/llm_orchestrator_config/llm_ochestrator_constants.py @@ -23,6 +23,12 @@ "en": "I apologize, but I'm unable to provide a response as it may violate our usage policies.", } +# Query validation messages - single generic message for all rejection types +# (empty queries, special characters only, too short, repetitive characters) +QUERY_VALIDATION_FAILED_MESSAGES = { + "et": "Palun esitage kehtiv küsimus või sõnum, et ma saaksin teid aidata." +} + # Legacy constants for backward compatibility (English defaults) OUT_OF_SCOPE_MESSAGE = OUT_OF_SCOPE_MESSAGES["en"] TECHNICAL_ISSUE_MESSAGE = TECHNICAL_ISSUE_MESSAGES["en"] @@ -106,9 +112,9 @@ # Helper function to get localized messages -def get_localized_message(message_dict: dict, language_code: str = "en") -> str: +def get_localized_message(message_dict: dict, language_code: str = "et") -> str: """ - Get message in the specified language, fallback to English. + Get message in the specified language, fallback to Estonian. Args: message_dict: Dictionary with language codes as keys @@ -117,7 +123,7 @@ def get_localized_message(message_dict: dict, language_code: str = "en") -> str: Returns: Localized message string """ - return message_dict.get(language_code, message_dict.get("en", "")) + return message_dict.get(language_code, message_dict.get("et", "")) # Service endpoints diff --git a/src/models/request_models.py b/src/models/request_models.py index f4a073c..689c68c 100644 --- a/src/models/request_models.py +++ b/src/models/request_models.py @@ -66,19 +66,20 @@ class OrchestrationRequest(BaseModel): def validate_and_sanitize_message(cls, v: str) -> str: """Sanitize and validate user message. - Note: Content safety checks (prompt injection, PII, harmful content) + Note: This validator only handles security/format concerns: + - XSS/HTML sanitization + - Maximum length enforcement + + Query quality validation (empty messages, special chars, etc.) is handled + by the business logic layer (query_validator) with localized error messages. + + Content safety checks (prompt injection, PII, harmful content) are handled by NeMo Guardrails after this validation layer. """ # Sanitize HTML/XSS and normalize whitespace v = InputSanitizer.sanitize_message(v) - # Check if message is empty after sanitization - if not v or len(v.strip()) < 3: - raise ValueError( - "Message must contain at least 3 characters after sanitization" - ) - - # Check length after sanitization + # Check length after sanitization (resource protection) if len(v) > StreamConfig.MAX_MESSAGE_LENGTH: raise ValueError( f"Message exceeds maximum length of {StreamConfig.MAX_MESSAGE_LENGTH} characters" diff --git a/src/utils/query_validator.py b/src/utils/query_validator.py new file mode 100644 index 0000000..98766f7 --- /dev/null +++ b/src/utils/query_validator.py @@ -0,0 +1,112 @@ +"""Basic query validation for empty/meaningless inputs. + +This module provides lightweight, rule-based validation to reject syntactically +invalid queries before they reach expensive LLM-based processing stages. + +Validation checks (all syntactic, NO semantic): +- Empty or whitespace-only messages +- Messages containing only special characters/punctuation (including unicode) +- Messages with too few meaningful characters (< 2) +- Messages with only repetitive characters (e.g., "aaaa", "????") +- Emoji-only messages + +Out of scope for this module: +- Semantic validation (greetings, chitchat, intent detection) +- Language quality checks +- Content policy checks (handled by guardrails) + +Design decisions: +- Numbers are considered valid (e.g., "123" passes validation) +- Mixed alphanumeric with punctuation is valid (e.g., "ab!" passes) +- Unicode punctuation is treated same as ASCII punctuation +- Emojis are not considered meaningful characters +""" + +import re +from typing import Optional +from pydantic import BaseModel + + +class QueryValidationResult(BaseModel): + """Result of basic query validation. + + Attributes: + is_valid: True if query passes all validation checks + rejection_reason: Optional reason code if validation fails + (empty, special_chars_only, too_short, repetitive) + """ + + is_valid: bool + rejection_reason: Optional[str] = None + + +def validate_query_basic(query: str) -> QueryValidationResult: + """ + Validate query for basic syntactic issues (NOT semantic). + + This is a fast, rule-based check that runs before expensive operations + like guardrails or prompt refinement. It only catches obvious syntactic + issues, not semantic problems. + + Args: + query: User's input message to validate + + Returns: + QueryValidationResult with is_valid flag and optional rejection_reason + + Examples: + Valid queries: + >>> validate_query_basic("How to apply for benefits?") + QueryValidationResult(is_valid=True, rejection_reason=None) + >>> validate_query_basic("hi") + QueryValidationResult(is_valid=True, rejection_reason=None) + >>> validate_query_basic("123") + QueryValidationResult(is_valid=True, rejection_reason=None) + >>> validate_query_basic("ab!") + QueryValidationResult(is_valid=True, rejection_reason=None) + + Invalid queries: + >>> validate_query_basic("...") + QueryValidationResult(is_valid=False, rejection_reason='special_chars_only') + >>> validate_query_basic("") + QueryValidationResult(is_valid=False, rejection_reason='empty') + >>> validate_query_basic("????") + QueryValidationResult(is_valid=False, rejection_reason='repetitive') + >>> validate_query_basic("a") + QueryValidationResult(is_valid=False, rejection_reason='too_short') + >>> validate_query_basic("😀😀😀") + QueryValidationResult(is_valid=False, rejection_reason='special_chars_only') + """ + # Trim whitespace + query = query.strip() + + # Check 1: Empty query + if not query: + return QueryValidationResult(is_valid=False, rejection_reason="empty") + + # Check 2: Only special characters/punctuation (including unicode and emojis) + # Remove all alphanumeric characters (letters and numbers in any language) + # If nothing remains or only punctuation/symbols/emojis, reject + alphanumeric_pattern = re.compile(r"[\w]", re.UNICODE) + has_alphanumeric = bool(alphanumeric_pattern.search(query)) + + if not has_alphanumeric: + # No letters or numbers found - only punctuation/symbols/emojis + return QueryValidationResult( + is_valid=False, rejection_reason="special_chars_only" + ) + + # Check 3: Too short (< 2 meaningful characters) + # Extract alphanumeric characters (letters + numbers, unicode-aware) + meaningful_chars = alphanumeric_pattern.findall(query) + if len(meaningful_chars) < 2: + return QueryValidationResult(is_valid=False, rejection_reason="too_short") + + # Check 4: Only repetitive characters (e.g., "aaaa", "????", "111") + # If all meaningful characters are the same (case-insensitive), likely spam + unique_chars = {c.lower() for c in meaningful_chars} + if len(unique_chars) == 1: + return QueryValidationResult(is_valid=False, rejection_reason="repetitive") + + # Passed all checks - query is syntactically valid + return QueryValidationResult(is_valid=True) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d1633b7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,8 @@ +"""Pytest configuration for test discovery and imports.""" + +import sys +from pathlib import Path + +# Add the project root to Python path so tests can import from src +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) diff --git a/tests/test_query_validator.py b/tests/test_query_validator.py new file mode 100644 index 0000000..361b1b9 --- /dev/null +++ b/tests/test_query_validator.py @@ -0,0 +1,306 @@ +"""Unit tests for query validator. + +Tests cover all documented examples, edge cases, and boundary conditions +to prevent regressions as validation rules evolve. +""" + +import pytest +from src.utils.query_validator import validate_query_basic, QueryValidationResult + + +class TestQueryValidatorEmpty: + """Test empty and whitespace-only queries.""" + + @pytest.mark.parametrize( + "query", + [ + "", + " ", + "\t", + "\n", + "\t\n ", + " \t\n\r ", + ], + ) + def test_empty_queries_rejected(self, query): + """Empty or whitespace-only queries should be rejected.""" + result = validate_query_basic(query) + assert result.is_valid is False + assert result.rejection_reason == "empty" + + +class TestQueryValidatorSpecialCharsOnly: + """Test queries with only special characters or punctuation.""" + + @pytest.mark.parametrize( + "query", + [ + "...", + "???", + "!!!", + "!@#$%^&*()", + ".,?!;:", + "---", + # Note: "___" is repetitive, not special_chars (underscore matches \w) + "[]{}()", + "<>", + "//", + "\\\\", + "++", + "**", + "~~", + "``", + "''", + '""', + "—", + "–", + "''", + "•••", + "→→", + "※※", + "!?!?", + "...???", + "!!! ???", + "????", # 4 question marks - special chars only + ], + ) + def test_special_chars_only_rejected(self, query): + """Queries with only special characters should be rejected.""" + result = validate_query_basic(query) + assert result.is_valid is False + assert result.rejection_reason == "special_chars_only" + + +class TestQueryValidatorTooShort: + """Test queries that are too short.""" + + @pytest.mark.parametrize( + "query", + [ + "a", + "A", + "1", + "õ", + "я", + "a!", + "a?", + "1.", + "a...", + "!a!", + ], + ) + def test_too_short_queries_rejected(self, query): + """Queries with fewer than 2 meaningful characters should be rejected.""" + result = validate_query_basic(query) + assert result.is_valid is False + assert result.rejection_reason == "too_short" + + +class TestQueryValidatorRepetitive: + """Test queries with only repetitive characters.""" + + @pytest.mark.parametrize( + "query", + [ + "aa", + "AAA", + "aaa", + "aaaa", + "AAAAAAA", + "aAaAa", + "11", + "111", + "0000", + "99999", + "õõõõ", + "ääää", + "яяяя", + "aa!", + "!!!aaa!!!", + "a.a.a.a", + "___", # 3 underscores - repetitive (underscore is \w) + ], + ) + def test_repetitive_queries_rejected(self, query): + """Queries with only one unique meaningful character should be rejected.""" + result = validate_query_basic(query) + assert result.is_valid is False + assert result.rejection_reason == "repetitive" + + +class TestQueryValidatorValid: + """Test valid queries that should pass validation.""" + + @pytest.mark.parametrize( + "query", + [ + "hi", + "hello", + "ok", + "ab", + "AB", + "Hi", + "123", + "12", + "abc123", + "test1", + "How to apply?", + "What is this?", + "When?", + "tere", + "kuidas", + "Mis on?", + "привет", + "как дела", + "ab!", + "hello!", + "test...", + "what???", + "a-b", + "test_case", + "test123!", + "hello world", + "a b c", + "http://test", + "test@email", + "a1", + "12ab", + "õõte", + ], + ) + def test_valid_queries_accepted(self, query): + """Valid queries with meaningful content should be accepted.""" + result = validate_query_basic(query) + assert result.is_valid is True + assert result.rejection_reason is None + + +class TestQueryValidatorEdgeCases: + """Test edge cases and boundary conditions.""" + + def test_whitespace_trimmed(self): + """Leading and trailing whitespace should be trimmed before validation.""" + result = validate_query_basic(" hello ") + assert result.is_valid is True + + result = validate_query_basic(" ") + assert result.is_valid is False + assert result.rejection_reason == "empty" + + def test_case_insensitive_repetition(self): + """Repetition check should be case-insensitive.""" + result = validate_query_basic("AaAa") + assert result.is_valid is False + assert result.rejection_reason == "repetitive" + + result = validate_query_basic("AaAaAa") + assert result.is_valid is False + assert result.rejection_reason == "repetitive" + + def test_unicode_normalization(self): + """Unicode characters should be handled consistently.""" + result = validate_query_basic("привет") + assert result.is_valid is True + + result = validate_query_basic("你好") + assert result.is_valid is True + + result = validate_query_basic("مرحبا") + assert result.is_valid is True + + def test_mixed_scripts(self): + """Queries with mixed scripts should be valid.""" + result = validate_query_basic("hello мир") + assert result.is_valid is True + + result = validate_query_basic("test测试") + assert result.is_valid is True + + def test_numbers_are_valid(self): + """Numbers-only queries are considered valid.""" + result = validate_query_basic("123") + assert result.is_valid is True + + result = validate_query_basic("42") + assert result.is_valid is True + + result = validate_query_basic("2024") + assert result.is_valid is True + + def test_numbers_repetitive(self): + """Repetitive numbers should be rejected.""" + result = validate_query_basic("111") + assert result.is_valid is False + assert result.rejection_reason == "repetitive" + + result = validate_query_basic("00") + assert result.is_valid is False + assert result.rejection_reason == "repetitive" + + def test_punctuation_doesnt_count_as_meaningful(self): + """Punctuation should not count toward meaningful character count.""" + result = validate_query_basic("a!!!") + assert result.is_valid is False + assert result.rejection_reason == "too_short" + + result = validate_query_basic("ab!!!") + assert result.is_valid is True + + def test_emoji_with_text(self): + """Emojis combined with text should be valid.""" + result = validate_query_basic("hello world") + assert result.is_valid is True + + result = validate_query_basic("test case") + assert result.is_valid is True + + def test_long_repetitive_string(self): + """Long strings of repeated characters should be rejected.""" + result = validate_query_basic("a" * 100) + assert result.is_valid is False + assert result.rejection_reason == "repetitive" + + def test_result_is_pydantic_model(self): + """Result should be a valid Pydantic model.""" + result = validate_query_basic("test") + assert isinstance(result, QueryValidationResult) + assert hasattr(result, "is_valid") + assert hasattr(result, "rejection_reason") + + result_dict = result.model_dump() + assert "is_valid" in result_dict + assert "rejection_reason" in result_dict + + +class TestQueryValidatorDocumentedExamples: + """Test all examples from function docstring.""" + + def test_documented_valid_examples(self): + """All documented valid examples should pass.""" + examples = [ + "How to apply for benefits?", + "hi", + "123", + "ab!", + ] + for query in examples: + result = validate_query_basic(query) + assert result.is_valid is True, f"Expected '{query}' to be valid" + assert result.rejection_reason is None + + def test_documented_invalid_examples(self): + """All documented invalid examples should fail with correct reason.""" + examples = [ + ("...", "special_chars_only"), + ("", "empty"), + # Note: ???? is special_chars_only (not in \w), not repetitive + ("????", "special_chars_only"), + ("a", "too_short"), + ] + for query, expected_reason in examples: + result = validate_query_basic(query) + assert result.is_valid is False, f"Expected '{query}' to be invalid" + assert result.rejection_reason == expected_reason, ( + f"Expected '{query}' to fail with '{expected_reason}', " + f"got '{result.rejection_reason}'" + )