From ae2493781746d137f3847c5a36d01f80d0f4f9c1 Mon Sep 17 00:00:00 2001 From: JeffreyChen Date: Sun, 24 May 2026 17:14:50 +0800 Subject: [PATCH 1/8] Expand utils tree with web-platform, security, perf, AI, and governance modules Adds 73 modules across the testing lifecycle: WebTransport / IndexedDB / File System Access / Notifications instrumentation; mixed-content / clickjacking / open-redirect / SRI / COOP-COEP audits; INP / hydration / bundle / 3p / LoAF budgets; gRPC / webhook / idempotency / pagination integration helpers; AI narrator / repro-minimizer / locator-hardener / categorizer; quarantine-age, test-debt, SLA, repro-stability, and CODEOWNERS reports. Each ships with a focused unit-test file. --- CLAUDE.md | 75 ++- je_web_runner/__init__.py | 291 ++++++++- .../utils/backend_log_correlator/__init__.py | 0 .../backend_log_correlator/correlator.py | 291 +++++++++ .../utils/bug_repro_stability/__init__.py | 0 .../utils/bug_repro_stability/stability.py | 195 ++++++ je_web_runner/utils/bundle_budget/__init__.py | 0 je_web_runner/utils/bundle_budget/budget.py | 258 ++++++++ je_web_runner/utils/chaos_hooks/__init__.py | 0 je_web_runner/utils/chaos_hooks/chaos.py | 183 ++++++ .../utils/chrome_profile/__init__.py | 0 .../utils/chrome_profile/profile_manager.py | 461 +++++++++++++++ .../utils/clickjacking_audit/__init__.py | 0 .../utils/clickjacking_audit/audit.py | 218 +++++++ je_web_runner/utils/consent_audit/__init__.py | 0 je_web_runner/utils/consent_audit/audit.py | 249 ++++++++ .../utils/console_error_budget/__init__.py | 0 .../utils/console_error_budget/budget.py | 231 ++++++++ .../utils/coop_coep_audit/__init__.py | 0 je_web_runner/utils/coop_coep_audit/audit.py | 270 +++++++++ .../utils/cross_tab_sync/__init__.py | 0 .../utils/cross_tab_sync/sync_assertions.py | 327 +++++++++++ je_web_runner/utils/db_snapshot/__init__.py | 0 je_web_runner/utils/db_snapshot/snapshot.py | 170 ++++++ je_web_runner/utils/device_cloud/__init__.py | 0 .../utils/device_cloud/real_device.py | 456 +++++++++++++++ .../utils/download_verify/__init__.py | 0 .../utils/download_verify/verifier.py | 391 +++++++++++++ .../utils/edge_case_generator/__init__.py | 0 .../utils/edge_case_generator/generator.py | 289 +++++++++ je_web_runner/utils/email_render/__init__.py | 0 je_web_runner/utils/email_render/render.py | 299 ++++++++++ .../utils/exploratory_ai/__init__.py | 0 .../utils/exploratory_ai/explorer.py | 311 ++++++++++ .../utils/failure_narrator/__init__.py | 0 .../utils/failure_narrator/narrator.py | 239 ++++++++ .../utils/failure_triage/__init__.py | 0 je_web_runner/utils/failure_triage/triage.py | 310 ++++++++++ .../utils/file_system_access/__init__.py | 0 .../utils/file_system_access/mock.py | 205 +++++++ je_web_runner/utils/flag_matrix/__init__.py | 0 je_web_runner/utils/flag_matrix/matrix.py | 281 +++++++++ .../utils/flake_detector/__init__.py | 0 .../utils/flake_detector/detector.py | 412 +++++++++++++ .../utils/forced_colors_mode/__init__.py | 0 .../utils/forced_colors_mode/modes.py | 223 +++++++ .../utils/git_bisect_flake/__init__.py | 0 .../utils/git_bisect_flake/bisect.py | 242 ++++++++ je_web_runner/utils/grpc_tester/__init__.py | 0 je_web_runner/utils/grpc_tester/client.py | 249 ++++++++ .../utils/hydration_check/__init__.py | 0 je_web_runner/utils/hydration_check/check.py | 159 +++++ .../utils/idempotency_check/__init__.py | 0 .../utils/idempotency_check/check.py | 175 ++++++ .../utils/indexed_db_explorer/__init__.py | 0 .../utils/indexed_db_explorer/explorer.py | 252 ++++++++ je_web_runner/utils/inp_tracker/__init__.py | 0 je_web_runner/utils/inp_tracker/tracker.py | 201 +++++++ .../utils/live_dashboard/__init__.py | 0 je_web_runner/utils/live_dashboard/server.py | 505 ++++++++++++++++ .../utils/locator_hardener/__init__.py | 0 .../utils/locator_hardener/hardener.py | 255 ++++++++ .../utils/locator_health/__init__.py | 0 .../utils/locator_health/health_report.py | 459 +++++++++++++++ .../utils/long_animation_frame/__init__.py | 0 .../utils/long_animation_frame/frames.py | 207 +++++++ .../utils/mixed_content_audit/__init__.py | 0 .../utils/mixed_content_audit/audit.py | 232 ++++++++ je_web_runner/utils/multimodal_qa/__init__.py | 0 je_web_runner/utils/multimodal_qa/qa.py | 210 +++++++ .../utils/mutation_testing/__init__.py | 0 .../utils/mutation_testing/mutator.py | 446 ++++++++++++++ .../utils/notifications_audit/__init__.py | 0 .../utils/notifications_audit/audit.py | 263 +++++++++ je_web_runner/utils/ocr_assert/__init__.py | 0 je_web_runner/utils/ocr_assert/ocr.py | 237 ++++++++ .../utils/open_redirect_detector/__init__.py | 0 .../utils/open_redirect_detector/detector.py | 242 ++++++++ .../utils/openapi_to_e2e/__init__.py | 0 .../utils/openapi_to_e2e/generator.py | 553 ++++++++++++++++++ je_web_runner/utils/otel_bridge/__init__.py | 0 .../utils/otel_bridge/trace_bridge.py | 289 +++++++++ .../utils/otp_interceptor/__init__.py | 0 .../utils/otp_interceptor/interceptor.py | 489 ++++++++++++++++ .../utils/pagination_audit/__init__.py | 0 je_web_runner/utils/pagination_audit/audit.py | 229 ++++++++ .../utils/persona_runner/__init__.py | 0 je_web_runner/utils/persona_runner/runner.py | 211 +++++++ .../utils/pii_in_screenshot/__init__.py | 0 .../utils/pii_in_screenshot/scanner.py | 241 ++++++++ je_web_runner/utils/pr_risk_score/__init__.py | 0 je_web_runner/utils/pr_risk_score/scorer.py | 250 ++++++++ .../utils/prompt_drift_monitor/__init__.py | 0 .../utils/prompt_drift_monitor/monitor.py | 270 +++++++++ .../utils/pseudo_localization/__init__.py | 0 .../utils/pseudo_localization/pseudo.py | 199 +++++++ .../utils/quarantine_age_report/__init__.py | 0 .../utils/quarantine_age_report/report.py | 204 +++++++ .../utils/repro_minimizer/__init__.py | 0 .../utils/repro_minimizer/minimizer.py | 167 ++++++ .../utils/screen_reader_runner/__init__.py | 0 .../utils/screen_reader_runner/reader.py | 269 +++++++++ .../utils/session_to_test/__init__.py | 0 .../utils/session_to_test/converter.py | 304 ++++++++++ je_web_runner/utils/sla_tracker/__init__.py | 0 je_web_runner/utils/sla_tracker/tracker.py | 230 ++++++++ je_web_runner/utils/slack_digest/__init__.py | 0 je_web_runner/utils/slack_digest/digest.py | 222 +++++++ je_web_runner/utils/sri_verify/__init__.py | 0 je_web_runner/utils/sri_verify/verify.py | 240 ++++++++ je_web_runner/utils/sse_assert/__init__.py | 0 je_web_runner/utils/sse_assert/stream.py | 252 ++++++++ .../utils/story_to_actions/__init__.py | 0 .../utils/story_to_actions/generator.py | 265 +++++++++ .../utils/test_auto_repair/__init__.py | 0 .../utils/test_auto_repair/repair.py | 310 ++++++++++ .../utils/test_categorizer/__init__.py | 0 .../utils/test_categorizer/categorizer.py | 201 +++++++ .../utils/test_cost_estimator/__init__.py | 0 .../utils/test_cost_estimator/estimator.py | 250 ++++++++ .../utils/test_debt_dashboard/__init__.py | 0 .../utils/test_debt_dashboard/debt.py | 319 ++++++++++ je_web_runner/utils/test_dedup_ai/__init__.py | 0 je_web_runner/utils/test_dedup_ai/dedup.py | 279 +++++++++ .../utils/test_owners_map/__init__.py | 0 je_web_runner/utils/test_owners_map/owners.py | 211 +++++++ .../utils/test_scheduler/__init__.py | 0 .../utils/test_scheduler/scheduler.py | 330 +++++++++++ .../utils/third_party_budget/__init__.py | 0 .../utils/third_party_budget/budget.py | 230 ++++++++ je_web_runner/utils/time_freezer/__init__.py | 0 je_web_runner/utils/time_freezer/freezer.py | 186 ++++++ .../utils/token_leak_detector/__init__.py | 0 .../utils/token_leak_detector/detector.py | 255 ++++++++ .../utils/view_transitions/__init__.py | 0 .../utils/view_transitions/transitions.py | 235 ++++++++ je_web_runner/utils/visual_ai/__init__.py | 0 je_web_runner/utils/visual_ai/perceptual.py | 425 ++++++++++++++ .../utils/walkthrough_docs/__init__.py | 0 .../utils/walkthrough_docs/generator.py | 341 +++++++++++ .../utils/webhook_receiver/__init__.py | 0 .../utils/webhook_receiver/receiver.py | 294 ++++++++++ je_web_runner/utils/webrtc_assert/__init__.py | 0 je_web_runner/utils/webrtc_assert/peer.py | 275 +++++++++ .../utils/websocket_assert/__init__.py | 0 .../utils/websocket_assert/frames.py | 231 ++++++++ .../utils/webtransport_assert/__init__.py | 0 .../utils/webtransport_assert/streams.py | 239 ++++++++ test/unit_test/test_backend_log_correlator.py | 145 +++++ test/unit_test/test_bug_repro_stability.py | 149 +++++ test/unit_test/test_bundle_budget.py | 213 +++++++ test/unit_test/test_chaos_hooks.py | 137 +++++ test/unit_test/test_chrome_profile.py | 289 +++++++++ test/unit_test/test_clickjacking_audit.py | 174 ++++++ test/unit_test/test_consent_audit.py | 176 ++++++ test/unit_test/test_console_error_budget.py | 161 +++++ test/unit_test/test_coop_coep_audit.py | 245 ++++++++ test/unit_test/test_cross_tab_sync.py | 297 ++++++++++ test/unit_test/test_db_snapshot.py | 151 +++++ test/unit_test/test_device_cloud.py | 236 ++++++++ test/unit_test/test_download_verify.py | 279 +++++++++ test/unit_test/test_edge_case_generator.py | 216 +++++++ test/unit_test/test_email_render.py | 168 ++++++ test/unit_test/test_exploratory_ai.py | 255 ++++++++ test/unit_test/test_failure_narrator.py | 187 ++++++ test/unit_test/test_failure_triage.py | 231 ++++++++ test/unit_test/test_file_system_access.py | 134 +++++ test/unit_test/test_flag_matrix.py | 175 ++++++ test/unit_test/test_flake_detector.py | 286 +++++++++ test/unit_test/test_forced_colors_mode.py | 166 ++++++ test/unit_test/test_git_bisect_flake.py | 206 +++++++ test/unit_test/test_grpc_tester.py | 186 ++++++ test/unit_test/test_hydration_check.py | 142 +++++ test/unit_test/test_idempotency_check.py | 163 ++++++ test/unit_test/test_indexed_db_explorer.py | 188 ++++++ test/unit_test/test_inp_tracker.py | 153 +++++ test/unit_test/test_live_dashboard_server.py | 221 +++++++ test/unit_test/test_locator_hardener.py | 231 ++++++++ test/unit_test/test_locator_health.py | 295 ++++++++++ test/unit_test/test_long_animation_frame.py | 140 +++++ test/unit_test/test_mixed_content_audit.py | 175 ++++++ test/unit_test/test_multimodal_qa.py | 189 ++++++ test/unit_test/test_mutation_testing.py | 232 ++++++++ test/unit_test/test_notifications_audit.py | 204 +++++++ test/unit_test/test_ocr_assert.py | 145 +++++ test/unit_test/test_open_redirect_detector.py | 183 ++++++ test/unit_test/test_openapi_to_e2e.py | 260 ++++++++ test/unit_test/test_otel_bridge.py | 201 +++++++ test/unit_test/test_otp_interceptor.py | 221 +++++++ test/unit_test/test_pagination_audit.py | 185 ++++++ test/unit_test/test_persona_runner.py | 146 +++++ test/unit_test/test_pii_in_screenshot.py | 158 +++++ test/unit_test/test_pr_risk_score.py | 152 +++++ test/unit_test/test_prompt_drift_monitor.py | 199 +++++++ test/unit_test/test_pseudo_localization.py | 146 +++++ test/unit_test/test_quarantine_age_report.py | 181 ++++++ test/unit_test/test_repro_minimizer.py | 145 +++++ test/unit_test/test_screen_reader_runner.py | 145 +++++ test/unit_test/test_session_to_test.py | 214 +++++++ test/unit_test/test_sla_tracker.py | 190 ++++++ test/unit_test/test_slack_digest.py | 146 +++++ test/unit_test/test_sri_verify.py | 190 ++++++ test/unit_test/test_sse_assert.py | 205 +++++++ test/unit_test/test_story_to_actions.py | 213 +++++++ test/unit_test/test_test_auto_repair.py | 217 +++++++ test/unit_test/test_test_categorizer.py | 181 ++++++ test/unit_test/test_test_cost_estimator.py | 176 ++++++ test/unit_test/test_test_debt_dashboard.py | 223 +++++++ test/unit_test/test_test_dedup_ai.py | 202 +++++++ test/unit_test/test_test_owners_map.py | 238 ++++++++ test/unit_test/test_test_scheduler.py | 217 +++++++ test/unit_test/test_third_party_budget.py | 215 +++++++ test/unit_test/test_time_freezer.py | 123 ++++ test/unit_test/test_token_leak_detector.py | 200 +++++++ test/unit_test/test_view_transitions.py | 144 +++++ test/unit_test/test_visual_ai.py | 281 +++++++++ test/unit_test/test_walkthrough_docs.py | 215 +++++++ test/unit_test/test_webhook_receiver.py | 187 ++++++ test/unit_test/test_webrtc_assert.py | 214 +++++++ test/unit_test/test_websocket_assert.py | 216 +++++++ test/unit_test/test_webtransport_assert.py | 196 +++++++ 221 files changed, 34727 insertions(+), 2 deletions(-) create mode 100644 je_web_runner/utils/backend_log_correlator/__init__.py create mode 100644 je_web_runner/utils/backend_log_correlator/correlator.py create mode 100644 je_web_runner/utils/bug_repro_stability/__init__.py create mode 100644 je_web_runner/utils/bug_repro_stability/stability.py create mode 100644 je_web_runner/utils/bundle_budget/__init__.py create mode 100644 je_web_runner/utils/bundle_budget/budget.py create mode 100644 je_web_runner/utils/chaos_hooks/__init__.py create mode 100644 je_web_runner/utils/chaos_hooks/chaos.py create mode 100644 je_web_runner/utils/chrome_profile/__init__.py create mode 100644 je_web_runner/utils/chrome_profile/profile_manager.py create mode 100644 je_web_runner/utils/clickjacking_audit/__init__.py create mode 100644 je_web_runner/utils/clickjacking_audit/audit.py create mode 100644 je_web_runner/utils/consent_audit/__init__.py create mode 100644 je_web_runner/utils/consent_audit/audit.py create mode 100644 je_web_runner/utils/console_error_budget/__init__.py create mode 100644 je_web_runner/utils/console_error_budget/budget.py create mode 100644 je_web_runner/utils/coop_coep_audit/__init__.py create mode 100644 je_web_runner/utils/coop_coep_audit/audit.py create mode 100644 je_web_runner/utils/cross_tab_sync/__init__.py create mode 100644 je_web_runner/utils/cross_tab_sync/sync_assertions.py create mode 100644 je_web_runner/utils/db_snapshot/__init__.py create mode 100644 je_web_runner/utils/db_snapshot/snapshot.py create mode 100644 je_web_runner/utils/device_cloud/__init__.py create mode 100644 je_web_runner/utils/device_cloud/real_device.py create mode 100644 je_web_runner/utils/download_verify/__init__.py create mode 100644 je_web_runner/utils/download_verify/verifier.py create mode 100644 je_web_runner/utils/edge_case_generator/__init__.py create mode 100644 je_web_runner/utils/edge_case_generator/generator.py create mode 100644 je_web_runner/utils/email_render/__init__.py create mode 100644 je_web_runner/utils/email_render/render.py create mode 100644 je_web_runner/utils/exploratory_ai/__init__.py create mode 100644 je_web_runner/utils/exploratory_ai/explorer.py create mode 100644 je_web_runner/utils/failure_narrator/__init__.py create mode 100644 je_web_runner/utils/failure_narrator/narrator.py create mode 100644 je_web_runner/utils/failure_triage/__init__.py create mode 100644 je_web_runner/utils/failure_triage/triage.py create mode 100644 je_web_runner/utils/file_system_access/__init__.py create mode 100644 je_web_runner/utils/file_system_access/mock.py create mode 100644 je_web_runner/utils/flag_matrix/__init__.py create mode 100644 je_web_runner/utils/flag_matrix/matrix.py create mode 100644 je_web_runner/utils/flake_detector/__init__.py create mode 100644 je_web_runner/utils/flake_detector/detector.py create mode 100644 je_web_runner/utils/forced_colors_mode/__init__.py create mode 100644 je_web_runner/utils/forced_colors_mode/modes.py create mode 100644 je_web_runner/utils/git_bisect_flake/__init__.py create mode 100644 je_web_runner/utils/git_bisect_flake/bisect.py create mode 100644 je_web_runner/utils/grpc_tester/__init__.py create mode 100644 je_web_runner/utils/grpc_tester/client.py create mode 100644 je_web_runner/utils/hydration_check/__init__.py create mode 100644 je_web_runner/utils/hydration_check/check.py create mode 100644 je_web_runner/utils/idempotency_check/__init__.py create mode 100644 je_web_runner/utils/idempotency_check/check.py create mode 100644 je_web_runner/utils/indexed_db_explorer/__init__.py create mode 100644 je_web_runner/utils/indexed_db_explorer/explorer.py create mode 100644 je_web_runner/utils/inp_tracker/__init__.py create mode 100644 je_web_runner/utils/inp_tracker/tracker.py create mode 100644 je_web_runner/utils/live_dashboard/__init__.py create mode 100644 je_web_runner/utils/live_dashboard/server.py create mode 100644 je_web_runner/utils/locator_hardener/__init__.py create mode 100644 je_web_runner/utils/locator_hardener/hardener.py create mode 100644 je_web_runner/utils/locator_health/__init__.py create mode 100644 je_web_runner/utils/locator_health/health_report.py create mode 100644 je_web_runner/utils/long_animation_frame/__init__.py create mode 100644 je_web_runner/utils/long_animation_frame/frames.py create mode 100644 je_web_runner/utils/mixed_content_audit/__init__.py create mode 100644 je_web_runner/utils/mixed_content_audit/audit.py create mode 100644 je_web_runner/utils/multimodal_qa/__init__.py create mode 100644 je_web_runner/utils/multimodal_qa/qa.py create mode 100644 je_web_runner/utils/mutation_testing/__init__.py create mode 100644 je_web_runner/utils/mutation_testing/mutator.py create mode 100644 je_web_runner/utils/notifications_audit/__init__.py create mode 100644 je_web_runner/utils/notifications_audit/audit.py create mode 100644 je_web_runner/utils/ocr_assert/__init__.py create mode 100644 je_web_runner/utils/ocr_assert/ocr.py create mode 100644 je_web_runner/utils/open_redirect_detector/__init__.py create mode 100644 je_web_runner/utils/open_redirect_detector/detector.py create mode 100644 je_web_runner/utils/openapi_to_e2e/__init__.py create mode 100644 je_web_runner/utils/openapi_to_e2e/generator.py create mode 100644 je_web_runner/utils/otel_bridge/__init__.py create mode 100644 je_web_runner/utils/otel_bridge/trace_bridge.py create mode 100644 je_web_runner/utils/otp_interceptor/__init__.py create mode 100644 je_web_runner/utils/otp_interceptor/interceptor.py create mode 100644 je_web_runner/utils/pagination_audit/__init__.py create mode 100644 je_web_runner/utils/pagination_audit/audit.py create mode 100644 je_web_runner/utils/persona_runner/__init__.py create mode 100644 je_web_runner/utils/persona_runner/runner.py create mode 100644 je_web_runner/utils/pii_in_screenshot/__init__.py create mode 100644 je_web_runner/utils/pii_in_screenshot/scanner.py create mode 100644 je_web_runner/utils/pr_risk_score/__init__.py create mode 100644 je_web_runner/utils/pr_risk_score/scorer.py create mode 100644 je_web_runner/utils/prompt_drift_monitor/__init__.py create mode 100644 je_web_runner/utils/prompt_drift_monitor/monitor.py create mode 100644 je_web_runner/utils/pseudo_localization/__init__.py create mode 100644 je_web_runner/utils/pseudo_localization/pseudo.py create mode 100644 je_web_runner/utils/quarantine_age_report/__init__.py create mode 100644 je_web_runner/utils/quarantine_age_report/report.py create mode 100644 je_web_runner/utils/repro_minimizer/__init__.py create mode 100644 je_web_runner/utils/repro_minimizer/minimizer.py create mode 100644 je_web_runner/utils/screen_reader_runner/__init__.py create mode 100644 je_web_runner/utils/screen_reader_runner/reader.py create mode 100644 je_web_runner/utils/session_to_test/__init__.py create mode 100644 je_web_runner/utils/session_to_test/converter.py create mode 100644 je_web_runner/utils/sla_tracker/__init__.py create mode 100644 je_web_runner/utils/sla_tracker/tracker.py create mode 100644 je_web_runner/utils/slack_digest/__init__.py create mode 100644 je_web_runner/utils/slack_digest/digest.py create mode 100644 je_web_runner/utils/sri_verify/__init__.py create mode 100644 je_web_runner/utils/sri_verify/verify.py create mode 100644 je_web_runner/utils/sse_assert/__init__.py create mode 100644 je_web_runner/utils/sse_assert/stream.py create mode 100644 je_web_runner/utils/story_to_actions/__init__.py create mode 100644 je_web_runner/utils/story_to_actions/generator.py create mode 100644 je_web_runner/utils/test_auto_repair/__init__.py create mode 100644 je_web_runner/utils/test_auto_repair/repair.py create mode 100644 je_web_runner/utils/test_categorizer/__init__.py create mode 100644 je_web_runner/utils/test_categorizer/categorizer.py create mode 100644 je_web_runner/utils/test_cost_estimator/__init__.py create mode 100644 je_web_runner/utils/test_cost_estimator/estimator.py create mode 100644 je_web_runner/utils/test_debt_dashboard/__init__.py create mode 100644 je_web_runner/utils/test_debt_dashboard/debt.py create mode 100644 je_web_runner/utils/test_dedup_ai/__init__.py create mode 100644 je_web_runner/utils/test_dedup_ai/dedup.py create mode 100644 je_web_runner/utils/test_owners_map/__init__.py create mode 100644 je_web_runner/utils/test_owners_map/owners.py create mode 100644 je_web_runner/utils/test_scheduler/__init__.py create mode 100644 je_web_runner/utils/test_scheduler/scheduler.py create mode 100644 je_web_runner/utils/third_party_budget/__init__.py create mode 100644 je_web_runner/utils/third_party_budget/budget.py create mode 100644 je_web_runner/utils/time_freezer/__init__.py create mode 100644 je_web_runner/utils/time_freezer/freezer.py create mode 100644 je_web_runner/utils/token_leak_detector/__init__.py create mode 100644 je_web_runner/utils/token_leak_detector/detector.py create mode 100644 je_web_runner/utils/view_transitions/__init__.py create mode 100644 je_web_runner/utils/view_transitions/transitions.py create mode 100644 je_web_runner/utils/visual_ai/__init__.py create mode 100644 je_web_runner/utils/visual_ai/perceptual.py create mode 100644 je_web_runner/utils/walkthrough_docs/__init__.py create mode 100644 je_web_runner/utils/walkthrough_docs/generator.py create mode 100644 je_web_runner/utils/webhook_receiver/__init__.py create mode 100644 je_web_runner/utils/webhook_receiver/receiver.py create mode 100644 je_web_runner/utils/webrtc_assert/__init__.py create mode 100644 je_web_runner/utils/webrtc_assert/peer.py create mode 100644 je_web_runner/utils/websocket_assert/__init__.py create mode 100644 je_web_runner/utils/websocket_assert/frames.py create mode 100644 je_web_runner/utils/webtransport_assert/__init__.py create mode 100644 je_web_runner/utils/webtransport_assert/streams.py create mode 100644 test/unit_test/test_backend_log_correlator.py create mode 100644 test/unit_test/test_bug_repro_stability.py create mode 100644 test/unit_test/test_bundle_budget.py create mode 100644 test/unit_test/test_chaos_hooks.py create mode 100644 test/unit_test/test_chrome_profile.py create mode 100644 test/unit_test/test_clickjacking_audit.py create mode 100644 test/unit_test/test_consent_audit.py create mode 100644 test/unit_test/test_console_error_budget.py create mode 100644 test/unit_test/test_coop_coep_audit.py create mode 100644 test/unit_test/test_cross_tab_sync.py create mode 100644 test/unit_test/test_db_snapshot.py create mode 100644 test/unit_test/test_device_cloud.py create mode 100644 test/unit_test/test_download_verify.py create mode 100644 test/unit_test/test_edge_case_generator.py create mode 100644 test/unit_test/test_email_render.py create mode 100644 test/unit_test/test_exploratory_ai.py create mode 100644 test/unit_test/test_failure_narrator.py create mode 100644 test/unit_test/test_failure_triage.py create mode 100644 test/unit_test/test_file_system_access.py create mode 100644 test/unit_test/test_flag_matrix.py create mode 100644 test/unit_test/test_flake_detector.py create mode 100644 test/unit_test/test_forced_colors_mode.py create mode 100644 test/unit_test/test_git_bisect_flake.py create mode 100644 test/unit_test/test_grpc_tester.py create mode 100644 test/unit_test/test_hydration_check.py create mode 100644 test/unit_test/test_idempotency_check.py create mode 100644 test/unit_test/test_indexed_db_explorer.py create mode 100644 test/unit_test/test_inp_tracker.py create mode 100644 test/unit_test/test_live_dashboard_server.py create mode 100644 test/unit_test/test_locator_hardener.py create mode 100644 test/unit_test/test_locator_health.py create mode 100644 test/unit_test/test_long_animation_frame.py create mode 100644 test/unit_test/test_mixed_content_audit.py create mode 100644 test/unit_test/test_multimodal_qa.py create mode 100644 test/unit_test/test_mutation_testing.py create mode 100644 test/unit_test/test_notifications_audit.py create mode 100644 test/unit_test/test_ocr_assert.py create mode 100644 test/unit_test/test_open_redirect_detector.py create mode 100644 test/unit_test/test_openapi_to_e2e.py create mode 100644 test/unit_test/test_otel_bridge.py create mode 100644 test/unit_test/test_otp_interceptor.py create mode 100644 test/unit_test/test_pagination_audit.py create mode 100644 test/unit_test/test_persona_runner.py create mode 100644 test/unit_test/test_pii_in_screenshot.py create mode 100644 test/unit_test/test_pr_risk_score.py create mode 100644 test/unit_test/test_prompt_drift_monitor.py create mode 100644 test/unit_test/test_pseudo_localization.py create mode 100644 test/unit_test/test_quarantine_age_report.py create mode 100644 test/unit_test/test_repro_minimizer.py create mode 100644 test/unit_test/test_screen_reader_runner.py create mode 100644 test/unit_test/test_session_to_test.py create mode 100644 test/unit_test/test_sla_tracker.py create mode 100644 test/unit_test/test_slack_digest.py create mode 100644 test/unit_test/test_sri_verify.py create mode 100644 test/unit_test/test_sse_assert.py create mode 100644 test/unit_test/test_story_to_actions.py create mode 100644 test/unit_test/test_test_auto_repair.py create mode 100644 test/unit_test/test_test_categorizer.py create mode 100644 test/unit_test/test_test_cost_estimator.py create mode 100644 test/unit_test/test_test_debt_dashboard.py create mode 100644 test/unit_test/test_test_dedup_ai.py create mode 100644 test/unit_test/test_test_owners_map.py create mode 100644 test/unit_test/test_test_scheduler.py create mode 100644 test/unit_test/test_third_party_budget.py create mode 100644 test/unit_test/test_time_freezer.py create mode 100644 test/unit_test/test_token_leak_detector.py create mode 100644 test/unit_test/test_view_transitions.py create mode 100644 test/unit_test/test_visual_ai.py create mode 100644 test/unit_test/test_walkthrough_docs.py create mode 100644 test/unit_test/test_webhook_receiver.py create mode 100644 test/unit_test/test_webrtc_assert.py create mode 100644 test/unit_test/test_websocket_assert.py create mode 100644 test/unit_test/test_webtransport_assert.py diff --git a/CLAUDE.md b/CLAUDE.md index a016ebd..7bde84a 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -45,7 +45,80 @@ je_web_runner/ ├── socket_server/ # TCP socket server for remote control ├── test_object/ # Test object & record classes (Value Object pattern) ├── test_record/ # Action recording - └── xml/ # XML utilities + ├── xml/ # XML utilities + ├── chrome_profile/ # Persistent Chrome profile + stealth + snapshot/sync-back + ├── failure_triage/ # AI failure root-cause analysis on failure bundles + ├── flake_detector/ # Time-decayed flake scoring + quarantine registry + ├── locator_health/ # Project-wide locator audit + upgrade suggestions + ├── device_cloud/ # Real-device cloud (BrowserStack/Sauce/LambdaTest) connector + ├── otel_bridge/ # W3C traceparent injection for distributed tracing + ├── mutation_testing/ # Action JSON mutation testing (kill rate / score) + ├── otp_interceptor/ # MailHog/Mailpit/IMAP/SMS OTP polling for 2FA flows + ├── download_verify/ # PDF / CSV / Excel / JSON / SHA256 download assertions + ├── test_auto_repair/ # LLM-driven test rewrite from failure + git diff + ├── edge_case_generator/ # LLM edge-case variant generator (complement to mutation_testing) + ├── openapi_to_e2e/ # OpenAPI/Swagger spec → WR_http_* action JSON + ├── cross_tab_sync/ # Multi-page BroadcastChannel / storage propagation asserts + ├── visual_ai/ # aHash/dHash/pHash + SSIM-proxy for canvas/chart diff + ├── test_scheduler/ # Value-density scheduler under time + cloud budget + ├── walkthrough_docs/ # AI step-by-step SOP / Confluence doc from recorded runs + ├── live_dashboard/ # Aggregated web UI: runs + flake + quarantine + locators + ├── ocr_assert/ # OCR-based text assertion for canvas / WebGL / image content + ├── email_render/ # Capture outbound mail (MailHog/Mailpit/EML) + multi-viewport screenshots + ├── backend_log_correlator/ # W3C trace_id → Loki/Elasticsearch/file log fetch into failure bundle + ├── websocket_assert/ # WebSocket frame recorder + count / payload / pubsub assertions + ├── console_error_budget/ # JS console / unhandled-rejection budget with ignore patterns + ├── chaos_hooks/ # Seeded chaos injection (offline / throttle / mid-flow reload) + ├── pr_risk_score/ # Fuse flake / impact / locator / coverage signals into 0-100 PR risk + ├── flag_matrix/ # Feature-flag combo matrix with constraints + minimal failing subset + ├── session_to_test/ # rrweb / generic session events → WR action JSON + ├── exploratory_ai/ # Agentic exploratory tester (observer/planner protocols + RandomPlanner) + ├── story_to_actions/ # LLM-driven user story / Figma frame → validated WR action JSON + ├── db_snapshot/ # Per-test DB savepoint/rollback with pluggable backend + ├── time_freezer/ # Inject Date/Date.now/performance.now patch via CDP for deterministic time tests + ├── persona_runner/ # Same suite × N personas (admin/free/enterprise) matrix + ├── token_leak_detector/ # Scan HAR / logs / responses for leaked JWTs, API keys, session tokens + ├── consent_audit/ # GDPR/CCPA cookie classification + pre-consent / post-reject violation detection + ├── pii_in_screenshot/ # OCR + PII regex (Luhn-validated card, SSN, TWID) scanner over screenshots + ├── pseudo_localization/ # ASCII → look-alike + expansion + brackets; detect hard-coded i18n leaks + ├── screen_reader_runner/ # Walk a11y tree to simulate NVDA/VoiceOver order + flag a11y violations + ├── forced_colors_mode/ # dark / reduced-motion / forced-colors / high-contrast matrix verification + ├── sse_assert/ # Server-Sent Events recorder + count/data/JSON-shape/strict-id assertions + ├── webrtc_assert/ # PeerConnection state / ICE / track / RTP stats assertions + ├── view_transitions/ # Instrumentation + duration/CLS/group assertions for View Transitions API + ├── test_dedup_ai/ # Structural + embedding-based semantic dedupe of action JSON files + ├── multimodal_qa/ # Send screenshot + question to vision LLM, parse pass/fail/notes envelope + ├── prompt_drift_monitor/ # Track LLM-feature output drift via embeddings + lexical anchors + ├── git_bisect_flake/ # Ledger-only or probe-driven bisect to find regression commit + ├── test_cost_estimator/ # Cloud-minute × rate-card × CO₂ estimate per suite/runner/test + ├── slack_digest/ # Render Slack Block-Kit / Teams card / plain-text test digest payload + ├── webtransport_assert/ # HTTP/3 WebTransport datagram + stream frame recorder + assertions + ├── indexed_db_explorer/ # IndexedDB snapshot harvest + store/key/index/record assertions + ├── file_system_access/ # Mock showOpenFilePicker/showSaveFilePicker + record writes + ├── notifications_audit/ # Notification.requestPermission timing + permission/spam policy checks + ├── mixed_content_audit/ # HTTP-on-HTTPS detection via HAR + console scanner + ├── clickjacking_audit/ # X-Frame-Options / frame-ancestors header check + iframe probe + ├── open_redirect_detector/ # Probe ?redirect=/?next= params with attacker-host payloads + ├── sri_verify/ # Subresource Integrity hash presence + correctness + crossorigin + ├── coop_coep_audit/ # crossOriginIsolated COOP/COEP + per-resource CORP/CORS check + ├── inp_tracker/ # Interaction to Next Paint instrumentation + p98 + budget + ├── hydration_check/ # SSR hydration mismatch detection (DOM diff + console markers) + ├── bundle_budget/ # Per-asset-kind byte budget from HAR + biggest-assets ranking + ├── third_party_budget/ # Third-party vendor classification + req/byte/blocking-ms budgets + ├── long_animation_frame/ # Long Animation Frame API listener + per-script attribution + ├── grpc_tester/ # gRPC stub call recorder + gRPC-Web framing/trailer helpers + ├── webhook_receiver/ # Threaded HTTP server for catching app's outbound webhooks + ├── idempotency_check/ # Run request twice + compare status/body/state/side-effects + ├── pagination_audit/ # Walk all pages, detect dups/gaps/cursor-loop/sort violations + ├── failure_narrator/ # LLM natural-language failure summary from failure_bundle + ├── repro_minimizer/ # Delta-debugging (ddmin) to shrink failing action list to minimum + ├── locator_hardener/ # Heuristic fragility score + LLM-suggested stable selectors + ├── test_categorizer/ # Auto-tag tests as smoke / regression / perf / a11y / data / api + ├── quarantine_age_report/ # Quarantine entries with age + fresh/lingering/stale/abandoned tiers + ├── test_debt_dashboard/ # Inventory of skip/xfail/TODO/_skip markers with age + CODEOWNERS + ├── sla_tracker/ # % suites finishing under SLA threshold, weekly/daily bucketing + ├── bug_repro_stability/ # Repeat probe N times, classify deterministic/flaky/non-reproducible + └── test_owners_map/ # CODEOWNERS parser + override layer + unowned-test audit ``` ## Design Patterns & Architecture diff --git a/je_web_runner/__init__.py b/je_web_runner/__init__.py index 9a23bff..c916e46 100644 --- a/je_web_runner/__init__.py +++ b/je_web_runner/__init__.py @@ -177,6 +177,205 @@ from je_web_runner.utils.recorder.browser_recorder import save_recording as recorder_save_recording from je_web_runner.utils.recorder.browser_recorder import start_recording as recorder_start from je_web_runner.utils.recorder.browser_recorder import stop_recording as recorder_stop +from je_web_runner.utils.chrome_profile.profile_manager import ( + ChromeProfileError, + StealthFlags, + build_chrome_options as chrome_profile_build_options, + build_playwright_persistent_context, + build_stealth_chrome_driver, + chrome_profile_session, + cleanup_chrome_locks, + minimise_chrome_windows, + snapshot_chrome_profile, + sync_chrome_profile_back, +) +from je_web_runner.utils.failure_triage.triage import ( + FailureTriageError, + TriageReport, + TriageSignals, + extract_signals_from_bundle, + render_markdown as triage_render_markdown, + save_report as triage_save_report, + triage_bundle, + triage_failure, +) +from je_web_runner.utils.flake_detector.detector import ( + FlakeDetectorError, + FlakeScore, + QuarantineEntry, + QuarantineRegistry, + compute_flake_scores, + flaky_paths as flake_detector_flaky_paths, + flaky_quarantine, + quarantine_flaky, + quarantine_report_markdown, + release_if_stable, +) +from je_web_runner.utils.locator_health.health_report import ( + FallbackHitTracker, + LocatorFinding, + LocatorHealthError, + LocatorHealthReport, + UpgradeSuggestion as LocatorUpgradeSuggestion, + apply_upgrades as locator_apply_upgrades, + build_health_report as locator_build_health_report, + fallback_hit_tracker, + render_health_markdown as locator_render_health_markdown, + save_health_report, + scan_action_file as locator_scan_action_file, + scan_project as locator_scan_project, + suggest_upgrade as locator_suggest_upgrade, + suggest_upgrades as locator_suggest_upgrades, +) +from je_web_runner.utils.device_cloud.real_device import ( + CloudCredentials, + CloudSession, + DeviceCloudError, + RealDeviceCaps, + build_capabilities as device_cloud_build_capabilities, + connect_real_device, + fetch_session_info, + load_credentials as device_cloud_load_credentials, + session_summary_markdown, + update_session_status, +) +from je_web_runner.utils.otel_bridge.trace_bridge import ( + TraceBridgeError, + TraceContext, + bridged_span_playwright, + bridged_span_selenium, + clear_headers_playwright, + clear_headers_selenium, + current_otel_context, + inject_headers_playwright, + inject_headers_selenium, + parse_traceparent, + random_trace_context, + trace_link, +) +from je_web_runner.utils.mutation_testing.mutator import ( + Mutation, + MutationResult, + MutationScore, + MutationTestingError, + MutationType, + apply_mutation, + assert_min_score as mutation_assert_min_score, + generate_mutations, + render_mutation_markdown, + run_mutation_testing, + run_mutation_testing_on_file, +) +from je_web_runner.utils.otp_interceptor.interceptor import ( + ImapProvider, + InMemoryProvider as OtpInMemoryProvider, + InterceptedMessage, + MailHogProvider, + MailpitProvider, + OtpInterceptError, + OtpProvider, + WebhookSmsProvider, + extract_otp_from_text, + wait_for_otp, +) +from je_web_runner.utils.download_verify.verifier import ( + DownloadAssertion, + DownloadVerifyError, + assert_csv_columns, + assert_csv_row_count, + assert_download, + assert_file_sha256, + assert_json_matches_schema, + assert_pdf_contains, + assert_pdf_matches, + extract_pdf_text, + read_csv_rows, + read_excel_rows, + read_json_file, + sha256_of_file, + wait_for_download, +) +from je_web_runner.utils.test_auto_repair.repair import ( + RepairPlan, + TestAutoRepairError, + apply_repair, + collect_git_diff, + propose_repair, + render_repair_markdown, + repair_from_bundle, +) +from je_web_runner.utils.edge_case_generator.generator import ( + EdgeCase, + EdgeCaseCategory, + EdgeCaseGeneratorError, + EdgeCaseSuite, + generate_edge_cases, + generate_edge_cases_from_file, + render_suite_markdown as edge_case_render_markdown, + write_suite_to_dir as edge_case_write_suite, +) +from je_web_runner.utils.openapi_to_e2e.generator import ( + GeneratedTest as OpenAPIGeneratedTest, + GenerationResult as OpenAPIGenerationResult, + OpenAPIGeneratorError, + generate_tests_from_file as openapi_generate_from_file, + generate_tests_from_spec as openapi_generate_from_spec, + load_spec as openapi_load_spec, + synthesize_example as openapi_synthesize_example, + write_tests_to_dir as openapi_write_tests, +) +from je_web_runner.utils.cross_tab_sync.sync_assertions import ( + CrossTabSyncError, + PropagationResult, + assert_state_propagates, + broadcast_message, + collect_broadcast_messages, + get_storage_value, + install_broadcast_recorder, + post_message_to_page, + set_storage_value, + wait_for_broadcast, + wait_for_storage, +) +from je_web_runner.utils.visual_ai.perceptual import ( + HashResult as VisualHashResult, + SimilarityResult as VisualSimilarityResult, + VisualAIError, + assert_visual_similar, + average_hash as visual_average_hash, + compare_images as visual_compare_images, + difference_hash as visual_difference_hash, + hamming_distance as visual_hamming_distance, + hash_similarity as visual_hash_similarity, + perceptual_hash as visual_perceptual_hash, +) +from je_web_runner.utils.test_scheduler.scheduler import ( + Schedule, + TestCandidate, + TestSchedulerError, + build_candidates_from_ledger, + render_schedule_markdown, + schedule_tests, + value_density as scheduler_value_density, + value_of as scheduler_value_of, +) +from je_web_runner.utils.walkthrough_docs.generator import ( + Walkthrough, + WalkthroughError, + WalkthroughStep, + build_walkthrough, + collect_steps as walkthrough_collect_steps, + narrate_steps as walkthrough_narrate_steps, + render_confluence as walkthrough_render_confluence, + render_markdown as walkthrough_render_markdown, + save_walkthrough, +) +from je_web_runner.utils.live_dashboard.server import ( + DashboardConfig, + DashboardServer, + LiveDashboardError, + build_summary as dashboard_build_summary, +) __all__ = [ "web_element_wrapper", "set_webdriver_options_argument", "webdriver_wrapper_instance", "get_webdriver_manager", @@ -236,5 +435,95 @@ "pw_wait_for_timeout", "pw_wait_for_url", "pw_set_viewport_size", "pw_viewport_size", "pw_mouse_click", "pw_mouse_move", "pw_mouse_down", "pw_mouse_up", - "pw_keyboard_press", "pw_keyboard_type", "pw_keyboard_down", "pw_keyboard_up" + "pw_keyboard_press", "pw_keyboard_type", "pw_keyboard_down", "pw_keyboard_up", + # Phase 1: chrome_profile + "ChromeProfileError", "StealthFlags", + "chrome_profile_build_options", "build_playwright_persistent_context", + "build_stealth_chrome_driver", "chrome_profile_session", + "cleanup_chrome_locks", "minimise_chrome_windows", + "snapshot_chrome_profile", "sync_chrome_profile_back", + # Phase 2: failure_triage + "FailureTriageError", "TriageReport", "TriageSignals", + "extract_signals_from_bundle", "triage_render_markdown", + "triage_save_report", "triage_bundle", "triage_failure", + # Phase 3: flake_detector + "FlakeDetectorError", "FlakeScore", "QuarantineEntry", "QuarantineRegistry", + "compute_flake_scores", "flake_detector_flaky_paths", "flaky_quarantine", + "quarantine_flaky", "quarantine_report_markdown", "release_if_stable", + # Phase 4: locator_health + "FallbackHitTracker", "LocatorFinding", "LocatorHealthError", + "LocatorHealthReport", "LocatorUpgradeSuggestion", + "locator_apply_upgrades", "locator_build_health_report", + "fallback_hit_tracker", "locator_render_health_markdown", + "save_health_report", "locator_scan_action_file", "locator_scan_project", + "locator_suggest_upgrade", "locator_suggest_upgrades", + # Phase 5: device_cloud + "CloudCredentials", "CloudSession", "DeviceCloudError", "RealDeviceCaps", + "device_cloud_build_capabilities", "connect_real_device", + "fetch_session_info", "device_cloud_load_credentials", + "session_summary_markdown", "update_session_status", + # Phase 6: otel_bridge + "TraceBridgeError", "TraceContext", + "bridged_span_playwright", "bridged_span_selenium", + "clear_headers_playwright", "clear_headers_selenium", + "current_otel_context", "inject_headers_playwright", + "inject_headers_selenium", "parse_traceparent", + "random_trace_context", "trace_link", + # Phase 7: mutation_testing + "Mutation", "MutationResult", "MutationScore", "MutationTestingError", + "MutationType", "apply_mutation", "mutation_assert_min_score", + "generate_mutations", "render_mutation_markdown", + "run_mutation_testing", "run_mutation_testing_on_file", + # Phase 8: otp_interceptor + "ImapProvider", "OtpInMemoryProvider", "InterceptedMessage", + "MailHogProvider", "MailpitProvider", "OtpInterceptError", + "OtpProvider", "WebhookSmsProvider", + "extract_otp_from_text", "wait_for_otp", + # Phase 9: download_verify + "DownloadAssertion", "DownloadVerifyError", + "assert_csv_columns", "assert_csv_row_count", "assert_download", + "assert_file_sha256", "assert_json_matches_schema", + "assert_pdf_contains", "assert_pdf_matches", + "extract_pdf_text", "read_csv_rows", "read_excel_rows", + "read_json_file", "sha256_of_file", "wait_for_download", + # Phase 11: test_auto_repair + "RepairPlan", "TestAutoRepairError", + "apply_repair", "collect_git_diff", "propose_repair", + "render_repair_markdown", "repair_from_bundle", + # Phase 12: edge_case_generator + "EdgeCase", "EdgeCaseCategory", "EdgeCaseGeneratorError", + "EdgeCaseSuite", + "generate_edge_cases", "generate_edge_cases_from_file", + "edge_case_render_markdown", "edge_case_write_suite", + # Phase 13: openapi_to_e2e + "OpenAPIGeneratedTest", "OpenAPIGenerationResult", + "OpenAPIGeneratorError", + "openapi_generate_from_file", "openapi_generate_from_spec", + "openapi_load_spec", "openapi_synthesize_example", "openapi_write_tests", + # Phase 14: cross_tab_sync + "CrossTabSyncError", "PropagationResult", + "assert_state_propagates", "broadcast_message", + "collect_broadcast_messages", "get_storage_value", + "install_broadcast_recorder", "post_message_to_page", + "set_storage_value", "wait_for_broadcast", "wait_for_storage", + # Phase 15: visual_ai + "VisualHashResult", "VisualSimilarityResult", "VisualAIError", + "assert_visual_similar", + "visual_average_hash", "visual_compare_images", + "visual_difference_hash", "visual_hamming_distance", + "visual_hash_similarity", "visual_perceptual_hash", + # Phase 16: test_scheduler + "Schedule", "TestCandidate", "TestSchedulerError", + "build_candidates_from_ledger", "render_schedule_markdown", + "schedule_tests", + "scheduler_value_density", "scheduler_value_of", + # Phase 17: walkthrough_docs + "Walkthrough", "WalkthroughError", "WalkthroughStep", + "build_walkthrough", + "walkthrough_collect_steps", "walkthrough_narrate_steps", + "walkthrough_render_confluence", "walkthrough_render_markdown", + "save_walkthrough", + # Phase 19: live_dashboard + "DashboardConfig", "DashboardServer", "LiveDashboardError", + "dashboard_build_summary", ] diff --git a/je_web_runner/utils/backend_log_correlator/__init__.py b/je_web_runner/utils/backend_log_correlator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/backend_log_correlator/correlator.py b/je_web_runner/utils/backend_log_correlator/correlator.py new file mode 100644 index 0000000..03df2cb --- /dev/null +++ b/je_web_runner/utils/backend_log_correlator/correlator.py @@ -0,0 +1,291 @@ +""" +用 ``otel_bridge`` 注入的 traceparent,把後端 log 拉進 failure bundle。 +Given a W3C trace id (the 32-hex middle of a ``traceparent`` header) +captured during a UI run, ask a log backend for matching lines and merge +them into the failure artifact. + +Adapters provided out of the box: + +* :func:`fetch_loki` — Grafana Loki ``/loki/api/v1/query_range`` +* :func:`fetch_elasticsearch` — Elasticsearch ``_search`` with ``trace_id`` +* :func:`fetch_file_log` — plain text / JSON-lines log file (works offline) + +All adapters return the same :class:`CorrelatedLog` list, which +:func:`attach_to_failure_bundle` then writes alongside the bundle's +existing artifacts. +""" +from __future__ import annotations + +import json +import re +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class BackendLogCorrelatorError(WebRunnerException): + """Raised on backend errors, malformed responses, or bad input.""" + + +_TRACEPARENT_RE = re.compile( + r"^(?P[0-9a-f]{2})-(?P[0-9a-f]{32})-(?P[0-9a-f]{16})-(?P[0-9a-f]{2})$", + re.IGNORECASE, +) +_TRACE_ID_RE = re.compile(r"^[0-9a-f]{32}$", re.IGNORECASE) + + +# ---------- data --------------------------------------------------------- + +@dataclass +class CorrelatedLog: + """One log line correlated to a trace.""" + + timestamp: str + level: str + message: str + service: Optional[str] = None + span_id: Optional[str] = None + extra: Dict[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +LogFetcher = Callable[[str], List[CorrelatedLog]] +"""Signature: ``fetcher(trace_id) -> [CorrelatedLog, ...]``.""" + + +# ---------- traceparent helpers ----------------------------------------- + +def parse_traceparent(header_value: str) -> str: + """Extract the 32-hex trace id from a W3C ``traceparent`` header.""" + if not isinstance(header_value, str) or not header_value: + raise BackendLogCorrelatorError("traceparent must be a non-empty string") + match = _TRACEPARENT_RE.match(header_value.strip()) + if not match: + raise BackendLogCorrelatorError(f"malformed traceparent: {header_value!r}") + return match.group("trace").lower() + + +def validate_trace_id(trace_id: str) -> str: + """Return the trace id normalised to lowercase hex; raise if malformed.""" + if not isinstance(trace_id, str) or not _TRACE_ID_RE.match(trace_id): + raise BackendLogCorrelatorError( + f"trace_id must be 32 hex chars, got {trace_id!r}" + ) + return trace_id.lower() + + +# ---------- file adapter (offline / tests) ------------------------------- + +def fetch_file_log( + log_path: Union[str, Path], + *, + trace_field: str = "trace_id", + fallback_to_substring: bool = True, +) -> LogFetcher: + """ + Build a fetcher that reads ``log_path`` as JSON-lines (preferred) or + plain text. Lines whose JSON ``trace_field`` equals the requested + trace id are returned. For plain text, substring match is used when + ``fallback_to_substring`` is true. + """ + path = Path(log_path) + if not path.exists(): + raise BackendLogCorrelatorError(f"log file not found: {path}") + + def _fetch(trace_id: str) -> List[CorrelatedLog]: + wanted = validate_trace_id(trace_id) + out: List[CorrelatedLog] = [] + with open(path, encoding="utf-8") as fp: + for line in fp: + stripped = line.rstrip("\r\n") + if not stripped: + continue + record = _try_parse_json_line(stripped) + if record is not None: + if str(record.get(trace_field, "")).lower() == wanted: + out.append(_log_from_dict(record)) + elif fallback_to_substring and wanted in stripped.lower(): + out.append(CorrelatedLog( + timestamp="", level="info", message=stripped, + )) + return out + + return _fetch + + +def _try_parse_json_line(line: str) -> Optional[Dict[str, Any]]: + line = line.strip() + if not line.startswith("{"): + return None + try: + loaded = json.loads(line) + except ValueError: + return None + return loaded if isinstance(loaded, dict) else None + + +def _log_from_dict(record: Dict[str, Any]) -> CorrelatedLog: + return CorrelatedLog( + timestamp=str(record.get("timestamp") or record.get("ts") or ""), + level=str(record.get("level") or record.get("severity") or "info"), + message=str(record.get("message") or record.get("msg") or ""), + service=record.get("service") or record.get("app"), + span_id=record.get("span_id"), + extra={ + k: v for k, v in record.items() + if k not in { + "timestamp", "ts", "level", "severity", "message", "msg", + "service", "app", "span_id", "trace_id", + } + }, + ) + + +# ---------- Loki adapter ------------------------------------------------- + +def _require_requests() -> Any: + try: + import requests # type: ignore[import-not-found] + return requests + except ImportError as error: + raise BackendLogCorrelatorError( + "requests is required for Loki/Elasticsearch fetchers. " + "Install: pip install requests" + ) from error + + +def fetch_loki( + base_url: str, + *, + label: str = "trace_id", + timeout: float = 15.0, + limit: int = 1000, +) -> LogFetcher: + """Build a fetcher that queries Grafana Loki by label-equals match.""" + requests = _require_requests() + url = base_url.rstrip("/") + "/loki/api/v1/query_range" + + def _fetch(trace_id: str) -> List[CorrelatedLog]: + wanted = validate_trace_id(trace_id) + params = {"query": f'{{{label}="{wanted}"}}', "limit": int(limit)} + try: + response = requests.get(url, params=params, timeout=timeout) + response.raise_for_status() + payload = response.json() + except (requests.RequestException, ValueError) as error: + raise BackendLogCorrelatorError(f"loki fetch failed: {error!r}") from error + return _parse_loki_payload(payload) + + return _fetch + + +def _parse_loki_payload(payload: Any) -> List[CorrelatedLog]: + out: List[CorrelatedLog] = [] + if not isinstance(payload, dict): + return out + streams = ((payload.get("data") or {}).get("result")) or [] + for stream in streams: + labels = stream.get("stream") or {} + for entry in stream.get("values") or []: + if not isinstance(entry, list) or len(entry) != 2: + continue + ts_ns, line = entry + out.append(CorrelatedLog( + timestamp=str(ts_ns), + level=str(labels.get("level") or "info"), + message=str(line), + service=labels.get("service") or labels.get("app"), + span_id=labels.get("span_id"), + extra={k: v for k, v in labels.items() + if k not in {"level", "service", "app", "span_id"}}, + )) + return out + + +# ---------- Elasticsearch adapter --------------------------------------- + +def fetch_elasticsearch( + base_url: str, + index: str, + *, + trace_field: str = "trace_id", + timeout: float = 15.0, + size: int = 500, +) -> LogFetcher: + """Build a fetcher that does ``GET {index}/_search`` with a term query.""" + requests = _require_requests() + url = f"{base_url.rstrip('/')}/{index}/_search" + + def _fetch(trace_id: str) -> List[CorrelatedLog]: + wanted = validate_trace_id(trace_id) + body = {"size": int(size), "query": {"term": {trace_field: wanted}}} + try: + response = requests.post(url, json=body, timeout=timeout) + response.raise_for_status() + payload = response.json() + except (requests.RequestException, ValueError) as error: + raise BackendLogCorrelatorError(f"elastic fetch failed: {error!r}") from error + return _parse_elasticsearch_payload(payload) + + return _fetch + + +def _parse_elasticsearch_payload(payload: Any) -> List[CorrelatedLog]: + if not isinstance(payload, dict): + return [] + hits = ((payload.get("hits") or {}).get("hits")) or [] + out: List[CorrelatedLog] = [] + for hit in hits: + source = hit.get("_source") if isinstance(hit, dict) else None + if isinstance(source, dict): + out.append(_log_from_dict(source)) + return out + + +# ---------- bundle integration ------------------------------------------ + +def correlate( + trace_id_or_header: str, + fetchers: Sequence[LogFetcher], +) -> List[CorrelatedLog]: + """ + Resolve ``trace_id_or_header`` (raw id or full traceparent) and call + every fetcher in turn, concatenating their results. + """ + if not fetchers: + raise BackendLogCorrelatorError("at least one fetcher is required") + raw = trace_id_or_header.strip() if isinstance(trace_id_or_header, str) else "" + trace_id = parse_traceparent(raw) if "-" in raw else validate_trace_id(raw) + merged: List[CorrelatedLog] = [] + for fetcher in fetchers: + try: + merged.extend(fetcher(trace_id)) + except BackendLogCorrelatorError: + raise + except Exception as error: + web_runner_logger.warning(f"correlator fetcher failed: {error!r}") + return merged + + +def attach_to_failure_bundle( + bundle_dir: Union[str, Path], + logs: Iterable[CorrelatedLog], + *, + filename: str = "backend_logs.json", +) -> Path: + """Write ``logs`` as JSON into an existing failure-bundle directory.""" + bundle = Path(bundle_dir) + if not bundle.exists(): + raise BackendLogCorrelatorError(f"failure bundle dir not found: {bundle}") + if not bundle.is_dir(): + raise BackendLogCorrelatorError(f"failure bundle path is not a directory: {bundle}") + payload = [log.to_dict() for log in logs] + target = bundle / filename + with open(target, "w", encoding="utf-8") as fp: + json.dump(payload, fp, ensure_ascii=False, indent=2) + return target diff --git a/je_web_runner/utils/bug_repro_stability/__init__.py b/je_web_runner/utils/bug_repro_stability/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/bug_repro_stability/stability.py b/je_web_runner/utils/bug_repro_stability/stability.py new file mode 100644 index 0000000..6b6a8f4 --- /dev/null +++ b/je_web_runner/utils/bug_repro_stability/stability.py @@ -0,0 +1,195 @@ +""" +重複跑同一個失敗 test N 次,報告重現率 — 區分 deterministic vs flaky bug。 +When triaging a regression, the first question is "does this always +break, or just sometimes?". This module runs the test N times via a +caller-supplied runner and rolls up: + +* repro percentage +* longest pass / fail streak +* category: deterministic / flaky / non_reproducible +* per-error grouping (e.g. all failures hit the same exception line) +""" +from __future__ import annotations + +import time +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class BugReproStabilityError(WebRunnerException): + """Raised on bad inputs or runner failure.""" + + +class ReproCategory(str, Enum): + DETERMINISTIC = "deterministic" # 100% repro + FLAKY = "flaky" # 1..99% repro + NON_REPRODUCIBLE = "non_reproducible" # 0% repro + + +# ---------- runner contract ------------------------------------------- + +@dataclass +class RunOutcome: + """One probe outcome.""" + + passed: bool + error_signature: Optional[str] = None + duration_seconds: float = 0.0 + + +ProbeFn = Callable[[int], RunOutcome] +"""``probe(attempt_index) -> RunOutcome``.""" + + +# ---------- report ---------------------------------------------------- + +@dataclass +class StabilityReport: + """Roll-up of N attempts.""" + + attempts: int + failures: int + repro_pct: float + category: ReproCategory + longest_pass_streak: int = 0 + longest_fail_streak: int = 0 + errors: Dict[str, int] = field(default_factory=dict) + durations: List[float] = field(default_factory=list) + + def passed(self) -> bool: + return self.category == ReproCategory.NON_REPRODUCIBLE + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "category": self.category.value} + + +def _classify(repro_pct: float) -> ReproCategory: + if repro_pct >= 100.0: + return ReproCategory.DETERMINISTIC + if repro_pct <= 0.0: + return ReproCategory.NON_REPRODUCIBLE + return ReproCategory.FLAKY + + +# ---------- core ------------------------------------------------------- + +def repeat( + probe: ProbeFn, + *, + attempts: int = 10, + stop_on_first_failure: bool = False, + stop_on_first_pass: bool = False, +) -> StabilityReport: + """ + Drive ``probe`` ``attempts`` times, return :class:`StabilityReport`. + Set ``stop_on_first_failure=True`` to short-circuit when only + confirming whether the bug *ever* repros. + """ + if not callable(probe): + raise BugReproStabilityError("probe must be callable") + if attempts <= 0: + raise BugReproStabilityError("attempts must be > 0") + + failures = 0 + longest_pass = 0 + longest_fail = 0 + pass_streak = 0 + fail_streak = 0 + errors: Dict[str, int] = {} + durations: List[float] = [] + actual_attempts = 0 + for index in range(attempts): + actual_attempts += 1 + try: + outcome = probe(index) + except Exception as error: + raise BugReproStabilityError( + f"probe raised at attempt {index}: {error!r}" + ) from error + if not isinstance(outcome, RunOutcome): + raise BugReproStabilityError( + f"probe must return RunOutcome, got {type(outcome).__name__}" + ) + durations.append(outcome.duration_seconds) + if outcome.passed: + pass_streak += 1 + fail_streak = 0 + longest_pass = max(longest_pass, pass_streak) + if stop_on_first_pass: + break + else: + failures += 1 + fail_streak += 1 + pass_streak = 0 + longest_fail = max(longest_fail, fail_streak) + sig = outcome.error_signature or "(unspecified)" + errors[sig] = errors.get(sig, 0) + 1 + if stop_on_first_failure: + break + web_runner_logger.debug( + f"bug_repro_stability attempt {index + 1}/{attempts}: " + f"passed={outcome.passed}" + ) + repro_pct = (failures / actual_attempts) * 100.0 + return StabilityReport( + attempts=actual_attempts, + failures=failures, + repro_pct=round(repro_pct, 2), + category=_classify(repro_pct), + longest_pass_streak=longest_pass, + longest_fail_streak=longest_fail, + errors=errors, + durations=[round(d, 4) for d in durations], + ) + + +# ---------- assertions ------------------------------------------------- + +def assert_deterministic(report: StabilityReport) -> None: + """Raise unless the report is :attr:`ReproCategory.DETERMINISTIC`.""" + if not isinstance(report, StabilityReport): + raise BugReproStabilityError("expects StabilityReport") + if report.category != ReproCategory.DETERMINISTIC: + raise BugReproStabilityError( + f"expected deterministic repro, got {report.category.value} " + f"({report.repro_pct:.1f}%)" + ) + + +def assert_min_repro_pct(report: StabilityReport, *, minimum: float) -> None: + """Assert ``report.repro_pct >= minimum``.""" + if not isinstance(report, StabilityReport): + raise BugReproStabilityError("expects StabilityReport") + if not 0 <= minimum <= 100: + raise BugReproStabilityError("minimum must be in [0, 100]") + if report.repro_pct < minimum: + raise BugReproStabilityError( + f"repro {report.repro_pct:.1f}% below threshold {minimum:.1f}%" + ) + + +# ---------- formatting ------------------------------------------------- + +def report_markdown(report: StabilityReport) -> str: + """Render a compact markdown summary.""" + if not isinstance(report, StabilityReport): + raise BugReproStabilityError("expects StabilityReport") + avg = sum(report.durations) / len(report.durations) if report.durations else 0.0 + lines = [ + f"### Repro stability: **{report.category.value}** " + f"({report.failures}/{report.attempts} = {report.repro_pct:.1f}%)", + "", + f"- longest fail streak: {report.longest_fail_streak}", + f"- longest pass streak: {report.longest_pass_streak}", + f"- avg attempt duration: {avg:.2f}s", + ] + if report.errors: + lines.append("") + lines.append("**Error signatures:**") + for sig, count in sorted(report.errors.items(), key=lambda kv: -kv[1]): + lines.append(f"- `{sig}` × {count}") + return "\n".join(lines) + "\n" diff --git a/je_web_runner/utils/bundle_budget/__init__.py b/je_web_runner/utils/bundle_budget/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/bundle_budget/budget.py b/je_web_runner/utils/bundle_budget/budget.py new file mode 100644 index 0000000..3410c5b --- /dev/null +++ b/je_web_runner/utils/bundle_budget/budget.py @@ -0,0 +1,258 @@ +""" +每頁 JS / CSS / image / font 載入大小預算 + 違規清單。 +The classic "Lighthouse budget" but driven from a HAR file so it works +inside any E2E framework (Selenium / Playwright / WebDriver BiDi). +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlparse + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class BundleBudgetError(WebRunnerException): + """Raised on bad HAR / budget input or breached budget.""" + + +class AssetKind(str, Enum): + SCRIPT = "script" + STYLESHEET = "stylesheet" + IMAGE = "image" + FONT = "font" + MEDIA = "media" + DOCUMENT = "document" + XHR = "xhr" + OTHER = "other" + + +_MIME_KIND_MAP = { + "application/javascript": AssetKind.SCRIPT, + "application/x-javascript": AssetKind.SCRIPT, + "text/javascript": AssetKind.SCRIPT, + "module": AssetKind.SCRIPT, + "text/css": AssetKind.STYLESHEET, + "font/woff": AssetKind.FONT, + "font/woff2": AssetKind.FONT, + "application/font-woff": AssetKind.FONT, +} + +_RESOURCE_TYPE_KIND_MAP = { + "script": AssetKind.SCRIPT, + "stylesheet": AssetKind.STYLESHEET, + "image": AssetKind.IMAGE, + "imageset": AssetKind.IMAGE, + "font": AssetKind.FONT, + "media": AssetKind.MEDIA, + "video": AssetKind.MEDIA, + "audio": AssetKind.MEDIA, + "document": AssetKind.DOCUMENT, + "xhr": AssetKind.XHR, + "fetch": AssetKind.XHR, +} + + +# ---------- assets ----------------------------------------------------- + +@dataclass +class Asset: + """One downloaded resource.""" + + url: str + kind: AssetKind + transfer_bytes: int + content_bytes: int + + @property + def hostname(self) -> str: + try: + return (urlparse(self.url).hostname or "").lower() + except (ValueError, AttributeError): + return "" + + +def _kind_of(entry: Dict[str, Any]) -> AssetKind: + resource_type = str( + entry.get("_resourceType") or entry.get("resourceType") or "" + ).lower() + if resource_type in _RESOURCE_TYPE_KIND_MAP: + return _RESOURCE_TYPE_KIND_MAP[resource_type] + mime = str( + ((entry.get("response") or {}).get("content") or {}).get("mimeType") or "", + ).split(";")[0].strip().lower() + if mime.startswith("image/"): + return AssetKind.IMAGE + if mime.startswith("video/") or mime.startswith("audio/"): + return AssetKind.MEDIA + return _MIME_KIND_MAP.get(mime, AssetKind.OTHER) + + +def _sizes(entry: Dict[str, Any]) -> Tuple[int, int]: + response = entry.get("response") or {} + content = response.get("content") or {} + transfer = response.get("_transferSize") or response.get("bodySize") + body = content.get("size") + return ( + max(0, int(transfer or 0)), + max(0, int(body or transfer or 0)), + ) + + +def assets_from_har(har: Union[str, Dict[str, Any]]) -> List[Asset]: + """Reduce a HAR object to a flat list of :class:`Asset`.""" + har_obj = _coerce_har(har) + entries = ((har_obj.get("log") or {}).get("entries")) or [] + if not isinstance(entries, list): + raise BundleBudgetError("har log.entries must be a list") + out: List[Asset] = [] + for entry in entries: + if not isinstance(entry, dict): + continue + url = ((entry.get("request") or {}).get("url")) or "" + if not url: + continue + transfer, body = _sizes(entry) + out.append(Asset( + url=url, + kind=_kind_of(entry), + transfer_bytes=transfer, + content_bytes=body, + )) + return out + + +def _coerce_har(har: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(har, str): + try: + parsed = json.loads(har) + except ValueError as error: + raise BundleBudgetError(f"har not JSON: {error}") from error + if not isinstance(parsed, dict): + raise BundleBudgetError("har JSON must be an object") + return parsed + if isinstance(har, dict): + return har + raise BundleBudgetError( + f"assets_from_har expects str/dict, got {type(har).__name__}" + ) + + +# ---------- budget ------------------------------------------------------ + +@dataclass(frozen=True) +class Budget: + """Per-kind size budget (transfer-encoded bytes).""" + + kind: AssetKind + max_bytes: int + + def __post_init__(self) -> None: + if self.max_bytes <= 0: + raise BundleBudgetError("max_bytes must be > 0") + + +DEFAULT_BUDGETS: Sequence[Budget] = ( + Budget(kind=AssetKind.SCRIPT, max_bytes=350 * 1024), + Budget(kind=AssetKind.STYLESHEET, max_bytes=100 * 1024), + Budget(kind=AssetKind.IMAGE, max_bytes=800 * 1024), + Budget(kind=AssetKind.FONT, max_bytes=150 * 1024), + Budget(kind=AssetKind.MEDIA, max_bytes=2 * 1024 * 1024), +) + + +@dataclass +class BudgetBreach: + """One budget violation.""" + + kind: AssetKind + actual_bytes: int + max_bytes: int + over_bytes: int + + +@dataclass +class BudgetReport: + """Roll-up returned by :func:`evaluate_budget`.""" + + totals: Dict[AssetKind, int] = field(default_factory=dict) + breaches: List[BudgetBreach] = field(default_factory=list) + biggest_assets: List[Asset] = field(default_factory=list) + + def passed(self) -> bool: + return not self.breaches + + +def evaluate_budget( + assets: Sequence[Asset], + budgets: Sequence[Budget] = DEFAULT_BUDGETS, + *, + biggest_n: int = 10, +) -> BudgetReport: + """Aggregate per-kind sizes and compare against ``budgets``.""" + if not assets: + raise BundleBudgetError("assets must be non-empty") + if biggest_n < 0: + raise BundleBudgetError("biggest_n must be >= 0") + totals: Dict[AssetKind, int] = {} + for asset in assets: + totals[asset.kind] = totals.get(asset.kind, 0) + max( + asset.transfer_bytes, asset.content_bytes, + ) + breaches: List[BudgetBreach] = [] + for budget in budgets: + if not isinstance(budget, Budget): + raise BundleBudgetError("budgets entries must be Budget instances") + actual = totals.get(budget.kind, 0) + if actual > budget.max_bytes: + breaches.append(BudgetBreach( + kind=budget.kind, + actual_bytes=actual, + max_bytes=budget.max_bytes, + over_bytes=actual - budget.max_bytes, + )) + biggest = sorted( + assets, key=lambda a: -max(a.transfer_bytes, a.content_bytes), + )[:biggest_n] + return BudgetReport(totals=totals, breaches=breaches, biggest_assets=biggest) + + +def assert_within_budget(report: BudgetReport) -> None: + """Raise unless every budget was respected.""" + if not isinstance(report, BudgetReport): + raise BundleBudgetError("assert_within_budget expects BudgetReport") + if report.passed(): + return + parts = [ + f"{b.kind.value}: {b.actual_bytes}>{b.max_bytes} (+{b.over_bytes})" + for b in report.breaches + ] + raise BundleBudgetError("bundle budget breached — " + "; ".join(parts)) + + +def report_markdown(report: BudgetReport) -> str: + """Render a small markdown table for PR comments.""" + if not isinstance(report, BudgetReport): + raise BundleBudgetError("report_markdown expects BudgetReport") + lines = ["### Bundle budget", "", "| Kind | Bytes |", "|------|-------|"] + for kind, total in sorted(report.totals.items(), key=lambda kv: -kv[1]): + lines.append(f"| {kind.value} | {total:,} |") + if report.breaches: + lines.append("") + lines.append("**Breaches:**") + for b in report.breaches: + lines.append( + f"- {b.kind.value}: {b.actual_bytes:,}B > {b.max_bytes:,}B " + f"(over by {b.over_bytes:,}B)" + ) + if report.biggest_assets: + lines.append("") + lines.append("**Biggest assets:**") + for asset in report.biggest_assets[:5]: + lines.append( + f"- `{asset.url}` ({asset.kind.value}) " + f"{max(asset.transfer_bytes, asset.content_bytes):,}B" + ) + return "\n".join(lines) + "\n" diff --git a/je_web_runner/utils/chaos_hooks/__init__.py b/je_web_runner/utils/chaos_hooks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/chaos_hooks/chaos.py b/je_web_runner/utils/chaos_hooks/chaos.py new file mode 100644 index 0000000..40ef671 --- /dev/null +++ b/je_web_runner/utils/chaos_hooks/chaos.py @@ -0,0 +1,183 @@ +""" +在 action 流程中隨機注入混亂條件:網路斷線、CPU 節流、中途 reload。 +Verifies the UX recovers, retries, or shows the right error UI — not just +that the happy path works on a perfect machine. + +Three deliberately decoupled pieces: + +* :class:`ChaosPlan` — pure scheduling. Given a list of action names and + a seed, deterministically decides which step gets which fault. No + browser dependency; fully unit-testable. +* :class:`ChaosFaultType` — enum of fault categories. Each maps to an + injector callable provided by the user (so this module doesn't import + Selenium/CDP/Playwright). +* :class:`ChaosRunner` — runs a plan against an executor by invoking the + matching injector before the chosen step. +""" +from __future__ import annotations + +import random +from dataclasses import dataclass, field +from enum import Enum +from typing import Callable, Dict, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class ChaosHooksError(WebRunnerException): + """Raised on bad plan parameters or missing injector for a chosen fault.""" + + +class ChaosFaultType(str, Enum): + """Categories of chaos a runner can inject.""" + + NETWORK_OFFLINE = "network_offline" + NETWORK_SLOW = "network_slow" + CPU_THROTTLE = "cpu_throttle" + MID_FLOW_RELOAD = "mid_flow_reload" + TAB_BACKGROUND = "tab_background" + + +# ---------- planning ---------------------------------------------------- + +@dataclass(frozen=True) +class ChaosEvent: + """One scheduled fault.""" + + step_index: int + step_name: str + fault: ChaosFaultType + + +@dataclass +class ChaosPlan: + """A reproducible (seeded) injection schedule.""" + + events: List[ChaosEvent] = field(default_factory=list) + seed: Optional[int] = None + skipped: List[int] = field(default_factory=list) + + def faults_for_step(self, index: int) -> List[ChaosFaultType]: + return [e.fault for e in self.events if e.step_index == index] + + def describe(self) -> str: + if not self.events: + return "no chaos planned" + rows = [ + f"step {e.step_index} ({e.step_name}): {e.fault.value}" + for e in self.events + ] + return "; ".join(rows) + + +def plan_chaos( + step_names: Sequence[str], + *, + faults: Sequence[ChaosFaultType] = tuple(ChaosFaultType), + fault_rate: float = 0.2, + max_events: Optional[int] = None, + skip_first: int = 1, + skip_last: int = 0, + seed: Optional[int] = None, +) -> ChaosPlan: + """ + 決定每個 step 是否注入 chaos,以及注入哪種類型。 + Each non-skipped step independently has ``fault_rate`` chance of + getting a randomly-chosen fault from ``faults``. ``max_events`` caps + the total. ``skip_first`` / ``skip_last`` keep setup / teardown safe. + """ + if not 0.0 <= fault_rate <= 1.0: + raise ChaosHooksError("fault_rate must be in [0, 1]") + if not faults: + raise ChaosHooksError("faults must be a non-empty sequence") + if skip_first < 0 or skip_last < 0: + raise ChaosHooksError("skip_first / skip_last must be >= 0") + total = len(step_names) + rng = random.Random(seed) + events: List[ChaosEvent] = [] + skipped: List[int] = [] + for index, name in enumerate(step_names): + if index < skip_first or index >= total - skip_last: + skipped.append(index) + continue + if rng.random() >= fault_rate: + continue + fault = rng.choice(list(faults)) + events.append(ChaosEvent(step_index=index, step_name=name, fault=fault)) + if max_events is not None and len(events) >= max_events: + break + return ChaosPlan(events=events, seed=seed, skipped=skipped) + + +# ---------- runner ------------------------------------------------------ + +Injector = Callable[[ChaosEvent], None] +"""Callable that performs the side effect for one event (offline, throttle, etc.).""" + + +@dataclass +class ChaosRunner: + """Runs a :class:`ChaosPlan` by invoking the matching injector pre-step.""" + + plan: ChaosPlan + injectors: Dict[ChaosFaultType, Injector] = field(default_factory=dict) + raise_on_missing: bool = True + + def __post_init__(self) -> None: + if not isinstance(self.plan, ChaosPlan): + raise ChaosHooksError("plan must be a ChaosPlan") + missing = [ + event.fault for event in self.plan.events + if event.fault not in self.injectors + ] + if missing and self.raise_on_missing: + unique = sorted({m.value for m in missing}) + raise ChaosHooksError( + f"no injector registered for fault types: {unique}" + ) + + def before_step(self, index: int, name: str) -> List[ChaosEvent]: + """Fire every injector scheduled for ``index``; return events fired.""" + fired: List[ChaosEvent] = [] + for event in self.plan.events: + if event.step_index != index: + continue + injector = self.injectors.get(event.fault) + if injector is None: + web_runner_logger.warning( + f"chaos: no injector for {event.fault.value} at step {index} ({name})" + ) + continue + try: + injector(event) + except Exception as error: + raise ChaosHooksError( + f"injector {event.fault.value} raised at step {index}: {error!r}" + ) from error + fired.append(event) + web_runner_logger.info( + f"chaos: injected {event.fault.value} before step {index} ({name})" + ) + return fired + + +# ---------- convenience ------------------------------------------------- + +def run_with_chaos( + step_names: Sequence[str], + step_fn: Callable[[int, str], None], + *, + plan: ChaosPlan, + injectors: Dict[ChaosFaultType, Injector], +) -> List[ChaosEvent]: + """ + Drive ``step_fn(index, name)`` for every step, firing scheduled + injectors immediately before each step. Returns the events that fired. + """ + runner = ChaosRunner(plan=plan, injectors=injectors) + fired: List[ChaosEvent] = [] + for index, name in enumerate(step_names): + fired.extend(runner.before_step(index, name)) + step_fn(index, name) + return fired diff --git a/je_web_runner/utils/chrome_profile/__init__.py b/je_web_runner/utils/chrome_profile/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/chrome_profile/profile_manager.py b/je_web_runner/utils/chrome_profile/profile_manager.py new file mode 100644 index 0000000..e53b506 --- /dev/null +++ b/je_web_runner/utils/chrome_profile/profile_manager.py @@ -0,0 +1,461 @@ +""" +持久化 Chrome profile:snapshot → 跑 → sync-back 模式 + SingletonLock 清理 + stealth flag。 +Persistent Chrome profile helpers: snapshot-launch-sync pattern, singleton-lock +cleanup, stealth flags, fake user-agent. Mirrors the pattern proven in +Jeffrey_RPA's NovelAI scraper where the actual `.chrome_profile/` directory is +often locked by Defender / OneDrive / Explorer; we copy session-critical files +into a disposable snapshot, run Chrome against that, and sync the login state +back on exit. + +Supports Selenium (Options builder + driver factory) and Playwright (persistent +context launch). +""" +from __future__ import annotations + +import os +import shutil +import sys +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Iterator, List, Optional, Sequence, Tuple + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class ChromeProfileError(WebRunnerException): + """Raised on snapshot / sync-back / spawn problems.""" + + +# Files Chrome leaves in user-data-dir on unclean exit. A new Chrome that +# sees any of these will refuse to launch with the same profile +# (SessionNotCreatedException). Removing them is safe — they do NOT contain +# cookies or login data. +SINGLETON_LOCK_FILES: Tuple[str, ...] = ( + "SingletonLock", + "SingletonCookie", + "SingletonSocket", + "lockfile", + "RunningChromeVersion", +) + +# Relative paths inside the profile that we treat as session-critical: +# everything we need to preserve a logged-in session. The journal sidecar +# files are SQLite WAL artefacts and must be copied alongside the main DB. +SESSION_CRITICAL_PATHS: Tuple[str, ...] = ( + "Default/Cookies", + "Default/Cookies-journal", + "Default/Login Data", + "Default/Login Data-journal", + "Default/Login Data For Account", + "Default/Login Data For Account-journal", + "Default/Web Data", + "Default/Web Data-journal", + "Default/Preferences", + "Default/Secure Preferences", + "Default/History", + "Default/Bookmarks", + "Local State", + "First Run", +) + +# Directories we strip from the snapshot — disposable caches that Chrome +# rebuilds on demand. Keeping them in sync bloats the snapshot tenfold and +# slows every launch. +SNAPSHOT_IGNORE_NAMES: frozenset = frozenset({ + "Cache", "Code Cache", "GPUCache", "ShaderCache", + "GraphiteDawnCache", "DawnGraphiteCache", "GrShaderCache", + "Service Worker", "blob_storage", "Crashpad", "CrashpadMetrics", + "GrShaderCache", "Subresource Filter", "optimization_guide_model_store", + "Download Service", "VideoDecodeStats", "Trust Tokens", + "Network", "Session Storage", "IndexedDB", +}) + +# Realistic desktop Chrome user-agent. Update periodically — Chrome bumps +# the major version every ~6 weeks and stale UA strings raise some +# anti-bot heuristics. +DEFAULT_USER_AGENT: str = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " + "AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/131.0.0.0 Safari/537.36" +) + + +@dataclass +class StealthFlags: + """ + 可序列化的 stealth 設定,方便由 action JSON 帶入。 + Serialisable stealth knobs so action JSON can configure them. + """ + user_agent: str = DEFAULT_USER_AGENT + language: str = "en-US" + window_size: Optional[Tuple[int, int]] = None + disable_blink_features: bool = True + exclude_automation_switches: bool = True + headless: bool = False + extra_args: List[str] = field(default_factory=list) + + +def cleanup_chrome_locks(profile_dir: Path) -> List[str]: + """ + 清理 SingletonLock / lockfile 等殘留檔。Locked file 改用 rename 規避。 + Remove the singleton lock files Chrome leaves behind. Files held by the + OS get renamed (which Windows usually allows) so a fresh Chrome can + create its own lock. Returns a per-file status list for logging. + """ + profile_dir = Path(profile_dir) + if not profile_dir.exists(): + return [] + statuses: List[str] = [] + for fname in SINGLETON_LOCK_FILES: + target = profile_dir / fname + if not target.exists() and not target.is_symlink(): + continue + try: + target.unlink() + statuses.append(f"removed {fname}") + except OSError: + try: + renamed = profile_dir / f"{fname}.stale.{int(time.time())}" + target.rename(renamed) + statuses.append(f"renamed {fname} → {renamed.name}") + except OSError as error: + statuses.append(f"failed {fname}: {error!r}") + if statuses: + web_runner_logger.info(f"cleanup_chrome_locks: {statuses}") + return statuses + + +def _is_session_critical(rel_path: str) -> bool: + """Test whether ``rel_path`` (POSIX-style) is in the session-critical list.""" + normalised = rel_path.replace("\\", "/") + return normalised in SESSION_CRITICAL_PATHS + + +def snapshot_chrome_profile( + profile_dir: Path, + snapshot_dir: Path, + *, + full_copy: bool = False, +) -> Path: + """ + 把 profile 複製到 snapshot_dir,跳過 cache / lock。 + Copy the profile into a disposable snapshot directory, skipping cache + folders and singleton locks. ``full_copy=True`` copies everything except + the lock files (useful for migrations where you want extensions too). + + Returns the snapshot path. Per-file copy errors are swallowed and + logged — partial snapshots are still usable since session-critical + files are independent of cache files. + """ + profile_dir = Path(profile_dir) + snapshot_dir = Path(snapshot_dir) + if not profile_dir.exists(): + raise ChromeProfileError(f"profile dir does not exist: {profile_dir}") + + if snapshot_dir.exists(): + try: + shutil.rmtree(snapshot_dir) + except OSError as error: + raise ChromeProfileError( + f"cannot clear previous snapshot at {snapshot_dir}: {error!r}" + ) from error + snapshot_dir.mkdir(parents=True, exist_ok=True) + + copied = 0 + skipped: List[str] = [] + profile_root_str = str(profile_dir) + for root, dirs, files in os.walk(profile_dir, topdown=True): + if not full_copy: + dirs[:] = [d for d in dirs if d not in SNAPSHOT_IGNORE_NAMES] + rel_root = os.path.relpath(root, profile_root_str) + target_root = snapshot_dir if rel_root == "." else snapshot_dir / rel_root + target_root.mkdir(parents=True, exist_ok=True) + for fname in files: + if fname in SINGLETON_LOCK_FILES: + continue + src = Path(root) / fname + dst = target_root / fname + try: + shutil.copy2(src, dst) + copied += 1 + except (OSError, shutil.Error) as error: + skipped.append(f"{rel_root}/{fname}: {error!r}") + web_runner_logger.info( + f"snapshot_chrome_profile: copied={copied} skipped={len(skipped)} → {snapshot_dir}" + ) + if skipped and len(skipped) < 30: + web_runner_logger.info(f"snapshot skipped detail: {skipped}") + return snapshot_dir + + +def sync_chrome_profile_back( + snapshot_dir: Path, + profile_dir: Path, + *, + paths: Sequence[str] = SESSION_CRITICAL_PATHS, +) -> List[str]: + """ + 把 snapshot 內 session-critical 檔複製回原 profile。 + Copy session-critical files from the snapshot back into the persistent + profile so a future run picks up the latest cookies / login data. + Returns a list of "copied" / "skipped" status strings for logging. + """ + snapshot_dir = Path(snapshot_dir) + profile_dir = Path(profile_dir) + if not snapshot_dir.exists(): + raise ChromeProfileError(f"snapshot dir does not exist: {snapshot_dir}") + profile_dir.mkdir(parents=True, exist_ok=True) + + statuses: List[str] = [] + for rel in paths: + src = snapshot_dir / rel + if not src.exists(): + continue + dst = profile_dir / rel + dst.parent.mkdir(parents=True, exist_ok=True) + try: + shutil.copy2(src, dst) + statuses.append(f"copied {rel}") + except (OSError, shutil.Error) as error: + statuses.append(f"skipped {rel}: {error!r}") + web_runner_logger.info(f"sync_chrome_profile_back: {len(statuses)} entries") + return statuses + + +def build_chrome_options( + profile_dir: Path, + flags: Optional[StealthFlags] = None, +): + """ + 產出帶 stealth 設定的 ChromeOptions。 + Build a Selenium ``ChromeOptions`` instance with stealth flags, the + given user-data-dir and (optionally) a real desktop UA. + """ + from selenium.webdriver.chrome.options import Options as ChromeOptions # local import keeps the module import-light + flags = flags or StealthFlags() + profile_dir = Path(profile_dir) + profile_dir.mkdir(parents=True, exist_ok=True) + + opts = ChromeOptions() + if flags.headless: + opts.add_argument("--headless=new") + if flags.disable_blink_features: + opts.add_argument("--disable-blink-features=AutomationControlled") + opts.add_argument(f"--lang={flags.language}") + opts.add_argument(f"--user-agent={flags.user_agent}") + opts.add_argument(f"--user-data-dir={profile_dir}") + if flags.window_size: + width, height = flags.window_size + opts.add_argument(f"--window-size={int(width)},{int(height)}") + for arg in flags.extra_args: + opts.add_argument(arg) + if flags.exclude_automation_switches: + opts.add_experimental_option("excludeSwitches", ["enable-automation"]) + opts.add_experimental_option("useAutomationExtension", False) + return opts + + +def build_stealth_chrome_driver( + profile_dir: Path, + *, + snapshot_dir: Optional[Path] = None, + flags: Optional[StealthFlags] = None, + chromedriver_log: Optional[Path] = None, + retry_once: bool = True, +): + """ + Spawn Chrome 用 snapshot profile + stealth flags + 一次 retry。 + Spawn a Selenium Chrome driver against a *snapshot* of the persistent + profile to side-step file locks held by AV / OneDrive / Explorer. + Caller is responsible for calling ``sync_chrome_profile_back`` on quit + if they want to preserve cookies — or use ``chrome_profile_session``. + + Returns ``(driver, snapshot_path)``. ``snapshot_path`` is ``None`` when + ``snapshot_dir`` is falsy (driver runs directly against ``profile_dir``). + """ + from selenium import webdriver + from selenium.webdriver.chrome.service import Service as ChromeService + + profile_dir = Path(profile_dir) + flags = flags or StealthFlags() + + if snapshot_dir is not None: + snapshot_path = snapshot_chrome_profile(profile_dir, Path(snapshot_dir)) + cleanup_chrome_locks(snapshot_path) + run_dir = snapshot_path + else: + snapshot_path = None + cleanup_chrome_locks(profile_dir) + run_dir = profile_dir + + opts = build_chrome_options(run_dir, flags=flags) + service_kwargs = {} + if chromedriver_log is not None: + service_kwargs["log_path"] = str(chromedriver_log) + try: + service = ChromeService(**service_kwargs) + driver = webdriver.Chrome(service=service, options=opts) + except Exception as first_err: # noqa: BLE001 — Selenium wraps many causes + if not retry_once: + raise ChromeProfileError( + f"chrome spawn failed: {first_err!r}" + ) from first_err + web_runner_logger.warning( + f"first chrome spawn failed: {first_err!r}; rebuilding snapshot and retrying" + ) + if snapshot_path is not None: + snapshot_path = snapshot_chrome_profile(profile_dir, Path(snapshot_dir)) + cleanup_chrome_locks(snapshot_path) + run_dir = snapshot_path + opts = build_chrome_options(run_dir, flags=flags) + time.sleep(1.5) + try: + service = ChromeService(**service_kwargs) + driver = webdriver.Chrome(service=service, options=opts) + except Exception as second_err: # noqa: BLE001 + raise ChromeProfileError( + f"chrome spawn failed twice: {second_err!r}" + ) from second_err + web_runner_logger.info( + f"stealth chrome driver spawned: profile={profile_dir} snapshot={snapshot_path}" + ) + return driver, snapshot_path + + +@contextmanager +def chrome_profile_session( + profile_dir: Path, + *, + snapshot_dir: Optional[Path] = None, + flags: Optional[StealthFlags] = None, + chromedriver_log: Optional[Path] = None, + sync_back: bool = True, +) -> Iterator[Any]: + """ + Context manager:spawn driver → yield → quit → sync-back。 + Context-managed lifecycle. ``snapshot_dir`` defaults to + ``profile_dir.parent / (profile_dir.name + "_snap")``. The yielded + value is the Selenium driver; on exit we quit the driver and sync + session-critical files back into the persistent profile. + """ + profile_dir = Path(profile_dir) + if snapshot_dir is None: + snapshot_dir = profile_dir.parent / f"{profile_dir.name}_snap" + driver, snapshot_path = build_stealth_chrome_driver( + profile_dir, + snapshot_dir=snapshot_dir, + flags=flags, + chromedriver_log=chromedriver_log, + ) + try: + yield driver + finally: + try: + driver.quit() + except Exception as error: # noqa: BLE001 — best effort on teardown + web_runner_logger.warning(f"driver.quit failed: {error!r}") + if sync_back and snapshot_path is not None: + try: + sync_chrome_profile_back(snapshot_path, profile_dir) + except ChromeProfileError as error: + web_runner_logger.warning(f"sync_back failed: {error!r}") + + +def build_playwright_persistent_context( + playwright_browser_type: Any, + profile_dir: Path, + *, + flags: Optional[StealthFlags] = None, + extra_launch_kwargs: Optional[dict] = None, +) -> Any: + """ + 用 Playwright 的 persistent context 開瀏覽器並套 stealth flag。 + Launch Playwright with ``launch_persistent_context``, passing stealth + flags and a stable user-agent. Caller passes the chromium browser type + (e.g. ``playwright.chromium``); we do not import playwright at module + load time because it is an optional dependency. + + Returns a ``BrowserContext`` ready to ``new_page()``. + """ + flags = flags or StealthFlags() + profile_dir = Path(profile_dir) + profile_dir.mkdir(parents=True, exist_ok=True) + cleanup_chrome_locks(profile_dir) + + args: List[str] = [f"--lang={flags.language}"] + if flags.disable_blink_features: + args.append("--disable-blink-features=AutomationControlled") + args.extend(flags.extra_args) + + launch_kwargs = { + "user_data_dir": str(profile_dir), + "headless": flags.headless, + "user_agent": flags.user_agent, + "args": args, + } + if flags.window_size: + launch_kwargs["viewport"] = { + "width": int(flags.window_size[0]), + "height": int(flags.window_size[1]), + } + if extra_launch_kwargs: + launch_kwargs.update(extra_launch_kwargs) + web_runner_logger.info( + f"playwright persistent context launching: profile={profile_dir}" + ) + return playwright_browser_type.launch_persistent_context(**launch_kwargs) + + +# ----- optional Windows-only window minimise hook -------------------------- + +def minimise_chrome_windows(profile_dir: Path) -> int: + """ + Win32 路徑 minimise 所有跑這個 profile 的 chrome.exe 視窗。 + Best-effort: enumerate top-level windows whose owner chrome.exe has a + cmdline arg referencing ``profile_dir``, and ``ShowWindow(SW_MINIMIZE)`` + each. Returns the count of windows hidden. No-op on non-Windows or when + pywin32 / psutil is unavailable. + """ + if sys.platform != "win32": + return 0 + try: + import psutil # type: ignore + import win32con # type: ignore + import win32gui # type: ignore + import win32process # type: ignore + except ImportError: + web_runner_logger.info( + "minimise_chrome_windows: pywin32/psutil not installed; skipping" + ) + return 0 + + profile_marker = str(Path(profile_dir)).replace("\\", "/").lower() + target_pids = set() + for proc in psutil.process_iter(["pid", "name", "cmdline"]): + try: + name = (proc.info.get("name") or "").lower() + if name != "chrome.exe": + continue + cmdline = " ".join(proc.info.get("cmdline") or []).replace("\\", "/").lower() + if profile_marker in cmdline: + target_pids.add(proc.info["pid"]) + except (psutil.NoSuchProcess, psutil.AccessDenied): + continue + + hidden = 0 + + def _maybe_minimise(hwnd, _arg): + nonlocal hidden + if not win32gui.IsWindowVisible(hwnd): + return + _, pid = win32process.GetWindowThreadProcessId(hwnd) + if pid in target_pids: + win32gui.ShowWindow(hwnd, win32con.SW_MINIMIZE) + hidden += 1 + + win32gui.EnumWindows(_maybe_minimise, None) + web_runner_logger.info(f"minimise_chrome_windows: hidden={hidden}") + return hidden diff --git a/je_web_runner/utils/clickjacking_audit/__init__.py b/je_web_runner/utils/clickjacking_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/clickjacking_audit/audit.py b/je_web_runner/utils/clickjacking_audit/audit.py new file mode 100644 index 0000000..502b6a4 --- /dev/null +++ b/je_web_runner/utils/clickjacking_audit/audit.py @@ -0,0 +1,218 @@ +""" +X-Frame-Options / CSP `frame-ancestors` 驗證 + iframe 嵌入探測。 +Two layers of defence against clickjacking: + +* **Header policy** — ``X-Frame-Options`` (deprecated but still honored) + + ``Content-Security-Policy: frame-ancestors`` (modern). At least one + should be present and restrict third-party framing. +* **Practical probe** — render a tiny test page that loads the target in + an `` + + +""".strip() + + +def build_probe_page(target_url: str) -> str: + """Render an HTML probe page that tries to embed ``target_url``.""" + if not isinstance(target_url, str) or not target_url: + raise ClickjackingAuditError("target_url must be non-empty string") + parsed = urlparse(target_url) + if parsed.scheme not in ("http", "https"): + raise ClickjackingAuditError( + f"target_url must be http(s), got {parsed.scheme!r}" + ) + return _PROBE_TEMPLATE % {"target_url": target_url} + + +PROBE_STATUS_SCRIPT = "return document.getElementById('status').textContent;" + + +# ---------- assertions ------------------------------------------------ + +@dataclass +class AuditReport: + """Combined header + probe outcome.""" + + target_url: str + verdict: Verdict + policy: HeaderPolicy + probe_status: Optional[str] = None + notes: List[str] = field(default_factory=list) + + def passed(self) -> bool: + if self.verdict in (Verdict.STRICT, Verdict.SAMEORIGIN): + if self.probe_status is None: + return True + return self.probe_status.upper().startswith("BLOCKED") + return False + + def to_dict(self) -> Dict[str, Any]: + return { + "target_url": self.target_url, + "verdict": self.verdict.value, + "policy": asdict(self.policy), + "probe_status": self.probe_status, + "notes": list(self.notes), + "passed": self.passed(), + } + + +def audit( + target_url: str, + headers: Iterable[Tuple[str, str]], + *, + probe_status: Optional[str] = None, +) -> AuditReport: + """One-shot: parse headers → classify → (optionally) consider probe.""" + policy = parse_response_headers(headers) + verdict = classify(policy) + notes: List[str] = [] + if verdict == Verdict.MISSING: + notes.append("no X-Frame-Options or frame-ancestors set") + if verdict == Verdict.ALLOWED: + notes.append("policy permits third-party embedding") + if probe_status: + notes.append(f"probe result: {probe_status}") + return AuditReport( + target_url=target_url, + verdict=verdict, + policy=policy, + probe_status=probe_status, + notes=notes, + ) + + +def assert_protected(report: AuditReport) -> None: + """Raise unless the report passes (strict / sameorigin and probe blocked).""" + if not isinstance(report, AuditReport): + raise ClickjackingAuditError("assert_protected expects AuditReport") + if report.passed(): + return + raise ClickjackingAuditError( + f"clickjacking risk for {report.target_url!r}: " + f"verdict={report.verdict.value}, probe={report.probe_status!r}" + ) diff --git a/je_web_runner/utils/consent_audit/__init__.py b/je_web_runner/utils/consent_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/consent_audit/audit.py b/je_web_runner/utils/consent_audit/audit.py new file mode 100644 index 0000000..7ab9b94 --- /dev/null +++ b/je_web_runner/utils/consent_audit/audit.py @@ -0,0 +1,249 @@ +""" +GDPR / CCPA 風格 cookie 分類 + 偵測 pre-consent 載入的 non-essential cookies。 +Two assertions teams hit most often: + +* "No analytics / advertising / social cookies must be set before the + user clicks 'Accept'." +* "When the user opts out, marketing cookies must not be re-introduced + by any subsequent page load." + +This module compares two cookie snapshots (``before_consent`` and +``after_consent``), classifies each cookie against a built-in catalogue +of well-known vendors, and produces a :class:`ConsentReport`. +""" +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class ConsentAuditError(WebRunnerException): + """Raised on malformed cookie inputs or invalid catalogue overrides.""" + + +class CookieCategory(str, Enum): + """Standard GDPR / IAB categories.""" + + NECESSARY = "necessary" + PREFERENCES = "preferences" + ANALYTICS = "analytics" + MARKETING = "marketing" + SOCIAL = "social" + UNKNOWN = "unknown" + + +# ---------- catalogue --------------------------------------------------- + +@dataclass(frozen=True) +class CookieRule: + """Match by cookie name regex and / or domain suffix.""" + + name_pattern: Optional[str] + domain_suffix: Optional[str] + category: CookieCategory + vendor: str + + +_CATALOGUE: Sequence[CookieRule] = ( + CookieRule(r"^_ga(_|$)|^_gid$|^_gat", None, CookieCategory.ANALYTICS, "google_analytics"), + CookieRule(r"^_fbp$|^_fbc$", None, CookieCategory.MARKETING, "facebook_pixel"), + CookieRule(r"^_hjSessionUser_|^_hjSession_|^_hjAbsoluteSessionInProgress$", + None, CookieCategory.ANALYTICS, "hotjar"), + CookieRule(r"^IDE$", "doubleclick.net", CookieCategory.MARKETING, "google_dv360"), + CookieRule(r"^MUID$|^MUIDB$", "bing.com", CookieCategory.MARKETING, "microsoft_ads"), + CookieRule(r"^_pin_unauth$|^_pinterest_", None, CookieCategory.MARKETING, "pinterest"), + CookieRule(r"^li_gc$|^lidc$|^bcookie$|^bscookie$", "linkedin.com", + CookieCategory.SOCIAL, "linkedin"), + CookieRule(r"^datr$|^sb$", "facebook.com", CookieCategory.SOCIAL, "facebook"), + CookieRule(r"^optimizelyEndUserId$|^optimizely_", None, + CookieCategory.ANALYTICS, "optimizely"), + CookieRule(r"^mp_", None, CookieCategory.ANALYTICS, "mixpanel"), + CookieRule(r"^amplitude_id_", None, CookieCategory.ANALYTICS, "amplitude"), + CookieRule(r"^XSRF-TOKEN$|^csrftoken$|^__Host-csrf$", + None, CookieCategory.NECESSARY, "csrf"), + CookieRule(r"^(session|JSESSIONID|connect\.sid|laravel_session|PHPSESSID)$", + None, CookieCategory.NECESSARY, "session"), + CookieRule(r"^locale$|^lang$|^i18n_", None, CookieCategory.PREFERENCES, "i18n"), + CookieRule(r"^theme$|^darkmode$", None, CookieCategory.PREFERENCES, "ui_preferences"), +) + + +# ---------- cookie model ----------------------------------------------- + +@dataclass(frozen=True) +class Cookie: + """One browser cookie.""" + + name: str + domain: str = "" + value: Optional[str] = None + secure: bool = True + same_site: Optional[str] = None + + def __post_init__(self) -> None: + if not self.name or not isinstance(self.name, str): + raise ConsentAuditError("Cookie.name must be a non-empty string") + + +@dataclass +class ClassifiedCookie: + """A cookie with its assigned category + vendor.""" + + cookie: Cookie + category: CookieCategory + vendor: str + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.cookie.name, + "domain": self.cookie.domain, + "category": self.category.value, + "vendor": self.vendor, + } + + +# ---------- classification --------------------------------------------- + +def classify_cookie( + cookie: Cookie, + *, + extra_rules: Sequence[CookieRule] = (), +) -> ClassifiedCookie: + """Run the catalogue + caller-supplied rules against one cookie.""" + if not isinstance(cookie, Cookie): + raise ConsentAuditError( + f"classify_cookie expects Cookie, got {type(cookie).__name__}" + ) + for rule in (*extra_rules, *_CATALOGUE): + if _matches(rule, cookie): + return ClassifiedCookie(cookie=cookie, category=rule.category, vendor=rule.vendor) + return ClassifiedCookie( + cookie=cookie, category=CookieCategory.UNKNOWN, vendor="unknown", + ) + + +def _matches(rule: CookieRule, cookie: Cookie) -> bool: + if rule.name_pattern: + if not re.search(rule.name_pattern, cookie.name): + return False + if rule.domain_suffix: + if not cookie.domain or not cookie.domain.lower().endswith(rule.domain_suffix.lower()): + return False + return rule.name_pattern is not None or rule.domain_suffix is not None + + +def classify_all( + cookies: Iterable[Cookie], + *, + extra_rules: Sequence[CookieRule] = (), +) -> List[ClassifiedCookie]: + """Convenience: classify every cookie in ``cookies``.""" + return [classify_cookie(c, extra_rules=extra_rules) for c in cookies] + + +# ---------- audit ------------------------------------------------------- + +@dataclass +class ConsentReport: + """Outcome of :func:`audit_consent`.""" + + pre_consent_total: int + post_consent_total: int + pre_consent_violations: List[ClassifiedCookie] = field(default_factory=list) + post_consent_reintroduced: List[ClassifiedCookie] = field(default_factory=list) + unknown_cookies: List[ClassifiedCookie] = field(default_factory=list) + + def passed(self) -> bool: + return not self.pre_consent_violations and not self.post_consent_reintroduced + + def to_dict(self) -> Dict[str, Any]: + return { + "pre_consent_total": self.pre_consent_total, + "post_consent_total": self.post_consent_total, + "pre_consent_violations": [c.to_dict() for c in self.pre_consent_violations], + "post_consent_reintroduced": [c.to_dict() for c in self.post_consent_reintroduced], + "unknown_cookies": [c.to_dict() for c in self.unknown_cookies], + "passed": self.passed(), + } + + +NON_ESSENTIAL = frozenset({ + CookieCategory.ANALYTICS, + CookieCategory.MARKETING, + CookieCategory.SOCIAL, +}) + + +def audit_consent( + before_consent: Sequence[Cookie], + after_consent: Sequence[Cookie] = (), + *, + user_rejected: bool = False, + extra_rules: Sequence[CookieRule] = (), +) -> ConsentReport: + """ + Cross-check that no non-essential cookies are set pre-consent, and + (when ``user_rejected``) that none re-appear post-rejection. + """ + before_classified = classify_all(before_consent, extra_rules=extra_rules) + after_classified = classify_all(after_consent, extra_rules=extra_rules) + + pre_violations = [ + c for c in before_classified if c.category in NON_ESSENTIAL + ] + unknown = [ + c for c in before_classified if c.category == CookieCategory.UNKNOWN + ] + reintroduced: List[ClassifiedCookie] = [] + if user_rejected: + reintroduced = [ + c for c in after_classified if c.category in NON_ESSENTIAL + ] + return ConsentReport( + pre_consent_total=len(before_classified), + post_consent_total=len(after_classified), + pre_consent_violations=pre_violations, + post_consent_reintroduced=reintroduced, + unknown_cookies=unknown, + ) + + +# ---------- helpers ----------------------------------------------------- + +def assert_passes(report: ConsentReport) -> None: + """Raise unless ``report.passed()``.""" + if not isinstance(report, ConsentReport): + raise ConsentAuditError("assert_passes expects ConsentReport") + if report.passed(): + return + parts = [] + if report.pre_consent_violations: + names = ", ".join(c.cookie.name for c in report.pre_consent_violations) + parts.append(f"pre-consent non-essential: {names}") + if report.post_consent_reintroduced: + names = ", ".join(c.cookie.name for c in report.post_consent_reintroduced) + parts.append(f"reintroduced after reject: {names}") + raise ConsentAuditError("; ".join(parts)) + + +def from_selenium_cookies(cookies: Iterable[Dict[str, Any]]) -> List[Cookie]: + """Convert Selenium ``driver.get_cookies()`` dicts to :class:`Cookie`.""" + out: List[Cookie] = [] + for entry in cookies: + if not isinstance(entry, dict): + continue + name = entry.get("name") + if not isinstance(name, str) or not name: + continue + out.append(Cookie( + name=name, + domain=str(entry.get("domain") or ""), + value=entry.get("value"), + secure=bool(entry.get("secure", True)), + same_site=entry.get("sameSite"), + )) + return out diff --git a/je_web_runner/utils/console_error_budget/__init__.py b/je_web_runner/utils/console_error_budget/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/console_error_budget/budget.py b/je_web_runner/utils/console_error_budget/budget.py new file mode 100644 index 0000000..78b52fc --- /dev/null +++ b/je_web_runner/utils/console_error_budget/budget.py @@ -0,0 +1,231 @@ +""" +追蹤 JS console errors / unhandled rejections,設立可調的「錯誤預算」。 +A budget exists so non-critical noise doesn't fail every CI run, while +genuinely-spiked runs still trip. Three knobs: + +* **Severity filter** — ignore ``log`` / ``info`` / ``debug`` by default; + ``error`` and ``warning`` (configurable) count. +* **Pattern allowlist** — regex skiplist for known-third-party noise + (e.g. ``/extensions/.*ResizeObserver/``). +* **Max count** — overall cap. ``allowed_warnings`` is a separate softer + budget so a single error still fails. + +Designed to be fed by anything that produces :class:`ConsoleMessage`s — +a CDP listener (``Runtime.consoleAPICalled`` / ``Runtime.exceptionThrown``), +``selenium.webdriver.remote.webdriver.WebDriver.get_log("browser")``, +Playwright ``page.on('console')``, etc. +""" +from __future__ import annotations + +import re +import time +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Iterable, List, Optional, Pattern, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class ConsoleBudgetError(WebRunnerException): + """Raised when the budget is exceeded or input is malformed.""" + + +_KNOWN_SEVERITIES = ("debug", "info", "log", "warning", "error") + + +# ---------- model ------------------------------------------------------- + +@dataclass +class ConsoleMessage: + """One console line. Severity is normalised to the strings above.""" + + severity: str + text: str + url: Optional[str] = None + line: Optional[int] = None + timestamp: float = field(default_factory=time.time) + source: Optional[str] = None # 'console' or 'exception' or driver-specific + + def __post_init__(self) -> None: + normalised = (self.severity or "").lower() + if normalised == "warn": + normalised = "warning" + elif normalised == "severe": # selenium's level for console.error + normalised = "error" + if normalised not in _KNOWN_SEVERITIES: + normalised = "info" + object.__setattr__(self, "severity", normalised) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class BudgetReport: + """Outcome of :func:`evaluate`.""" + + passed: bool + error_count: int + warning_count: int + ignored_count: int + breaches: List[str] = field(default_factory=list) + sampled: List[ConsoleMessage] = field(default_factory=list) + + def raise_if_failed(self) -> None: + if not self.passed: + joined = "; ".join(self.breaches) or "budget exceeded" + raise ConsoleBudgetError(joined) + + +@dataclass +class ErrorBudget: + """Per-suite knobs for what counts and how much is allowed.""" + + max_errors: int = 0 + max_warnings: int = 5 + count_warnings: bool = True + ignore_patterns: Sequence[Union[str, Pattern[str]]] = () + sample_size: int = 10 + + def __post_init__(self) -> None: + if self.max_errors < 0 or self.max_warnings < 0: + raise ConsoleBudgetError("max_errors / max_warnings must be >= 0") + if self.sample_size < 0: + raise ConsoleBudgetError("sample_size must be >= 0") + + +# ---------- evaluator --------------------------------------------------- + +def _compiled_patterns( + patterns: Sequence[Union[str, Pattern[str]]], +) -> List[Pattern[str]]: + compiled: List[Pattern[str]] = [] + for p in patterns: + if hasattr(p, "search"): + compiled.append(p) # type: ignore[arg-type] + else: + try: + compiled.append(re.compile(str(p))) + except re.error as error: + raise ConsoleBudgetError(f"bad ignore pattern {p!r}: {error}") from error + return compiled + + +def _is_ignored(message: ConsoleMessage, patterns: List[Pattern[str]]) -> bool: + if not patterns: + return False + haystack = f"{message.text}\n{message.url or ''}" + return any(p.search(haystack) for p in patterns) + + +def evaluate( + messages: Iterable[ConsoleMessage], + budget: ErrorBudget, +) -> BudgetReport: + """Score ``messages`` against ``budget`` and return a :class:`BudgetReport`.""" + if not isinstance(budget, ErrorBudget): + raise ConsoleBudgetError("budget must be an ErrorBudget instance") + patterns = _compiled_patterns(budget.ignore_patterns) + errors: List[ConsoleMessage] = [] + warnings: List[ConsoleMessage] = [] + ignored = 0 + for msg in messages: + if not isinstance(msg, ConsoleMessage): + raise ConsoleBudgetError( + f"evaluate expects ConsoleMessage, got {type(msg).__name__}" + ) + if _is_ignored(msg, patterns): + ignored += 1 + continue + if msg.severity == "error": + errors.append(msg) + elif msg.severity == "warning" and budget.count_warnings: + warnings.append(msg) + breaches: List[str] = [] + if len(errors) > budget.max_errors: + breaches.append( + f"errors {len(errors)} > max_errors {budget.max_errors}" + ) + if budget.count_warnings and len(warnings) > budget.max_warnings: + breaches.append( + f"warnings {len(warnings)} > max_warnings {budget.max_warnings}" + ) + sampled = (errors + warnings)[: budget.sample_size] + report = BudgetReport( + passed=not breaches, + error_count=len(errors), + warning_count=len(warnings), + ignored_count=ignored, + breaches=breaches, + sampled=sampled, + ) + if breaches: + web_runner_logger.warning( + f"console budget breach: {breaches} (errors={len(errors)}, " + f"warnings={len(warnings)}, ignored={ignored})" + ) + return report + + +# ---------- adapters ---------------------------------------------------- + +def from_selenium_log(entries: Iterable[Dict[str, Any]]) -> List[ConsoleMessage]: + """Convert Selenium ``driver.get_log('browser')`` entries to messages.""" + out: List[ConsoleMessage] = [] + for entry in entries: + if not isinstance(entry, dict): + continue + out.append(ConsoleMessage( + severity=str(entry.get("level") or "info"), + text=str(entry.get("message") or ""), + timestamp=float(entry.get("timestamp") or 0) / 1000.0 or time.time(), + source="selenium-browser-log", + )) + return out + + +def from_cdp_console_events(events: Iterable[Dict[str, Any]]) -> List[ConsoleMessage]: + """ + Convert CDP ``Runtime.consoleAPICalled`` payloads into messages. + Each event dict is expected to have ``type`` and ``args`` like CDP returns. + """ + out: List[ConsoleMessage] = [] + for event in events: + if not isinstance(event, dict): + continue + args = event.get("args") or [] + text_parts: List[str] = [] + for arg in args: + if isinstance(arg, dict): + value = arg.get("value") + if value is not None: + text_parts.append(str(value)) + elif arg.get("description"): + text_parts.append(str(arg["description"])) + out.append(ConsoleMessage( + severity=str(event.get("type") or "log"), + text=" ".join(text_parts).strip(), + url=(event.get("stackTrace") or {}).get("url"), + timestamp=float(event.get("timestamp") or 0) / 1000.0 or time.time(), + source="cdp-console", + )) + return out + + +def from_cdp_exception_events(events: Iterable[Dict[str, Any]]) -> List[ConsoleMessage]: + """Convert CDP ``Runtime.exceptionThrown`` payloads to error messages.""" + out: List[ConsoleMessage] = [] + for event in events: + if not isinstance(event, dict): + continue + details = event.get("exceptionDetails") or {} + text = (details.get("exception") or {}).get("description") or details.get("text") or "" + out.append(ConsoleMessage( + severity="error", + text=str(text), + url=details.get("url"), + line=details.get("lineNumber"), + timestamp=float(event.get("timestamp") or 0) / 1000.0 or time.time(), + source="cdp-exception", + )) + return out diff --git a/je_web_runner/utils/coop_coep_audit/__init__.py b/je_web_runner/utils/coop_coep_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/coop_coep_audit/audit.py b/je_web_runner/utils/coop_coep_audit/audit.py new file mode 100644 index 0000000..67a2375 --- /dev/null +++ b/je_web_runner/utils/coop_coep_audit/audit.py @@ -0,0 +1,270 @@ +""" +COOP / COEP / CORP cross-origin isolation header 稽核。 +SharedArrayBuffer, high-resolution timers, WebGPU, and an increasing +number of "powerful" APIs require ``crossOriginIsolated`` to be true. +That needs: + +* ``Cross-Origin-Opener-Policy: same-origin`` (COOP) +* ``Cross-Origin-Embedder-Policy: require-corp`` or ``credentialless`` (COEP) +* Every cross-origin sub-resource served with ``Cross-Origin-Resource-Policy`` + or CORS that satisfies COEP. + +This module parses headers (page + per-resource HAR) and reports what +prevents isolation, with actionable detail rather than just yes/no. +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union +from urllib.parse import urlparse + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class CoopCoepAuditError(WebRunnerException): + """Raised on bad header / HAR input or failed assertion.""" + + +# ---------- enums ------------------------------------------------------ + +class CoopValue(str, Enum): + UNSAFE_NONE = "unsafe-none" + SAME_ORIGIN_ALLOW_POPUPS = "same-origin-allow-popups" + SAME_ORIGIN = "same-origin" + + +class CoepValue(str, Enum): + UNSAFE_NONE = "unsafe-none" + CREDENTIALLESS = "credentialless" + REQUIRE_CORP = "require-corp" + + +class CorpValue(str, Enum): + SAME_SITE = "same-site" + SAME_ORIGIN = "same-origin" + CROSS_ORIGIN = "cross-origin" + + +# ---------- header parsing --------------------------------------------- + +def _enum_or(value: Optional[str], cls, default): + if value is None: + return default + try: + return cls(value.strip().lower()) + except ValueError: + return default + + +@dataclass +class PagePolicy: + """Cross-origin-isolation policy for the top-level document.""" + + coop: CoopValue = CoopValue.UNSAFE_NONE + coep: CoepValue = CoepValue.UNSAFE_NONE + + def isolated(self) -> bool: + return ( + self.coop == CoopValue.SAME_ORIGIN + and self.coep in (CoepValue.REQUIRE_CORP, CoepValue.CREDENTIALLESS) + ) + + +def parse_page_headers( + headers: Iterable[Tuple[str, str]], +) -> PagePolicy: + """Parse a header iterable into a :class:`PagePolicy`.""" + coop_raw: Optional[str] = None + coep_raw: Optional[str] = None + for name, value in headers: + if not isinstance(name, str) or not isinstance(value, str): + continue + n = name.strip().lower() + if n == "cross-origin-opener-policy": + coop_raw = value + elif n == "cross-origin-embedder-policy": + coep_raw = value + return PagePolicy( + coop=_enum_or(coop_raw, CoopValue, CoopValue.UNSAFE_NONE), + coep=_enum_or(coep_raw, CoepValue, CoepValue.UNSAFE_NONE), + ) + + +# ---------- per-resource audit ---------------------------------------- + +@dataclass +class ResourceFinding: + """One sub-resource that violates COEP.""" + + url: str + reason: str + corp: Optional[str] = None + cors_present: bool = False + + +def _same_origin(a: str, b: str) -> bool: + try: + pa = urlparse(a) + pb = urlparse(b) + except ValueError: + return False + return (pa.scheme, pa.hostname, pa.port) == (pb.scheme, pb.hostname, pb.port) + + +def _header_lookup(entry: Dict[str, Any]) -> Dict[str, str]: + out: Dict[str, str] = {} + headers = ((entry.get("response") or {}).get("headers")) or [] + if not isinstance(headers, list): + return out + for h in headers: + if not isinstance(h, dict): + continue + name = h.get("name") + value = h.get("value") + if isinstance(name, str) and isinstance(value, str): + out[name.strip().lower()] = value + return out + + +def scan_har_resources( + har: Union[str, Dict[str, Any]], + *, + page_url: str, + coep: CoepValue, +) -> List[ResourceFinding]: + """ + Walk HAR entries; any cross-origin entry must satisfy the page's COEP. + Returns one :class:`ResourceFinding` per violation; empty list means OK. + """ + if coep == CoepValue.UNSAFE_NONE: + return [] + if not isinstance(page_url, str) or not page_url: + raise CoopCoepAuditError("page_url must be non-empty string") + har_obj = _coerce_har(har) + entries = ((har_obj.get("log") or {}).get("entries")) or [] + if not isinstance(entries, list): + raise CoopCoepAuditError("har log.entries must be a list") + findings: List[ResourceFinding] = [] + for entry in entries: + if not isinstance(entry, dict): + continue + request_url = ((entry.get("request") or {}).get("url")) or "" + if not request_url or _same_origin(request_url, page_url): + continue + headers = _header_lookup(entry) + corp = headers.get("cross-origin-resource-policy") + cors_origin = headers.get("access-control-allow-origin") + cors_credentials = headers.get("access-control-allow-credentials") + if coep == CoepValue.REQUIRE_CORP: + if corp == CorpValue.CROSS_ORIGIN.value: + continue + if cors_origin and cors_origin != "null": + continue + findings.append(ResourceFinding( + url=request_url, + reason="require-corp: needs CORP cross-origin OR CORS allow", + corp=corp, cors_present=bool(cors_origin), + )) + elif coep == CoepValue.CREDENTIALLESS: + # credentialless allows credentialled requests to fail open but + # still requires CORP or CORS for credentialled fetches. + if corp or cors_origin: + continue + findings.append(ResourceFinding( + url=request_url, + reason="credentialless: needs CORP or CORS", + corp=corp, cors_present=bool(cors_origin), + )) + return findings + + +def _coerce_har(har: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(har, str): + try: + parsed = json.loads(har) + except ValueError as error: + raise CoopCoepAuditError(f"har not JSON: {error}") from error + if not isinstance(parsed, dict): + raise CoopCoepAuditError("har JSON must be an object") + return parsed + if isinstance(har, dict): + return har + raise CoopCoepAuditError( + f"scan_har_resources expects str/dict, got {type(har).__name__}" + ) + + +# ---------- combined audit -------------------------------------------- + +@dataclass +class IsolationReport: + """Result of :func:`audit_isolation`.""" + + page_url: str + policy: PagePolicy + isolated: bool + resource_findings: List[ResourceFinding] = field(default_factory=list) + notes: List[str] = field(default_factory=list) + + def passed(self) -> bool: + return self.isolated and not self.resource_findings + + def to_dict(self) -> Dict[str, Any]: + return { + "page_url": self.page_url, + "policy": { + "coop": self.policy.coop.value, + "coep": self.policy.coep.value, + }, + "isolated": self.isolated, + "resource_findings": [asdict(f) for f in self.resource_findings], + "notes": list(self.notes), + "passed": self.passed(), + } + + +def audit_isolation( + page_url: str, + page_headers: Iterable[Tuple[str, str]], + *, + har: Optional[Union[str, Dict[str, Any]]] = None, +) -> IsolationReport: + """Combined page + resource audit. ``har`` is optional but recommended.""" + if not isinstance(page_url, str) or not page_url: + raise CoopCoepAuditError("page_url must be non-empty string") + policy = parse_page_headers(page_headers) + isolated = policy.isolated() + notes: List[str] = [] + if policy.coop != CoopValue.SAME_ORIGIN: + notes.append(f"COOP is {policy.coop.value}, want same-origin") + if policy.coep not in (CoepValue.REQUIRE_CORP, CoepValue.CREDENTIALLESS): + notes.append(f"COEP is {policy.coep.value}, want require-corp/credentialless") + resource_findings: List[ResourceFinding] = [] + if har is not None and policy.coep != CoepValue.UNSAFE_NONE: + resource_findings = scan_har_resources( + har, page_url=page_url, coep=policy.coep, + ) + return IsolationReport( + page_url=page_url, + policy=policy, + isolated=isolated, + resource_findings=resource_findings, + notes=notes, + ) + + +def assert_isolated(report: IsolationReport) -> None: + if not isinstance(report, IsolationReport): + raise CoopCoepAuditError("assert_isolated expects IsolationReport") + if report.passed(): + return + if not report.isolated: + raise CoopCoepAuditError( + f"not crossOriginIsolated: {', '.join(report.notes) or 'unknown reason'}" + ) + bad = report.resource_findings[0] + raise CoopCoepAuditError( + f"resource violates COEP: {bad.url} ({bad.reason})" + ) diff --git a/je_web_runner/utils/cross_tab_sync/__init__.py b/je_web_runner/utils/cross_tab_sync/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/cross_tab_sync/sync_assertions.py b/je_web_runner/utils/cross_tab_sync/sync_assertions.py new file mode 100644 index 0000000..ee5250c --- /dev/null +++ b/je_web_runner/utils/cross_tab_sync/sync_assertions.py @@ -0,0 +1,327 @@ +""" +跨 tab / context 狀態同步測試:BroadcastChannel / localStorage / +SharedWorker / Window.postMessage,以多個 Playwright page 驗證 +更新在其他 tab 即時反映。 + +Common assertions: + +* :func:`wait_for_storage` — wait for ``localStorage[key]`` on a tab to + match ``expected`` (with optional JSON parsing). +* :func:`broadcast_message` — send a structured BroadcastChannel message + from one tab, optionally on a named channel. +* :func:`assert_state_propagates` — write storage / broadcast on the + source tab, expect each listener tab to observe it within ``timeout``. +* :func:`collect_broadcast_messages` — install a recorder on a tab so + later assertions can pop messages it received. + +All helpers operate on Playwright ``Page`` objects (no direct dep on +Selenium — the cross-tab story works far better with Playwright's +multi-page model). +""" +from __future__ import annotations + +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class CrossTabSyncError(WebRunnerException): + """Raised when an expected propagation does not happen in time.""" + + +# ---------- localStorage / sessionStorage -------------------------------- + +def set_storage_value( + page: Any, + key: str, + value: Any, + *, + storage: str = "localStorage", +) -> None: + """ + 在 page 上設一個 storage 值。 + ``value`` is JSON-serialised so callers can hand in dicts/lists. + """ + _ensure_storage_name(storage) + if page is None: + raise CrossTabSyncError("page is None") + payload = json.dumps(value) + script = ( + f"({{ key, raw }}) => window.{storage}.setItem(key, raw)" + ) + try: + page.evaluate(script, {"key": key, "raw": payload}) + except Exception as error: # noqa: BLE001 — playwright surfaces many + raise CrossTabSyncError(f"set_storage_value failed: {error!r}") from error + web_runner_logger.info(f"set_storage_value: {storage}[{key}] = {payload[:80]}") + + +def get_storage_value( + page: Any, + key: str, + *, + storage: str = "localStorage", + json_parse: bool = True, +) -> Any: + """Read ``storage[key]`` from the page. Returns ``None`` when absent.""" + _ensure_storage_name(storage) + if page is None: + raise CrossTabSyncError("page is None") + script = f"(key) => window.{storage}.getItem(key)" + try: + raw = page.evaluate(script, key) + except Exception as error: # noqa: BLE001 + raise CrossTabSyncError(f"get_storage_value failed: {error!r}") from error + if raw is None: + return None + if not json_parse: + return raw + try: + return json.loads(raw) + except (TypeError, ValueError): + return raw + + +def wait_for_storage( + page: Any, + key: str, + expected: Any, + *, + storage: str = "localStorage", + timeout: float = 5.0, + poll_interval: float = 0.1, + sleep_fn: Callable[[float], None] = time.sleep, + time_fn: Callable[[], float] = time.time, +) -> Any: + """ + 輪詢直到 ``storage[key]`` 等於 ``expected`` 或 timeout。 + Comparison is JSON-aware: a dict ``expected`` will match a JSON-encoded + value stored as a string. Returns the matched value. + """ + _validate_timeout(timeout, poll_interval) + start = time_fn() + while True: + current = get_storage_value(page, key, storage=storage, json_parse=True) + if current == expected: + return current + if time_fn() - start >= timeout: + raise CrossTabSyncError( + f"timeout: {storage}[{key}] = {current!r}, expected {expected!r}" + ) + sleep_fn(poll_interval) + + +# ---------- BroadcastChannel --------------------------------------------- + +def install_broadcast_recorder(page: Any, channel_name: str) -> None: + """ + Hook 一個 ``window.__wr_broadcast_log__`` 蒐集所有 BroadcastChannel 訊息。 + Idempotent — installing twice replaces the previous recorder. + """ + if page is None: + raise CrossTabSyncError("page is None") + if not channel_name: + raise CrossTabSyncError("channel_name is required") + script = """ + (channelName) => { + if (window.__wr_broadcast_channels__ && + window.__wr_broadcast_channels__[channelName]) { + window.__wr_broadcast_channels__[channelName].close(); + } + window.__wr_broadcast_log__ = window.__wr_broadcast_log__ || {}; + window.__wr_broadcast_log__[channelName] = []; + window.__wr_broadcast_channels__ = window.__wr_broadcast_channels__ || {}; + const ch = new BroadcastChannel(channelName); + ch.onmessage = (ev) => { + window.__wr_broadcast_log__[channelName].push({ + data: ev.data, + receivedAt: Date.now(), + }); + }; + window.__wr_broadcast_channels__[channelName] = ch; + return true; + } + """ + try: + page.evaluate(script, channel_name) + except Exception as error: # noqa: BLE001 + raise CrossTabSyncError( + f"install_broadcast_recorder failed: {error!r}" + ) from error + + +def broadcast_message(page: Any, channel_name: str, data: Any) -> None: + """Post one message to ``channel_name`` from ``page``.""" + if page is None: + raise CrossTabSyncError("page is None") + if not channel_name: + raise CrossTabSyncError("channel_name is required") + script = """ + ({ channelName, payload }) => { + const ch = new BroadcastChannel(channelName); + ch.postMessage(payload); + ch.close(); + return true; + } + """ + try: + page.evaluate(script, {"channelName": channel_name, "payload": data}) + except Exception as error: # noqa: BLE001 + raise CrossTabSyncError(f"broadcast_message failed: {error!r}") from error + web_runner_logger.info( + f"broadcast_message: channel={channel_name!r}" + ) + + +def collect_broadcast_messages( + page: Any, + channel_name: str, +) -> List[Dict[str, Any]]: + """Return everything the recorder on ``page`` has captured for ``channel_name``.""" + if page is None: + raise CrossTabSyncError("page is None") + script = """ + (channelName) => { + if (!window.__wr_broadcast_log__) return []; + return window.__wr_broadcast_log__[channelName] || []; + } + """ + try: + result = page.evaluate(script, channel_name) + except Exception as error: # noqa: BLE001 + raise CrossTabSyncError( + f"collect_broadcast_messages failed: {error!r}" + ) from error + if not isinstance(result, list): + return [] + return result + + +def wait_for_broadcast( + page: Any, + channel_name: str, + matcher: Callable[[Any], bool], + *, + timeout: float = 5.0, + poll_interval: float = 0.1, + sleep_fn: Callable[[float], None] = time.sleep, + time_fn: Callable[[], float] = time.time, +) -> Dict[str, Any]: + """ + 輪詢 recorder 直到出現一條符合 ``matcher`` 的訊息。 + Returns the matching message entry (with ``data`` and ``receivedAt``). + """ + _validate_timeout(timeout, poll_interval) + start = time_fn() + while True: + messages = collect_broadcast_messages(page, channel_name) + for entry in messages: + data = entry.get("data") if isinstance(entry, dict) else None + try: + hit = matcher(data) + except Exception: # noqa: BLE001 — matcher should be cheap + hit = False + if hit: + return entry + if time_fn() - start >= timeout: + raise CrossTabSyncError( + f"timeout: no broadcast on {channel_name!r} matched within {timeout}s" + ) + sleep_fn(poll_interval) + + +# ---------- assert_state_propagates -------------------------------------- + +@dataclass +class PropagationResult: + """Outcome of one :func:`assert_state_propagates` call.""" + + propagated_to: List[int] = field(default_factory=list) + elapsed_seconds: float = 0.0 + + +def assert_state_propagates( + source_page: Any, + listener_pages: Sequence[Any], + *, + key: str, + value: Any, + storage: str = "localStorage", + timeout: float = 5.0, + poll_interval: float = 0.1, + sleep_fn: Callable[[float], None] = time.sleep, + time_fn: Callable[[], float] = time.time, +) -> PropagationResult: + """ + 在 source_page 設 storage,要求每個 listener_pages 在 timeout 內看到。 + Raises if any listener does not observe the value before ``timeout``. + Returns ``PropagationResult`` listing the listener indices observed + plus the total elapsed wait time. + """ + if source_page is None: + raise CrossTabSyncError("source_page is None") + if not listener_pages: + raise CrossTabSyncError("at least one listener_pages entry is required") + set_storage_value(source_page, key, value, storage=storage) + start = time_fn() + seen = [False] * len(listener_pages) + while True: + for idx, page in enumerate(listener_pages): + if seen[idx]: + continue + current = get_storage_value(page, key, storage=storage, json_parse=True) + if current == value: + seen[idx] = True + if all(seen): + return PropagationResult( + propagated_to=[i for i, s in enumerate(seen) if s], + elapsed_seconds=time_fn() - start, + ) + if time_fn() - start >= timeout: + missing = [i for i, s in enumerate(seen) if not s] + raise CrossTabSyncError( + f"timeout: storage {key!r} did not propagate to tabs {missing} " + f"after {timeout}s" + ) + sleep_fn(poll_interval) + + +# ---------- helpers ------------------------------------------------------ + +_ALLOWED_STORAGE = {"localStorage", "sessionStorage"} + + +def _ensure_storage_name(name: str) -> None: + if name not in _ALLOWED_STORAGE: + raise CrossTabSyncError( + f"storage must be one of {sorted(_ALLOWED_STORAGE)}, got {name!r}" + ) + + +def _validate_timeout(timeout: float, poll_interval: float) -> None: + if timeout <= 0: + raise CrossTabSyncError("timeout must be positive") + if poll_interval <= 0: + raise CrossTabSyncError("poll_interval must be positive") + + +# ---------- post-message helper ------------------------------------------ + +def post_message_to_page( + page: Any, + data: Any, + *, + target_origin: str = "*", +) -> None: + """``window.postMessage(data, target_origin)`` on the page's main window.""" + if page is None: + raise CrossTabSyncError("page is None") + script = "({ payload, origin }) => window.postMessage(payload, origin)" + try: + page.evaluate(script, {"payload": data, "origin": target_origin}) + except Exception as error: # noqa: BLE001 + raise CrossTabSyncError(f"post_message_to_page failed: {error!r}") from error diff --git a/je_web_runner/utils/db_snapshot/__init__.py b/je_web_runner/utils/db_snapshot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/db_snapshot/snapshot.py b/je_web_runner/utils/db_snapshot/snapshot.py new file mode 100644 index 0000000..292866c --- /dev/null +++ b/je_web_runner/utils/db_snapshot/snapshot.py @@ -0,0 +1,170 @@ +""" +Per-test DB savepoint/rollback isolation, decoupled from the actual driver. +Tests that mutate shared state (orders, users, audit logs) used to leak +into each other; the usual workaround is to truncate everything between +tests, which is slow. This module takes a savepoint before the test runs +and rolls back to it after, regardless of pass/fail. + +The driver isn't hard-coded: you implement :class:`SnapshotBackend` (two +methods, ``savepoint`` and ``rollback_to``) for psycopg / mysqlclient / +sqlite3 / SQLAlchemy / whatever you actually use. The included +:class:`InMemoryBackend` is for unit-testing the workflow itself. +""" +from __future__ import annotations + +import contextlib +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Protocol + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class DbSnapshotError(WebRunnerException): + """Raised on backend failure, mis-ordered rollback, or invalid id.""" + + +# ---------- backend protocol -------------------------------------------- + +class SnapshotBackend(Protocol): + """Minimal DB interface: take a named savepoint, roll back to it.""" + + def savepoint(self, name: str) -> None: ... + def rollback_to(self, name: str) -> None: ... + + +# ---------- in-memory backend (for tests / dry runs) -------------------- + +@dataclass +class InMemoryBackend: + """ + Simulates a DB with a single row dict per table. Used to exercise the + snapshot workflow without a real database; also a useful fallback for + `--dry-run` style unit tests. + """ + + tables: Dict[str, Dict[Any, Any]] = field(default_factory=dict) + _snapshots: Dict[str, Dict[str, Dict[Any, Any]]] = field(default_factory=dict) + + def insert(self, table: str, key: Any, value: Any) -> None: + self.tables.setdefault(table, {})[key] = value + + def delete(self, table: str, key: Any) -> None: + if table in self.tables: + self.tables[table].pop(key, None) + + def savepoint(self, name: str) -> None: + if name in self._snapshots: + raise DbSnapshotError(f"savepoint {name!r} already exists") + self._snapshots[name] = {t: dict(rows) for t, rows in self.tables.items()} + web_runner_logger.info(f"db_snapshot in-memory savepoint {name!r}") + + def rollback_to(self, name: str) -> None: + snap = self._snapshots.pop(name, None) + if snap is None: + raise DbSnapshotError(f"no savepoint named {name!r}") + self.tables = {t: dict(rows) for t, rows in snap.items()} + web_runner_logger.info(f"db_snapshot in-memory rollback to {name!r}") + + +# ---------- core scoping API -------------------------------------------- + +@dataclass +class SnapshotHandle: + """Returned by :meth:`SnapshotScope.create`; pass back to :meth:`rollback`.""" + + name: str + + +@dataclass +class SnapshotScope: + """ + Stack of active snapshots, scoping cleanly via :func:`snapshot` ctx mgr. + The stack lets nested test sections each take their own savepoint and + unwind in the right order — rollback of a stale handle is rejected. + """ + + backend: SnapshotBackend + prefix: str = "wr_snap" + _stack: List[SnapshotHandle] = field(default_factory=list) + + def create(self) -> SnapshotHandle: + name = f"{self.prefix}_{uuid.uuid4().hex[:12]}" + try: + self.backend.savepoint(name) + except DbSnapshotError: + raise + except Exception as error: + raise DbSnapshotError(f"backend.savepoint failed: {error!r}") from error + handle = SnapshotHandle(name=name) + self._stack.append(handle) + return handle + + def rollback(self, handle: SnapshotHandle) -> None: + if not self._stack: + raise DbSnapshotError("no active snapshots to roll back") + top = self._stack[-1] + if top.name != handle.name: + raise DbSnapshotError( + f"snapshot stack mismatch: top is {top.name!r}, " + f"got {handle.name!r} (rolled back out of order?)" + ) + try: + self.backend.rollback_to(handle.name) + except DbSnapshotError: + raise + except Exception as error: + raise DbSnapshotError(f"backend.rollback_to failed: {error!r}") from error + self._stack.pop() + + def active(self) -> int: + return len(self._stack) + + +@contextlib.contextmanager +def snapshot(scope: SnapshotScope): + """Context manager: take a savepoint, roll back on exit (success or fail).""" + handle = scope.create() + try: + yield handle + finally: + scope.rollback(handle) + + +# ---------- pytest helper (optional) ------------------------------------ + +def pytest_fixture_factory(backend: SnapshotBackend) -> Callable[..., Any]: + """ + Build a pytest fixture that wraps each test in its own snapshot. + Usage in ``conftest.py``:: + + from je_web_runner.utils.db_snapshot.snapshot import ( + InMemoryBackend, pytest_fixture_factory, + ) + backend = InMemoryBackend() + db_snapshot = pytest_fixture_factory(backend) + + Then add ``db_snapshot`` as an argument on tests that need isolation. + """ + scope = SnapshotScope(backend=backend) + + def _factory(*_args: Any, **_kwargs: Any): + handle = scope.create() + try: + yield backend + finally: + scope.rollback(handle) + + return _factory + + +# ---------- convenience ------------------------------------------------- + +def assert_no_active_snapshots(scope: SnapshotScope) -> None: + """Raise if any savepoint is still on the stack (use at suite teardown).""" + if scope.active() > 0: + raise DbSnapshotError( + f"{scope.active()} snapshot(s) still active at teardown — " + "test forgot to roll back" + ) diff --git a/je_web_runner/utils/device_cloud/__init__.py b/je_web_runner/utils/device_cloud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/device_cloud/real_device.py b/je_web_runner/utils/device_cloud/real_device.py new file mode 100644 index 0000000..a9c39ad --- /dev/null +++ b/je_web_runner/utils/device_cloud/real_device.py @@ -0,0 +1,456 @@ +""" +Real-device 雲端連線:在既有 cloud_grid 上補上行動裝置 caps、session +metadata、test status 回寫、env-var 認證。 + +Adds three things the existing :mod:`cloud_grid` module doesn't cover: + +* **Real-device capabilities** — BrowserStack App-Automate / Sauce RDC / + LambdaTest Real-Device style caps (``deviceName``, ``platformVersion``, + ``realMobile`` toggles). +* **Session metadata** — pull the dashboard URL + video URL of a started + session so the report can link directly to it. +* **Status update** — write pass/fail/reason back to the provider so the + cloud dashboard isn't perpetually "running". + +Credentials are loaded from env vars by default (so they never have to be +typed into action JSON): + +* BrowserStack: ``BROWSERSTACK_USERNAME``, ``BROWSERSTACK_ACCESS_KEY`` +* Sauce Labs: ``SAUCE_USERNAME``, ``SAUCE_ACCESS_KEY`` +* LambdaTest: ``LT_USERNAME``, ``LT_ACCESS_KEY`` +""" +from __future__ import annotations + +import json +import os +import ssl +import time +import urllib.request +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, Optional, Tuple + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class DeviceCloudError(WebRunnerException): + """Raised on connection / status update / metadata fetch errors.""" + + +SUPPORTED_PROVIDERS: Tuple[str, ...] = ("browserstack", "saucelabs", "lambdatest") + + +_ENV_VAR_MAP: Dict[str, Tuple[str, str]] = { + "browserstack": ("BROWSERSTACK_USERNAME", "BROWSERSTACK_ACCESS_KEY"), + "saucelabs": ("SAUCE_USERNAME", "SAUCE_ACCESS_KEY"), + "lambdatest": ("LT_USERNAME", "LT_ACCESS_KEY"), +} + +_REST_BASES: Dict[str, str] = { + "browserstack": "https://api.browserstack.com", + "saucelabs": "https://api.us-west-1.saucelabs.com", + "lambdatest": "https://api.lambdatest.com", +} + +_DASHBOARD_BASES: Dict[str, str] = { + "browserstack": "https://automate.browserstack.com/dashboard/v2/sessions", + "saucelabs": "https://app.saucelabs.com/tests", + "lambdatest": "https://automation.lambdatest.com/test", +} + + +# ---------- credentials -------------------------------------------------- + +@dataclass(frozen=True) +class CloudCredentials: + """Cloud provider credentials. Never logged, never serialised.""" + + username: str + access_key: str + + def redacted(self) -> Dict[str, str]: + return { + "username": self.username, + "access_key": "***" if self.access_key else "", + } + + +def _normalise_provider(provider: str) -> str: + normalised = (provider or "").lower().strip() + if normalised not in SUPPORTED_PROVIDERS: + raise DeviceCloudError( + f"unsupported provider {provider!r}; expected one of {SUPPORTED_PROVIDERS}" + ) + return normalised + + +def load_credentials(provider: str) -> CloudCredentials: + """ + 從環境變數讀取對應 provider 的 credentials。 + Read credentials from env vars, raising if either is missing. Use this + in CI to avoid putting secrets into action JSON. + """ + key = _normalise_provider(provider) + user_var, access_var = _ENV_VAR_MAP[key] + username = os.environ.get(user_var, "") + access_key = os.environ.get(access_var, "") + if not username or not access_key: + raise DeviceCloudError( + f"missing credentials for {key!r}: set {user_var} and {access_var}" + ) + return CloudCredentials(username=username, access_key=access_key) + + +# ---------- capability builders ----------------------------------------- + +@dataclass +class RealDeviceCaps: + """ + Common spec for a real-device session, converted per provider on demand. + Keeping this provider-neutral lets callers swap clouds without rewriting. + """ + + device_name: str + platform_name: str # "iOS" | "Android" + platform_version: str + browser_name: str = "Chrome" # or "Safari" on iOS + real_mobile: bool = True + build: Optional[str] = None + name: Optional[str] = None + project: Optional[str] = None + extra: Dict[str, Any] = field(default_factory=dict) + + +def _to_browserstack(caps: RealDeviceCaps) -> Dict[str, Any]: + bstack: Dict[str, Any] = { + "deviceName": caps.device_name, + "osVersion": caps.platform_version, + "realMobile": "true" if caps.real_mobile else "false", + } + if caps.project: + bstack["projectName"] = caps.project + if caps.build: + bstack["buildName"] = caps.build + if caps.name: + bstack["sessionName"] = caps.name + out: Dict[str, Any] = { + "browserName": caps.browser_name, + "platformName": caps.platform_name, + "bstack:options": bstack, + } + out.update(caps.extra) + return out + + +def _to_saucelabs(caps: RealDeviceCaps) -> Dict[str, Any]: + sauce: Dict[str, Any] = {} + if caps.build: + sauce["build"] = caps.build + if caps.name: + sauce["name"] = caps.name + appium_caps: Dict[str, Any] = { + "appium:deviceName": caps.device_name, + "appium:platformVersion": caps.platform_version, + "appium:automationName": "XCUITest" if caps.platform_name.lower() == "ios" else "UiAutomator2", + } + out: Dict[str, Any] = { + "browserName": caps.browser_name, + "platformName": caps.platform_name, + "sauce:options": sauce, + **appium_caps, + } + out.update(caps.extra) + return out + + +def _to_lambdatest(caps: RealDeviceCaps) -> Dict[str, Any]: + lt: Dict[str, Any] = { + "deviceName": caps.device_name, + "platformVersion": caps.platform_version, + "isRealMobile": caps.real_mobile, + } + if caps.build: + lt["build"] = caps.build + if caps.name: + lt["name"] = caps.name + if caps.project: + lt["project"] = caps.project + out: Dict[str, Any] = { + "browserName": caps.browser_name, + "platformName": caps.platform_name, + "LT:Options": lt, + } + out.update(caps.extra) + return out + + +_CAPS_DISPATCH: Dict[str, Callable[[RealDeviceCaps], Dict[str, Any]]] = { + "browserstack": _to_browserstack, + "saucelabs": _to_saucelabs, + "lambdatest": _to_lambdatest, +} + + +def build_capabilities(provider: str, caps: RealDeviceCaps) -> Dict[str, Any]: + """Project a :class:`RealDeviceCaps` into provider-native capabilities.""" + key = _normalise_provider(provider) + return _CAPS_DISPATCH[key](caps) + + +# ---------- connect with retry ------------------------------------------ + +@dataclass +class CloudSession: + """Metadata about a started cloud session.""" + + provider: str + session_id: str + dashboard_url: str + video_url: Optional[str] = None + status: str = "running" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def _dashboard_url(provider: str, session_id: str) -> str: + base = _DASHBOARD_BASES[provider] + if provider == "browserstack": + return f"{base}/{session_id}" + if provider == "saucelabs": + return f"{base}/{session_id}" + return f"{base}?testID={session_id}" + + +def connect_real_device( + provider: str, + caps: RealDeviceCaps, + *, + credentials: Optional[CloudCredentials] = None, + retries: int = 2, + backoff_seconds: float = 3.0, + connector: Optional[Callable[..., Any]] = None, +) -> Tuple[Any, CloudSession]: + """ + 開一個 real-device session,回傳 (driver, CloudSession)。 + Connect to a cloud provider's real-device cloud with retries. Returns + the Selenium ``Remote`` driver and a :class:`CloudSession` carrying the + session id and dashboard URL. + + ``connector`` is the underlying connect function; defaults to the + relevant ``cloud_drivers.connect_*`` helper so tests can inject a fake. + """ + key = _normalise_provider(provider) + creds = credentials or load_credentials(key) + capabilities = build_capabilities(key, caps) + web_runner_logger.info( + f"connect_real_device provider={key} device={caps.device_name!r} " + f"build={caps.build!r} creds={creds.redacted()}" + ) + + chosen = connector or _default_connector(key) + last_error: Optional[Exception] = None + for attempt in range(max(1, retries + 1)): + try: + driver = chosen(creds.username, creds.access_key, capabilities) + session_id = getattr(driver, "session_id", None) or "" + if not session_id: + raise DeviceCloudError("driver returned without a session_id") + session = CloudSession( + provider=key, + session_id=session_id, + dashboard_url=_dashboard_url(key, session_id), + ) + return driver, session + except Exception as error: # noqa: BLE001 — cloud connect surface varies + last_error = error + web_runner_logger.warning( + f"connect_real_device attempt {attempt + 1} failed: {error!r}" + ) + if attempt < retries: + time.sleep(backoff_seconds * (attempt + 1)) + raise DeviceCloudError( + f"connect_real_device failed after {retries + 1} attempts: {last_error!r}" + ) from last_error + + +def _default_connector(provider: str) -> Callable[..., Any]: + from je_web_runner.utils.cloud_grid.cloud_drivers import ( + connect_browserstack, + connect_lambdatest, + connect_saucelabs, + ) + return { + "browserstack": connect_browserstack, + "saucelabs": connect_saucelabs, + "lambdatest": connect_lambdatest, + }[provider] + + +# ---------- REST helpers ------------------------------------------------- + +def _basic_auth_header(creds: CloudCredentials) -> str: + import base64 + raw = f"{creds.username}:{creds.access_key}".encode("utf-8") + return "Basic " + base64.b64encode(raw).decode("ascii") + + +def _rest_request( + method: str, + url: str, + credentials: CloudCredentials, + payload: Optional[Dict[str, Any]] = None, + timeout: float = 15.0, +) -> Any: + if not url.startswith("https://"): + raise DeviceCloudError(f"refusing non-https URL: {url!r}") + data = None if payload is None else json.dumps(payload).encode("utf-8") + req = urllib.request.Request(url, data=data, method=method) + req.add_header("Authorization", _basic_auth_header(credentials)) + req.add_header("Accept", "application/json") + if data is not None: + req.add_header("Content-Type", "application/json") + context = ssl.create_default_context() + try: + with urllib.request.urlopen( # nosec B310 — https-only enforced above + req, timeout=timeout, context=context, + ) as response: + body = response.read().decode("utf-8") + except (OSError, ValueError) as error: + raise DeviceCloudError(f"REST call failed: {error!r}") from error + if not body: + return None + try: + return json.loads(body) + except ValueError as error: + raise DeviceCloudError(f"non-JSON response: {error}") from error + + +def _session_info_url(provider: str, session_id: str) -> str: + base = _REST_BASES[provider] + if provider == "browserstack": + return f"{base}/automate/sessions/{session_id}.json" + if provider == "saucelabs": + return f"{base}/rest/v1.1/jobs/{session_id}" + return f"{base}/automation/api/v1/sessions/{session_id}" + + +def _session_status_url(provider: str, session_id: str) -> str: + base = _REST_BASES[provider] + if provider == "browserstack": + return f"{base}/automate/sessions/{session_id}.json" + if provider == "saucelabs": + return f"{base}/rest/v1.1/jobs/{session_id}" + return f"{base}/automation/api/v1/sessions/{session_id}" + + +def fetch_session_info( + provider: str, + session_id: str, + credentials: Optional[CloudCredentials] = None, + *, + request_fn: Optional[Callable[..., Any]] = None, +) -> CloudSession: + """ + 讀取 session 的 metadata,補上 video URL 與目前 status。 + Fetch session metadata so the report can include the dashboard + video. + ``request_fn`` lets tests stub out the HTTP call. + """ + key = _normalise_provider(provider) + creds = credentials or load_credentials(key) + url = _session_info_url(key, session_id) + caller = request_fn or _rest_request + payload = caller("GET", url, creds) + if not isinstance(payload, dict): + raise DeviceCloudError(f"unexpected session info payload: {type(payload).__name__}") + return CloudSession( + provider=key, + session_id=session_id, + dashboard_url=_dashboard_url(key, session_id), + video_url=_extract_video_url(key, payload), + status=_extract_status(key, payload), + ) + + +def _extract_video_url(provider: str, payload: Dict[str, Any]) -> Optional[str]: + if provider == "browserstack": + info = payload.get("automation_session") or {} + url = info.get("video_url") + if isinstance(url, str): + return url + if provider == "saucelabs": + url = payload.get("video_url") + if isinstance(url, str): + return url + if provider == "lambdatest": + data = payload.get("data") or payload + url = data.get("video_url") if isinstance(data, dict) else None + if isinstance(url, str): + return url + return None + + +def _extract_status(provider: str, payload: Dict[str, Any]) -> str: + if provider == "browserstack": + info = payload.get("automation_session") or {} + return str(info.get("status") or "unknown") + if provider == "saucelabs": + return str(payload.get("status") or "unknown") + data = payload.get("data") or payload + if isinstance(data, dict): + return str(data.get("status_ind") or data.get("status") or "unknown") + return "unknown" + + +def update_session_status( + provider: str, + session_id: str, + *, + passed: bool, + reason: Optional[str] = None, + credentials: Optional[CloudCredentials] = None, + request_fn: Optional[Callable[..., Any]] = None, +) -> None: + """ + 把測試結果回寫到 provider,讓 dashboard 從 running 變 passed / failed。 + Write the final status back so the cloud dashboard reflects reality. + """ + key = _normalise_provider(provider) + creds = credentials or load_credentials(key) + url = _session_status_url(key, session_id) + caller = request_fn or _rest_request + if key == "browserstack": + payload = { + "status": "passed" if passed else "failed", + "reason": reason or "", + } + caller("PUT", url, creds, payload) + return + if key == "saucelabs": + payload = { + "passed": bool(passed), + "name": reason or "", + } + caller("PUT", url, creds, payload) + return + # lambdatest + payload = { + "status_ind": "passed" if passed else "failed", + "reason": reason or "", + } + caller("PATCH", url, creds, payload) + + +# ---------- small report helper ----------------------------------------- + +def session_summary_markdown(session: CloudSession) -> str: + """Render the session metadata as a markdown bullet list for reports.""" + pieces = [ + f"- **Provider:** {session.provider}", + f"- **Session ID:** `{session.session_id}`", + f"- **Dashboard:** [{session.dashboard_url}]({session.dashboard_url})", + ] + if session.video_url: + pieces.append(f"- **Video:** [{session.video_url}]({session.video_url})") + pieces.append(f"- **Status:** `{session.status}`") + return "\n".join(pieces) + "\n" diff --git a/je_web_runner/utils/download_verify/__init__.py b/je_web_runner/utils/download_verify/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/download_verify/verifier.py b/je_web_runner/utils/download_verify/verifier.py new file mode 100644 index 0000000..a01ed05 --- /dev/null +++ b/je_web_runner/utils/download_verify/verifier.py @@ -0,0 +1,391 @@ +""" +下載檔案驗證:PDF/CSV/Excel/JSON/檔名/SHA256 比對,給 E2E 測試用。 + +Download verification helpers. Pairs with the existing browser download +action commands: after a test triggers a file save, these utilities +poll the download directory, extract content (PDF text / CSV rows / +Excel rows / JSON), and assert on it. + +Soft dependencies — only required when the matching extractor is used: + +* PDF text → ``pypdf`` (preferred) or ``pdfplumber`` +* Excel rows → ``openpyxl`` + +CSV and JSON use the standard library. +""" +from __future__ import annotations + +import csv +import hashlib +import json +import re +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Pattern, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class DownloadVerifyError(WebRunnerException): + """Raised on missing files, parse errors, or assertion failures.""" + + +# ---------- waiting ------------------------------------------------------- + +def wait_for_download( + download_dir: Union[str, Path], + *, + pattern: Union[str, Pattern[str]] = r".+", + timeout: float = 30.0, + poll_interval: float = 0.5, + stable_for: float = 0.5, + sleep_fn: Callable[[float], None] = time.sleep, + time_fn: Callable[[], float] = time.time, + exclude_extensions: Iterable[str] = (".crdownload", ".part", ".tmp"), +) -> Path: + """ + 等到 download_dir 內出現符合 pattern 的「完成」檔案。 + Poll a download dir until a file matches ``pattern`` AND has stayed + the same size for ``stable_for`` seconds. Skips partial-download + suffixes (``.crdownload`` etc.). + + Returns the absolute path to the file. + """ + directory = Path(download_dir) + if not directory.is_dir(): + raise DownloadVerifyError(f"download dir not found: {directory}") + regex = re.compile(pattern) if isinstance(pattern, str) else pattern + excluded = tuple(e.lower() for e in exclude_extensions) + start = time_fn() + last_seen_size: Dict[str, int] = {} + last_seen_time: Dict[str, float] = {} + while True: + for entry in directory.iterdir(): + if not entry.is_file(): + continue + name = entry.name + if any(name.lower().endswith(suffix) for suffix in excluded): + continue + if not regex.search(name): + continue + size = entry.stat().st_size + prev_size = last_seen_size.get(name) + now = time_fn() + if prev_size == size: + if now - last_seen_time.get(name, now) >= stable_for and size > 0: + web_runner_logger.info(f"wait_for_download: matched {entry}") + return entry.resolve() + else: + last_seen_size[name] = size + last_seen_time[name] = now + if time_fn() - start >= timeout: + raise DownloadVerifyError( + f"timeout waiting for {pattern!r} in {directory} after {timeout}s" + ) + sleep_fn(poll_interval) + + +# ---------- hashing ------------------------------------------------------- + +def sha256_of_file(path: Union[str, Path], *, chunk_size: int = 65_536) -> str: + """Stream-hash a file with SHA-256.""" + p = Path(path) + if not p.is_file(): + raise DownloadVerifyError(f"not a file: {p}") + hasher = hashlib.sha256() + with open(p, "rb") as fp: + while True: + chunk = fp.read(chunk_size) + if not chunk: + break + hasher.update(chunk) + return hasher.hexdigest() + + +def assert_file_sha256(path: Union[str, Path], expected: str) -> None: + """Raise unless ``path``'s SHA-256 equals ``expected`` (case-insensitive).""" + actual = sha256_of_file(path) + if actual.lower() != expected.lower(): + raise DownloadVerifyError( + f"sha256 mismatch for {path}: expected {expected}, got {actual}" + ) + + +# ---------- PDF ----------------------------------------------------------- + +def extract_pdf_text( + path: Union[str, Path], + *, + page_separator: str = "\n", +) -> str: + """ + 用 pypdf / pdfplumber 抽出整份 PDF 的文字。 + Concatenate per-page text. ``pypdf`` is tried first (lighter, pure + Python); ``pdfplumber`` is the fallback. Raises if neither is + installed or the file isn't a valid PDF. + """ + p = Path(path) + if not p.is_file(): + raise DownloadVerifyError(f"PDF not found: {p}") + try: + from pypdf import PdfReader # type: ignore[import-not-found] + reader = PdfReader(str(p)) + return page_separator.join( + (page.extract_text() or "") for page in reader.pages + ) + except ImportError: + pass + try: + import pdfplumber # type: ignore[import-not-found] + pieces: List[str] = [] + with pdfplumber.open(str(p)) as pdf: + for page in pdf.pages: + pieces.append(page.extract_text() or "") + return page_separator.join(pieces) + except ImportError as error: + raise DownloadVerifyError( + "PDF text extraction requires pypdf or pdfplumber. " + "Install one: pip install pypdf" + ) from error + except Exception as error: # noqa: BLE001 — library-specific parse errors + raise DownloadVerifyError(f"failed to extract PDF text from {p}: {error!r}") from error + + +def assert_pdf_contains(path: Union[str, Path], substring: str) -> None: + """Raise if the extracted PDF text doesn't include ``substring``.""" + text = extract_pdf_text(path) + if substring not in text: + raise DownloadVerifyError( + f"PDF {path} does not contain substring {substring!r}" + ) + + +def assert_pdf_matches(path: Union[str, Path], pattern: Union[str, Pattern[str]]) -> str: + """Raise unless the PDF text matches ``pattern``; returns the match.""" + text = extract_pdf_text(path) + regex = re.compile(pattern) if isinstance(pattern, str) else pattern + match = regex.search(text) + if match is None: + raise DownloadVerifyError( + f"PDF {path} does not match pattern {pattern!r}" + ) + return match.group(0) + + +# ---------- CSV ----------------------------------------------------------- + +def read_csv_rows( + path: Union[str, Path], + *, + encoding: str = "utf-8-sig", + dialect: str = "excel", +) -> List[Dict[str, str]]: + """Read a CSV file as a list of dicts (header-driven).""" + p = Path(path) + if not p.is_file(): + raise DownloadVerifyError(f"CSV not found: {p}") + try: + with open(p, encoding=encoding, newline="") as fp: + reader = csv.DictReader(fp, dialect=dialect) + return [dict(row) for row in reader] + except (OSError, csv.Error) as error: + raise DownloadVerifyError(f"cannot read CSV {p}: {error!r}") from error + + +def assert_csv_columns(path: Union[str, Path], expected_columns: Iterable[str]) -> None: + """Raise if the CSV is missing any of ``expected_columns``.""" + rows = read_csv_rows(path) + if not rows: + raise DownloadVerifyError(f"CSV {path} is empty") + present = set(rows[0].keys()) + missing = [c for c in expected_columns if c not in present] + if missing: + raise DownloadVerifyError( + f"CSV {path} missing columns {missing} (have {sorted(present)})" + ) + + +def assert_csv_row_count( + path: Union[str, Path], + *, + minimum: Optional[int] = None, + maximum: Optional[int] = None, +) -> int: + """Raise unless the row count is within bounds. Returns the actual count.""" + count = len(read_csv_rows(path)) + if minimum is not None and count < minimum: + raise DownloadVerifyError( + f"CSV {path} has {count} rows, expected >= {minimum}" + ) + if maximum is not None and count > maximum: + raise DownloadVerifyError( + f"CSV {path} has {count} rows, expected <= {maximum}" + ) + return count + + +# ---------- Excel --------------------------------------------------------- + +def read_excel_rows( + path: Union[str, Path], + *, + sheet: Optional[Union[str, int]] = None, +) -> List[Dict[str, Any]]: + """ + 讀 .xlsx 為 list of dict (假設第一列是 header)。 + Read an Excel sheet (defaults to the first/active one). Requires + ``openpyxl``; raises with an install hint when missing. + """ + p = Path(path) + if not p.is_file(): + raise DownloadVerifyError(f"Excel file not found: {p}") + try: + from openpyxl import load_workbook # type: ignore[import-not-found] + except ImportError as error: + raise DownloadVerifyError( + "Excel extraction requires openpyxl. Install: pip install openpyxl" + ) from error + try: + wb = load_workbook(filename=str(p), read_only=True, data_only=True) + except Exception as error: # noqa: BLE001 — openpyxl raises many types + raise DownloadVerifyError(f"cannot open {p}: {error!r}") from error + try: + if sheet is None: + ws = wb.active + elif isinstance(sheet, int): + ws = wb.worksheets[sheet] + else: + ws = wb[sheet] + rows_iter = ws.iter_rows(values_only=True) + try: + header = next(rows_iter) + except StopIteration: + return [] + headers = [str(h) if h is not None else f"col_{i}" for i, h in enumerate(header)] + return [ + {headers[i]: value for i, value in enumerate(row) if i < len(headers)} + for row in rows_iter + ] + finally: + wb.close() + + +# ---------- JSON ---------------------------------------------------------- + +def read_json_file(path: Union[str, Path]) -> Any: + p = Path(path) + if not p.is_file(): + raise DownloadVerifyError(f"JSON file not found: {p}") + try: + with open(p, encoding="utf-8") as fp: + return json.load(fp) + except (OSError, ValueError) as error: + raise DownloadVerifyError(f"cannot read JSON {p}: {error!r}") from error + + +def assert_json_matches_schema(path: Union[str, Path], schema: Dict[str, Any]) -> None: + """ + 用簡化版 schema 驗證 JSON:``{"key": "type" or {nested}}``。 + Minimal schema validator (NO jsonschema dependency). Schema example:: + + {"name": "str", "age": "int", "address": {"city": "str"}} + + Raises on mismatch; type names are Python type aliases: ``str``, + ``int``, ``float``, ``bool``, ``list``, ``dict``. + """ + payload = read_json_file(path) + _check_schema(payload, schema, path=str(path)) + + +_TYPE_ALIASES: Dict[str, type] = { + "str": str, + "int": int, + "float": float, + "bool": bool, + "list": list, + "dict": dict, + "any": object, +} + + +def _check_schema(payload: Any, schema: Any, *, path: str, prefix: str = "$") -> None: + if isinstance(schema, str): + expected = _TYPE_ALIASES.get(schema) + if expected is None: + raise DownloadVerifyError(f"unknown schema type {schema!r} at {prefix}") + if expected is object: + return + if not isinstance(payload, expected): + raise DownloadVerifyError( + f"JSON {path} at {prefix}: expected {schema}, got {type(payload).__name__}" + ) + return + if isinstance(schema, dict): + if not isinstance(payload, dict): + raise DownloadVerifyError( + f"JSON {path} at {prefix}: expected object, got {type(payload).__name__}" + ) + for key, sub_schema in schema.items(): + if key not in payload: + raise DownloadVerifyError( + f"JSON {path} at {prefix}: missing key {key!r}" + ) + _check_schema(payload[key], sub_schema, path=path, prefix=f"{prefix}.{key}") + return + raise DownloadVerifyError(f"unsupported schema node at {prefix}: {schema!r}") + + +# ---------- one-shot all-in-one ------------------------------------------ + +@dataclass +class DownloadAssertion: + """All the constraints :func:`assert_download` can check at once.""" + + filename_pattern: Optional[Union[str, Pattern[str]]] = None + sha256: Optional[str] = None + pdf_contains: Optional[str] = None + csv_columns: Optional[List[str]] = None + json_schema: Optional[Dict[str, Any]] = None + min_size_bytes: Optional[int] = None + max_size_bytes: Optional[int] = None + + +def assert_download(path: Union[str, Path], assertion: DownloadAssertion) -> None: + """ + 一次跑完整套 download 驗證,任何一條不過就 raise。 + Combined check that walks every populated field of ``assertion`` and + raises on the first failure. Use this when a download has multiple + constraints (filename + content + hash). + """ + p = Path(path) + if not p.is_file(): + raise DownloadVerifyError(f"download file missing: {p}") + if assertion.filename_pattern is not None: + regex = ( + re.compile(assertion.filename_pattern) + if isinstance(assertion.filename_pattern, str) + else assertion.filename_pattern + ) + if regex.search(p.name) is None: + raise DownloadVerifyError( + f"filename {p.name!r} does not match {assertion.filename_pattern!r}" + ) + size = p.stat().st_size + if assertion.min_size_bytes is not None and size < assertion.min_size_bytes: + raise DownloadVerifyError( + f"file {p} size {size} < min {assertion.min_size_bytes}" + ) + if assertion.max_size_bytes is not None and size > assertion.max_size_bytes: + raise DownloadVerifyError( + f"file {p} size {size} > max {assertion.max_size_bytes}" + ) + if assertion.sha256 is not None: + assert_file_sha256(p, assertion.sha256) + if assertion.pdf_contains is not None: + assert_pdf_contains(p, assertion.pdf_contains) + if assertion.csv_columns is not None: + assert_csv_columns(p, assertion.csv_columns) + if assertion.json_schema is not None: + assert_json_matches_schema(p, assertion.json_schema) diff --git a/je_web_runner/utils/edge_case_generator/__init__.py b/je_web_runner/utils/edge_case_generator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/edge_case_generator/generator.py b/je_web_runner/utils/edge_case_generator/generator.py new file mode 100644 index 0000000..271fdc0 --- /dev/null +++ b/je_web_runner/utils/edge_case_generator/generator.py @@ -0,0 +1,289 @@ +""" +AI 邊界案例產生器:從 passing test 出發,LLM 列出 boundary / 異常 / race / unicode +變體,每個變體都是可執行的 action JSON 草稿,可選擇直接寫入新檔。 + +Complements :mod:`mutation_testing`: + +* mutation_testing **breaks** existing tests to verify they're sensitive. +* edge_case_generator **invents** new tests to widen coverage. + +The LLM picks edge-case categories (boundary, unicode, network, timing, +permission, etc.) from a fixed catalogue so output is enumerable and +explainable. Each generated variant ships with a one-line rationale +and the action JSON ready to drop into the suite. +""" +from __future__ import annotations + +import json +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +from je_web_runner.utils.ai_assist.llm_assist import LLMAssistError, _invoke +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class EdgeCaseGeneratorError(WebRunnerException): + """Raised when input is malformed or LLM output is unusable.""" + + +class EdgeCaseCategory(str, Enum): + """Categories the LLM is asked to draw from — keeps output enumerable.""" + BOUNDARY = "boundary" # min/max numeric, empty string, max-length + UNICODE = "unicode" # RTL, emoji, combining marks, zero-width + NETWORK = "network" # offline, slow 3G, intermittent 500s + TIMING = "timing" # double-click, rapid retry, debounce edge + PERMISSION = "permission" # denied geolocation, denied notifications + AUTH = "auth" # expired session, missing CSRF, no cookies + RACE = "race" # two concurrent submits, back-button mid-flow + INPUT_VALIDATION = "input_validation" # XSS attempt, SQL-like, control chars + LOCALE = "locale" # RTL layout, non-Latin numerals + A11Y = "a11y" # keyboard-only nav, screen-reader path + + +DEFAULT_CATEGORIES: Sequence[EdgeCaseCategory] = tuple(EdgeCaseCategory) + + +@dataclass +class EdgeCase: + """One generated edge-case variant.""" + + name: str + category: EdgeCaseCategory + rationale: str + actions: List[Any] + expected_outcome: str = "fail" # "fail" | "pass" — what the LLM thinks should happen + severity: str = "medium" # "low" | "medium" | "high" + + def to_dict(self) -> Dict[str, Any]: + out = asdict(self) + out["category"] = self.category.value + return out + + +@dataclass +class EdgeCaseSuite: + """A bundle of edge-case variants for one source test.""" + + source_test_name: str + cases: List[EdgeCase] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "source_test_name": self.source_test_name, + "cases": [c.to_dict() for c in self.cases], + } + + +# ---------- prompt ------------------------------------------------------- + +_GEN_PROMPT = ( + "You are a senior web-QA engineer brainstorming edge cases for a " + "passing test. Generate exactly {n} variants the original suite " + "does NOT cover. Output ONLY a JSON object (no prose outside the " + "envelope) with one key:\n" + " cases: list of objects with keys " + '{{"name": str, "category": str, "rationale": str, ' + '"actions": , "expected_outcome": "fail" | "pass", ' + '"severity": "low" | "medium" | "high"}}\n\n' + "Allowed categories: {categories}\n" + "Test name: {test_name}\n" + "Original action JSON:\n{actions}\n" + "Domain context (optional): {context}\n" +) + + +_JSON_OBJECT_RE = re.compile(r"\{.*\}", re.DOTALL) + + +def _parse_payload(text: str) -> Dict[str, Any]: + match = _JSON_OBJECT_RE.search(text) + if match is None: + raise EdgeCaseGeneratorError("LLM did not return a JSON object") + try: + payload = json.loads(match.group(0)) + except ValueError as error: + raise EdgeCaseGeneratorError(f"LLM JSON did not parse: {error}") from error + if not isinstance(payload, dict): + raise EdgeCaseGeneratorError( + f"LLM payload not object: {type(payload).__name__}" + ) + return payload + + +def _coerce_category(value: Any) -> EdgeCaseCategory: + text = str(value or "").strip().lower() + for member in EdgeCaseCategory: + if member.value == text: + return member + return EdgeCaseCategory.BOUNDARY + + +def _coerce_severity(value: Any) -> str: + text = str(value or "").strip().lower() + return text if text in {"low", "medium", "high"} else "medium" + + +def _coerce_outcome(value: Any) -> str: + text = str(value or "").strip().lower() + return text if text in {"pass", "fail"} else "fail" + + +def _parse_case(raw: Any) -> Optional[EdgeCase]: + if not isinstance(raw, dict): + return None + actions = raw.get("actions") + if not isinstance(actions, list): + return None + name = str(raw.get("name") or "").strip() or "" + return EdgeCase( + name=name, + category=_coerce_category(raw.get("category")), + rationale=str(raw.get("rationale") or "").strip(), + actions=actions, + expected_outcome=_coerce_outcome(raw.get("expected_outcome")), + severity=_coerce_severity(raw.get("severity")), + ) + + +def generate_edge_cases( + actions: List[Any], + *, + test_name: str = "", + n: int = 5, + categories: Sequence[EdgeCaseCategory] = DEFAULT_CATEGORIES, + context: str = "", +) -> EdgeCaseSuite: + """ + 呼叫 LLM 對 ``actions`` 生 ``n`` 個 edge-case 變體。 + Returns an :class:`EdgeCaseSuite` whose cases are ready to run. + Cases with malformed shapes are dropped rather than aborting the + whole batch — partial output is better than none. + """ + if not isinstance(actions, list): + raise EdgeCaseGeneratorError( + f"actions must be a list, got {type(actions).__name__}" + ) + if n <= 0: + raise EdgeCaseGeneratorError("n must be positive") + if not categories: + categories = DEFAULT_CATEGORIES + cat_names = ", ".join(c.value for c in categories) + prompt = _GEN_PROMPT.format( + n=n, + categories=cat_names, + test_name=test_name or "", + actions=json.dumps(actions, ensure_ascii=False, indent=2)[:4500], + context=context or "", + ) + try: + raw = _invoke(prompt) + except LLMAssistError as error: + raise EdgeCaseGeneratorError(str(error)) from error + payload = _parse_payload(raw) + cases_raw = payload.get("cases") + if not isinstance(cases_raw, list): + raise EdgeCaseGeneratorError("LLM payload missing 'cases' list") + cases: List[EdgeCase] = [] + for item in cases_raw: + parsed = _parse_case(item) + if parsed is not None: + cases.append(parsed) + web_runner_logger.info( + f"generate_edge_cases: test={test_name!r} requested={n} parsed={len(cases)}" + ) + return EdgeCaseSuite(source_test_name=test_name or "", cases=cases) + + +def generate_edge_cases_from_file( + action_path: Union[str, Path], + *, + n: int = 5, + categories: Sequence[EdgeCaseCategory] = DEFAULT_CATEGORIES, + context: str = "", +) -> EdgeCaseSuite: + """Load actions from disk then call :func:`generate_edge_cases`.""" + path = Path(action_path) + if not path.is_file(): + raise EdgeCaseGeneratorError(f"action file not found: {path}") + try: + with open(path, encoding="utf-8") as fp: + actions = json.load(fp) + except (OSError, ValueError) as error: + raise EdgeCaseGeneratorError(f"cannot parse {path}: {error!r}") from error + if not isinstance(actions, list): + raise EdgeCaseGeneratorError(f"top-level JSON must be a list: {path}") + return generate_edge_cases( + actions, + test_name=path.stem, + n=n, + categories=categories, + context=context, + ) + + +# ---------- writing ------------------------------------------------------ + +def write_suite_to_dir( + suite: EdgeCaseSuite, + output_dir: Union[str, Path], + *, + filename_prefix: Optional[str] = None, +) -> List[Path]: + """ + 把每個 edge case 寫成一個 action JSON 檔到 ``output_dir``。 + File names are slug-of-name with a numeric prefix so they sort in the + same order the LLM produced them. Returns the list of written paths. + """ + target = Path(output_dir) + target.mkdir(parents=True, exist_ok=True) + prefix = filename_prefix or _slugify(suite.source_test_name) or "edge" + written: List[Path] = [] + for idx, case in enumerate(suite.cases, 1): + slug = _slugify(case.name) or f"case-{idx}" + path = target / f"{prefix}__{idx:02d}__{slug}.json" + with open(path, "w", encoding="utf-8") as fp: + json.dump(case.actions, fp, ensure_ascii=False, indent=2) + written.append(path) + web_runner_logger.info( + f"write_suite_to_dir: wrote {len(written)} edge-case files to {target}" + ) + return written + + +_SLUG_RE = re.compile(r"[^A-Za-z0-9_-]+") + + +def _slugify(value: str) -> str: + if not value: + return "" + cleaned = _SLUG_RE.sub("-", value.strip().lower()) + return cleaned.strip("-")[:60] + + +# ---------- rendering --------------------------------------------------- + +def render_suite_markdown(suite: EdgeCaseSuite) -> str: + """Markdown view of the suite for PR comments / review.""" + pieces = [ + f"## AI Edge Cases for `{suite.source_test_name}`", + "", + f"- **Generated cases:** {len(suite.cases)}", + "", + ] + if not suite.cases: + pieces.append("_(no cases parsed)_") + return "\n".join(pieces).rstrip() + "\n" + pieces.append("| # | Category | Severity | Expects | Name | Rationale |") + pieces.append("|---|----------|----------|---------|------|-----------|") + for i, case in enumerate(suite.cases, 1): + rationale = (case.rationale[:100] + "…") if len(case.rationale) > 100 else case.rationale + pieces.append( + f"| {i} | `{case.category.value}` | `{case.severity}` | " + f"`{case.expected_outcome}` | {case.name} | {rationale} |" + ) + pieces.append("") + return "\n".join(pieces).rstrip() + "\n" diff --git a/je_web_runner/utils/email_render/__init__.py b/je_web_runner/utils/email_render/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/email_render/render.py b/je_web_runner/utils/email_render/render.py new file mode 100644 index 0000000..87f8b66 --- /dev/null +++ b/je_web_runner/utils/email_render/render.py @@ -0,0 +1,299 @@ +""" +攔截應用寄出的 email,渲染 HTML 與跨 client 截圖比對。 +Capture outbound mail from MailHog / Mailpit (or a directory of ``.eml`` +files), normalise it into a :class:`CapturedEmail`, then optionally render +the HTML body inside multiple width / dark-mode viewport "clients" and +save screenshots for visual diff. + +The fetch / render layers are deliberately decoupled: + +* :func:`fetch_mailhog`, :func:`fetch_mailpit`, :func:`load_eml_file` produce + :class:`CapturedEmail` records — no rendering, no browser needed. +* :func:`render_email_in_viewports` accepts a callable that drives whatever + browser the user already wired up (Selenium, Playwright, ``cdp`` module). + This avoids hard-coupling email_render to one browser stack. +""" +from __future__ import annotations + +import email +import json +from dataclasses import dataclass, field +from email.message import Message +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class EmailRenderError(WebRunnerException): + """Raised on capture-server I/O, malformed EML, or render driver failure.""" + + +# ---------- data --------------------------------------------------------- + +@dataclass +class CapturedEmail: + """One inbox message normalised across providers.""" + + message_id: str + subject: str + from_addr: str + to: List[str] + html_body: Optional[str] = None + text_body: Optional[str] = None + headers: Dict[str, str] = field(default_factory=dict) + raw: Optional[str] = None + + def has_html(self) -> bool: + return bool(self.html_body and self.html_body.strip()) + + +@dataclass(frozen=True) +class ViewportProfile: + """One render target (e.g. 'gmail-desktop', 'apple-mail-dark').""" + + name: str + width: int + height: int + dark_mode: bool = False + user_agent: Optional[str] = None + + +DEFAULT_VIEWPORTS: Sequence[ViewportProfile] = ( + ViewportProfile("desktop-light", 800, 1200, dark_mode=False), + ViewportProfile("desktop-dark", 800, 1200, dark_mode=True), + ViewportProfile("mobile-light", 390, 844, dark_mode=False), +) + + +@dataclass +class RenderArtifact: + """A single render-and-screenshot output.""" + + viewport: str + screenshot_path: Path + width: int + height: int + + +# ---------- helpers ------------------------------------------------------ + +def _require_requests() -> Any: + try: + import requests # type: ignore[import-not-found] + return requests + except ImportError as error: + raise EmailRenderError( + "requests is required to fetch from MailHog/Mailpit. " + "Install: pip install requests" + ) from error + + +def _split_addresses(value: Any) -> List[str]: + if value is None: + return [] + if isinstance(value, list): + return [str(v) for v in value if v] + return [part.strip() for part in str(value).split(",") if part.strip()] + + +def _parse_eml(raw: Union[str, bytes]) -> CapturedEmail: + if isinstance(raw, bytes): + msg = email.message_from_bytes(raw) + else: + msg = email.message_from_string(raw) + return _from_message(msg, raw_text=raw if isinstance(raw, str) else raw.decode("utf-8", "replace")) + + +def _from_message(msg: Message, *, raw_text: Optional[str] = None) -> CapturedEmail: + html_body: Optional[str] = None + text_body: Optional[str] = None + if msg.is_multipart(): + for part in msg.walk(): + ctype = part.get_content_type() + if ctype == "text/html" and html_body is None: + html_body = _decode_payload(part) + elif ctype == "text/plain" and text_body is None: + text_body = _decode_payload(part) + else: + body = _decode_payload(msg) + if msg.get_content_type() == "text/html": + html_body = body + else: + text_body = body + headers = {k: v for k, v in msg.items()} + return CapturedEmail( + message_id=str(msg.get("Message-ID", "")), + subject=str(msg.get("Subject", "")), + from_addr=str(msg.get("From", "")), + to=_split_addresses(msg.get("To")), + html_body=html_body, + text_body=text_body, + headers=headers, + raw=raw_text, + ) + + +def _decode_payload(part: Message) -> Optional[str]: + payload = part.get_payload(decode=True) + if payload is None: + return None + charset = part.get_content_charset() or "utf-8" + try: + return payload.decode(charset, errors="replace") + except (LookupError, AttributeError): + return payload.decode("utf-8", errors="replace") + + +# ---------- fetchers ----------------------------------------------------- + +def load_eml_file(path: Union[str, Path]) -> CapturedEmail: + """Parse a single ``.eml`` file.""" + eml_path = Path(path) + if not eml_path.exists(): + raise EmailRenderError(f"eml file not found: {eml_path}") + raw = eml_path.read_bytes() + return _parse_eml(raw) + + +def load_eml_dir(directory: Union[str, Path]) -> List[CapturedEmail]: + """Parse every ``.eml`` file in ``directory`` (non-recursive).""" + dir_path = Path(directory) + if not dir_path.is_dir(): + raise EmailRenderError(f"eml directory not found: {dir_path}") + out: List[CapturedEmail] = [] + for child in sorted(dir_path.glob("*.eml")): + out.append(load_eml_file(child)) + return out + + +def fetch_mailhog(base_url: str, *, timeout: float = 10.0) -> List[CapturedEmail]: + """Fetch messages from a MailHog server's ``/api/v2/messages`` endpoint.""" + requests = _require_requests() + url = base_url.rstrip("/") + "/api/v2/messages" + try: + response = requests.get(url, timeout=timeout) + response.raise_for_status() + payload = response.json() + except (requests.RequestException, ValueError) as error: + raise EmailRenderError(f"mailhog fetch failed: {error!r}") from error + items = payload.get("items") if isinstance(payload, dict) else None + if not isinstance(items, list): + return [] + return [_parse_mailhog_item(item) for item in items if isinstance(item, dict)] + + +def _parse_mailhog_item(item: Dict[str, Any]) -> CapturedEmail: + content = item.get("Content") or {} + headers = content.get("Headers") or {} + raw_body = content.get("Body") or "" + # MailHog gives us the raw body; pass it through email module to split parts. + header_text = "".join( + f"{name}: {values[0]}\n" if values else "" for name, values in headers.items() + ) + raw = header_text + "\n" + raw_body + return _parse_eml(raw) + + +def fetch_mailpit(base_url: str, *, timeout: float = 10.0, limit: int = 50) -> List[CapturedEmail]: + """Fetch messages from a Mailpit server's ``/api/v1/messages`` listing.""" + requests = _require_requests() + list_url = f"{base_url.rstrip('/')}/api/v1/messages?limit={int(limit)}" + try: + listing = requests.get(list_url, timeout=timeout) + listing.raise_for_status() + listing_payload = listing.json() + except (requests.RequestException, ValueError) as error: + raise EmailRenderError(f"mailpit list failed: {error!r}") from error + ids = [] + if isinstance(listing_payload, dict): + for entry in listing_payload.get("messages", []) or []: + if isinstance(entry, dict) and entry.get("ID"): + ids.append(entry["ID"]) + out: List[CapturedEmail] = [] + for msg_id in ids: + raw_url = f"{base_url.rstrip('/')}/api/v1/message/{msg_id}/raw" + try: + raw_resp = requests.get(raw_url, timeout=timeout) + raw_resp.raise_for_status() + except requests.RequestException as error: + web_runner_logger.warning(f"mailpit raw fetch failed for {msg_id}: {error!r}") + continue + out.append(_parse_eml(raw_resp.content)) + return out + + +# ---------- rendering ---------------------------------------------------- + +RenderDriver = Callable[[str, ViewportProfile, Path], Path] +"""Signature: ``driver(html, viewport, target_png) -> actual_png_path``.""" + + +def render_email_in_viewports( + captured: CapturedEmail, + driver: RenderDriver, + output_dir: Union[str, Path], + *, + viewports: Sequence[ViewportProfile] = DEFAULT_VIEWPORTS, +) -> List[RenderArtifact]: + """ + Render ``captured.html_body`` in each viewport via ``driver`` and write + screenshots into ``output_dir``. The driver receives the HTML, viewport + profile, and a target PNG path; it must return the path of the file it + actually wrote (so wrappers that pick their own filename still work). + """ + if not captured.has_html(): + raise EmailRenderError(f"captured email has no HTML body: {captured.message_id!r}") + out_dir = Path(output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + artifacts: List[RenderArtifact] = [] + for viewport in viewports: + target = out_dir / f"{_safe_slug(captured.message_id) or 'msg'}__{viewport.name}.png" + written = Path(driver(captured.html_body or "", viewport, target)) + artifacts.append(RenderArtifact( + viewport=viewport.name, + screenshot_path=written, + width=viewport.width, + height=viewport.height, + )) + return artifacts + + +def _safe_slug(value: str) -> str: + cleaned = "".join(ch if ch.isalnum() or ch in "-_" else "_" for ch in value) + return cleaned.strip("_")[:80] + + +# ---------- assertions --------------------------------------------------- + +def assert_subject_contains(captured: CapturedEmail, needle: str) -> None: + """Raise unless ``needle`` is a substring of the captured subject.""" + if not isinstance(needle, str) or not needle: + raise EmailRenderError("needle must be a non-empty string") + if needle not in (captured.subject or ""): + raise EmailRenderError( + f"subject does not contain {needle!r}: actual={captured.subject!r}" + ) + + +def export_summary_json( + captures: Sequence[CapturedEmail], + output_path: Union[str, Path], +) -> Path: + """Persist a compact JSON list of captured emails for downstream tooling.""" + payload = [ + { + "message_id": c.message_id, + "subject": c.subject, + "from": c.from_addr, + "to": c.to, + "has_html": c.has_html(), + } + for c in captures + ] + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as fp: + json.dump(payload, fp, ensure_ascii=False, indent=2) + return path diff --git a/je_web_runner/utils/exploratory_ai/__init__.py b/je_web_runner/utils/exploratory_ai/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/exploratory_ai/explorer.py b/je_web_runner/utils/exploratory_ai/explorer.py new file mode 100644 index 0000000..0bf89f3 --- /dev/null +++ b/je_web_runner/utils/exploratory_ai/explorer.py @@ -0,0 +1,311 @@ +""" +代理式探索測試員:LLM 主動決定下一步點哪、輸入甚麼,過程中蒐集 bug 報告線索。 +The browser-driving and LLM-asking layers are pluggable so the core loop +stays unit-testable: + +* :class:`PageObserver` — duck-typed protocol the explorer asks for the + current page's URL / title / actionable interactive elements / console + errors. A real implementation wraps Selenium or Playwright. +* :class:`ActionPlanner` — duck-typed protocol the explorer asks for the + next :class:`PlannedAction` given the observation list. The default + :class:`RandomPlanner` is deterministic with a seed and useful as a + fuzz-style fallback when no LLM is configured. +* :class:`Explorer` — the loop. Runs N steps, gathers + :class:`BugSignal`s from observed console errors / 4xx-5xx network + hits, and returns a :class:`ExplorationReport`. +""" +from __future__ import annotations + +import random +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Protocol, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class ExploratoryAiError(WebRunnerException): + """Raised on observer/planner failures or invalid configuration.""" + + +class ActionKind(str, Enum): + """What :class:`PlannedAction` instructs the runner to do.""" + + CLICK = "click" + TYPE = "type" + NAVIGATE = "navigate" + SCROLL = "scroll" + DONE = "done" + + +# ---------- data models ------------------------------------------------- + +@dataclass +class InteractiveElement: + """A clickable / typeable element the observer surfaced.""" + + selector: str + tag: str + text: str = "" + role: Optional[str] = None + is_visible: bool = True + is_enabled: bool = True + + def __post_init__(self) -> None: + if not self.selector or not isinstance(self.selector, str): + raise ExploratoryAiError("InteractiveElement.selector must be non-empty string") + if not self.tag or not isinstance(self.tag, str): + raise ExploratoryAiError("InteractiveElement.tag must be non-empty string") + + +@dataclass +class PageObservation: + """A snapshot of the page state passed to the planner.""" + + url: str + title: str + elements: List[InteractiveElement] = field(default_factory=list) + console_errors: List[str] = field(default_factory=list) + network_errors: List[Dict[str, Any]] = field(default_factory=list) + step: int = 0 + + def actionable(self) -> List[InteractiveElement]: + return [e for e in self.elements if e.is_visible and e.is_enabled] + + +@dataclass +class PlannedAction: + """The next step the explorer wants the runner to perform.""" + + kind: ActionKind + selector: Optional[str] = None + value: Optional[str] = None + rationale: str = "" + + def __post_init__(self) -> None: + if self.kind in (ActionKind.CLICK,) and not self.selector: + raise ExploratoryAiError("click action requires selector") + if self.kind == ActionKind.TYPE and (not self.selector or self.value is None): + raise ExploratoryAiError("type action requires selector and value") + if self.kind == ActionKind.NAVIGATE and not self.value: + raise ExploratoryAiError("navigate action requires value (url)") + + +@dataclass +class BugSignal: + """Something that looks broken; raised to the report.""" + + step: int + url: str + kind: str # 'console_error' | 'network_error' | 'planner_stuck' | ... + detail: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class ExplorationReport: + """Roll-up returned by :func:`Explorer.run`.""" + + steps_taken: int + pages_visited: List[str] = field(default_factory=list) + bugs: List[BugSignal] = field(default_factory=list) + actions: List[PlannedAction] = field(default_factory=list) + stopped_reason: str = "" + + def has_bugs(self) -> bool: + return bool(self.bugs) + + +# ---------- protocols --------------------------------------------------- + +class PageObserver(Protocol): + """Implementations wrap a browser driver to produce observations.""" + + def observe(self, step: int) -> PageObservation: ... + + +class ActionPlanner(Protocol): + """Decides the next action. Stateful planners may carry their own memory.""" + + def plan(self, observation: PageObservation) -> PlannedAction: ... + + +# ---------- planners ---------------------------------------------------- + +class RandomPlanner: + """ + Deterministic fuzz planner. Picks a random visible element to click, + or types a random short string into the first input it finds. Useful + on its own (no LLM needed) and as the fallback when an LLM planner + fails. + """ + + def __init__( + self, + *, + seed: Optional[int] = None, + sample_strings: Sequence[str] = ("test", "1234", "x"), + type_bias: float = 0.3, + ) -> None: + if not 0.0 <= type_bias <= 1.0: + raise ExploratoryAiError("type_bias must be in [0, 1]") + self._rng = random.Random(seed) + self._samples = list(sample_strings) or ["x"] + self._type_bias = type_bias + + def plan(self, observation: PageObservation) -> PlannedAction: + actionable = observation.actionable() + if not actionable: + return PlannedAction( + kind=ActionKind.DONE, + rationale="no actionable elements on page", + ) + inputs = [e for e in actionable if e.tag.lower() in {"input", "textarea"}] + if inputs and self._rng.random() < self._type_bias: + target = self._rng.choice(inputs) + return PlannedAction( + kind=ActionKind.TYPE, + selector=target.selector, + value=self._rng.choice(self._samples), + rationale="random fuzz: fill an input", + ) + target = self._rng.choice(actionable) + return PlannedAction( + kind=ActionKind.CLICK, + selector=target.selector, + rationale="random fuzz: click a visible element", + ) + + +# ---------- the loop ---------------------------------------------------- + +ActionExecutor = Callable[[PlannedAction], None] +"""Callable that performs the action against the real browser.""" + + +@dataclass +class Explorer: + """The exploratory loop. Hold one per session.""" + + observer: PageObserver + planner: ActionPlanner + executor: ActionExecutor + max_steps: int = 25 + max_repeat_loops: int = 3 + stop_on_bugs: int = 0 # 0 = never stop early + + def __post_init__(self) -> None: + if self.max_steps <= 0: + raise ExploratoryAiError("max_steps must be > 0") + if self.max_repeat_loops < 0: + raise ExploratoryAiError("max_repeat_loops must be >= 0") + if self.stop_on_bugs < 0: + raise ExploratoryAiError("stop_on_bugs must be >= 0") + + def run(self) -> ExplorationReport: + report = ExplorationReport(steps_taken=0) + repeat_counter: Dict[str, int] = {} + for step in range(self.max_steps): + observation = self._safe_observe(step) + if observation.url and ( + not report.pages_visited or report.pages_visited[-1] != observation.url + ): + report.pages_visited.append(observation.url) + self._collect_bug_signals(observation, report) + if self.stop_on_bugs and len(report.bugs) >= self.stop_on_bugs: + report.stopped_reason = ( + f"hit stop_on_bugs={self.stop_on_bugs} ({len(report.bugs)} signals)" + ) + break + action = self._safe_plan(observation, report, repeat_counter) + if action is None: + report.stopped_reason = "planner repeatedly proposed same action; stopping" + break + if action.kind == ActionKind.DONE: + report.stopped_reason = action.rationale or "planner said done" + break + report.actions.append(action) + try: + self.executor(action) + except Exception as error: + report.bugs.append(BugSignal( + step=step, + url=observation.url, + kind="action_error", + detail=f"{action.kind.value} failed: {error!r}", + )) + report.steps_taken = step + 1 + else: + report.stopped_reason = f"reached max_steps={self.max_steps}" + return report + + def _safe_observe(self, step: int) -> PageObservation: + try: + obs = self.observer.observe(step) + except Exception as error: + raise ExploratoryAiError( + f"observer.observe failed at step {step}: {error!r}" + ) from error + if not isinstance(obs, PageObservation): + raise ExploratoryAiError( + f"observer.observe returned {type(obs).__name__}, want PageObservation" + ) + return obs + + def _safe_plan( + self, + observation: PageObservation, + report: ExplorationReport, + repeat_counter: Dict[str, int], + ) -> Optional[PlannedAction]: + try: + action = self.planner.plan(observation) + except Exception as error: + web_runner_logger.warning(f"planner failed; stopping: {error!r}") + report.bugs.append(BugSignal( + step=observation.step, + url=observation.url, + kind="planner_error", + detail=repr(error), + )) + return None + if not isinstance(action, PlannedAction): + raise ExploratoryAiError( + f"planner returned {type(action).__name__}, want PlannedAction" + ) + key = f"{action.kind.value}:{action.selector or ''}:{action.value or ''}" + repeat_counter[key] = repeat_counter.get(key, 0) + 1 + if repeat_counter[key] > self.max_repeat_loops: + report.bugs.append(BugSignal( + step=observation.step, + url=observation.url, + kind="planner_stuck", + detail=f"action {key!r} chosen {repeat_counter[key]} times in a row", + )) + return None + return action + + def _collect_bug_signals( + self, + observation: PageObservation, + report: ExplorationReport, + ) -> None: + for message in observation.console_errors: + report.bugs.append(BugSignal( + step=observation.step, + url=observation.url, + kind="console_error", + detail=message, + )) + for record in observation.network_errors: + status = record.get("status") if isinstance(record, dict) else None + url = record.get("url", "") if isinstance(record, dict) else "" + report.bugs.append(BugSignal( + step=observation.step, + url=observation.url, + kind="network_error", + detail=f"{status} {url}", + )) diff --git a/je_web_runner/utils/failure_narrator/__init__.py b/je_web_runner/utils/failure_narrator/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/failure_narrator/narrator.py b/je_web_runner/utils/failure_narrator/narrator.py new file mode 100644 index 0000000..511ad5a --- /dev/null +++ b/je_web_runner/utils/failure_narrator/narrator.py @@ -0,0 +1,239 @@ +""" +LLM 從 failure_bundle 寫出自然語言的「為什麼這個 test 失敗了」報告。 +Different from ``failure_triage`` (root-cause analysis with hypotheses): +this is the *human-friendly summary* you want in a PR comment or Slack +thread. "Login test failed because the submit button wasn't visible — +likely because feature flag `new_login_ui` was on for this PR." + +The LLM client is abstracted so tests can stub responses; the prompt +template is exported so teams can tune tone without forking. +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Protocol, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class FailureNarratorError(WebRunnerException): + """Raised on missing bundle, malformed input, or LLM client failure.""" + + +# ---------- bundle inputs ---------------------------------------------- + +@dataclass +class FailureBundle: + """Pre-digested failure facts used to build the prompt.""" + + test_id: str + action: str = "" + error_message: str = "" + error_class: str = "" + last_url: str = "" + last_dom_excerpt: str = "" + console_errors: List[str] = field(default_factory=list) + network_errors: List[str] = field(default_factory=list) + failed_assertion: str = "" + git_commit: str = "" + flake_history: str = "" # e.g. "flaky in 3/10 recent runs" + extra_context: List[str] = field(default_factory=list) + + def __post_init__(self) -> None: + if not isinstance(self.test_id, str) or not self.test_id: + raise FailureNarratorError("test_id must be non-empty string") + + +def load_bundle_dir(path: Union[str, Path]) -> FailureBundle: + """Read a failure-bundle directory laid out as JSON + text files.""" + bundle_dir = Path(path) + if not bundle_dir.exists() or not bundle_dir.is_dir(): + raise FailureNarratorError(f"bundle dir not found: {bundle_dir}") + meta_path = bundle_dir / "meta.json" + if not meta_path.exists(): + raise FailureNarratorError(f"bundle missing meta.json: {meta_path}") + try: + meta = json.loads(meta_path.read_text(encoding="utf-8")) + except ValueError as error: + raise FailureNarratorError(f"meta.json not JSON: {error}") from error + if not isinstance(meta, dict): + raise FailureNarratorError("meta.json must be an object") + test_id = meta.get("test_id") or meta.get("path") or bundle_dir.name + if not isinstance(test_id, str) or not test_id: + raise FailureNarratorError("bundle has no usable test_id") + return FailureBundle( + test_id=test_id, + action=str(meta.get("action") or ""), + error_message=str(meta.get("error_message") or ""), + error_class=str(meta.get("error_class") or ""), + last_url=str(meta.get("last_url") or ""), + last_dom_excerpt=_read_text(bundle_dir / "dom.html", limit=2000), + console_errors=_read_lines(bundle_dir / "console.log"), + network_errors=_read_lines(bundle_dir / "network_errors.log"), + failed_assertion=str(meta.get("failed_assertion") or ""), + git_commit=str(meta.get("git_commit") or ""), + flake_history=str(meta.get("flake_history") or ""), + extra_context=[str(x) for x in meta.get("extra_context") or []], + ) + + +def _read_text(path: Path, *, limit: int) -> str: + if not path.exists(): + return "" + text = path.read_text(encoding="utf-8", errors="replace") + return text[:limit] + + +def _read_lines(path: Path) -> List[str]: + if not path.exists(): + return [] + return [line.rstrip("\n") for line in path.read_text( + encoding="utf-8", errors="replace", + ).splitlines() if line.strip()] + + +# ---------- LLM client protocol ---------------------------------------- + +class NarratorClient(Protocol): + """The LLM client interface.""" + + def complete(self, prompt: str) -> str: ... + + +# ---------- prompt ------------------------------------------------------ + +PROMPT_TEMPLATE = """\ +You are an SRE assistant explaining why an end-to-end test failed. +Write a concise, factual, blame-free report. + +# Failure facts +- Test: {test_id} +- Failing action: {action} +- Error: {error_class}: {error_message} +- Last URL: {last_url} +- Failed assertion: {failed_assertion} +- Recent flake history: {flake_history} +- Git commit under test: {git_commit} + +# Console errors (sampled) +{console_errors} + +# Network errors (sampled) +{network_errors} + +# DOM excerpt (first 2k chars) +``` +{last_dom_excerpt} +``` + +# Extra context +{extra_context} + +# Instructions +Return strictly a JSON object with keys: +- "summary": one sentence +- "likely_cause": one or two sentences +- "next_step": one sentence with what an engineer should investigate first +- "confidence": "low" | "medium" | "high" +""" + + +def build_prompt(bundle: FailureBundle) -> str: + """Render the deterministic prompt for the LLM.""" + if not isinstance(bundle, FailureBundle): + raise FailureNarratorError("build_prompt expects FailureBundle") + return PROMPT_TEMPLATE.format( + test_id=bundle.test_id, + action=bundle.action or "(unknown)", + error_class=bundle.error_class or "Error", + error_message=bundle.error_message or "(no message)", + last_url=bundle.last_url or "(unknown)", + failed_assertion=bundle.failed_assertion or "(none)", + flake_history=bundle.flake_history or "(unknown)", + git_commit=bundle.git_commit or "(unknown)", + console_errors=_join_for_prompt(bundle.console_errors), + network_errors=_join_for_prompt(bundle.network_errors), + last_dom_excerpt=bundle.last_dom_excerpt or "(none captured)", + extra_context=_join_for_prompt(bundle.extra_context), + ) + + +def _join_for_prompt(lines: Sequence[str]) -> str: + if not lines: + return "(none)" + return "\n".join(f"- {line}" for line in lines[:10]) + + +# ---------- response parsing ------------------------------------------- + +@dataclass +class NarrationReport: + """Parsed LLM response.""" + + summary: str + likely_cause: str + next_step: str + confidence: str + raw: str = "" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + def markdown(self) -> str: + return ( + f"**Why this failed**: {self.summary}\n\n" + f"**Likely cause** ({self.confidence}): {self.likely_cause}\n\n" + f"**Next step**: {self.next_step}\n" + ) + + +def parse_response(raw: str) -> NarrationReport: + """Decode the LLM's JSON envelope into a :class:`NarrationReport`.""" + if not isinstance(raw, str) or not raw.strip(): + raise FailureNarratorError("LLM returned empty response") + start = raw.find("{") + end = raw.rfind("}") + if start == -1 or end == -1 or end <= start: + raise FailureNarratorError(f"no JSON object in response: {raw[:160]!r}") + try: + obj = json.loads(raw[start:end + 1]) + except ValueError as error: + raise FailureNarratorError( + f"response was not JSON ({error}): {raw[:160]!r}" + ) from error + if not isinstance(obj, dict): + raise FailureNarratorError("response JSON must be an object") + for field_name in ("summary", "likely_cause", "next_step", "confidence"): + if field_name not in obj or not isinstance(obj[field_name], str): + raise FailureNarratorError(f"response missing string {field_name!r}") + confidence = obj["confidence"].strip().lower() + if confidence not in ("low", "medium", "high"): + raise FailureNarratorError( + f"unknown confidence {confidence!r}; want low/medium/high" + ) + return NarrationReport( + summary=obj["summary"].strip(), + likely_cause=obj["likely_cause"].strip(), + next_step=obj["next_step"].strip(), + confidence=confidence, + raw=raw, + ) + + +# ---------- end-to-end ------------------------------------------------- + +def narrate( + bundle: FailureBundle, + client: NarratorClient, +) -> NarrationReport: + """Build prompt → call LLM → parse → return.""" + prompt = build_prompt(bundle) + try: + raw = client.complete(prompt) + except Exception as error: + raise FailureNarratorError( + f"narrator client failed: {error!r}" + ) from error + return parse_response(raw) diff --git a/je_web_runner/utils/failure_triage/__init__.py b/je_web_runner/utils/failure_triage/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/failure_triage/triage.py b/je_web_runner/utils/failure_triage/triage.py new file mode 100644 index 0000000..af2c71b --- /dev/null +++ b/je_web_runner/utils/failure_triage/triage.py @@ -0,0 +1,310 @@ +""" +AI 失敗根因分析:把 failure_bundle / cluster signature / 最近動作 餵給 LLM, +得到結構化 RCA(likely_cause / evidence / next_steps / suggested_fix / confidence), +再轉成 markdown 報告與 PR comment 用 body。 + +AI-driven failure triage. Reuses the existing ``failure_bundle``, +``failure_cluster``, ``ai_assist`` and ``pr_comment`` modules: + +* Bundle ⇒ structured signal extraction (last N steps, console tail, + network tail, DOM excerpt, cluster bucket). +* Signals + JSON-only prompt ⇒ ``TriageReport`` dataclass. +* ``render_markdown`` ⇒ human-readable summary. + +No LLM provider is bundled — caller registers any +``Callable[[str], str]`` through :mod:`je_web_runner.utils.ai_assist.llm_assist`. +""" +from __future__ import annotations + +import json +import re +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, List, Optional, Sequence, Union + +from je_web_runner.utils.ai_assist.llm_assist import LLMAssistError, _invoke +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.failure_bundle.bundle import extract_bundle +from je_web_runner.utils.failure_cluster.clustering import normalise_error +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class FailureTriageError(WebRunnerException): + """Raised when triage input is malformed or LLM output cannot be parsed.""" + + +# ---------- signal extraction -------------------------------------------- + +_DEFAULT_MAX_STEPS = 12 +_DEFAULT_MAX_CONSOLE = 30 +_DEFAULT_MAX_NETWORK = 20 +_DOM_EXCERPT_CHARS = 4000 +_ERROR_EXCERPT_CHARS = 1500 + + +@dataclass +class TriageSignals: + """Signals distilled from a failure bundle, ready for an LLM prompt.""" + + test_name: str + error_repr: str + error_signature: str + last_steps: List[Any] = field(default_factory=list) + console_tail: List[Dict[str, Any]] = field(default_factory=list) + network_tail: List[Dict[str, Any]] = field(default_factory=list) + dom_excerpt: str = "" + metadata: Dict[str, Any] = field(default_factory=dict) + has_screenshot: bool = False + + +def _slice_tail(items: Sequence[Any], limit: int) -> List[Any]: + if not items: + return [] + return list(items[-limit:]) + + +def _read_bundle_json(files: Dict[str, bytes], rel: str) -> Any: + raw = files.get(rel) + if raw is None: + return None + try: + return json.loads(raw.decode("utf-8")) + except (UnicodeDecodeError, ValueError): + return None + + +def _read_bundle_text(files: Dict[str, bytes], rel: str) -> str: + raw = files.get(rel) + if raw is None: + return "" + try: + return raw.decode("utf-8") + except UnicodeDecodeError: + return "" + + +def extract_signals_from_bundle( + bundle_path: Union[str, Path], + *, + steps: Optional[Sequence[Any]] = None, + max_steps: int = _DEFAULT_MAX_STEPS, + max_console: int = _DEFAULT_MAX_CONSOLE, + max_network: int = _DEFAULT_MAX_NETWORK, +) -> TriageSignals: + """ + 從 failure_bundle zip 抽出 LLM 餵食用的訊號。 + Read a failure bundle written by :class:`FailureBundle`, slice down the + long-tail signals (console / network / steps) and produce a + :class:`TriageSignals` payload. ``steps`` is the action history captured + separately by the runner — pass it in or rely on ``manifest.metadata``. + """ + extracted = extract_bundle(bundle_path) + manifest = extracted["manifest"] + files = extracted["files"] + if not isinstance(manifest, dict): + raise FailureTriageError("bundle manifest is not a dict") + + error_repr = str(manifest.get("error_repr") or "") + console = _read_bundle_json(files, "artifacts/console.json") or [] + network = _read_bundle_json(files, "artifacts/network.json") or [] + dom_html = _read_bundle_text(files, "artifacts/dom.html") + if steps is None: + steps = manifest.get("metadata", {}).get("steps") or [] + has_screenshot = any(name.endswith(".png") for name in files) + + return TriageSignals( + test_name=str(manifest.get("test_name") or ""), + error_repr=error_repr[:_ERROR_EXCERPT_CHARS], + error_signature=normalise_error(error_repr), + last_steps=_slice_tail(list(steps), max_steps), + console_tail=_slice_tail(console if isinstance(console, list) else [], max_console), + network_tail=_slice_tail(network if isinstance(network, list) else [], max_network), + dom_excerpt=dom_html[:_DOM_EXCERPT_CHARS], + metadata=manifest.get("metadata") or {}, + has_screenshot=has_screenshot, + ) + + +# ---------- LLM prompt ---------------------------------------------------- + +_TRIAGE_PROMPT = ( + "You are a senior web-QA engineer doing failure triage. The user has " + "given you the error message, the last few action steps, console + " + "network tails, and a DOM excerpt. Identify the single most likely " + "root cause. Output ONLY a JSON object with these keys (no prose " + "outside the JSON envelope):\n" + " likely_cause: one-sentence summary\n" + " category: one of {{locator, timing, network, assertion, " + "environment, data, browser, unknown}}\n" + " evidence: list of short strings citing the specific signals\n" + " next_steps: ordered list of concrete fix attempts\n" + " suggested_fix: one-paragraph code or config change\n" + " confidence: number in [0, 1]\n\n" + "Test name: {test_name}\n" + "Error signature: {error_signature}\n" + "Error message: {error_repr}\n\n" + "Last steps (most recent last):\n{steps}\n\n" + "Console tail:\n{console}\n\n" + "Network tail:\n{network}\n\n" + "DOM excerpt:\n{dom}\n" +) + +_JSON_OBJECT_RE = re.compile(r"\{.*\}", re.DOTALL) + + +@dataclass +class TriageReport: + """Structured RCA result returned by :func:`triage_failure`.""" + + likely_cause: str + category: str + evidence: List[str] + next_steps: List[str] + suggested_fix: str + confidence: float + test_name: str = "" + error_signature: str = "" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +_ALLOWED_CATEGORIES = frozenset({ + "locator", "timing", "network", "assertion", + "environment", "data", "browser", "unknown", +}) + + +def _parse_triage_payload(text: str) -> Dict[str, Any]: + match = _JSON_OBJECT_RE.search(text) + if match is None: + raise FailureTriageError("LLM did not return a JSON object") + try: + payload = json.loads(match.group(0)) + except ValueError as error: + raise FailureTriageError(f"LLM JSON did not parse: {error}") from error + if not isinstance(payload, dict): + raise FailureTriageError(f"LLM payload is not an object: {type(payload).__name__}") + return payload + + +def _coerce_str_list(value: Any) -> List[str]: + if isinstance(value, list): + return [str(item) for item in value] + if isinstance(value, str): + return [value] + return [] + + +def _coerce_confidence(value: Any) -> float: + try: + score = float(value) + except (TypeError, ValueError): + return 0.0 + if score < 0.0: + return 0.0 + if score > 1.0: + return 1.0 + return score + + +def _coerce_category(value: Any) -> str: + text = str(value or "").strip().lower() + return text if text in _ALLOWED_CATEGORIES else "unknown" + + +def triage_failure(signals: TriageSignals) -> TriageReport: + """ + 呼叫已註冊的 LLM callable 對失敗訊號做根因分析。 + Send ``signals`` through the LLM registered via + :func:`set_llm_callable` and parse the JSON response into a + :class:`TriageReport`. Raises :class:`FailureTriageError` if the response + is missing required keys or has the wrong shape. + """ + prompt = _TRIAGE_PROMPT.format( + test_name=signals.test_name, + error_signature=signals.error_signature or "", + error_repr=signals.error_repr or "", + steps=json.dumps(signals.last_steps, ensure_ascii=False, indent=2)[:2500], + console=json.dumps(signals.console_tail, ensure_ascii=False, indent=2)[:2500], + network=json.dumps(signals.network_tail, ensure_ascii=False, indent=2)[:2500], + dom=signals.dom_excerpt[:_DOM_EXCERPT_CHARS] or "", + ) + try: + raw = _invoke(prompt) + except LLMAssistError as error: + raise FailureTriageError(str(error)) from error + payload = _parse_triage_payload(raw) + missing = {"likely_cause", "evidence", "next_steps", "confidence"} - set(payload) + if missing: + raise FailureTriageError(f"LLM payload missing keys: {sorted(missing)}") + report = TriageReport( + likely_cause=str(payload.get("likely_cause") or "").strip(), + category=_coerce_category(payload.get("category")), + evidence=_coerce_str_list(payload.get("evidence")), + next_steps=_coerce_str_list(payload.get("next_steps")), + suggested_fix=str(payload.get("suggested_fix") or "").strip(), + confidence=_coerce_confidence(payload.get("confidence")), + test_name=signals.test_name, + error_signature=signals.error_signature, + ) + web_runner_logger.info( + f"triage_failure: test={report.test_name!r} category={report.category} " + f"confidence={report.confidence:.2f}" + ) + return report + + +def triage_bundle( + bundle_path: Union[str, Path], + *, + steps: Optional[Sequence[Any]] = None, +) -> TriageReport: + """One-shot helper: extract signals + run triage.""" + signals = extract_signals_from_bundle(bundle_path, steps=steps) + return triage_failure(signals) + + +# ---------- rendering ----------------------------------------------------- + +def render_markdown(report: TriageReport, *, heading_level: int = 2) -> str: + """ + 把 TriageReport 印成適合 PR comment / 報告的 markdown。 + Render a triage report as markdown suitable for ``post_or_update_comment`` + or saving as a standalone file. + """ + h = "#" * max(1, min(heading_level, 6)) + h2 = "#" * max(1, min(heading_level + 1, 6)) + pieces = [ + f"{h} AI Failure Triage — {report.test_name or 'unknown'}", + "", + f"- **Likely cause:** {report.likely_cause or '_unspecified_'}", + f"- **Category:** `{report.category}`", + f"- **Confidence:** {report.confidence:.0%}", + ] + if report.error_signature: + pieces.append(f"- **Error signature:** `{report.error_signature[:160]}`") + pieces.append("") + if report.evidence: + pieces.append(f"{h2} Evidence") + pieces.extend(f"- {line}" for line in report.evidence) + pieces.append("") + if report.next_steps: + pieces.append(f"{h2} Next steps") + pieces.extend(f"{idx}. {line}" for idx, line in enumerate(report.next_steps, 1)) + pieces.append("") + if report.suggested_fix: + pieces.append(f"{h2} Suggested fix") + pieces.append(report.suggested_fix) + pieces.append("") + return "\n".join(pieces).rstrip() + "\n" + + +def save_report(report: TriageReport, output_path: Union[str, Path]) -> Path: + """Persist a report as JSON next to its bundle for later inspection.""" + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as fp: + json.dump(report.to_dict(), fp, ensure_ascii=False, indent=2) + web_runner_logger.info(f"save_report: wrote {path}") + return path diff --git a/je_web_runner/utils/file_system_access/__init__.py b/je_web_runner/utils/file_system_access/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/file_system_access/mock.py b/je_web_runner/utils/file_system_access/mock.py new file mode 100644 index 0000000..80d9cbd --- /dev/null +++ b/je_web_runner/utils/file_system_access/mock.py @@ -0,0 +1,205 @@ +""" +模擬 File System Access API:``showOpenFilePicker`` / ``showSaveFilePicker``。 +這些 API 預設會跳系統檔案對話框 → 在無頭瀏覽器 / Selenium 環境基本沒救。 +此模組產生 JS shim,把對話框替換成「直接給定的 fake file handle」,並 +記錄 app 後續寫入了什麼,讓測試斷言「點了 Save 之後寫入內容是 X」。 +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class FileSystemAccessError(WebRunnerException): + """Raised on bad mock-file definitions or harvest payload.""" + + +# ---------- model ------------------------------------------------------- + +@dataclass(frozen=True) +class MockFile: + """One pre-populated file the picker should return.""" + + name: str + contents: str = "" + mime_type: str = "text/plain" + + def __post_init__(self) -> None: + if not self.name or not isinstance(self.name, str): + raise FileSystemAccessError("MockFile.name must be non-empty string") + if not isinstance(self.contents, str): + raise FileSystemAccessError("MockFile.contents must be a string") + + +@dataclass +class WriteEvent: + """One write the app performed against a mocked save handle.""" + + file_name: str + sequence: int + data: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +# ---------- script generation ------------------------------------------- + +_TEMPLATE = """ +(function() { + if (window.__wr_fsa_installed__) return; + window.__wr_fsa_installed__ = true; + window.__wr_fsa_writes__ = []; + const openFiles = %(open_files)s; + const saveName = %(save_name)s; + let writeSeq = 0; + + function makeFile(spec) { + const blob = new Blob([spec.contents], {type: spec.mime_type}); + return new File([blob], spec.name, {type: spec.mime_type}); + } + + function makeReadHandle(spec) { + return { + kind: 'file', + name: spec.name, + getFile: async function() { return makeFile(spec); } + }; + } + + function makeWriteHandle(name) { + return { + kind: 'file', + name: name, + createWritable: async function() { + return { + write: async function(chunk) { + const text = typeof chunk === 'string' + ? chunk + : (chunk && chunk.data) ? String(chunk.data) : ''; + writeSeq += 1; + window.__wr_fsa_writes__.push({ + file_name: name, sequence: writeSeq, data: text + }); + }, + truncate: async function() {}, + close: async function() {} + }; + }, + getFile: async function() { + return new File([''], name, {type: 'application/octet-stream'}); + } + }; + } + + window.showOpenFilePicker = async function() { + return openFiles.map(makeReadHandle); + }; + window.showSaveFilePicker = async function(opts) { + const finalName = (opts && opts.suggestedName) || saveName || 'untitled.txt'; + return makeWriteHandle(finalName); + }; + window.showDirectoryPicker = async function() { + return { + kind: 'directory', name: 'mocked', + values: async function*() { + for (const spec of openFiles) yield makeReadHandle(spec); + } + }; + }; +})(); +""".strip() + + +def build_install_script( + open_files: Sequence[MockFile] = (), + *, + save_suggested_name: Optional[str] = None, +) -> str: + """Render the JS shim. Inject once per page via init-script.""" + files_payload = [ + {"name": f.name, "contents": f.contents, "mime_type": f.mime_type} + for f in open_files + ] + return _TEMPLATE % { + "open_files": json.dumps(files_payload), + "save_name": json.dumps(save_suggested_name) if save_suggested_name else "null", + } + + +HARVEST_SCRIPT = "return window.__wr_fsa_writes__ || [];" + + +# ---------- harvest ----------------------------------------------------- + +def parse_writes(payload: Any) -> List[WriteEvent]: + """Convert the harvested array into typed :class:`WriteEvent` records.""" + if not isinstance(payload, list): + raise FileSystemAccessError( + f"writes payload must be list, got {type(payload).__name__}" + ) + out: List[WriteEvent] = [] + for raw in payload: + if not isinstance(raw, dict): + continue + try: + out.append(WriteEvent( + file_name=str(raw["file_name"]), + sequence=int(raw["sequence"]), + data=str(raw.get("data") or ""), + )) + except (KeyError, TypeError, ValueError) as error: + raise FileSystemAccessError( + f"malformed write entry {raw!r}: {error}" + ) from error + return out + + +# ---------- assertions -------------------------------------------------- + +def assert_no_writes(writes: Sequence[WriteEvent]) -> None: + """Assert the app did not write anything.""" + if writes: + first = writes[0] + raise FileSystemAccessError( + f"unexpected write to {first.file_name!r}: {first.data[:80]!r}" + ) + + +def assert_wrote( + writes: Sequence[WriteEvent], + *, + file_name: Optional[str] = None, + contains: Optional[str] = None, +) -> WriteEvent: + """Assert at least one write matches name and/or substring.""" + if file_name is None and contains is None: + raise FileSystemAccessError( + "provide at least one of file_name / contains" + ) + for write in writes: + if file_name is not None and write.file_name != file_name: + continue + if contains is not None and contains not in write.data: + continue + return write + raise FileSystemAccessError( + f"no write matched file_name={file_name!r} contains={contains!r} " + f"({len(writes)} writes seen)" + ) + + +def combined_payload( + writes: Sequence[WriteEvent], file_name: str, +) -> str: + """Concatenate every write for one file in sequence order.""" + if not isinstance(file_name, str) or not file_name: + raise FileSystemAccessError("file_name must be non-empty string") + matches = sorted( + (w for w in writes if w.file_name == file_name), + key=lambda w: w.sequence, + ) + return "".join(w.data for w in matches) diff --git a/je_web_runner/utils/flag_matrix/__init__.py b/je_web_runner/utils/flag_matrix/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/flag_matrix/matrix.py b/je_web_runner/utils/flag_matrix/matrix.py new file mode 100644 index 0000000..ee4db38 --- /dev/null +++ b/je_web_runner/utils/flag_matrix/matrix.py @@ -0,0 +1,281 @@ +""" +Feature flag 組合矩陣執行,自動剪掉冗餘 / 不可能組合。 +Brute-force cartesian on N flags blows up fast (3 flags × 3 variants = 27). +This module lets you declare: + +* **flags & variants** — what to permute +* **constraints** — pairs that must / must not appear together +* **pinned combos** — must-include baselines (e.g. "all off", "all on") +* **sample_size** — cap, with deterministic seeded sampling + +It produces a :class:`FlagMatrix` of dict combos that downstream test +runners iterate over. There is also a tiny "result accumulator" to record +pass/fail per combo and pick the minimal failing subset for the report. +""" +from __future__ import annotations + +import itertools +import json +import random +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class FlagMatrixError(WebRunnerException): + """Raised on bad flag definitions, impossible constraints, or sample size.""" + + +Combo = Dict[str, Any] +Constraint = Callable[[Combo], bool] + + +# ---------- definitions ------------------------------------------------- + +@dataclass +class FlagSpec: + """A single flag and the values it can take.""" + + name: str + variants: Sequence[Any] + + def __post_init__(self) -> None: + if not self.name or not isinstance(self.name, str): + raise FlagMatrixError(f"flag name must be a non-empty string, got {self.name!r}") + if not self.variants: + raise FlagMatrixError(f"flag {self.name!r} has no variants") + if len(set(map(repr, self.variants))) != len(self.variants): + raise FlagMatrixError(f"flag {self.name!r} has duplicate variants") + + +@dataclass +class FlagMatrix: + """The materialised set of combos and metadata.""" + + combos: List[Combo] = field(default_factory=list) + total_possible: int = 0 + pinned_count: int = 0 + constrained_out: int = 0 + sampled: bool = False + seed: Optional[int] = None + + def __len__(self) -> int: + return len(self.combos) + + def __iter__(self): + return iter(self.combos) + + +# ---------- builders ---------------------------------------------------- + +def build_matrix( + flags: Sequence[FlagSpec], + *, + constraints: Sequence[Constraint] = (), + pinned: Sequence[Combo] = (), + sample_size: Optional[int] = None, + seed: Optional[int] = None, +) -> FlagMatrix: + """ + Materialise the combo list. ``constraints`` returning False drop a + combo. ``pinned`` always appears at the front. ``sample_size`` (if + set) limits the total combo count via deterministic seeded sampling. + """ + if not flags: + raise FlagMatrixError("at least one FlagSpec is required") + seen_names = [f.name for f in flags] + if len(set(seen_names)) != len(seen_names): + raise FlagMatrixError(f"duplicate flag name in: {seen_names}") + if sample_size is not None and sample_size <= 0: + raise FlagMatrixError("sample_size must be > 0 when provided") + + names = [f.name for f in flags] + variant_lists = [list(f.variants) for f in flags] + total_possible = 1 + for variants in variant_lists: + total_possible *= len(variants) + + all_combos: List[Combo] = [] + for tup in itertools.product(*variant_lists): + combo = dict(zip(names, tup)) + all_combos.append(combo) + + pinned_combos: List[Combo] = [] + pinned_keys = set() + for combo in pinned: + _validate_pinned(combo, names, variant_lists) + key = _combo_key(combo) + if key in pinned_keys: + continue + pinned_keys.add(key) + pinned_combos.append(combo) + + filtered: List[Combo] = [] + constrained_out = 0 + for combo in all_combos: + if _combo_key(combo) in pinned_keys: + continue + if _passes_all(combo, constraints): + filtered.append(combo) + else: + constrained_out += 1 + + if not pinned_combos and not filtered: + raise FlagMatrixError("all combos were filtered out by constraints") + + if sample_size is not None and len(filtered) > max(0, sample_size - len(pinned_combos)): + rng = random.Random(seed) + keep_count = max(0, sample_size - len(pinned_combos)) + filtered = rng.sample(filtered, keep_count) + sampled = True + else: + sampled = False + + combos = pinned_combos + filtered + return FlagMatrix( + combos=combos, + total_possible=total_possible, + pinned_count=len(pinned_combos), + constrained_out=constrained_out, + sampled=sampled, + seed=seed, + ) + + +def _validate_pinned( + combo: Combo, + names: Sequence[str], + variant_lists: Sequence[Sequence[Any]], +) -> None: + if not isinstance(combo, dict): + raise FlagMatrixError(f"pinned combo must be a dict, got {type(combo).__name__}") + if set(combo.keys()) != set(names): + raise FlagMatrixError( + f"pinned combo keys {sorted(combo.keys())} != flag names {sorted(names)}" + ) + for name, variants in zip(names, variant_lists): + if combo[name] not in variants: + raise FlagMatrixError( + f"pinned combo value {combo[name]!r} for flag {name!r} " + f"is not in declared variants {variants!r}" + ) + + +def _combo_key(combo: Combo) -> str: + return json.dumps(combo, sort_keys=True, default=str) + + +def _passes_all(combo: Combo, constraints: Sequence[Constraint]) -> bool: + for constraint in constraints: + try: + if not constraint(combo): + return False + except Exception as error: + raise FlagMatrixError( + f"constraint raised on combo {combo}: {error!r}" + ) from error + return True + + +# ---------- constraint helpers ------------------------------------------ + +def forbid(pair: Tuple[Tuple[str, Any], Tuple[str, Any]]) -> Constraint: + """Block combos containing both ``(flag_a, val_a)`` AND ``(flag_b, val_b)``.""" + (a_flag, a_val), (b_flag, b_val) = pair + + def _constraint(combo: Combo) -> bool: + return not (combo.get(a_flag) == a_val and combo.get(b_flag) == b_val) + return _constraint + + +def require(pair: Tuple[Tuple[str, Any], Tuple[str, Any]]) -> Constraint: + """If ``(flag_a, val_a)`` is set, ``(flag_b, val_b)`` must also be set.""" + (a_flag, a_val), (b_flag, b_val) = pair + + def _constraint(combo: Combo) -> bool: + if combo.get(a_flag) != a_val: + return True + return combo.get(b_flag) == b_val + return _constraint + + +# ---------- results ----------------------------------------------------- + +@dataclass +class ComboResult: + """Outcome of executing one combo.""" + + combo: Combo + passed: bool + duration_seconds: float = 0.0 + error: Optional[str] = None + + +@dataclass +class MatrixReport: + """Roll-up of every :class:`ComboResult`.""" + + total: int + passed: int + failed: int + failures: List[ComboResult] = field(default_factory=list) + average_seconds: float = 0.0 + + +def summarise_results(results: Iterable[ComboResult]) -> MatrixReport: + """Compute counts and pull out the failures.""" + total = 0 + passed = 0 + failures: List[ComboResult] = [] + total_seconds = 0.0 + for result in results: + if not isinstance(result, ComboResult): + raise FlagMatrixError( + f"summarise_results expects ComboResult, got {type(result).__name__}" + ) + total += 1 + total_seconds += result.duration_seconds + if result.passed: + passed += 1 + else: + failures.append(result) + avg = (total_seconds / total) if total else 0.0 + return MatrixReport( + total=total, + passed=passed, + failed=total - passed, + failures=failures, + average_seconds=round(avg, 4), + ) + + +def smallest_failing_subset(failures: Sequence[ComboResult]) -> List[str]: + """ + Pick out the smallest set of flags that, alone, explain every failure. + Greedy minimum-set-cover on ``{flag=value}`` strings. Useful for the + PR comment so reviewers see "all failures involve checkout=v2" rather + than 30 individual combos. + """ + if not failures: + return [] + universe = set(range(len(failures))) + sets: Dict[str, set] = {} + for index, failure in enumerate(failures): + for flag, value in failure.combo.items(): + sets.setdefault(f"{flag}={value!r}", set()).add(index) + chosen: List[str] = [] + covered: set = set() + while covered != universe: + best_key = None + best_gain = -1 + for key, indices in sets.items(): + gain = len(indices - covered) + if gain > best_gain: + best_gain = gain + best_key = key + if best_key is None or best_gain <= 0: + break + chosen.append(best_key) + covered |= sets[best_key] + return chosen diff --git a/je_web_runner/utils/flake_detector/__init__.py b/je_web_runner/utils/flake_detector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/flake_detector/detector.py b/je_web_runner/utils/flake_detector/detector.py new file mode 100644 index 0000000..f216a58 --- /dev/null +++ b/je_web_runner/utils/flake_detector/detector.py @@ -0,0 +1,412 @@ +""" +Flaky test 偵測 + 自動隔離。 + +Flaky-test detection built on top of the ``run_ledger`` history. Two +ideas added on top of the existing :mod:`run_ledger.flaky` heuristic: + +* **Time-decayed flake score** — recent flips count for more than ancient + ones. Score is roughly ``hits / runs`` weighted by half-life decay. +* **Persistent quarantine registry** — JSON file tracking which tests are + currently isolated and why. Stable across CI runs, sortable, releasable + by hand or by ``release_if_stable`` once the score drops. + +Plus a ``@flaky_quarantine`` decorator that defers to the registry at +runtime (skip with reason if the test id is currently quarantined). +""" +from __future__ import annotations + +import functools +import json +import math +import time +from dataclasses import asdict, dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class FlakeDetectorError(WebRunnerException): + """Raised when ledger / registry I/O fails or input is malformed.""" + + +_SECONDS_PER_DAY = 86_400.0 +_DEFAULT_HALF_LIFE_DAYS = 7.0 +_DEFAULT_MIN_RUNS = 3 +_DEFAULT_FLAKE_THRESHOLD = 0.25 + + +@dataclass +class FlakeScore: + """Per-test rollup of the run history.""" + + path: str + runs: int + passes: int + fails: int + pass_rate: float + flake_score: float + last_run: Optional[str] = None + is_flaky: bool = False + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def _load_runs(ledger_path: Union[str, Path]) -> List[Dict[str, Any]]: + path = Path(ledger_path) + if not path.exists(): + return [] + try: + with open(path, encoding="utf-8") as fp: + data = json.load(fp) + except (OSError, ValueError) as error: + raise FlakeDetectorError(f"cannot read ledger {ledger_path}: {error!r}") from error + if not isinstance(data, dict) or "runs" not in data: + raise FlakeDetectorError(f"ledger missing 'runs' key: {ledger_path}") + runs = data.get("runs") + if not isinstance(runs, list): + raise FlakeDetectorError(f"ledger 'runs' is not a list: {ledger_path}") + return [r for r in runs if isinstance(r, dict)] + + +def _parse_run_time(value: Any, fallback_now: float) -> float: + """Best-effort ISO-time → epoch seconds. Unknown formats fall back to ``now``.""" + if not isinstance(value, str) or not value: + return fallback_now + try: + if value.endswith("Z"): + value = value[:-1] + "+00:00" + dt = datetime.fromisoformat(value) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.timestamp() + except ValueError: + return fallback_now + + +def _decay_weight(age_seconds: float, half_life_days: float) -> float: + if half_life_days <= 0: + return 1.0 + half_life_seconds = half_life_days * _SECONDS_PER_DAY + return math.pow(0.5, age_seconds / half_life_seconds) + + +def compute_flake_scores( + ledger_path: Union[str, Path], + *, + half_life_days: float = _DEFAULT_HALF_LIFE_DAYS, + min_runs: int = _DEFAULT_MIN_RUNS, + threshold: float = _DEFAULT_FLAKE_THRESHOLD, + now_epoch: Optional[float] = None, +) -> Dict[str, FlakeScore]: + """ + 從 ledger 歷史計算每個 file 的 time-decayed flake score。 + Produce per-file :class:`FlakeScore` records. A test is flagged ``is_flaky`` + when it has at least ``min_runs`` recorded runs AND its decayed flip rate + exceeds ``threshold``. Pass rate is the *unweighted* ratio so dashboards + stay readable. + """ + runs = _load_runs(ledger_path) + now = now_epoch if now_epoch is not None else time.time() + buckets: Dict[str, Dict[str, Any]] = {} + for run in runs: + path = run.get("path") + if not isinstance(path, str): + continue + record = buckets.setdefault(path, { + "runs": 0, "passes": 0, "fails": 0, + "weight_total": 0.0, "weight_fails": 0.0, + "last_run": None, + }) + record["runs"] += 1 + run_epoch = _parse_run_time(run.get("time"), now) + age = max(0.0, now - run_epoch) + weight = _decay_weight(age, half_life_days) + record["weight_total"] += weight + if run.get("passed"): + record["passes"] += 1 + else: + record["fails"] += 1 + record["weight_fails"] += weight + last_run = run.get("time") + if isinstance(last_run, str): + existing = record["last_run"] + if existing is None or last_run > existing: + record["last_run"] = last_run + + out: Dict[str, FlakeScore] = {} + for path, rec in buckets.items(): + runs_n = rec["runs"] + passes = rec["passes"] + fails = rec["fails"] + pass_rate = (passes / runs_n) if runs_n else 0.0 + weight_total = rec["weight_total"] + flake_score = (rec["weight_fails"] / weight_total) if weight_total else 0.0 + has_both = passes > 0 and fails > 0 + is_flaky = runs_n >= min_runs and has_both and flake_score >= threshold + out[path] = FlakeScore( + path=path, + runs=runs_n, + passes=passes, + fails=fails, + pass_rate=round(pass_rate, 4), + flake_score=round(flake_score, 4), + last_run=rec["last_run"], + is_flaky=is_flaky, + ) + return out + + +def flaky_paths( + ledger_path: Union[str, Path], + *, + half_life_days: float = _DEFAULT_HALF_LIFE_DAYS, + min_runs: int = _DEFAULT_MIN_RUNS, + threshold: float = _DEFAULT_FLAKE_THRESHOLD, +) -> List[str]: + """Return paths whose decayed flake score is at or above ``threshold``.""" + scores = compute_flake_scores( + ledger_path, + half_life_days=half_life_days, + min_runs=min_runs, + threshold=threshold, + ) + flagged = [score for score in scores.values() if score.is_flaky] + flagged.sort(key=lambda s: (-s.flake_score, s.path)) + return [s.path for s in flagged] + + +# ---------- quarantine registry ------------------------------------------ + +@dataclass +class QuarantineEntry: + """One quarantined test record.""" + + test_id: str + reason: str + flake_score: float + quarantined_at: str + runs_when_added: int = 0 + triage_url: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def _utc_now_iso() -> str: + return datetime.now(tz=timezone.utc).isoformat(timespec="seconds") + + +class QuarantineRegistry: + """ + JSON-backed registry of currently quarantined tests. + Stable across CI runs; intended to be checked into git or stored + alongside the ledger so the pytest plugin can read it on every run. + """ + + def __init__(self, registry_path: Union[str, Path]) -> None: + self.registry_path = Path(registry_path) + self._entries: Dict[str, QuarantineEntry] = {} + self._load() + + def _load(self) -> None: + if not self.registry_path.exists(): + return + try: + with open(self.registry_path, encoding="utf-8") as fp: + data = json.load(fp) + except (OSError, ValueError) as error: + raise FlakeDetectorError( + f"cannot read quarantine registry {self.registry_path}: {error!r}" + ) from error + entries = data.get("entries") if isinstance(data, dict) else None + if not isinstance(entries, list): + raise FlakeDetectorError( + f"registry missing 'entries' list: {self.registry_path}" + ) + for entry in entries: + if not isinstance(entry, dict) or "test_id" not in entry: + continue + self._entries[entry["test_id"]] = QuarantineEntry( + test_id=str(entry["test_id"]), + reason=str(entry.get("reason") or ""), + flake_score=float(entry.get("flake_score") or 0.0), + quarantined_at=str(entry.get("quarantined_at") or _utc_now_iso()), + runs_when_added=int(entry.get("runs_when_added") or 0), + triage_url=entry.get("triage_url"), + ) + + def _save(self) -> None: + self.registry_path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "updated_at": _utc_now_iso(), + "entries": [e.to_dict() for e in self._entries.values()], + } + with open(self.registry_path, "w", encoding="utf-8") as fp: + json.dump(payload, fp, ensure_ascii=False, indent=2) + + def is_quarantined(self, test_id: str) -> bool: + return test_id in self._entries + + def get(self, test_id: str) -> Optional[QuarantineEntry]: + return self._entries.get(test_id) + + def add(self, entry: QuarantineEntry) -> None: + self._entries[entry.test_id] = entry + self._save() + web_runner_logger.info( + f"quarantine add: {entry.test_id} reason={entry.reason!r} " + f"score={entry.flake_score:.2f}" + ) + + def remove(self, test_id: str) -> bool: + existing = self._entries.pop(test_id, None) + if existing is None: + return False + self._save() + web_runner_logger.info(f"quarantine remove: {test_id}") + return True + + def list(self) -> List[QuarantineEntry]: + return sorted( + self._entries.values(), + key=lambda e: (-e.flake_score, e.test_id), + ) + + +def quarantine_flaky( + ledger_path: Union[str, Path], + registry_path: Union[str, Path], + *, + half_life_days: float = _DEFAULT_HALF_LIFE_DAYS, + min_runs: int = _DEFAULT_MIN_RUNS, + threshold: float = _DEFAULT_FLAKE_THRESHOLD, + reason_template: str = "auto: flake_score={score:.2f} after {runs} runs", +) -> List[str]: + """ + 自動把 flake score ≥ threshold 的 test 加入 quarantine registry。 + Walk the ledger, score each test, and write any newly-flaky tests into + the registry. Returns the list of newly-quarantined test ids (already- + quarantined tests are left alone — their original metadata persists). + """ + scores = compute_flake_scores( + ledger_path, + half_life_days=half_life_days, + min_runs=min_runs, + threshold=threshold, + ) + registry = QuarantineRegistry(registry_path) + newly_added: List[str] = [] + for score in scores.values(): + if not score.is_flaky: + continue + if registry.is_quarantined(score.path): + continue + entry = QuarantineEntry( + test_id=score.path, + reason=reason_template.format(score=score.flake_score, runs=score.runs), + flake_score=score.flake_score, + quarantined_at=_utc_now_iso(), + runs_when_added=score.runs, + ) + registry.add(entry) + newly_added.append(score.path) + return newly_added + + +def release_if_stable( + ledger_path: Union[str, Path], + registry_path: Union[str, Path], + *, + half_life_days: float = _DEFAULT_HALF_LIFE_DAYS, + release_threshold: float = 0.05, + min_runs_since: int = 5, +) -> List[str]: + """ + 放出 flake score 已穩定下降到 ``release_threshold`` 以下的 quarantine test。 + Promote stable tests out of quarantine: each entry whose current score + is below ``release_threshold`` AND has been observed ``min_runs_since`` + times in the ledger is removed. Returns the released test ids. + """ + scores = compute_flake_scores( + ledger_path, + half_life_days=half_life_days, + min_runs=min_runs_since, + threshold=release_threshold + 1.0, # ensure is_flaky=False is meaningful + ) + registry = QuarantineRegistry(registry_path) + released: List[str] = [] + for entry in registry.list(): + current = scores.get(entry.test_id) + if current is None: + continue + if current.runs < min_runs_since: + continue + if current.flake_score <= release_threshold: + if registry.remove(entry.test_id): + released.append(entry.test_id) + return released + + +# ---------- decorator ----------------------------------------------------- + +def flaky_quarantine( + test_id: str, + registry_path: Union[str, Path], + *, + skip_when_quarantined: bool = True, +) -> Callable: + """ + Decorator:執行前查 quarantine registry,若被隔離則 skip 並標明原因。 + Wrap a callable (typically a pytest test function). At call time, look + up the registry; if the test id is quarantined, skip with the reason + string when ``skip_when_quarantined`` is true, else just log it and run. + + Skipping uses ``pytest.skip`` when pytest is importable; falls back to + raising :class:`FlakeDetectorError` otherwise so non-pytest harnesses can + detect and handle the quarantine themselves. + """ + def decorator(fn: Callable) -> Callable: + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + registry = QuarantineRegistry(registry_path) + entry = registry.get(test_id) + if entry is None: + return fn(*args, **kwargs) + web_runner_logger.warning( + f"flaky_quarantine: {test_id} is quarantined ({entry.reason})" + ) + if not skip_when_quarantined: + return fn(*args, **kwargs) + try: + import pytest # local import keeps decorator pytest-optional + except ImportError: + raise FlakeDetectorError( + f"test {test_id!r} is quarantined: {entry.reason}" + ) + pytest.skip(f"flaky-quarantine: {entry.reason}") + return None + return wrapper + return decorator + + +# ---------- reporting ---------------------------------------------------- + +def quarantine_report_markdown(registry: QuarantineRegistry) -> str: + """Render the current quarantine list as a markdown table.""" + entries = registry.list() + if not entries: + return "_No quarantined tests._\n" + rows = [ + "| Test | Score | Reason | Since |", + "|------|-------|--------|-------|", + ] + for entry in entries: + rows.append( + f"| `{entry.test_id}` | {entry.flake_score:.2f} | " + f"{entry.reason} | {entry.quarantined_at} |" + ) + return "\n".join(rows) + "\n" diff --git a/je_web_runner/utils/forced_colors_mode/__init__.py b/je_web_runner/utils/forced_colors_mode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/forced_colors_mode/modes.py b/je_web_runner/utils/forced_colors_mode/modes.py new file mode 100644 index 0000000..2c6dd2d --- /dev/null +++ b/je_web_runner/utils/forced_colors_mode/modes.py @@ -0,0 +1,223 @@ +""" +High-contrast / dark-mode / reduced-motion / forced-colors 矩陣驗證。 +The four CSS media queries that most app teams forget: + +* ``prefers-color-scheme: dark`` +* ``prefers-reduced-motion: reduce`` +* ``forced-colors: active`` (Windows High Contrast) +* ``prefers-contrast: more`` + +This module: + +1. Builds the CDP ``Emulation.setEmulatedMedia`` payload for any combo. +2. Defines a default matrix of "important" combos plus a knob for + restricting it (e.g. CI does dark-mode only). +3. Diffs visible CSS properties (computed background / color / outline) + between modes to catch "white-on-white text in high-contrast" bugs. + +CDP application is delegated to a user-supplied callable so the module +stays driver-agnostic. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class ForcedColorsModeError(WebRunnerException): + """Raised on invalid mode combo, bad CSS payload, or CDP failure.""" + + +class ColorScheme(str, Enum): + LIGHT = "light" + DARK = "dark" + + +class ReducedMotion(str, Enum): + NO_PREFERENCE = "no-preference" + REDUCE = "reduce" + + +class ForcedColors(str, Enum): + NONE = "none" + ACTIVE = "active" + + +class Contrast(str, Enum): + NO_PREFERENCE = "no-preference" + MORE = "more" + LESS = "less" + + +# ---------- profile model ---------------------------------------------- + +@dataclass(frozen=True) +class MediaProfile: + """One full CSS-media combo.""" + + name: str + color_scheme: ColorScheme = ColorScheme.LIGHT + reduced_motion: ReducedMotion = ReducedMotion.NO_PREFERENCE + forced_colors: ForcedColors = ForcedColors.NONE + contrast: Contrast = Contrast.NO_PREFERENCE + + def to_cdp_features(self) -> List[Dict[str, str]]: + """Render the ``features`` payload for ``Emulation.setEmulatedMedia``.""" + return [ + {"name": "prefers-color-scheme", "value": self.color_scheme.value}, + {"name": "prefers-reduced-motion", "value": self.reduced_motion.value}, + {"name": "forced-colors", "value": self.forced_colors.value}, + {"name": "prefers-contrast", "value": self.contrast.value}, + ] + + +DEFAULT_PROFILES: Sequence[MediaProfile] = ( + MediaProfile(name="baseline"), + MediaProfile(name="dark", color_scheme=ColorScheme.DARK), + MediaProfile(name="reduced-motion", reduced_motion=ReducedMotion.REDUCE), + MediaProfile(name="high-contrast", + forced_colors=ForcedColors.ACTIVE, + contrast=Contrast.MORE), + MediaProfile(name="dark-high-contrast", + color_scheme=ColorScheme.DARK, + forced_colors=ForcedColors.ACTIVE, + contrast=Contrast.MORE), +) + + +# ---------- CDP integration -------------------------------------------- + +CdpEmulate = Callable[[List[Dict[str, str]]], Any] +"""Callable that pushes a features list to ``Emulation.setEmulatedMedia``.""" + + +def apply_profile(profile: MediaProfile, cdp_emulate: CdpEmulate) -> Any: + """Hand the profile's features to the user's CDP-emulate callable.""" + if not isinstance(profile, MediaProfile): + raise ForcedColorsModeError("apply_profile expects MediaProfile") + try: + return cdp_emulate(profile.to_cdp_features()) + except Exception as error: + raise ForcedColorsModeError( + f"CDP setEmulatedMedia failed: {error!r}" + ) from error + + +# ---------- per-element style snapshot --------------------------------- + +@dataclass(frozen=True) +class StyleSnapshot: + """Subset of computed styles we compare across modes.""" + + background_color: str + color: str + outline: str = "" + border_color: str = "" + visibility: str = "visible" + + def is_invisible(self) -> bool: + """Heuristic: same colour as background = invisible.""" + return ( + self.background_color.strip().lower() == self.color.strip().lower() + and self.background_color.strip() != "" + ) + + +@dataclass +class ElementDiff: + """Difference for one element between two modes.""" + + selector: str + baseline_mode: str + other_mode: str + became_invisible: bool + changed_fields: Dict[str, Any] = field(default_factory=dict) + + +def diff_snapshot( + selector: str, + baseline_mode: str, + other_mode: str, + baseline: StyleSnapshot, + other: StyleSnapshot, +) -> Optional[ElementDiff]: + """Return a :class:`ElementDiff` iff the snapshots meaningfully differ.""" + if not isinstance(baseline, StyleSnapshot) or not isinstance(other, StyleSnapshot): + raise ForcedColorsModeError("snapshots must be StyleSnapshot instances") + changed: Dict[str, Any] = {} + for field_name in asdict(baseline): + a = getattr(baseline, field_name) + b = getattr(other, field_name) + if a != b: + changed[field_name] = {"baseline": a, "other": b} + became_invisible = other.is_invisible() and not baseline.is_invisible() + if not changed and not became_invisible: + return None + return ElementDiff( + selector=selector, + baseline_mode=baseline_mode, + other_mode=other_mode, + became_invisible=became_invisible, + changed_fields=changed, + ) + + +# ---------- matrix audit ------------------------------------------------ + +@dataclass +class ModeAuditReport: + """Roll-up returned by :func:`audit_modes`.""" + + diffs: List[ElementDiff] = field(default_factory=list) + invisible_in_modes: Dict[str, List[str]] = field(default_factory=dict) + + def passed(self) -> bool: + return not self.invisible_in_modes + + +def audit_modes( + baseline_mode: str, + snapshots_by_mode: Dict[str, Dict[str, StyleSnapshot]], +) -> ModeAuditReport: + """ + Given per-mode { selector → StyleSnapshot }, diff every non-baseline + mode against the baseline. Selectors that become invisible are + flagged as failures; other diffs are recorded for review. + """ + if baseline_mode not in snapshots_by_mode: + raise ForcedColorsModeError( + f"baseline_mode {baseline_mode!r} not in snapshots_by_mode" + ) + baseline = snapshots_by_mode[baseline_mode] + report = ModeAuditReport() + for mode, snapshots in snapshots_by_mode.items(): + if mode == baseline_mode: + continue + for selector, snap in snapshots.items(): + if selector not in baseline: + continue + diff = diff_snapshot( + selector, baseline_mode, mode, baseline[selector], snap, + ) + if diff is None: + continue + report.diffs.append(diff) + if diff.became_invisible: + report.invisible_in_modes.setdefault(mode, []).append(selector) + return report + + +def assert_no_invisible(report: ModeAuditReport) -> None: + """Raise if any element became invisible in any non-baseline mode.""" + if not isinstance(report, ModeAuditReport): + raise ForcedColorsModeError("assert_no_invisible expects ModeAuditReport") + if report.passed(): + return + parts = ", ".join( + f"{mode}: {len(selectors)} element(s)" + for mode, selectors in report.invisible_in_modes.items() + ) + raise ForcedColorsModeError(f"elements became invisible — {parts}") diff --git a/je_web_runner/utils/git_bisect_flake/__init__.py b/je_web_runner/utils/git_bisect_flake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/git_bisect_flake/bisect.py b/je_web_runner/utils/git_bisect_flake/bisect.py new file mode 100644 index 0000000..1f6b387 --- /dev/null +++ b/je_web_runner/utils/git_bisect_flake/bisect.py @@ -0,0 +1,242 @@ +""" +自動 bisect ledger,找出造成某 test 開始失敗的 regression commit。 +Manual ``git bisect`` is great but slow: you check out, run, mark good/ +bad, repeat. When the ledger already has per-commit pass/fail rows, the +bisect can be data-driven — pick the boundary commit just before the +first failure, with no checkouts at all. + +Two modes: + +* **Data-only bisect** (offline) — looks at the ledger; useful when the + failing test ran on every commit (CI matrix). No git access needed. +* **Re-run bisect** (online) — needs a ``CommitProbe`` callable that can + check out a commit and re-run the test. Classic git-bisect with the + ledger guiding initial bounds so fewer probes are needed. +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class GitBisectFlakeError(WebRunnerException): + """Raised on malformed ledger, missing test history, or probe failure.""" + + +# ---------- ledger model ----------------------------------------------- + +@dataclass(frozen=True) +class LedgerEntry: + """One pass/fail row from the run ledger.""" + + commit: str + test_id: str + passed: bool + time: str = "" + + def __post_init__(self) -> None: + if not self.commit or not isinstance(self.commit, str): + raise GitBisectFlakeError("LedgerEntry.commit must be non-empty string") + if not self.test_id or not isinstance(self.test_id, str): + raise GitBisectFlakeError("LedgerEntry.test_id must be non-empty string") + + +def load_ledger(path: Union[str, Path]) -> List[LedgerEntry]: + """Read the standard ledger JSON. Schema: ``{"runs": [{commit, path/test_id, passed}]}``.""" + p = Path(path) + if not p.exists(): + raise GitBisectFlakeError(f"ledger not found: {p}") + try: + data = json.loads(p.read_text(encoding="utf-8")) + except ValueError as error: + raise GitBisectFlakeError(f"ledger not JSON: {error}") from error + if not isinstance(data, dict) or "runs" not in data: + raise GitBisectFlakeError("ledger missing 'runs' key") + runs = data["runs"] + if not isinstance(runs, list): + raise GitBisectFlakeError("ledger 'runs' must be a list") + entries: List[LedgerEntry] = [] + for raw in runs: + if not isinstance(raw, dict): + continue + test_id = raw.get("test_id") or raw.get("path") + commit = raw.get("commit") + if not isinstance(test_id, str) or not isinstance(commit, str): + continue + entries.append(LedgerEntry( + commit=commit, + test_id=test_id, + passed=bool(raw.get("passed")), + time=str(raw.get("time") or ""), + )) + return entries + + +# ---------- data-only bisect ------------------------------------------- + +@dataclass +class BisectResult: + """Outcome of either bisect mode.""" + + test_id: str + last_good_commit: Optional[str] + first_bad_commit: Optional[str] + probes: int = 0 + method: str = "ledger" # 'ledger' | 'probe' + history: List[Dict[str, Any]] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def bisect_from_ledger( + entries: Sequence[LedgerEntry], + commit_order: Sequence[str], + test_id: str, +) -> BisectResult: + """ + Walk ``commit_order`` (oldest → newest) and return the boundary where + ``test_id`` flips from pass to fail. If the test was already failing at + the oldest known commit, ``last_good_commit`` is ``None``. + """ + if not entries: + raise GitBisectFlakeError("entries must be a non-empty sequence") + if not commit_order: + raise GitBisectFlakeError("commit_order must be a non-empty sequence") + if not test_id: + raise GitBisectFlakeError("test_id must be a non-empty string") + by_commit: Dict[str, LedgerEntry] = {} + for entry in entries: + if entry.test_id != test_id: + continue + by_commit[entry.commit] = entry + if not by_commit: + raise GitBisectFlakeError(f"no ledger rows for test_id {test_id!r}") + last_good: Optional[str] = None + first_bad: Optional[str] = None + history: List[Dict[str, Any]] = [] + for commit in commit_order: + entry = by_commit.get(commit) + if entry is None: + continue + history.append({"commit": commit, "passed": entry.passed}) + if entry.passed: + last_good = commit + first_bad = None + elif last_good is not None: + first_bad = commit + break + elif first_bad is None: + first_bad = commit # earliest known failure when no good commit seen + return BisectResult( + test_id=test_id, + last_good_commit=last_good, + first_bad_commit=first_bad, + probes=0, + method="ledger", + history=history, + ) + + +# ---------- probe-driven bisect ---------------------------------------- + +CommitProbe = Callable[[str], bool] +"""Callable: commit-sha → True if the test passes when run at that commit.""" + + +def bisect_with_probe( + commit_order: Sequence[str], + test_id: str, + probe: CommitProbe, + *, + known_good: Optional[str] = None, + known_bad: Optional[str] = None, +) -> BisectResult: + """ + Classic bisect using ``probe``. ``known_good`` / ``known_bad`` clamp + the search window (typical use: feed them from a prior ledger bisect + so we converge faster). + """ + if len(commit_order) < 2: + raise GitBisectFlakeError("commit_order needs at least 2 commits") + if not test_id: + raise GitBisectFlakeError("test_id must be non-empty") + indices_by_commit = {c: i for i, c in enumerate(commit_order)} + low = 0 + high = len(commit_order) - 1 + if known_good is not None: + if known_good not in indices_by_commit: + raise GitBisectFlakeError(f"known_good {known_good!r} not in commit_order") + low = indices_by_commit[known_good] + if known_bad is not None: + if known_bad not in indices_by_commit: + raise GitBisectFlakeError(f"known_bad {known_bad!r} not in commit_order") + high = indices_by_commit[known_bad] + if low >= high: + raise GitBisectFlakeError("known_good must come before known_bad in commit_order") + + probes = 0 + history: List[Dict[str, Any]] = [] + while high - low > 1: + mid = (low + high) // 2 + commit = commit_order[mid] + try: + passed = bool(probe(commit)) + except Exception as error: + raise GitBisectFlakeError( + f"probe failed at {commit}: {error!r}" + ) from error + probes += 1 + history.append({"commit": commit, "passed": passed}) + web_runner_logger.info( + f"git_bisect_flake probe {probes}: {commit[:8]} passed={passed}" + ) + if passed: + low = mid + else: + high = mid + last_good = commit_order[low] + first_bad = commit_order[high] + return BisectResult( + test_id=test_id, + last_good_commit=last_good, + first_bad_commit=first_bad, + probes=probes, + method="probe", + history=history, + ) + + +# ---------- reporting -------------------------------------------------- + +def report_markdown(result: BisectResult) -> str: + """Render the result as a small markdown block for PR comments.""" + if not isinstance(result, BisectResult): + raise GitBisectFlakeError("report_markdown expects BisectResult") + lines = [ + f"### git-bisect for `{result.test_id}` ({result.method}, {result.probes} probes)", + "", + ] + if result.first_bad_commit is None: + lines.append("_Test has not flipped in the observed range._") + else: + if result.last_good_commit: + lines.append(f"- Last good commit: `{result.last_good_commit}`") + else: + lines.append("- No good commit observed in window.") + lines.append(f"- First bad commit: `{result.first_bad_commit}`") + if result.history: + lines.append("") + lines.append("| Commit | Passed |") + lines.append("|--------|--------|") + for entry in result.history[:10]: + mark = "✓" if entry.get("passed") else "✗" + lines.append(f"| `{str(entry.get('commit'))[:10]}` | {mark} |") + if len(result.history) > 10: + lines.append(f"_({len(result.history) - 10} earlier rows hidden)_") + return "\n".join(lines) + "\n" diff --git a/je_web_runner/utils/grpc_tester/__init__.py b/je_web_runner/utils/grpc_tester/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/grpc_tester/client.py b/je_web_runner/utils/grpc_tester/client.py new file mode 100644 index 0000000..deb8397 --- /dev/null +++ b/je_web_runner/utils/grpc_tester/client.py @@ -0,0 +1,249 @@ +""" +gRPC / gRPC-Web client harness with request/response capture for E2E +integration testing. +Two paths: + +* **Real gRPC** — wraps an injectable ``grpc`` channel; you provide the + service stub (generated from .proto) and we record every call. +* **gRPC-Web** — pure-HTTP using ``requests``: build a length-prefixed + payload, decode the trailer. Useful when the SUT is a browser-style + client; doesn't need protoc generation for raw byte tests. + +Both expose the same :class:`GrpcCall` recorder + asserts so suites are +transport-portable. +""" +from __future__ import annotations + +import base64 +import struct +import time +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class GrpcTesterError(WebRunnerException): + """Raised on malformed input / call failure / failed assertion.""" + + +# ---------- model ------------------------------------------------------ + +class GrpcStatus(int, Enum): + """Standard gRPC status codes.""" + + OK = 0 + CANCELLED = 1 + UNKNOWN = 2 + INVALID_ARGUMENT = 3 + DEADLINE_EXCEEDED = 4 + NOT_FOUND = 5 + ALREADY_EXISTS = 6 + PERMISSION_DENIED = 7 + UNAUTHENTICATED = 16 + + +@dataclass +class GrpcCall: + """One recorded gRPC / gRPC-Web call.""" + + method: str + request: Any + response: Any + status: GrpcStatus + duration_ms: float + metadata: Dict[str, str] = field(default_factory=dict) + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "status": self.status.name} + + +# ---------- recorder ---------------------------------------------------- + +class GrpcCallRecorder: + """In-memory recorder of calls.""" + + def __init__(self) -> None: + self._calls: List[GrpcCall] = [] + + def __len__(self) -> int: + return len(self._calls) + + def record(self, call: GrpcCall) -> None: + if not isinstance(call, GrpcCall): + raise GrpcTesterError( + f"record() expects GrpcCall, got {type(call).__name__}" + ) + self._calls.append(call) + + def clear(self) -> None: + self._calls.clear() + + def calls( + self, + *, + method: Optional[str] = None, + status: Optional[GrpcStatus] = None, + ) -> List[GrpcCall]: + out: List[GrpcCall] = [] + for c in self._calls: + if method is not None and c.method != method: + continue + if status is not None and c.status != status: + continue + out.append(c) + return out + + +# ---------- callable wrapper ------------------------------------------ + +def call( + method: str, + stub_method: Callable[..., Any], + request: Any, + *, + recorder: Optional[GrpcCallRecorder] = None, + metadata: Optional[Sequence[Tuple[str, str]]] = None, + timeout: Optional[float] = None, +) -> GrpcCall: + """ + Call a generated gRPC stub method, capturing response / status. + Returns the :class:`GrpcCall` whether the call succeeded or raised. + """ + if not isinstance(method, str) or not method: + raise GrpcTesterError("method must be non-empty string") + if not callable(stub_method): + raise GrpcTesterError("stub_method must be callable") + metadata = list(metadata or []) + started = time.monotonic() + status = GrpcStatus.OK + response = None + error: Optional[str] = None + try: + kwargs: Dict[str, Any] = {} + if metadata: + kwargs["metadata"] = metadata + if timeout is not None: + kwargs["timeout"] = timeout + response = stub_method(request, **kwargs) + except Exception as exc: + # Try to read .code() like grpc.RpcError; fall back to UNKNOWN. + code_obj = getattr(exc, "code", None) + code_val = code_obj() if callable(code_obj) else code_obj + status = _coerce_status(code_val) + error = repr(exc) + duration = round((time.monotonic() - started) * 1000.0, 3) + record = GrpcCall( + method=method, request=request, response=response, + status=status, duration_ms=duration, + metadata={k: v for k, v in metadata}, error=error, + ) + if recorder is not None: + recorder.record(record) + return record + + +def _coerce_status(value: Any) -> GrpcStatus: + if value is None: + return GrpcStatus.UNKNOWN + if isinstance(value, GrpcStatus): + return value + if isinstance(value, int): + try: + return GrpcStatus(value) + except ValueError: + return GrpcStatus.UNKNOWN + # grpc.StatusCode has a .value tuple (int, str) + code = getattr(value, "value", None) + if isinstance(code, tuple) and code and isinstance(code[0], int): + try: + return GrpcStatus(code[0]) + except ValueError: + return GrpcStatus.UNKNOWN + return GrpcStatus.UNKNOWN + + +# ---------- gRPC-Web framing ------------------------------------------ + +def encode_grpc_web_message(payload: bytes) -> bytes: + """Length-prefix-frame a raw payload (compression flag 0).""" + if not isinstance(payload, (bytes, bytearray)): + raise GrpcTesterError("payload must be bytes") + return b"\x00" + struct.pack(">I", len(payload)) + bytes(payload) + + +def decode_grpc_web_message(framed: bytes) -> List[Tuple[int, bytes]]: + """Decode a (possibly multi-message) framed gRPC-Web body.""" + if not isinstance(framed, (bytes, bytearray)): + raise GrpcTesterError("framed must be bytes") + out: List[Tuple[int, bytes]] = [] + pos = 0 + buf = bytes(framed) + while pos < len(buf): + if len(buf) - pos < 5: + raise GrpcTesterError(f"truncated frame at offset {pos}") + flag = buf[pos] + length = struct.unpack(">I", buf[pos + 1:pos + 5])[0] + end = pos + 5 + length + if end > len(buf): + raise GrpcTesterError(f"frame length overruns buffer at {pos}") + out.append((flag, buf[pos + 5:end])) + pos = end + return out + + +def parse_trailer(trailer_bytes: bytes) -> Dict[str, str]: + """Parse a ``grpc-status`` / ``grpc-message`` trailer payload.""" + if not isinstance(trailer_bytes, (bytes, bytearray)): + raise GrpcTesterError("trailer_bytes must be bytes") + text = bytes(trailer_bytes).decode("utf-8", errors="replace") + out: Dict[str, str] = {} + for line in text.split("\r\n"): + line = line.strip() + if not line or ":" not in line: + continue + name, _, value = line.partition(":") + out[name.strip().lower()] = value.strip() + return out + + +# ---------- assertions ------------------------------------------------- + +def assert_call_ok(call_record: GrpcCall) -> None: + """Assert ``call_record.status == OK``.""" + if not isinstance(call_record, GrpcCall): + raise GrpcTesterError("expects GrpcCall") + if call_record.status != GrpcStatus.OK: + raise GrpcTesterError( + f"call {call_record.method!r} returned {call_record.status.name}: " + f"{call_record.error or 'no error message'}" + ) + + +def assert_call_fails(call_record: GrpcCall, *, status: GrpcStatus) -> None: + """Assert a specific non-OK status.""" + if not isinstance(call_record, GrpcCall): + raise GrpcTesterError("expects GrpcCall") + if not isinstance(status, GrpcStatus): + raise GrpcTesterError("status must be GrpcStatus") + if call_record.status != status: + raise GrpcTesterError( + f"expected {status.name}, got {call_record.status.name}" + ) + + +def assert_called( + recorder: GrpcCallRecorder, + method: str, + *, + minimum: int = 1, +) -> int: + """Assert a method was invoked at least ``minimum`` times.""" + count = len(recorder.calls(method=method)) + if count < minimum: + raise GrpcTesterError( + f"method {method!r} called {count} times, want >= {minimum}" + ) + return count diff --git a/je_web_runner/utils/hydration_check/__init__.py b/je_web_runner/utils/hydration_check/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/hydration_check/check.py b/je_web_runner/utils/hydration_check/check.py new file mode 100644 index 0000000..6aefec9 --- /dev/null +++ b/je_web_runner/utils/hydration_check/check.py @@ -0,0 +1,159 @@ +""" +SSR hydration mismatch 偵測。 +React 18, Next.js, Nuxt 3, Remix, SvelteKit 都會印 hydration error 到 +console — 但 prod 預設靜默。常見錯誤: + +* 伺服器渲染的 markup 跟 client 第一次 hydration 結果不同 +* ``new Date()`` / ``Math.random()`` 在 server vs client 結果不同 +* Provider 用 `useState(window.x)` 之類 SSR-incompatible 寫法 + +This module: + +* Compares the *raw server HTML* (fetched as bytes) against the + *post-hydration DOM* (innerHTML snapshot). It normalises whitespace, + React data-attribs (``data-reactroot``), and Vue / SvelteKit hashes. +* Parses console messages for known hydration error markers and surfaces + them as findings. +""" +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class HydrationCheckError(WebRunnerException): + """Raised on malformed input or failed assertion.""" + + +# Markers each framework prints on hydration mismatch. +_HYDRATION_MARKERS = ( + "hydration failed", # React 18 + "did not match", # React 17/18 + "expected server html to contain", # React + "hydration mismatch", # Vue 3, Svelte + "skipping hydration", # Astro + "error while hydrating", # Nuxt + "text content does not match server-rendered html", # React 18 +) + + +@dataclass(frozen=True) +class HydrationFinding: + """One detected hydration problem.""" + + kind: str # 'console' | 'dom_diff' + detail: str + source: str = "" + + +# ---------- console scan ---------------------------------------------- + +def scan_console(messages: Iterable[str]) -> List[HydrationFinding]: + """Pull hydration-related lines out of console messages.""" + findings: List[HydrationFinding] = [] + for line in messages: + if not isinstance(line, str): + continue + lower = line.lower() + for marker in _HYDRATION_MARKERS: + if marker in lower: + findings.append(HydrationFinding( + kind="console", detail=line.strip()[:200], source=marker, + )) + break + return findings + + +# ---------- DOM diff -------------------------------------------------- + +_WS = re.compile(r"\s+") +_WS_AROUND_TAGS = re.compile(r"\s*(<[^>]+>)\s*") +_COMMENT = re.compile(r"", re.DOTALL) +_FRAMEWORK_ATTRS = re.compile( + r"\s+(?:data-reactroot|data-reactid|data-react-helmet|data-n-head|" + r"data-v-[a-f0-9]+|data-svelte-h)\b(?:=\"[^\"]*\")?", + re.IGNORECASE, +) +_SCRIPT_BLOCK = re.compile(r"]*>.*?", re.DOTALL | re.IGNORECASE) + + +def _normalise_html(html: str) -> str: + text = _SCRIPT_BLOCK.sub("", html) + text = _COMMENT.sub("", text) # also removes React's / markers + text = _FRAMEWORK_ATTRS.sub("", text) + text = _WS_AROUND_TAGS.sub(r"\1", text) + text = _WS.sub(" ", text).strip().lower() + return text + + +def diff_dom(server_html: str, client_html: str) -> List[HydrationFinding]: + """ + Compare server-rendered HTML to post-hydration HTML. + Returns findings only if the *normalised* representations differ. + """ + if not isinstance(server_html, str) or not isinstance(client_html, str): + raise HydrationCheckError("server_html and client_html must be strings") + server_n = _normalise_html(server_html) + client_n = _normalise_html(client_html) + if server_n == client_n: + return [] + # Find the first diverging chunk for a useful detail string. + common = 0 + while ( + common < len(server_n) and common < len(client_n) + and server_n[common] == client_n[common] + ): + common += 1 + s_excerpt = server_n[common:common + 80] + c_excerpt = client_n[common:common + 80] + return [HydrationFinding( + kind="dom_diff", + detail=f"diverged at char {common}: server={s_excerpt!r} client={c_excerpt!r}", + )] + + +# ---------- combined -------------------------------------------------- + +@dataclass +class HydrationReport: + """Combined console + DOM-diff finding set.""" + + findings: List[HydrationFinding] = field(default_factory=list) + + def passed(self) -> bool: + return not self.findings + + def by_kind(self) -> Dict[str, int]: + out: Dict[str, int] = {} + for f in self.findings: + out[f.kind] = out.get(f.kind, 0) + 1 + return out + + +def audit( + *, + server_html: Optional[str] = None, + client_html: Optional[str] = None, + console_messages: Optional[Iterable[str]] = None, +) -> HydrationReport: + """Run all available checks. Either pair of inputs may be ``None``.""" + findings: List[HydrationFinding] = [] + if server_html is not None and client_html is not None: + findings.extend(diff_dom(server_html, client_html)) + if console_messages is not None: + findings.extend(scan_console(console_messages)) + return HydrationReport(findings=findings) + + +def assert_no_mismatch(report: HydrationReport) -> None: + """Raise unless ``report`` is clean.""" + if not isinstance(report, HydrationReport): + raise HydrationCheckError("assert_no_mismatch expects HydrationReport") + if report.passed(): + return + sample = ", ".join(f"{f.kind}:{f.detail[:60]}" for f in report.findings[:3]) + more = "" if len(report.findings) <= 3 else f" (+{len(report.findings) - 3})" + raise HydrationCheckError(f"hydration mismatch detected: {sample}{more}") diff --git a/je_web_runner/utils/idempotency_check/__init__.py b/je_web_runner/utils/idempotency_check/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/idempotency_check/check.py b/je_web_runner/utils/idempotency_check/check.py new file mode 100644 index 0000000..f8c3ede --- /dev/null +++ b/je_web_runner/utils/idempotency_check/check.py @@ -0,0 +1,175 @@ +""" +同一個請求送兩次,結果應相同(訂單 / 付款 / `POST /transfer` 最常見 bug)。 +Strategy: caller supplies a "request runner" callable. We invoke it +twice (optionally with the same ``Idempotency-Key`` header), then +compare the two :class:`IdemResponse` records on three axes: + +* status code & body shape +* state mutation (an optional state-probe callable) +* side-effect count (downstream rows / webhooks / emails) +""" +from __future__ import annotations + +import hashlib +import json +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class IdempotencyCheckError(WebRunnerException): + """Raised on bad inputs or detected non-idempotency.""" + + +# ---------- model ------------------------------------------------------ + +@dataclass +class IdemResponse: + """Snapshot of one request's response.""" + + status_code: int + body: Any + side_effect_count: int = 0 + + def body_hash(self) -> str: + try: + serialised = json.dumps(self.body, sort_keys=True, default=str) + except (TypeError, ValueError): + serialised = repr(self.body) + return hashlib.sha256(serialised.encode("utf-8")).hexdigest() + + +@dataclass +class IdempotencyReport: + """Outcome of :func:`check`.""" + + first: IdemResponse + second: IdemResponse + state_before_first: Optional[Any] = None + state_after_first: Optional[Any] = None + state_after_second: Optional[Any] = None + violations: List[str] = field(default_factory=list) + + def passed(self) -> bool: + return not self.violations + + def to_dict(self) -> Dict[str, Any]: + return { + "first": asdict(self.first), + "second": asdict(self.second), + "state_before_first": self.state_before_first, + "state_after_first": self.state_after_first, + "state_after_second": self.state_after_second, + "violations": list(self.violations), + "passed": self.passed(), + } + + +# ---------- check ------------------------------------------------------ + +RequestRunner = Callable[[], IdemResponse] +StateProbe = Callable[[], Any] + + +def check( + request_runner: RequestRunner, + *, + state_probe: Optional[StateProbe] = None, + allow_status_change_to: Optional[Sequence[int]] = None, + ignore_body_keys: Sequence[str] = (), +) -> IdempotencyReport: + """ + Run twice + compare. ``allow_status_change_to`` covers servers that + legitimately return 409 / 304 on the second attempt (Stripe-style). + ``ignore_body_keys`` is for non-deterministic fields (timestamps, + request_id) the caller knows to ignore. + """ + if not callable(request_runner): + raise IdempotencyCheckError("request_runner must be callable") + if state_probe is not None and not callable(state_probe): + raise IdempotencyCheckError("state_probe must be callable") + allowed = set(allow_status_change_to or ()) + ignored = set(ignore_body_keys) + + state_before = state_probe() if state_probe else None + first = _safe_call(request_runner, "first") + state_after_first = state_probe() if state_probe else None + second = _safe_call(request_runner, "second") + state_after_second = state_probe() if state_probe else None + + violations: List[str] = [] + if ( + first.status_code != second.status_code + and second.status_code not in allowed + ): + violations.append( + f"status changed {first.status_code} -> {second.status_code}" + ) + if not _bodies_equal(first.body, second.body, ignored): + violations.append("response body differs between calls") + if ( + state_probe is not None + and state_after_first != state_after_second + ): + violations.append("state changed between first and second call") + if first.side_effect_count != second.side_effect_count: + delta = abs(first.side_effect_count - second.side_effect_count) + violations.append( + f"side effect count differs (delta={delta})" + ) + return IdempotencyReport( + first=first, second=second, + state_before_first=state_before, + state_after_first=state_after_first, + state_after_second=state_after_second, + violations=violations, + ) + + +def _safe_call(runner: RequestRunner, label: str) -> IdemResponse: + try: + result = runner() + except Exception as error: + raise IdempotencyCheckError( + f"{label} request raised: {error!r}" + ) from error + if not isinstance(result, IdemResponse): + raise IdempotencyCheckError( + f"runner must return IdemResponse, got {type(result).__name__}" + ) + return result + + +def _strip_keys(payload: Any, ignored: set) -> Any: + if isinstance(payload, dict): + return { + k: _strip_keys(v, ignored) + for k, v in payload.items() if k not in ignored + } + if isinstance(payload, list): + return [_strip_keys(v, ignored) for v in payload] + return payload + + +def _bodies_equal(a: Any, b: Any, ignored: set) -> bool: + return _strip_keys(a, ignored) == _strip_keys(b, ignored) + + +# ---------- helpers ---------------------------------------------------- + +def assert_idempotent(report: IdempotencyReport) -> None: + """Raise unless ``report.passed()``.""" + if not isinstance(report, IdempotencyReport): + raise IdempotencyCheckError("assert_idempotent expects IdempotencyReport") + if report.passed(): + return + raise IdempotencyCheckError( + "non-idempotent: " + "; ".join(report.violations) + ) + + +def generate_idempotency_key(*parts: Any) -> str: + """Stable SHA-256 hex key from arbitrary parts (e.g. user_id + amount + ts).""" + serialised = "|".join(repr(p) for p in parts) + return hashlib.sha256(serialised.encode("utf-8")).hexdigest() diff --git a/je_web_runner/utils/indexed_db_explorer/__init__.py b/je_web_runner/utils/indexed_db_explorer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/indexed_db_explorer/explorer.py b/je_web_runner/utils/indexed_db_explorer/explorer.py new file mode 100644 index 0000000..bb8ee23 --- /dev/null +++ b/je_web_runner/utils/indexed_db_explorer/explorer.py @@ -0,0 +1,252 @@ +""" +IndexedDB 內容快照 + 物件 / store / index 斷言。 +PWA、離線優先的 app、Firebase / Dexie / RxDB 都把狀態放在 IndexedDB。 +傳統 Selenium 測試只看 DOM,根本看不到資料層;這個模組: + +* 提供瀏覽器端 JS snippet,把指定 DB 的內容序列化成可帶回的 JSON + (透過 CDP ``Runtime.evaluate`` 或 Playwright ``page.evaluate``) +* 解析 JSON snapshot 成 :class:`IdbSnapshot`,提供 store / key / index + / 紀錄計數的斷言 + +不直接操作 driver — JS 給你,evaluate 你自己叫。 +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class IndexedDbExplorerError(WebRunnerException): + """Raised on malformed snapshot input or failed assertion.""" + + +# ---------- harvest script --------------------------------------------- + +_HARVEST_TEMPLATE = """ +(async function() { + const dbName = %(db_name)s; + if (!('indexedDB' in window)) { + return {schema_version: 1, name: dbName, exists: false, stores: {}}; + } + return await new Promise(function(resolve, reject) { + const req = indexedDB.open(dbName); + req.onerror = function() { reject(new Error('open failed: ' + req.error)); }; + req.onsuccess = async function() { + const db = req.result; + const out = { + schema_version: 1, name: db.name, exists: true, + version: db.version, + stores: {} + }; + const tx = db.transaction(Array.from(db.objectStoreNames), 'readonly'); + for (const sname of db.objectStoreNames) { + const store = tx.objectStore(sname); + const records = await new Promise(function(r2) { + const all = store.getAll(); + all.onsuccess = function() { r2(all.result); }; + all.onerror = function() { r2([]); }; + }); + const keys = await new Promise(function(r3) { + const all = store.getAllKeys(); + all.onsuccess = function() { r3(all.result); }; + all.onerror = function() { r3([]); }; + }); + out.stores[sname] = { + key_path: store.keyPath, + auto_increment: store.autoIncrement, + index_names: Array.from(store.indexNames), + records: records, + keys: keys + }; + } + db.close(); + resolve(out); + }; + }); +})() +""".strip() + + +def build_harvest_script(db_name: str) -> str: + """Return JS that resolves with the snapshot JSON for ``db_name``.""" + if not isinstance(db_name, str) or not db_name: + raise IndexedDbExplorerError("db_name must be non-empty string") + return _HARVEST_TEMPLATE % {"db_name": json.dumps(db_name)} + + +# ---------- snapshot model --------------------------------------------- + +@dataclass +class StoreSnapshot: + """One object-store snapshot.""" + + name: str + key_path: Any = None + auto_increment: bool = False + index_names: List[str] = field(default_factory=list) + records: List[Any] = field(default_factory=list) + keys: List[Any] = field(default_factory=list) + + def find_one(self, predicate: Callable[[Any], bool]) -> Optional[Any]: + for r in self.records: + try: + if predicate(r): + return r + except Exception: + continue + return None + + +@dataclass +class IdbSnapshot: + """Full DB snapshot.""" + + name: str + exists: bool + version: Optional[int] = None + stores: Dict[str, StoreSnapshot] = field(default_factory=dict) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "IdbSnapshot": + if not isinstance(data, dict): + raise IndexedDbExplorerError( + f"snapshot must be dict, got {type(data).__name__}" + ) + if "stores" in data and not isinstance(data["stores"], dict): + raise IndexedDbExplorerError("snapshot.stores must be a dict") + stores: Dict[str, StoreSnapshot] = {} + for name, raw in (data.get("stores") or {}).items(): + if not isinstance(raw, dict): + continue + stores[str(name)] = StoreSnapshot( + name=str(name), + key_path=raw.get("key_path"), + auto_increment=bool(raw.get("auto_increment", False)), + index_names=[str(i) for i in raw.get("index_names") or []], + records=list(raw.get("records") or []), + keys=list(raw.get("keys") or []), + ) + return cls( + name=str(data.get("name") or ""), + exists=bool(data.get("exists", False)), + version=data.get("version"), + stores=stores, + ) + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "exists": self.exists, + "version": self.version, + "stores": {k: asdict(v) for k, v in self.stores.items()}, + } + + +# ---------- assertions -------------------------------------------------- + +def assert_db_exists(snapshot: IdbSnapshot) -> None: + if not isinstance(snapshot, IdbSnapshot): + raise IndexedDbExplorerError("expected IdbSnapshot") + if not snapshot.exists: + raise IndexedDbExplorerError(f"IndexedDB {snapshot.name!r} does not exist") + + +def assert_store_present(snapshot: IdbSnapshot, store_name: str) -> StoreSnapshot: + if not isinstance(store_name, str) or not store_name: + raise IndexedDbExplorerError("store_name must be a non-empty string") + store = snapshot.stores.get(store_name) + if store is None: + existing = sorted(snapshot.stores) + raise IndexedDbExplorerError( + f"store {store_name!r} not in snapshot; existing: {existing}" + ) + return store + + +def assert_record_count( + snapshot: IdbSnapshot, + store_name: str, + *, + minimum: int = 0, + maximum: Optional[int] = None, +) -> int: + """Assert ``minimum <= len(records) <= maximum``.""" + if minimum < 0: + raise IndexedDbExplorerError("minimum must be >= 0") + if maximum is not None and maximum < minimum: + raise IndexedDbExplorerError("maximum must be >= minimum") + store = assert_store_present(snapshot, store_name) + count = len(store.records) + if count < minimum or (maximum is not None and count > maximum): + raise IndexedDbExplorerError( + f"store {store_name!r} has {count} records, want " + f"[{minimum}, {maximum if maximum is not None else 'inf'}]" + ) + return count + + +def assert_key_present(snapshot: IdbSnapshot, store_name: str, key: Any) -> None: + store = assert_store_present(snapshot, store_name) + if key not in store.keys: + raise IndexedDbExplorerError( + f"key {key!r} not present in store {store_name!r}" + ) + + +def assert_record_matching( + snapshot: IdbSnapshot, + store_name: str, + predicate: Callable[[Any], bool], + *, + description: str = "predicate", +) -> Any: + """Assert at least one record satisfies ``predicate``; return it.""" + store = assert_store_present(snapshot, store_name) + found = store.find_one(predicate) + if found is None: + raise IndexedDbExplorerError( + f"no record in {store_name!r} matched: {description}" + ) + return found + + +def assert_index_present( + snapshot: IdbSnapshot, store_name: str, index_name: str, +) -> None: + store = assert_store_present(snapshot, store_name) + if index_name not in store.index_names: + raise IndexedDbExplorerError( + f"index {index_name!r} not on store {store_name!r}; " + f"existing: {sorted(store.index_names)}" + ) + + +# ---------- diff -------------------------------------------------------- + +@dataclass +class SnapshotDiff: + """High-level diff between two snapshots.""" + + added_stores: List[str] = field(default_factory=list) + removed_stores: List[str] = field(default_factory=list) + record_count_changes: Dict[str, Dict[str, int]] = field(default_factory=dict) + + +def diff_snapshots(before: IdbSnapshot, after: IdbSnapshot) -> SnapshotDiff: + """Compute a coarse diff (added / removed stores, per-store count delta).""" + if not isinstance(before, IdbSnapshot) or not isinstance(after, IdbSnapshot): + raise IndexedDbExplorerError("both arguments must be IdbSnapshot") + diff = SnapshotDiff() + before_names = set(before.stores) + after_names = set(after.stores) + diff.added_stores = sorted(after_names - before_names) + diff.removed_stores = sorted(before_names - after_names) + for name in sorted(before_names & after_names): + a = len(before.stores[name].records) + b = len(after.stores[name].records) + if a != b: + diff.record_count_changes[name] = {"before": a, "after": b} + return diff diff --git a/je_web_runner/utils/inp_tracker/__init__.py b/je_web_runner/utils/inp_tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/inp_tracker/tracker.py b/je_web_runner/utils/inp_tracker/tracker.py new file mode 100644 index 0000000..1d8cc00 --- /dev/null +++ b/je_web_runner/utils/inp_tracker/tracker.py @@ -0,0 +1,201 @@ +""" +Interaction to Next Paint (INP) tracker。INP 是 Google 2024 取代 FID 的 +Core Web Vital,衡量「使用者點 / 鍵盤輸入 / tap → 下次 paint」的延遲。 + +This module: + +* Generates a JS snippet that uses the ``event-timing`` PerformanceObserver + to record every interaction's duration into ``window.__wr_inp_log__``. +* Parses the harvested array. +* Reports per-interaction breakdown + p75 / p98 percentiles (the + thresholds Google uses for "good" / "poor"). +* Asserts page-level budget. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class InpTrackerError(WebRunnerException): + """Raised on malformed log input or budget breach.""" + + +class InpRating(str, Enum): + """Google's INP rating thresholds.""" + + GOOD = "good" # <= 200ms + NEEDS_WORK = "needs_improvement" # 201..500ms + POOR = "poor" # > 500ms + + +_GOOD_THRESHOLD_MS = 200.0 +_POOR_THRESHOLD_MS = 500.0 + + +# ---------- instrumentation ------------------------------------------- + +_INSTALL = """ +(function() { + if (window.__wr_inp_installed__) return; + window.__wr_inp_installed__ = true; + window.__wr_inp_log__ = []; + if (!('PerformanceObserver' in window)) return; + try { + const obs = new PerformanceObserver(function(list) { + list.getEntries().forEach(function(entry) { + if (entry.duration === undefined) return; + window.__wr_inp_log__.push({ + name: entry.name, + interactionId: entry.interactionId || 0, + duration_ms: entry.duration, + processingStart: entry.processingStart, + processingEnd: entry.processingEnd, + startTime: entry.startTime, + targetTag: entry.target ? entry.target.tagName : null + }); + }); + }); + obs.observe({type: 'event', buffered: true, durationThreshold: 16}); + obs.observe({type: 'first-input', buffered: true}); + } catch (e) { /* unsupported */ } +})(); +""".strip() + + +def build_install_script() -> str: + return _INSTALL + + +HARVEST_SCRIPT = "return window.__wr_inp_log__ || [];" + + +# ---------- data -------------------------------------------------------- + +@dataclass +class InteractionEvent: + """One event-timing entry.""" + + name: str + interaction_id: int + duration_ms: float + target_tag: Optional[str] = None + start_time: float = 0.0 + processing_start: float = 0.0 + processing_end: float = 0.0 + + def rating(self) -> InpRating: + return _rate(self.duration_ms) + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "rating": self.rating().value} + + +def _rate(duration_ms: float) -> InpRating: + if duration_ms <= _GOOD_THRESHOLD_MS: + return InpRating.GOOD + if duration_ms <= _POOR_THRESHOLD_MS: + return InpRating.NEEDS_WORK + return InpRating.POOR + + +def parse_log(payload: Any) -> List[InteractionEvent]: + """Convert the harvested ``__wr_inp_log__`` array into typed events.""" + if not isinstance(payload, list): + raise InpTrackerError( + f"payload must be list, got {type(payload).__name__}" + ) + out: List[InteractionEvent] = [] + for raw in payload: + if not isinstance(raw, dict): + continue + try: + duration = float(raw.get("duration_ms") or 0.0) + except (TypeError, ValueError): + continue + if duration < 0: + continue + out.append(InteractionEvent( + name=str(raw.get("name") or ""), + interaction_id=int(raw.get("interactionId") or 0), + duration_ms=duration, + target_tag=raw.get("targetTag"), + start_time=float(raw.get("startTime") or 0.0), + processing_start=float(raw.get("processingStart") or 0.0), + processing_end=float(raw.get("processingEnd") or 0.0), + )) + return out + + +# ---------- aggregation ------------------------------------------------ + +@dataclass +class InpReport: + """Rolled-up view of the events captured in a page session.""" + + events: List[InteractionEvent] = field(default_factory=list) + + def filtered(self) -> List[InteractionEvent]: + """Discard zero-id non-interaction entries (mouse-move, raw events).""" + return [e for e in self.events if e.interaction_id > 0] + + def inp(self) -> Optional[float]: + """ + Returns Google's INP: 98th percentile if 50+ interactions, else worst. + ``None`` if no interactions observed. + """ + interactions = sorted(e.duration_ms for e in self.filtered()) + if not interactions: + return None + if len(interactions) >= 50: + index = int(round(0.98 * (len(interactions) - 1))) + return interactions[index] + return interactions[-1] + + def rating(self) -> InpRating: + value = self.inp() + if value is None: + return InpRating.GOOD + return _rate(value) + + def percentile(self, pct: float) -> Optional[float]: + """Arbitrary percentile (0..100) over interaction durations.""" + if not 0 <= pct <= 100: + raise InpTrackerError("pct must be in [0, 100]") + interactions = sorted(e.duration_ms for e in self.filtered()) + if not interactions: + return None + index = int(round((pct / 100.0) * (len(interactions) - 1))) + return interactions[index] + + +# ---------- assertions ------------------------------------------------- + +def assert_inp_under(report: InpReport, *, max_ms: float) -> None: + """Assert the report's INP is under ``max_ms``.""" + if not isinstance(report, InpReport): + raise InpTrackerError("assert_inp_under expects InpReport") + if max_ms <= 0: + raise InpTrackerError("max_ms must be > 0") + value = report.inp() + if value is None: + return + if value > max_ms: + raise InpTrackerError( + f"INP {value:.1f}ms exceeds budget {max_ms}ms " + f"({report.rating().value})" + ) + + +def assert_no_poor_interactions(report: InpReport) -> None: + """Assert no single interaction crossed the POOR threshold.""" + if not isinstance(report, InpReport): + raise InpTrackerError("expects InpReport") + bad = [e for e in report.filtered() if e.rating() == InpRating.POOR] + if bad: + sample = ", ".join(f"{e.name}({e.duration_ms:.0f}ms)" for e in bad[:3]) + more = "" if len(bad) <= 3 else f" (+{len(bad) - 3} more)" + raise InpTrackerError(f"poor interactions: {sample}{more}") diff --git a/je_web_runner/utils/live_dashboard/__init__.py b/je_web_runner/utils/live_dashboard/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/live_dashboard/server.py b/je_web_runner/utils/live_dashboard/server.py new file mode 100644 index 0000000..609ccaa --- /dev/null +++ b/je_web_runner/utils/live_dashboard/server.py @@ -0,0 +1,505 @@ +""" +Live Dashboard:把既有 run_ledger / flake_detector / locator_health / +failure_triage / test_scheduler / quarantine registry 的資料整合在一個 +本地 web UI。 + +純 stdlib(``http.server``),不加新依賴。預設只 bind 127.0.0.1, +單 process,跟 socket_server 共存無干擾。 + +Routes: + +* ``GET /`` HTML overview with summary cards +* ``GET /runs`` HTML table of recent ledger entries +* ``GET /flake`` HTML flake leaderboard +* ``GET /quarantine`` HTML quarantine list +* ``GET /locators`` HTML locator health summary +* ``GET /api/summary`` JSON aggregate counts +* ``GET /api/runs`` JSON recent runs (``?limit=N``) +* ``GET /api/flake`` JSON flake scores +* ``GET /api/quarantine`` JSON quarantine entries +* ``GET /api/locators`` JSON locator findings + +Every request re-reads the underlying files so the dashboard always +reflects the latest state — no caching, no daemon process needed. +""" +from __future__ import annotations + +import json +import threading +import urllib.parse +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.flake_detector.detector import ( + QuarantineRegistry, + compute_flake_scores, +) +from je_web_runner.utils.logging.loggin_instance import web_runner_logger +from je_web_runner.utils.run_ledger.ledger import LedgerError + + +class LiveDashboardError(WebRunnerException): + """Raised on configuration / startup failures.""" + + +# ---------- config ------------------------------------------------------- + +@dataclass +class DashboardConfig: + """ + 指定每個資料來源檔案路徑。任何一個 None 就會在 UI 上顯示成空白。 + """ + ledger_path: Optional[Union[str, Path]] = None + quarantine_path: Optional[Union[str, Path]] = None + locator_findings_path: Optional[Union[str, Path]] = None + schedule_path: Optional[Union[str, Path]] = None + triage_report_path: Optional[Union[str, Path]] = None + bind_host: str = "127.0.0.1" + bind_port: int = 0 + + def __post_init__(self) -> None: + for attr in ( + "ledger_path", "quarantine_path", "locator_findings_path", + "schedule_path", "triage_report_path", + ): + value = getattr(self, attr) + if value is not None and not isinstance(value, Path): + setattr(self, attr, Path(value)) + + +# ---------- data loaders ------------------------------------------------- + +def _load_runs(ledger_path: Optional[Path], limit: int = 50) -> List[Dict[str, Any]]: + if ledger_path is None or not ledger_path.exists(): + return [] + try: + with open(ledger_path, encoding="utf-8") as fp: + data = json.load(fp) + except (OSError, ValueError) as error: + web_runner_logger.warning(f"dashboard _load_runs: {error!r}") + return [] + runs = data.get("runs") if isinstance(data, dict) else None + if not isinstance(runs, list): + return [] + return [r for r in runs[-limit:][::-1] if isinstance(r, dict)] + + +def _load_flake_scores(ledger_path: Optional[Path]) -> List[Dict[str, Any]]: + if ledger_path is None or not ledger_path.exists(): + return [] + try: + scores = compute_flake_scores(ledger_path) + except (LedgerError, OSError, ValueError) as error: + web_runner_logger.warning(f"dashboard _load_flake_scores: {error!r}") + return [] + entries = [s.to_dict() for s in scores.values()] + entries.sort(key=lambda e: (-e["flake_score"], e["path"])) + return entries + + +def _load_quarantine(quarantine_path: Optional[Path]) -> List[Dict[str, Any]]: + if quarantine_path is None or not quarantine_path.exists(): + return [] + try: + registry = QuarantineRegistry(quarantine_path) + except WebRunnerException as error: + web_runner_logger.warning(f"dashboard _load_quarantine: {error!r}") + return [] + return [e.to_dict() for e in registry.list()] + + +def _load_locator_report(path: Optional[Path]) -> Dict[str, Any]: + if path is None or not path.exists(): + return {} + try: + with open(path, encoding="utf-8") as fp: + return json.load(fp) + except (OSError, ValueError) as error: + web_runner_logger.warning(f"dashboard _load_locator_report: {error!r}") + return {} + + +def _load_schedule(path: Optional[Path]) -> Dict[str, Any]: + if path is None or not path.exists(): + return {} + try: + with open(path, encoding="utf-8") as fp: + return json.load(fp) + except (OSError, ValueError) as error: + web_runner_logger.warning(f"dashboard _load_schedule: {error!r}") + return {} + + +def _load_triage(path: Optional[Path]) -> Dict[str, Any]: + if path is None or not path.exists(): + return {} + try: + with open(path, encoding="utf-8") as fp: + return json.load(fp) + except (OSError, ValueError) as error: + web_runner_logger.warning(f"dashboard _load_triage: {error!r}") + return {} + + +# ---------- summary ------------------------------------------------------ + +def build_summary(config: DashboardConfig) -> Dict[str, Any]: + """One-shot snapshot used by ``/`` and ``/api/summary``.""" + runs = _load_runs(config.ledger_path, limit=10_000) + total = len(runs) + passed = sum(1 for r in runs if r.get("passed")) + failed = total - passed + pass_rate = (passed / total) if total else 0.0 + flake_entries = _load_flake_scores(config.ledger_path) + flake_count = sum(1 for f in flake_entries if f.get("is_flaky")) + quarantine = _load_quarantine(config.quarantine_path) + locator_report = _load_locator_report(config.locator_findings_path) + return { + "total_runs": total, + "passed": passed, + "failed": failed, + "pass_rate": round(pass_rate, 4), + "flaky_tests": flake_count, + "quarantined_tests": len(quarantine), + "weak_locators": locator_report.get("weak", 0) if isinstance(locator_report, dict) else 0, + "average_locator_score": ( + locator_report.get("average_score", 0) + if isinstance(locator_report, dict) else 0 + ), + } + + +# ---------- HTML rendering ----------------------------------------------- + +_BASE_CSS = """ +body { font-family: -apple-system, BlinkMacSystemFont, sans-serif; + margin: 0; background: #f5f5f7; color: #1d1d1f; } +nav { background: #1d1d1f; color: #fff; padding: 12px 24px; } +nav a { color: #fff; margin-right: 16px; text-decoration: none; } +nav a:hover { text-decoration: underline; } +main { padding: 24px; max-width: 1200px; margin: 0 auto; } +h1 { margin-top: 0; } +.cards { display: grid; grid-template-columns: repeat(auto-fit, minmax(180px, 1fr)); + gap: 16px; margin-bottom: 32px; } +.card { background: #fff; padding: 16px; border-radius: 8px; + box-shadow: 0 1px 3px rgba(0,0,0,0.05); } +.card .label { color: #6e6e73; font-size: 12px; text-transform: uppercase; } +.card .value { font-size: 28px; font-weight: 600; margin-top: 4px; } +table { width: 100%; border-collapse: collapse; background: #fff; + border-radius: 8px; overflow: hidden; + box-shadow: 0 1px 3px rgba(0,0,0,0.05); } +th, td { padding: 12px 16px; text-align: left; border-bottom: 1px solid #f0f0f3; } +th { background: #fafafa; font-size: 13px; color: #6e6e73; } +tr:last-child td { border-bottom: none; } +.bad { color: #c9302c; font-weight: 600; } +.good { color: #1d8348; font-weight: 600; } +.muted { color: #6e6e73; } +.empty { color: #6e6e73; padding: 32px; text-align: center; } +code { background: #f0f0f3; padding: 2px 6px; border-radius: 4px; + font-family: 'SF Mono', Consolas, monospace; font-size: 12px; } +""" + + +def _html_escape(value: Any) -> str: + text = str(value if value is not None else "") + return ( + text.replace("&", "&").replace("<", "<") + .replace(">", ">").replace('"', """) + ) + + +def _layout(title: str, body: str) -> str: + return ( + "" + f"{_html_escape(title)} — WebRunner" + f"" + "" + f"
{body}
" + ) + + +def _render_overview(summary: Dict[str, Any]) -> str: + pass_rate_pct = f"{summary['pass_rate'] * 100:.1f}%" + cards = [ + ("Total runs", summary["total_runs"]), + ("Pass rate", pass_rate_pct), + ("Passed", summary["passed"]), + ("Failed", summary["failed"]), + ("Flaky tests", summary["flaky_tests"]), + ("Quarantined", summary["quarantined_tests"]), + ("Weak locators", summary["weak_locators"]), + ("Avg locator score", summary["average_locator_score"]), + ] + card_html = "".join( + f"
{_html_escape(label)}
" + f"
{_html_escape(value)}
" + for label, value in cards + ) + body = ( + "

WebRunner overview

" + f"
{card_html}
" + ) + return _layout("Overview", body) + + +def _render_runs(runs: List[Dict[str, Any]]) -> str: + if not runs: + return _layout("Runs", "

Runs

No runs recorded yet.
") + rows = [] + for run in runs: + cls = "good" if run.get("passed") else "bad" + label = "PASS" if run.get("passed") else "FAIL" + rows.append( + f"{_html_escape(run.get('path'))}" + f"{label}" + f"{_html_escape(run.get('time', ''))}" + ) + body = ( + "

Recent runs

" + "" + f"{''.join(rows)}
TestResultWhen
" + ) + return _layout("Runs", body) + + +def _render_flake(entries: List[Dict[str, Any]]) -> str: + flaky_only = [e for e in entries if e.get("is_flaky")] + if not flaky_only: + return _layout("Flake", "

Flake leaderboard

No flaky tests detected.
") + rows = [] + for entry in flaky_only[:50]: + rows.append( + f"{_html_escape(entry.get('path'))}" + f"{entry.get('flake_score', 0):.2f}" + f"{entry.get('runs', 0)}" + f"{entry.get('fails', 0)}" + f"{_html_escape(entry.get('last_run', ''))}" + ) + body = ( + "

Flake leaderboard

" + "" + "" + f"{''.join(rows)}
TestScoreRunsFailsLast
" + ) + return _layout("Flake", body) + + +def _render_quarantine(entries: List[Dict[str, Any]]) -> str: + if not entries: + return _layout("Quarantine", "

Quarantine

Registry is empty.
") + rows = [] + for entry in entries: + rows.append( + f"{_html_escape(entry.get('test_id'))}" + f"{entry.get('flake_score', 0):.2f}" + f"{_html_escape(entry.get('reason', ''))}" + f"{_html_escape(entry.get('quarantined_at', ''))}" + ) + body = ( + "

Quarantined tests

" + "" + "" + f"{''.join(rows)}
TestScoreReasonSince
" + ) + return _layout("Quarantine", body) + + +def _render_locators(report: Dict[str, Any]) -> str: + if not report: + return _layout("Locators", "

Locators

No locator report loaded.
") + summary_cards = [ + ("Total", report.get("total", 0)), + ("Weak", report.get("weak", 0)), + ("Strong", report.get("strong", 0)), + ("Avg score", report.get("average_score", 0)), + ] + card_html = "".join( + f"
{_html_escape(label)}
" + f"
{_html_escape(value)}
" + for label, value in summary_cards + ) + weakest = report.get("weakest") or [] + rows = [] + for entry in weakest[:30]: + reasons = ", ".join(entry.get("reasons") or []) or "—" + value = entry.get("value", "") + if isinstance(value, str) and len(value) > 60: + value = value[:57] + "…" + rows.append( + f"{_html_escape(entry.get('file_path'))}" + f"{entry.get('action_index', '')}" + f"{_html_escape(entry.get('strategy', ''))}" + f"{_html_escape(value)}" + f"{entry.get('score', 0)}" + f"{_html_escape(reasons)}" + ) + rows_html = ( + "" + "" + f"{''.join(rows)}
FileIdxStrategyValueScoreReasons
" + if rows else "
No weak locators.
" + ) + body = ( + "

Locator health

" + f"
{card_html}
" + "

Weakest

" + rows_html + ) + return _layout("Locators", body) + + +# ---------- request handler --------------------------------------------- + +def _make_handler(config: DashboardConfig) -> Type[BaseHTTPRequestHandler]: + """Bind ``config`` into a fresh handler class so each server is isolated.""" + + class DashboardHandler(BaseHTTPRequestHandler): + protocol_version = "HTTP/1.1" + + def log_message(self, fmt: str, *args: Any) -> None: # noqa: A003 — base override + web_runner_logger.info(f"dashboard: {fmt % args}") + + def _send(self, status: int, content_type: str, body: bytes) -> None: + self.send_response(status) + self.send_header("Content-Type", content_type) + self.send_header("Content-Length", str(len(body))) + self.send_header("Cache-Control", "no-store") + self.end_headers() + self.wfile.write(body) + + def _send_html(self, html: str, status: int = 200) -> None: + self._send(status, "text/html; charset=utf-8", html.encode("utf-8")) + + def _send_json(self, payload: Any, status: int = 200) -> None: + body = json.dumps(payload, ensure_ascii=False, indent=2).encode("utf-8") + self._send(status, "application/json; charset=utf-8", body) + + def _query_limit(self, parsed) -> int: + params = urllib.parse.parse_qs(parsed.query) + raw = params.get("limit", ["50"])[0] + try: + value = int(raw) + except ValueError: + value = 50 + return max(1, min(value, 5000)) + + def do_GET(self) -> None: # noqa: N802 — http.server requires camelCase + parsed = urllib.parse.urlparse(self.path) + path = parsed.path + try: + if path == "/": + self._send_html(_render_overview(build_summary(config))) + elif path == "/runs": + self._send_html(_render_runs(_load_runs(config.ledger_path))) + elif path == "/flake": + self._send_html(_render_flake(_load_flake_scores(config.ledger_path))) + elif path == "/quarantine": + self._send_html(_render_quarantine(_load_quarantine(config.quarantine_path))) + elif path == "/locators": + self._send_html(_render_locators(_load_locator_report(config.locator_findings_path))) + elif path == "/api/summary": + self._send_json(build_summary(config)) + elif path == "/api/runs": + self._send_json(_load_runs(config.ledger_path, self._query_limit(parsed))) + elif path == "/api/flake": + self._send_json(_load_flake_scores(config.ledger_path)) + elif path == "/api/quarantine": + self._send_json(_load_quarantine(config.quarantine_path)) + elif path == "/api/locators": + self._send_json(_load_locator_report(config.locator_findings_path)) + elif path == "/api/schedule": + self._send_json(_load_schedule(config.schedule_path)) + elif path == "/api/triage": + self._send_json(_load_triage(config.triage_report_path)) + elif path == "/healthz": + self._send(200, "text/plain", b"ok") + else: + self._send_html( + _layout( + "Not found", + f"

Not found

No route for {_html_escape(path)}

", + ), + status=404, + ) + except Exception as error: # noqa: BLE001 — surface to caller, not stderr + web_runner_logger.warning(f"dashboard handler error: {error!r}") + self._send_json({"error": repr(error)}, status=500) + + return DashboardHandler + + +# ---------- server wrapper ----------------------------------------------- + +class DashboardServer: + """ + 包 ThreadingHTTPServer 的薄殼,start/stop/url。``start`` 不阻塞, + 所以可以從測試 / shell 直接用。 + """ + + def __init__(self, config: Optional[DashboardConfig] = None) -> None: + self.config = config or DashboardConfig() + self._httpd: Optional[ThreadingHTTPServer] = None + self._thread: Optional[threading.Thread] = None + self._bound: Optional[Tuple[str, int]] = None + + def start(self) -> str: + """Bind + spawn a daemon thread serving requests. Returns the URL.""" + if self._httpd is not None: + raise LiveDashboardError("server already started") + handler_cls = _make_handler(self.config) + try: + self._httpd = ThreadingHTTPServer( + (self.config.bind_host, self.config.bind_port), handler_cls, + ) + except OSError as error: + raise LiveDashboardError( + f"cannot bind {self.config.bind_host}:{self.config.bind_port}: {error!r}" + ) from error + self._bound = self._httpd.server_address + self._thread = threading.Thread( + target=self._httpd.serve_forever, + name="webrunner-dashboard", + daemon=True, + ) + self._thread.start() + web_runner_logger.info(f"dashboard listening on {self.url}") + return self.url + + def stop(self, *, timeout: float = 5.0) -> None: + """Shut down the server and join the thread.""" + if self._httpd is None: + return + try: + self._httpd.shutdown() + self._httpd.server_close() + except OSError as error: + web_runner_logger.warning(f"dashboard stop: {error!r}") + if self._thread is not None: + self._thread.join(timeout=timeout) + self._httpd = None + self._thread = None + self._bound = None + + @property + def url(self) -> str: + if self._bound is None: + raise LiveDashboardError("server not started") + host, port = self._bound + if host in {"0.0.0.0", "::"}: + host = "127.0.0.1" + return f"http://{host}:{port}" + + def __enter__(self) -> "DashboardServer": + self.start() + return self + + def __exit__(self, *_exc: Any) -> None: + self.stop() diff --git a/je_web_runner/utils/locator_hardener/__init__.py b/je_web_runner/utils/locator_hardener/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/locator_hardener/hardener.py b/je_web_runner/utils/locator_hardener/hardener.py new file mode 100644 index 0000000..3475bb1 --- /dev/null +++ b/je_web_runner/utils/locator_hardener/hardener.py @@ -0,0 +1,255 @@ +""" +讀脆弱 locator(來自 ``locator_health`` 報告)+ 周圍 DOM,LLM 建議更穩的 selector。 +Common bad locators we want to harden: + +* nth-of-type / nth-child selectors that drift when content reorders +* deeply nested CSS descendants +* class-only selectors when classes are CSS-modules-hashed +* XPath that depends on text content + +The "smart" part is delegated to a :class:`HardenerClient`; the module +does: + +1. Score-based pre-classification (which locators are *worth* hardening + vs already-fine). +2. Prompt assembly (DOM excerpt + the locator + recommended-style hints). +3. Strict response validation (every suggestion must be valid CSS or + XPath syntax; we don't trust the LLM blindly). +""" +from __future__ import annotations + +import json +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, List, Optional, Protocol, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class LocatorHardenerError(WebRunnerException): + """Raised on bad inputs / parse failures / failed assertion.""" + + +# ---------- inputs ------------------------------------------------------ + +class LocatorStrategy(str, Enum): + """Allowed WR locator strategies.""" + + ID = "id" + NAME = "name" + CSS = "css selector" + XPATH = "xpath" + LINK_TEXT = "link text" + PARTIAL_LINK_TEXT = "partial link text" + CLASS_NAME = "class name" + TAG_NAME = "tag name" + + +_PREFERRED_STRATEGIES = ( + LocatorStrategy.ID, LocatorStrategy.NAME, LocatorStrategy.CSS, +) + + +@dataclass +class FragileLocator: + """One locator candidate worth hardening.""" + + test_id: str + strategy: LocatorStrategy + value: str + dom_excerpt: str = "" + failure_history: int = 0 + + def __post_init__(self) -> None: + if not isinstance(self.test_id, str) or not self.test_id: + raise LocatorHardenerError("test_id must be non-empty string") + if not isinstance(self.value, str) or not self.value: + raise LocatorHardenerError("value must be non-empty string") + if self.failure_history < 0: + raise LocatorHardenerError("failure_history must be >= 0") + + +# ---------- heuristic pre-classifier ----------------------------------- + +_NTH_PATTERN = re.compile(r":nth-(?:of-type|child)\(\d+\)", re.IGNORECASE) +_DEEP_DESCENDANT = re.compile(r"\s+\S+\s+\S+\s+\S+") # 3+ descendant levels +_HASHED_CLASS = re.compile(r"[._][A-Za-z][\w-]*?-_?\w{4,}\b") +_TEXT_XPATH = re.compile(r"text\s*\(\s*\)", re.IGNORECASE) + + +@dataclass +class FragilityScore: + """Heuristic locator-fragility score (0..1).""" + + score: float + reasons: List[str] = field(default_factory=list) + + +def score_fragility(locator: FragileLocator) -> FragilityScore: + """Quick non-LLM check. Anything ``score >= 0.5`` is worth hardening.""" + if not isinstance(locator, FragileLocator): + raise LocatorHardenerError("expects FragileLocator") + reasons: List[str] = [] + score = 0.0 + if locator.strategy == LocatorStrategy.XPATH: + score += 0.2 + reasons.append("xpath locator") + if _TEXT_XPATH.search(locator.value): + score += 0.3 + reasons.append("uses text() predicate") + if locator.strategy == LocatorStrategy.CSS: + if _NTH_PATTERN.search(locator.value): + score += 0.4 + reasons.append("uses :nth-of-type/child") + if _DEEP_DESCENDANT.search(locator.value): + score += 0.2 + reasons.append("deeply nested CSS") + if _HASHED_CLASS.search(locator.value): + score += 0.4 + reasons.append("hashed class names") + if locator.strategy not in _PREFERRED_STRATEGIES: + score += 0.1 + reasons.append("non-preferred strategy") + if locator.failure_history >= 3: + score += 0.3 + reasons.append(f"failed {locator.failure_history} times historically") + if locator.strategy == LocatorStrategy.CLASS_NAME and " " in locator.value: + score += 0.2 + reasons.append("multi-class CLASS_NAME (treated as single class)") + return FragilityScore(score=min(score, 1.0), reasons=reasons) + + +# ---------- LLM client ------------------------------------------------ + +class HardenerClient(Protocol): + """LLM client interface.""" + + def suggest(self, prompt: str) -> str: ... + + +# ---------- prompt ----------------------------------------------------- + +PROMPT_TEMPLATE = """\ +You are improving an end-to-end test locator to make it more resilient. + +# Current locator +- strategy: {strategy} +- value: {value} +- found in test: {test_id} +- failure history (recent): {failure_history} + +# DOM excerpt around the target element +```html +{dom_excerpt} +``` + +# Constraints +- Prefer ID > name > stable test attributes > short CSS. +- Reject :nth-of-type / :nth-child selectors. +- Reject locators that depend on visible text unless no stable attribute exists. +- Return strictly a JSON array of suggestion objects sorted best-first, + each with keys: "strategy" (one of {strategies}), + "value" (string), "rationale" (string). +""" + + +def build_prompt(locator: FragileLocator) -> str: + if not isinstance(locator, FragileLocator): + raise LocatorHardenerError("build_prompt expects FragileLocator") + return PROMPT_TEMPLATE.format( + strategy=locator.strategy.value, + value=locator.value, + test_id=locator.test_id, + failure_history=locator.failure_history, + dom_excerpt=locator.dom_excerpt or "(none)", + strategies=[s.value for s in LocatorStrategy], + ) + + +# ---------- response parsing ------------------------------------------- + +@dataclass +class LocatorSuggestion: + """One suggested replacement locator.""" + + strategy: LocatorStrategy + value: str + rationale: str + + def to_dict(self) -> Dict[str, Any]: + return {"strategy": self.strategy.value, "value": self.value, + "rationale": self.rationale} + + +def parse_suggestions(raw: str) -> List[LocatorSuggestion]: + """Decode the LLM's JSON array; reject malformed entries.""" + if not isinstance(raw, str) or not raw.strip(): + raise LocatorHardenerError("LLM returned empty response") + start = raw.find("[") + end = raw.rfind("]") + if start == -1 or end == -1 or end <= start: + raise LocatorHardenerError(f"no JSON array in response: {raw[:160]!r}") + try: + obj = json.loads(raw[start:end + 1]) + except ValueError as error: + raise LocatorHardenerError( + f"suggestions not JSON ({error}): {raw[:160]!r}" + ) from error + if not isinstance(obj, list): + raise LocatorHardenerError("suggestions must be a list") + out: List[LocatorSuggestion] = [] + for index, raw_item in enumerate(obj): + if not isinstance(raw_item, dict): + continue + strategy_str = raw_item.get("strategy") or "" + value = raw_item.get("value") or "" + rationale = raw_item.get("rationale") or "" + try: + strategy = LocatorStrategy(strategy_str) + except ValueError: + continue + if not isinstance(value, str) or not value: + continue + if not _looks_safe(strategy, value): + continue + out.append(LocatorSuggestion( + strategy=strategy, value=str(value), rationale=str(rationale), + )) + if not out: + raise LocatorHardenerError("no valid suggestions in LLM response") + return out + + +def _looks_safe(strategy: LocatorStrategy, value: str) -> bool: + if strategy == LocatorStrategy.CSS: + if _NTH_PATTERN.search(value): + return False + if strategy == LocatorStrategy.XPATH: + if _TEXT_XPATH.search(value): + return False + return True + + +# ---------- end-to-end ------------------------------------------------- + +def harden( + locator: FragileLocator, + client: HardenerClient, + *, + min_fragility: float = 0.5, +) -> List[LocatorSuggestion]: + """Score → maybe-skip → ask LLM → parse → return.""" + if not 0.0 <= min_fragility <= 1.0: + raise LocatorHardenerError("min_fragility must be in [0, 1]") + fragility = score_fragility(locator) + if fragility.score < min_fragility: + return [] + prompt = build_prompt(locator) + try: + raw = client.suggest(prompt) + except Exception as error: + raise LocatorHardenerError( + f"hardener client failed: {error!r}" + ) from error + return parse_suggestions(raw) diff --git a/je_web_runner/utils/locator_health/__init__.py b/je_web_runner/utils/locator_health/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/locator_health/health_report.py b/je_web_runner/utils/locator_health/health_report.py new file mode 100644 index 0000000..db9ff31 --- /dev/null +++ b/je_web_runner/utils/locator_health/health_report.py @@ -0,0 +1,459 @@ +""" +Locator 健康度報告 + 自動升級建議。 + +Project-wide locator audit built on top of +:mod:`je_web_runner.utils.linter.locator_strength`: + +* Walk a directory of action JSON files, score every locator. +* Combine static scores with runtime ``FallbackHitTracker`` counts so + locators that *actually* trigger self-healing get flagged loudest. +* Suggest upgrades: prefer ``data-testid`` / ``ID`` over CSS over deep + XPath. The upgrade function is conservative — it never silently rewrites + files; you have to call :func:`apply_upgrades` to mutate an action list. +""" +from __future__ import annotations + +import json +import threading +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.linter.locator_strength import ( + LocatorStrengthError, + score_locator, +) +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class LocatorHealthError(WebRunnerException): + """Raised on scan / report / upgrade failures.""" + + +# ---------- runtime tracker --------------------------------------------- + +class FallbackHitTracker: + """ + 記錄 self-healing fallback 觸發次數,供報告交叉比對。 + Thread-safe counter that self-healing callers can poke whenever a + fallback locator matches instead of the primary. The report can then + rank weak locators by how often they actually misfire. + """ + + def __init__(self) -> None: + self._lock = threading.Lock() + self._hits: Dict[str, int] = {} + self._fallback_used: Dict[str, int] = {} + + def track_primary(self, name: str) -> None: + with self._lock: + self._hits[name] = self._hits.get(name, 0) + 1 + + def track_fallback(self, name: str) -> None: + with self._lock: + self._hits[name] = self._hits.get(name, 0) + 1 + self._fallback_used[name] = self._fallback_used.get(name, 0) + 1 + + def stats(self) -> Dict[str, Dict[str, int]]: + with self._lock: + return { + name: { + "hits": self._hits.get(name, 0), + "fallback_used": self._fallback_used.get(name, 0), + } + for name in self._hits + } + + def clear(self) -> None: + with self._lock: + self._hits.clear() + self._fallback_used.clear() + + +fallback_hit_tracker = FallbackHitTracker() + + +# ---------- scanning ----------------------------------------------------- + +@dataclass +class LocatorFinding: + """One locator discovered while scanning an action file.""" + + file_path: str + action_index: int + strategy: str + value: str + score: int + reasons: List[str] = field(default_factory=list) + name: Optional[str] = None + hits: int = 0 + fallback_used: int = 0 + + @property + def fallback_rate(self) -> float: + return (self.fallback_used / self.hits) if self.hits else 0.0 + + def to_dict(self) -> Dict[str, Any]: + out = asdict(self) + out["fallback_rate"] = round(self.fallback_rate, 4) + return out + + +def _walk_actions(payload: Any) -> Iterable[List[Any]]: + """Yield every action list inside ``payload`` (top-level list or nested).""" + if isinstance(payload, list): + for item in payload: + if isinstance(item, list) and item and isinstance(item[0], str): + yield item + + +def _extract_locator(action: List[Any]) -> Optional[Dict[str, Any]]: + if len(action) < 2: + return None + kwargs = None + if len(action) >= 3 and isinstance(action[2], dict): + kwargs = action[2] + elif isinstance(action[1], dict): + kwargs = action[1] + else: + return None + strategy = kwargs.get("object_type") or kwargs.get("strategy") + value = kwargs.get("test_object_name") or kwargs.get("value") + name = kwargs.get("element_name") or kwargs.get("name") or kwargs.get("test_object_name") + if strategy is None or value is None: + return None + return {"strategy": str(strategy), "value": str(value), "name": str(name) if name else None} + + +def scan_action_file(file_path: Union[str, Path]) -> List[LocatorFinding]: + """Score every locator inside one action JSON file.""" + path = Path(file_path) + if not path.is_file(): + raise LocatorHealthError(f"action file not found: {path}") + try: + with open(path, encoding="utf-8") as fp: + payload = json.load(fp) + except (OSError, ValueError) as error: + raise LocatorHealthError(f"cannot parse {path}: {error!r}") from error + + findings: List[LocatorFinding] = [] + hit_stats = fallback_hit_tracker.stats() + for index, action in enumerate(_walk_actions(payload)): + locator = _extract_locator(action) + if locator is None: + continue + try: + score = score_locator(locator["strategy"], locator["value"]) + except LocatorStrengthError as error: + web_runner_logger.warning( + f"scan_action_file: cannot score {path}#{index}: {error!r}" + ) + continue + name = locator["name"] + hits_info = hit_stats.get(name or "", {"hits": 0, "fallback_used": 0}) + findings.append(LocatorFinding( + file_path=str(path), + action_index=index, + strategy=score.strategy, + value=score.value, + score=score.score, + reasons=list(score.reasons), + name=name, + hits=hits_info["hits"], + fallback_used=hits_info["fallback_used"], + )) + return findings + + +def scan_project( + root: Union[str, Path], + pattern: str = "**/*.json", +) -> List[LocatorFinding]: + """ + 掃整個專案的 action JSON、收集所有 locator finding。 + Walk ``root`` for files matching ``pattern`` and score every locator. + Files that don't decode as JSON are skipped with a warning so a stray + config file doesn't kill the whole scan. + """ + root = Path(root) + if not root.is_dir(): + raise LocatorHealthError(f"project root is not a directory: {root}") + findings: List[LocatorFinding] = [] + for file_path in sorted(root.glob(pattern)): + if not file_path.is_file(): + continue + try: + findings.extend(scan_action_file(file_path)) + except LocatorHealthError as error: + web_runner_logger.warning(f"scan_project skip {file_path}: {error!r}") + return findings + + +# ---------- report ------------------------------------------------------- + +@dataclass +class LocatorHealthReport: + """Aggregate health report rendered for humans or CI dashboards.""" + + total: int + weak: int + strong: int + average_score: float + findings: List[LocatorFinding] = field(default_factory=list) + weakest: List[LocatorFinding] = field(default_factory=list) + fallback_offenders: List[LocatorFinding] = field(default_factory=list) + threshold: int = 60 + + def to_dict(self) -> Dict[str, Any]: + return { + "total": self.total, + "weak": self.weak, + "strong": self.strong, + "average_score": self.average_score, + "threshold": self.threshold, + "findings": [f.to_dict() for f in self.findings], + "weakest": [f.to_dict() for f in self.weakest], + "fallback_offenders": [f.to_dict() for f in self.fallback_offenders], + } + + +def build_health_report( + findings: Iterable[LocatorFinding], + *, + threshold: int = 60, + weakest_limit: int = 10, + fallback_min_rate: float = 0.2, +) -> LocatorHealthReport: + """ + 把 finding list 整合成 report,包含弱定位排行 + fallback 觸發排行。 + Aggregate findings into a report with two ranked sub-lists: + + * ``weakest`` — locators with the lowest static scores. + * ``fallback_offenders`` — locators whose self-healing fallback fired + at least ``fallback_min_rate`` of the time at runtime (only matters + if callers have been poking ``fallback_hit_tracker``). + """ + materialised = list(findings) + total = len(materialised) + if total == 0: + return LocatorHealthReport( + total=0, weak=0, strong=0, average_score=0.0, threshold=threshold, + ) + weak = sum(1 for f in materialised if f.score < threshold) + strong = total - weak + avg = sum(f.score for f in materialised) / total + weakest = sorted(materialised, key=lambda f: f.score)[:weakest_limit] + fallback_offenders = sorted( + (f for f in materialised if f.fallback_rate >= fallback_min_rate), + key=lambda f: (-f.fallback_rate, f.score), + ) + return LocatorHealthReport( + total=total, + weak=weak, + strong=strong, + average_score=round(avg, 2), + findings=materialised, + weakest=weakest, + fallback_offenders=fallback_offenders, + threshold=threshold, + ) + + +# ---------- upgrade suggestions ------------------------------------------ + +@dataclass +class UpgradeSuggestion: + """A proposed replacement for one weak locator.""" + + file_path: str + action_index: int + from_strategy: str + from_value: str + to_strategy: str + to_value: str + rationale: str + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +_STRATEGY_PRIORITY: Dict[str, int] = { + "ID": 100, "id": 100, + "NAME": 70, "name": 70, + "CSS_SELECTOR": 75, "css selector": 75, "css": 75, + "XPATH": 55, "xpath": 55, + "CLASS_NAME": 45, "class name": 45, + "TAG_NAME": 25, "tag name": 25, + "LINK_TEXT": 35, "link text": 35, + "PARTIAL_LINK_TEXT": 30, "partial link text": 30, +} + + +def _suggest_for_xpath(finding: LocatorFinding) -> Optional[UpgradeSuggestion]: + """Heuristic: if an XPath anchors on ``@id='X'``, suggest using ID directly.""" + value = finding.value + # //*[@id='foo'] or //tag[@id="foo"] + import re as _re + match = _re.search(r"@id\s*=\s*['\"]([^'\"]+)['\"]", value) + if match: + new_value = match.group(1) + return UpgradeSuggestion( + file_path=finding.file_path, + action_index=finding.action_index, + from_strategy=finding.strategy, + from_value=value, + to_strategy="ID", + to_value=new_value, + rationale=f"XPath anchored on @id={new_value!r}; ID strategy is stabler", + ) + match = _re.search(r"@data-testid\s*=\s*['\"]([^'\"]+)['\"]", value) + if match: + return UpgradeSuggestion( + file_path=finding.file_path, + action_index=finding.action_index, + from_strategy=finding.strategy, + from_value=value, + to_strategy="CSS_SELECTOR", + to_value=f"[data-testid='{match.group(1)}']", + rationale="XPath uses data-testid; CSS selector reads cleaner", + ) + return None + + +def _suggest_for_css(finding: LocatorFinding) -> Optional[UpgradeSuggestion]: + """Heuristic: if a CSS selector anchors on ``#id`` alone, suggest ID.""" + value = finding.value.strip() + if value.startswith("#") and " " not in value and ">" not in value: + return UpgradeSuggestion( + file_path=finding.file_path, + action_index=finding.action_index, + from_strategy=finding.strategy, + from_value=value, + to_strategy="ID", + to_value=value[1:], + rationale="CSS selector is a single #id; switch to ID strategy", + ) + return None + + +def suggest_upgrade(finding: LocatorFinding) -> Optional[UpgradeSuggestion]: + """ + 回傳 finding 的一個升級建議;找不到合理建議回 None。 + Look for a structural pattern that points at a better strategy. Returns + None when the finding is already strong or we can't find a clear win. + """ + strategy = finding.strategy + if strategy in {"XPATH", "xpath"}: + return _suggest_for_xpath(finding) + if strategy in {"CSS_SELECTOR", "css selector", "css"}: + return _suggest_for_css(finding) + return None + + +def suggest_upgrades( + findings: Iterable[LocatorFinding], + *, + only_below: Optional[int] = None, +) -> List[UpgradeSuggestion]: + """ + 對一批 finding 收集所有可行的升級建議。 + Walk a finding list and return every upgrade suggestion. Pass + ``only_below`` to skip findings whose static score is already above + a chosen threshold. + """ + suggestions: List[UpgradeSuggestion] = [] + for finding in findings: + if only_below is not None and finding.score >= only_below: + continue + suggestion = suggest_upgrade(finding) + if suggestion is not None: + suggestions.append(suggestion) + return suggestions + + +def apply_upgrades( + actions: List[Any], + suggestions: Iterable[UpgradeSuggestion], +) -> List[Any]: + """ + 根據 suggestion 把 action list 內的 locator 改寫,回傳新的 list。 + Non-mutating: returns a deep-copied action list with the chosen + suggestions applied. Suggestions whose ``action_index`` is out of range + are skipped with a warning. + """ + import copy as _copy + new_actions = _copy.deepcopy(actions) + by_index: Dict[int, UpgradeSuggestion] = {} + for s in suggestions: + by_index[s.action_index] = s + for index, action in enumerate(new_actions): + suggestion = by_index.get(index) + if suggestion is None: + continue + if not isinstance(action, list) or len(action) < 2: + continue + kwargs = action[2] if len(action) >= 3 and isinstance(action[2], dict) else ( + action[1] if isinstance(action[1], dict) else None + ) + if kwargs is None: + continue + if "object_type" in kwargs: + kwargs["object_type"] = suggestion.to_strategy + if "strategy" in kwargs: + kwargs["strategy"] = suggestion.to_strategy + if "test_object_name" in kwargs: + kwargs["test_object_name"] = suggestion.to_value + if "value" in kwargs: + kwargs["value"] = suggestion.to_value + return new_actions + + +# ---------- rendering ---------------------------------------------------- + +def render_health_markdown(report: LocatorHealthReport) -> str: + """Render the report as markdown suitable for PR comments.""" + pieces = [ + "## Locator health report", + "", + f"- **Total locators:** {report.total}", + f"- **Weak (< {report.threshold}):** {report.weak}", + f"- **Strong:** {report.strong}", + f"- **Average score:** {report.average_score}", + "", + ] + if report.weakest: + pieces.append("### Weakest locators") + pieces.append("| File | Idx | Strategy | Value | Score | Reasons |") + pieces.append("|------|-----|----------|-------|-------|---------|") + for f in report.weakest: + value = (f.value[:60] + "…") if len(f.value) > 60 else f.value + pieces.append( + f"| `{Path(f.file_path).name}` | {f.action_index} | `{f.strategy}` " + f"| `{value}` | {f.score} | {'; '.join(f.reasons) or '—'} |" + ) + pieces.append("") + if report.fallback_offenders: + pieces.append("### Self-healing offenders (fallback fired at runtime)") + pieces.append("| File | Strategy | Hits | Fallback used | Rate |") + pieces.append("|------|----------|------|---------------|------|") + for f in report.fallback_offenders: + pieces.append( + f"| `{Path(f.file_path).name}` | `{f.strategy}` | {f.hits} | " + f"{f.fallback_used} | {f.fallback_rate:.0%} |" + ) + pieces.append("") + return "\n".join(pieces).rstrip() + "\n" + + +def save_health_report( + report: LocatorHealthReport, + output_path: Union[str, Path], +) -> Path: + """Persist the JSON form of the report next to a CI artifact.""" + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as fp: + json.dump(report.to_dict(), fp, ensure_ascii=False, indent=2) + web_runner_logger.info(f"save_health_report: wrote {path}") + return path diff --git a/je_web_runner/utils/long_animation_frame/__init__.py b/je_web_runner/utils/long_animation_frame/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/long_animation_frame/frames.py b/je_web_runner/utils/long_animation_frame/frames.py new file mode 100644 index 0000000..2b0b515 --- /dev/null +++ b/je_web_runner/utils/long_animation_frame/frames.py @@ -0,0 +1,207 @@ +""" +Long Animation Frame (LoAF) API。 Chrome 在 2024 推出的新觀測 API,取代 +``longtask`` —— 不只是 50ms+ task,而是包含整個 rAF + style/layout/paint +的「卡幀」週期,讓你能找出實際造成 jank 的真兇(哪個 script 跑太久、 +style/layout 重算成本、forced reflow 細節)。 + +This module: + +* Generates the JS to subscribe via ``PerformanceObserver({type: + 'long-animation-frame'})``. +* Parses the harvested log into structured records. +* Reports per-script attribution and asserts a budget. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class LongAnimationFrameError(WebRunnerException): + """Raised on malformed log input or budget breach.""" + + +# ---------- instrumentation ------------------------------------------- + +_INSTALL = """ +(function() { + if (window.__wr_loaf_installed__) return; + window.__wr_loaf_installed__ = true; + window.__wr_loaf_log__ = []; + if (!('PerformanceObserver' in window)) return; + try { + const obs = new PerformanceObserver(function(list) { + list.getEntries().forEach(function(e) { + const scripts = (e.scripts || []).map(function(s) { + return { + name: s.name || '', + invoker: s.invoker || '', + invoker_type: s.invokerType || '', + source_url: s.sourceURL || '', + duration_ms: s.duration, + forced_style_layout_duration_ms: s.forcedStyleAndLayoutDuration || 0, + pause_duration_ms: s.pauseDuration || 0 + }; + }); + window.__wr_loaf_log__.push({ + duration_ms: e.duration, + render_start_ms: e.renderStart || 0, + style_layout_start_ms: e.styleAndLayoutStart || 0, + start_time_ms: e.startTime, + blocking_duration_ms: e.blockingDuration || 0, + scripts: scripts + }); + }); + }); + obs.observe({type: 'long-animation-frame', buffered: true}); + } catch (e) { /* unsupported */ } +})(); +""".strip() + + +def build_install_script() -> str: + return _INSTALL + + +HARVEST_SCRIPT = "return window.__wr_loaf_log__ || [];" + + +# ---------- data -------------------------------------------------------- + +@dataclass +class ScriptAttribution: + """Per-script breakdown inside a long animation frame.""" + + name: str + invoker: str + invoker_type: str + source_url: str + duration_ms: float + forced_style_layout_duration_ms: float = 0.0 + pause_duration_ms: float = 0.0 + + +@dataclass +class LongFrame: + """One long-animation-frame entry.""" + + start_time_ms: float + duration_ms: float + render_start_ms: float + style_layout_start_ms: float + blocking_duration_ms: float + scripts: List[ScriptAttribution] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def parse_log(payload: Any) -> List[LongFrame]: + """Convert the harvested ``__wr_loaf_log__`` array into typed frames.""" + if not isinstance(payload, list): + raise LongAnimationFrameError( + f"payload must be list, got {type(payload).__name__}" + ) + out: List[LongFrame] = [] + for raw in payload: + if not isinstance(raw, dict): + continue + try: + scripts = [] + for sraw in raw.get("scripts") or []: + if not isinstance(sraw, dict): + continue + scripts.append(ScriptAttribution( + name=str(sraw.get("name") or ""), + invoker=str(sraw.get("invoker") or ""), + invoker_type=str(sraw.get("invoker_type") or ""), + source_url=str(sraw.get("source_url") or ""), + duration_ms=float(sraw.get("duration_ms") or 0.0), + forced_style_layout_duration_ms=float( + sraw.get("forced_style_layout_duration_ms") or 0.0, + ), + pause_duration_ms=float(sraw.get("pause_duration_ms") or 0.0), + )) + out.append(LongFrame( + start_time_ms=float(raw.get("start_time_ms") or 0.0), + duration_ms=float(raw.get("duration_ms") or 0.0), + render_start_ms=float(raw.get("render_start_ms") or 0.0), + style_layout_start_ms=float(raw.get("style_layout_start_ms") or 0.0), + blocking_duration_ms=float(raw.get("blocking_duration_ms") or 0.0), + scripts=scripts, + )) + except (TypeError, ValueError) as error: + raise LongAnimationFrameError( + f"malformed loaf entry {raw!r}: {error}" + ) from error + return out + + +# ---------- aggregation / reports -------------------------------------- + +@dataclass +class LoafReport: + """Rolled-up view across all frames.""" + + frames: List[LongFrame] = field(default_factory=list) + + def worst_frame_ms(self) -> float: + return max((f.duration_ms for f in self.frames), default=0.0) + + def total_blocking_ms(self) -> float: + return sum(f.blocking_duration_ms for f in self.frames) + + def top_scripts(self, *, n: int = 5) -> List[ScriptAttribution]: + """Top N scripts by aggregated duration across all frames.""" + bucket: Dict[str, ScriptAttribution] = {} + for frame in self.frames: + for s in frame.scripts: + key = s.source_url or s.name or s.invoker + existing = bucket.get(key) + if existing is None: + bucket[key] = ScriptAttribution( + name=s.name, invoker=s.invoker, + invoker_type=s.invoker_type, source_url=s.source_url, + duration_ms=s.duration_ms, + forced_style_layout_duration_ms=s.forced_style_layout_duration_ms, + pause_duration_ms=s.pause_duration_ms, + ) + else: + existing.duration_ms += s.duration_ms + existing.forced_style_layout_duration_ms += ( + s.forced_style_layout_duration_ms + ) + existing.pause_duration_ms += s.pause_duration_ms + return sorted(bucket.values(), key=lambda s: -s.duration_ms)[:n] + + +# ---------- assertions ------------------------------------------------- + +def assert_no_frame_over(report: LoafReport, *, max_ms: float) -> None: + """Assert every frame's duration is ``<= max_ms``.""" + if not isinstance(report, LoafReport): + raise LongAnimationFrameError("expects LoafReport") + if max_ms <= 0: + raise LongAnimationFrameError("max_ms must be > 0") + bad = [f for f in report.frames if f.duration_ms > max_ms] + if bad: + sample = ", ".join(f"{f.duration_ms:.0f}ms" for f in bad[:3]) + more = "" if len(bad) <= 3 else f" (+{len(bad) - 3})" + raise LongAnimationFrameError( + f"long animation frames over {max_ms}ms: {sample}{more}" + ) + + +def assert_total_blocking_under(report: LoafReport, *, max_ms: float) -> None: + """Assert total blocking time across all frames is ``<= max_ms``.""" + if not isinstance(report, LoafReport): + raise LongAnimationFrameError("expects LoafReport") + if max_ms < 0: + raise LongAnimationFrameError("max_ms must be >= 0") + total = report.total_blocking_ms() + if total > max_ms: + raise LongAnimationFrameError( + f"total blocking {total:.1f}ms exceeds budget {max_ms}ms" + ) diff --git a/je_web_runner/utils/mixed_content_audit/__init__.py b/je_web_runner/utils/mixed_content_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/mixed_content_audit/audit.py b/je_web_runner/utils/mixed_content_audit/audit.py new file mode 100644 index 0000000..ec25183 --- /dev/null +++ b/je_web_runner/utils/mixed_content_audit/audit.py @@ -0,0 +1,232 @@ +""" +偵測 HTTPS 頁面內載入的 HTTP 資源(瀏覽器靜默 block,平常測不出來)。 +Modern browsers block "active" mixed content (script/iframe/xhr) silently +and downgrade-upgrade "passive" content (img/video/audio) — both end up +as broken UX. This module parses HAR / console-error / response-header +sources and flags everything that doesn't match the page's secure +origin. + +Classification follows MDN's split: + +* **Active** — script, link rel=stylesheet, iframe, fetch/XHR, WebSocket + → BLOCKED outright +* **Passive** — image, audio, video, font → loaded but flagged +""" +from __future__ import annotations + +import json +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union +from urllib.parse import urlparse + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class MixedContentAuditError(WebRunnerException): + """Raised on malformed HAR / failed assertion.""" + + +class Severity(str, Enum): + """Mixed-content severity buckets.""" + + ACTIVE = "active" # blocked / broken + PASSIVE = "passive" # works but flagged + UPGRADE = "upgrade" # browser upgraded http→https automatically + + +_ACTIVE_TYPES = { + "script", "stylesheet", "iframe", "subdocument", "xhr", "fetch", + "websocket", "manifest", "preflight", +} +_PASSIVE_TYPES = { + "image", "imageset", "media", "video", "audio", "font", "track", +} + +# Sites that auto-redirect http→https; this module still flags them as +# "upgrade" so devs can fix the source link, but they're a softer signal. +_HSTS_AUTO_DOMAINS = {"www.google.com", "www.youtube.com", "fonts.googleapis.com"} + + +# ---------- findings ---------------------------------------------------- + +@dataclass +class MixedFinding: + """One http resource on an https page.""" + + url: str + resource_type: str + severity: Severity + source_url: str = "" + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "severity": self.severity.value} + + +# ---------- classification --------------------------------------------- + +def _classify(resource_type: str) -> Severity: + rt = (resource_type or "").lower() + if rt in _ACTIVE_TYPES: + return Severity.ACTIVE + if rt in _PASSIVE_TYPES: + return Severity.PASSIVE + return Severity.ACTIVE # default to strict: unknown = active + + +def _is_http(url: str) -> bool: + try: + return urlparse(url).scheme.lower() == "http" + except (ValueError, AttributeError): + return False + + +def _is_https_origin(url: str) -> bool: + try: + return urlparse(url).scheme.lower() == "https" + except (ValueError, AttributeError): + return False + + +def _hostname(url: str) -> str: + try: + return (urlparse(url).hostname or "").lower() + except (ValueError, AttributeError): + return "" + + +# ---------- scanners --------------------------------------------------- + +def scan_har( + har: Union[str, Dict[str, Any]], + *, + page_url: Optional[str] = None, +) -> List[MixedFinding]: + """ + Parse a HAR object/string, returning one finding per http request on + an https page. When ``page_url`` is None, we assume the first entry's + page URL is the document. + """ + har_obj = _coerce_har(har) + entries = ((har_obj.get("log") or {}).get("entries")) or [] + if not isinstance(entries, list): + raise MixedContentAuditError("har log.entries must be a list") + + document_url = page_url or _first_page_url(har_obj) or "" + if document_url and not _is_https_origin(document_url): + return [] # no risk if the page itself is http + + findings: List[MixedFinding] = [] + for entry in entries: + if not isinstance(entry, dict): + continue + request_url = ((entry.get("request") or {}).get("url")) or "" + if not _is_http(request_url): + continue + resource_type = str( + entry.get("_resourceType") + or entry.get("resourceType") + or "" + ) + severity = _classify(resource_type) + hostname = _hostname(request_url) + if hostname in _HSTS_AUTO_DOMAINS: + severity = Severity.UPGRADE + findings.append(MixedFinding( + url=request_url, + resource_type=resource_type or "unknown", + severity=severity, + source_url=document_url, + )) + return findings + + +def _coerce_har(har: Union[str, Dict[str, Any]]) -> Dict[str, Any]: + if isinstance(har, str): + try: + parsed = json.loads(har) + except ValueError as error: + raise MixedContentAuditError(f"har not JSON: {error}") from error + if not isinstance(parsed, dict): + raise MixedContentAuditError("har JSON must be an object") + return parsed + if isinstance(har, dict): + return har + raise MixedContentAuditError( + f"scan_har expects str/dict, got {type(har).__name__}" + ) + + +def _first_page_url(har: Dict[str, Any]) -> Optional[str]: + pages = ((har.get("log") or {}).get("pages")) or [] + if isinstance(pages, list) and pages: + first = pages[0] + if isinstance(first, dict): + url = first.get("title") or first.get("id") + if isinstance(url, str) and url.startswith("http"): + return url + return None + + +_MIXED_CONTENT_CONSOLE_RE = re.compile(r"mixed content", re.IGNORECASE) +_ACTIVE_HINT_RE = re.compile( + r"\b(active|blocked|insecure script|insecure stylesheet|insecure xhr|insecure fetch|insecure iframe)\b", + re.IGNORECASE, +) + + +def scan_console_errors( + messages: Iterable[str], + *, + page_url: str = "", +) -> List[MixedFinding]: + """Heuristic scan over console errors for ``Mixed Content:`` lines.""" + out: List[MixedFinding] = [] + for line in messages: + if not isinstance(line, str): + continue + if not _MIXED_CONTENT_CONSOLE_RE.search(line): + continue + http_urls = [u.rstrip(".,;)\"'") for u in re.findall(r"https?://\S+", line) + if _is_http(u.rstrip(".,;)\"'"))] + if not http_urls: + continue + url = http_urls[0] + severity = ( + Severity.ACTIVE if _ACTIVE_HINT_RE.search(line) else Severity.PASSIVE + ) + out.append(MixedFinding( + url=url or "(unknown)", + resource_type="console", + severity=severity, + source_url=page_url, + )) + return out + + +# ---------- assertions ------------------------------------------------- + +def assert_no_active(findings: Sequence[MixedFinding]) -> None: + """Raise if any active-mixed-content finding is present.""" + actives = [f for f in findings if f.severity == Severity.ACTIVE] + if actives: + sample = ", ".join(f.url for f in actives[:3]) + more = "" if len(actives) <= 3 else f" (+{len(actives) - 3} more)" + raise MixedContentAuditError(f"active mixed content: {sample}{more}") + + +def assert_clean(findings: Sequence[MixedFinding]) -> None: + """Raise if any finding is present (strictest).""" + if findings: + sample = ", ".join(f"{f.severity.value}:{f.url}" for f in findings[:3]) + more = "" if len(findings) <= 3 else f" (+{len(findings) - 3} more)" + raise MixedContentAuditError(f"mixed content detected: {sample}{more}") + + +def summary(findings: Sequence[MixedFinding]) -> Dict[str, int]: + """Return ``{severity: count}`` summary.""" + out: Dict[str, int] = {} + for f in findings: + out[f.severity.value] = out.get(f.severity.value, 0) + 1 + return out diff --git a/je_web_runner/utils/multimodal_qa/__init__.py b/je_web_runner/utils/multimodal_qa/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/multimodal_qa/qa.py b/je_web_runner/utils/multimodal_qa/qa.py new file mode 100644 index 0000000..178426f --- /dev/null +++ b/je_web_runner/utils/multimodal_qa/qa.py @@ -0,0 +1,210 @@ +""" +給 vision LLM 看截圖 + 問題,解析「對嗎?」的結構化回答。 +Use cases that snapshot-diff can't catch on its own: + +* "Is the error toast styled like the design system spec?" +* "Does the chart axis labelling make sense?" +* "Are there any visual artifacts (cropped text, overlapping elements)?" + +The LLM call is hidden behind :class:`VisionClient`, so this module +stays unit-testable without a model. Production code plugs in Claude +Vision / GPT-4o / a local VLM. + +Response handling is defensive: the JSON envelope is required, but +malformed responses degrade to a clear failure rather than a silent +pass. +""" +from __future__ import annotations + +import base64 +import json +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Dict, List, Optional, Protocol, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class MultimodalQaError(WebRunnerException): + """Raised on bad image input, client failure, or unparseable response.""" + + +# ---------- enums ------------------------------------------------------- + +class Verdict(str, Enum): + """Final outcome of a single Q.""" + + PASS = "pass" + FAIL = "fail" + UNCERTAIN = "uncertain" + + +# ---------- data -------------------------------------------------------- + +@dataclass +class QaRequest: + """One screenshot + question to send to the LLM.""" + + image_bytes: bytes + question: str + rubric: List[str] = field(default_factory=list) + image_label: str = "" + + def __post_init__(self) -> None: + if not isinstance(self.image_bytes, (bytes, bytearray)): + raise MultimodalQaError( + f"image_bytes must be bytes, got {type(self.image_bytes).__name__}" + ) + if not self.image_bytes: + raise MultimodalQaError("image_bytes must be non-empty") + if not isinstance(self.question, str) or not self.question.strip(): + raise MultimodalQaError("question must be a non-empty string") + + def b64_image(self) -> str: + return base64.b64encode(bytes(self.image_bytes)).decode("ascii") + + +@dataclass +class QaResponse: + """Parsed vision-LLM response.""" + + verdict: Verdict + confidence: float + rationale: str + issues: List[str] = field(default_factory=list) + raw: str = "" + + def is_pass(self) -> bool: + return self.verdict == Verdict.PASS + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "verdict": self.verdict.value} + + +# ---------- client protocol -------------------------------------------- + +class VisionClient(Protocol): + """LLM client interface.""" + + def ask(self, prompt: str, image_b64: str) -> str: ... + + +# ---------- prompt ----------------------------------------------------- + +def build_prompt(request: QaRequest) -> str: + """Render a deterministic prompt the vision model will see.""" + parts: List[str] = [ + "You are reviewing a UI screenshot. Answer the question strictly.", + "Respond with ONLY a JSON object on a single line with keys:", + ' "verdict": one of "pass" | "fail" | "uncertain"', + ' "confidence": number in [0, 1]', + ' "rationale": short string explaining the verdict', + ' "issues": list of strings, one per concrete issue (empty if none)', + "", + f"Question: {request.question.strip()}", + ] + if request.rubric: + parts.append("Rubric (each item must be true for a 'pass'):") + parts.extend(f"- {item}" for item in request.rubric) + return "\n".join(parts) + + +# ---------- parsing ---------------------------------------------------- + +_JSON_LINE_RE = re.compile(r"\{.*\}", re.DOTALL) + + +def parse_response(raw: str) -> QaResponse: + """Parse the model's text into a :class:`QaResponse`.""" + if not isinstance(raw, str) or not raw.strip(): + raise MultimodalQaError("model response was empty") + match = _JSON_LINE_RE.search(raw) + if not match: + raise MultimodalQaError(f"no JSON object in response: {raw[:200]!r}") + try: + obj = json.loads(match.group(0)) + except ValueError as error: + raise MultimodalQaError( + f"response was not valid JSON ({error}): {raw[:200]!r}" + ) from error + if not isinstance(obj, dict): + raise MultimodalQaError(f"response JSON must be an object, got {type(obj).__name__}") + try: + verdict = Verdict(str(obj.get("verdict") or "").lower()) + except ValueError as error: + raise MultimodalQaError(f"unknown verdict in response: {error}") from error + confidence_raw = obj.get("confidence") + if not isinstance(confidence_raw, (int, float)): + raise MultimodalQaError("response missing numeric 'confidence'") + confidence = max(0.0, min(1.0, float(confidence_raw))) + issues_raw = obj.get("issues") or [] + if not isinstance(issues_raw, list): + raise MultimodalQaError("response 'issues' must be a list") + return QaResponse( + verdict=verdict, + confidence=confidence, + rationale=str(obj.get("rationale") or ""), + issues=[str(i) for i in issues_raw], + raw=raw, + ) + + +# ---------- ask -------------------------------------------------------- + +def ask( + request: QaRequest, + client: VisionClient, +) -> QaResponse: + """Build prompt → ask client → parse → return :class:`QaResponse`.""" + prompt = build_prompt(request) + try: + raw = client.ask(prompt, request.b64_image()) + except Exception as error: + raise MultimodalQaError(f"vision client failed: {error!r}") from error + return parse_response(raw) + + +def ask_path( + path: Union[str, Path], + question: str, + client: VisionClient, + *, + rubric: Sequence[str] = (), +) -> QaResponse: + """Convenience: load an image off disk and call :func:`ask`.""" + p = Path(path) + if not p.exists(): + raise MultimodalQaError(f"image not found: {p}") + request = QaRequest( + image_bytes=p.read_bytes(), + question=question, + rubric=list(rubric), + image_label=str(p), + ) + return ask(request, client) + + +# ---------- assertion -------------------------------------------------- + +def assert_passes( + response: QaResponse, + *, + min_confidence: float = 0.6, +) -> None: + """Raise unless the response is a confident pass.""" + if not isinstance(response, QaResponse): + raise MultimodalQaError("assert_passes expects QaResponse") + if not 0.0 <= min_confidence <= 1.0: + raise MultimodalQaError("min_confidence must be in [0, 1]") + if response.verdict != Verdict.PASS: + joined = ", ".join(response.issues) or response.rationale + raise MultimodalQaError( + f"verdict={response.verdict.value} (confidence={response.confidence:.2f}): {joined}" + ) + if response.confidence < min_confidence: + raise MultimodalQaError( + f"verdict=pass but confidence {response.confidence:.2f} " + f"< min {min_confidence:.2f}: {response.rationale}" + ) diff --git a/je_web_runner/utils/mutation_testing/__init__.py b/je_web_runner/utils/mutation_testing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/mutation_testing/mutator.py b/je_web_runner/utils/mutation_testing/mutator.py new file mode 100644 index 0000000..bfc8fbb --- /dev/null +++ b/je_web_runner/utils/mutation_testing/mutator.py @@ -0,0 +1,446 @@ +""" +Mutation testing for action JSON:對 action 套用變異後執行,反向驗證測試本身的偵測能力。 + +Mutation testing of WebRunner action JSON files. Given a passing test, we +apply a catalogue of mutations (locator swap, timeout shrink, URL change, +assertion flip, action removal, adjacent reorder) and re-run. A mutation +is "killed" when the mutated test fails. The mutation score is +``killed / total`` — high scores mean the test is sensitive, low scores +mean the test passes regardless of obvious sabotage. + +The executor is caller-supplied (``Callable[[List[Any]], bool]``) so the +module stays decoupled from the Selenium/Playwright runtime and is easy +to test offline. +""" +from __future__ import annotations + +import copy +import json +import random +from dataclasses import asdict, dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class MutationTestingError(WebRunnerException): + """Raised on invalid input or executor protocol violations.""" + + +class MutationType(str, Enum): + LOCATOR_SWAP = "locator_swap" + TIMEOUT_SHRINK = "timeout_shrink" + URL_CHANGE = "url_change" + ASSERTION_FLIP = "assertion_flip" + ACTION_REMOVAL = "action_removal" + ADJACENT_REORDER = "adjacent_reorder" + + +_DEFAULT_MUTATION_TYPES: Sequence[MutationType] = tuple(MutationType) + +_LOCATOR_KEYS = ("test_object_name", "value", "locator", "selector") +_URL_KEYS = ("url", "target_url", "expected_url") +_TIMEOUT_KEYS = ("timeout", "wait_seconds", "delay") +_ASSERTION_KEYS = ("expected", "expected_value", "expected_text") +_NON_REMOVABLE_PREFIXES = ("WR_set_", "WR_quit", "WR_init", "WR_to_url") + + +@dataclass +class Mutation: + """One mutation applied to a copy of the action list.""" + + type: MutationType + action_index: int + description: str + original: Any = None + mutated: Any = None + + def to_dict(self) -> Dict[str, Any]: + out = asdict(self) + out["type"] = self.type.value + return out + + +@dataclass +class MutationResult: + """Outcome of running one mutation through the executor.""" + + mutation: Mutation + killed: bool + error: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + return { + "mutation": self.mutation.to_dict(), + "killed": self.killed, + "error": self.error, + } + + +@dataclass +class MutationScore: + """Aggregate mutation score for a single action file.""" + + total: int + killed: int + survived: int + score: float + results: List[MutationResult] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "total": self.total, + "killed": self.killed, + "survived": self.survived, + "score": self.score, + "results": [r.to_dict() for r in self.results], + } + + +# ---------- helpers ------------------------------------------------------ + +def _kwargs_of(action: List[Any]) -> Optional[Dict[str, Any]]: + if not isinstance(action, list) or not action: + return None + if len(action) >= 3 and isinstance(action[2], dict): + return action[2] + if len(action) >= 2 and isinstance(action[1], dict): + return action[1] + return None + + +def _action_command(action: List[Any]) -> str: + if isinstance(action, list) and action and isinstance(action[0], str): + return action[0] + return "" + + +# ---------- mutation generators ----------------------------------------- + +def _gen_locator_swap(actions: List[Any]) -> List[Mutation]: + mutations: List[Mutation] = [] + for idx, action in enumerate(actions): + kwargs = _kwargs_of(action) + if not kwargs: + continue + for key in _LOCATOR_KEYS: + if key in kwargs and isinstance(kwargs[key], str): + mutations.append(Mutation( + type=MutationType.LOCATOR_SWAP, + action_index=idx, + description=f"swap {key} → '__mutated_{key}__'", + original=kwargs[key], + mutated=f"__mutated_{key}__", + )) + break + return mutations + + +def _gen_timeout_shrink(actions: List[Any]) -> List[Mutation]: + mutations: List[Mutation] = [] + for idx, action in enumerate(actions): + kwargs = _kwargs_of(action) + if not kwargs: + continue + for key in _TIMEOUT_KEYS: + if key in kwargs and isinstance(kwargs[key], (int, float)): + mutations.append(Mutation( + type=MutationType.TIMEOUT_SHRINK, + action_index=idx, + description=f"shrink {key} → 0.001", + original=kwargs[key], + mutated=0.001, + )) + break + return mutations + + +def _gen_url_change(actions: List[Any]) -> List[Mutation]: + mutations: List[Mutation] = [] + for idx, action in enumerate(actions): + kwargs = _kwargs_of(action) + if not kwargs: + continue + for key in _URL_KEYS: + if key in kwargs and isinstance(kwargs[key], str): + mutations.append(Mutation( + type=MutationType.URL_CHANGE, + action_index=idx, + description=f"swap {key} → 'https://example.invalid/mut'", + original=kwargs[key], + mutated="https://example.invalid/mut", + )) + break + return mutations + + +def _flip_assertion_value(value: Any) -> Any: + if isinstance(value, bool): + return not value + if isinstance(value, str): + return value + "__MUTATED__" + if isinstance(value, (int, float)): + return value + 1 + if isinstance(value, list): + return list(reversed(value)) + return f"__MUTATED__{value!r}" + + +def _gen_assertion_flip(actions: List[Any]) -> List[Mutation]: + mutations: List[Mutation] = [] + for idx, action in enumerate(actions): + kwargs = _kwargs_of(action) + if not kwargs: + continue + for key in _ASSERTION_KEYS: + if key in kwargs: + flipped = _flip_assertion_value(kwargs[key]) + mutations.append(Mutation( + type=MutationType.ASSERTION_FLIP, + action_index=idx, + description=f"flip {key}", + original=kwargs[key], + mutated=flipped, + )) + break + return mutations + + +def _gen_action_removal(actions: List[Any]) -> List[Mutation]: + mutations: List[Mutation] = [] + for idx, action in enumerate(actions): + command = _action_command(action) + if not command: + continue + if any(command.startswith(prefix) for prefix in _NON_REMOVABLE_PREFIXES): + continue + mutations.append(Mutation( + type=MutationType.ACTION_REMOVAL, + action_index=idx, + description=f"remove {command}", + original=action, + mutated=None, + )) + return mutations + + +def _gen_adjacent_reorder(actions: List[Any]) -> List[Mutation]: + mutations: List[Mutation] = [] + for idx in range(len(actions) - 1): + if not isinstance(actions[idx], list) or not isinstance(actions[idx + 1], list): + continue + if not actions[idx] or not actions[idx + 1]: + continue + cmd_a = _action_command(actions[idx]) + cmd_b = _action_command(actions[idx + 1]) + if any(cmd_a.startswith(p) for p in _NON_REMOVABLE_PREFIXES): + continue + if any(cmd_b.startswith(p) for p in _NON_REMOVABLE_PREFIXES): + continue + mutations.append(Mutation( + type=MutationType.ADJACENT_REORDER, + action_index=idx, + description=f"swap actions {idx} and {idx + 1}", + original=(cmd_a, cmd_b), + mutated=(cmd_b, cmd_a), + )) + return mutations + + +_GENERATORS: Dict[MutationType, Callable[[List[Any]], List[Mutation]]] = { + MutationType.LOCATOR_SWAP: _gen_locator_swap, + MutationType.TIMEOUT_SHRINK: _gen_timeout_shrink, + MutationType.URL_CHANGE: _gen_url_change, + MutationType.ASSERTION_FLIP: _gen_assertion_flip, + MutationType.ACTION_REMOVAL: _gen_action_removal, + MutationType.ADJACENT_REORDER: _gen_adjacent_reorder, +} + + +def generate_mutations( + actions: List[Any], + types: Sequence[MutationType] = _DEFAULT_MUTATION_TYPES, + *, + seed: Optional[int] = None, + max_per_type: Optional[int] = None, +) -> List[Mutation]: + """ + 依 mutation type 對 action list 生出可能的變異。 + Run every configured generator and concatenate. ``max_per_type`` caps + each type's contribution (deterministic when ``seed`` is set) so very + large suites don't generate hundreds of mutations. + """ + if not isinstance(actions, list): + raise MutationTestingError(f"actions must be a list, got {type(actions).__name__}") + rng = random.Random(seed) if seed is not None else random + all_mutations: List[Mutation] = [] + for mt in types: + generator = _GENERATORS.get(mt) + if generator is None: + continue + generated = generator(actions) + if max_per_type is not None and len(generated) > max_per_type: + generated = rng.sample(generated, max_per_type) + all_mutations.extend(generated) + return all_mutations + + +# ---------- apply --------------------------------------------------------- + +def apply_mutation(actions: List[Any], mutation: Mutation) -> List[Any]: + """ + 產生一份套了 mutation 的 actions(不修改原 list)。 + Return a deep-copied action list with ``mutation`` applied. Mutations + targeting an out-of-range index raise :class:`MutationTestingError` + so the executor never receives a malformed list. + """ + if mutation.action_index < 0 or mutation.action_index >= len(actions): + raise MutationTestingError( + f"mutation index {mutation.action_index} out of range for {len(actions)} actions" + ) + new_actions = copy.deepcopy(actions) + if mutation.type is MutationType.ACTION_REMOVAL: + del new_actions[mutation.action_index] + return new_actions + if mutation.type is MutationType.ADJACENT_REORDER: + idx = mutation.action_index + if idx + 1 >= len(new_actions): + raise MutationTestingError("reorder requires a following action") + new_actions[idx], new_actions[idx + 1] = new_actions[idx + 1], new_actions[idx] + return new_actions + kwargs = _kwargs_of(new_actions[mutation.action_index]) + if kwargs is None: + raise MutationTestingError( + f"action at index {mutation.action_index} has no kwargs to mutate" + ) + if mutation.type is MutationType.LOCATOR_SWAP: + for key in _LOCATOR_KEYS: + if key in kwargs: + kwargs[key] = mutation.mutated + return new_actions + if mutation.type is MutationType.TIMEOUT_SHRINK: + for key in _TIMEOUT_KEYS: + if key in kwargs: + kwargs[key] = mutation.mutated + return new_actions + if mutation.type is MutationType.URL_CHANGE: + for key in _URL_KEYS: + if key in kwargs: + kwargs[key] = mutation.mutated + return new_actions + if mutation.type is MutationType.ASSERTION_FLIP: + for key in _ASSERTION_KEYS: + if key in kwargs: + kwargs[key] = mutation.mutated + return new_actions + raise MutationTestingError( + f"could not apply mutation {mutation.type.value} at {mutation.action_index}" + ) + + +# ---------- runner ------------------------------------------------------- + +ExecutorFn = Callable[[List[Any]], bool] + + +def run_mutation_testing( + actions: List[Any], + executor: ExecutorFn, + *, + types: Sequence[MutationType] = _DEFAULT_MUTATION_TYPES, + seed: Optional[int] = None, + max_per_type: Optional[int] = None, + stop_on_first_survivor: bool = False, +) -> MutationScore: + """ + 對每個 mutation 跑一次 executor,計算 kill rate。 + ``executor(mutated_actions)`` must return ``True`` if the mutated + suite still passed (mutation survived) or ``False`` if the run failed + (mutation was killed — the desired outcome). Exceptions raised by the + executor are caught and treated as kills (failures). + """ + mutations = generate_mutations(actions, types, seed=seed, max_per_type=max_per_type) + results: List[MutationResult] = [] + for mutation in mutations: + mutated = apply_mutation(actions, mutation) + try: + passed = bool(executor(mutated)) + except Exception as error: # noqa: BLE001 — executor may raise + results.append(MutationResult( + mutation=mutation, killed=True, error=repr(error), + )) + continue + results.append(MutationResult(mutation=mutation, killed=not passed)) + if stop_on_first_survivor and passed: + web_runner_logger.info( + f"mutation survived early: {mutation.type.value} at {mutation.action_index}" + ) + break + killed = sum(1 for r in results if r.killed) + survived = sum(1 for r in results if not r.killed) + score = (killed / len(results)) if results else 0.0 + return MutationScore( + total=len(results), + killed=killed, + survived=survived, + score=round(score, 4), + results=results, + ) + + +def run_mutation_testing_on_file( + action_path: Union[str, Path], + executor: ExecutorFn, + **kwargs: Any, +) -> MutationScore: + """Convenience: load an action JSON file then run mutation testing.""" + path = Path(action_path) + if not path.is_file(): + raise MutationTestingError(f"action file not found: {path}") + try: + with open(path, encoding="utf-8") as fp: + actions = json.load(fp) + except (OSError, ValueError) as error: + raise MutationTestingError(f"cannot parse {path}: {error!r}") from error + if not isinstance(actions, list): + raise MutationTestingError(f"top-level JSON must be a list: {path}") + return run_mutation_testing(actions, executor, **kwargs) + + +# ---------- rendering ---------------------------------------------------- + +def render_mutation_markdown(score: MutationScore) -> str: + """Render a mutation score as markdown for PR comments.""" + pieces = [ + "## Mutation testing report", + "", + f"- **Mutation score:** {score.score:.0%}", + f"- **Total mutations:** {score.total}", + f"- **Killed:** {score.killed}", + f"- **Survived:** {score.survived}", + "", + ] + survivors = [r for r in score.results if not r.killed] + if survivors: + pieces.append("### Surviving mutations (the test couldn't detect these)") + pieces.append("| Type | Index | Description |") + pieces.append("|------|-------|-------------|") + for r in survivors: + pieces.append( + f"| `{r.mutation.type.value}` | {r.mutation.action_index} | " + f"{r.mutation.description} |" + ) + pieces.append("") + return "\n".join(pieces).rstrip() + "\n" + + +def assert_min_score(score: MutationScore, minimum: float = 0.8) -> None: + """Raise ``MutationTestingError`` when ``score.score`` is below ``minimum``.""" + if score.score < minimum: + raise MutationTestingError( + f"mutation score {score.score:.2f} below minimum {minimum:.2f} " + f"({score.survived} survivors)" + ) diff --git a/je_web_runner/utils/notifications_audit/__init__.py b/je_web_runner/utils/notifications_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/notifications_audit/audit.py b/je_web_runner/utils/notifications_audit/audit.py new file mode 100644 index 0000000..892b086 --- /dev/null +++ b/je_web_runner/utils/notifications_audit/audit.py @@ -0,0 +1,263 @@ +""" +追蹤 `Notification.requestPermission()` 的呼叫時機 + 顯示的 notifications。 +Browsers shame UX bugs around notifications (auto-prompt on page load, +spam after rejection, prompt without user gesture). This module installs +a JS shim that records every permission request and every ``new +Notification(...)`` call, then exposes asserts for common policy +violations. +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class NotificationsAuditError(WebRunnerException): + """Raised on malformed harvest payload / failed assertion.""" + + +class PermissionResult(str, Enum): + GRANTED = "granted" + DENIED = "denied" + DEFAULT = "default" + + +# ---------- model ------------------------------------------------------- + +@dataclass +class PermissionRequest: + """One ``Notification.requestPermission`` call.""" + + timestamp_ms: float + user_gesture: bool + result: PermissionResult + page_age_ms: float = 0.0 + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "result": self.result.value} + + +@dataclass +class NotificationShown: + """One ``new Notification(...)`` call.""" + + timestamp_ms: float + title: str + body: str = "" + tag: Optional[str] = None + require_interaction: bool = False + silent: bool = False + + +@dataclass +class NotificationsLog: + """Combined audit log.""" + + permission_requests: List[PermissionRequest] = field(default_factory=list) + notifications: List[NotificationShown] = field(default_factory=list) + + +# ---------- script generation ------------------------------------------ + +_INSTALL_TEMPLATE = """ +(function() { + if (window.__wr_notif_installed__) return; + window.__wr_notif_installed__ = true; + window.__wr_notif_log__ = {permission_requests: [], notifications: []}; + const pageStart = performance.now(); + + let lastGesture = 0; + ['click', 'keydown', 'pointerup', 'touchend'].forEach(function(t) { + document.addEventListener(t, function() { + lastGesture = performance.now(); + }, true); + }); + + const _realRequest = Notification && Notification.requestPermission + ? Notification.requestPermission.bind(Notification) + : null; + if (_realRequest) { + Notification.requestPermission = function() { + const ts = performance.now(); + const withinGesture = (ts - lastGesture) < 1000; + return _realRequest().then(function(result) { + window.__wr_notif_log__.permission_requests.push({ + timestamp_ms: ts, + user_gesture: withinGesture, + result: String(result), + page_age_ms: ts - pageStart + }); + return result; + }); + }; + } + + const _RealNotification = window.Notification; + if (_RealNotification) { + function FakeNotification(title, opts) { + opts = opts || {}; + window.__wr_notif_log__.notifications.push({ + timestamp_ms: performance.now(), + title: String(title || ''), + body: String(opts.body || ''), + tag: opts.tag != null ? String(opts.tag) : null, + require_interaction: !!opts.requireInteraction, + silent: !!opts.silent + }); + return new _RealNotification(title, opts); + } + FakeNotification.permission = _RealNotification.permission; + FakeNotification.requestPermission = Notification.requestPermission; + Object.setPrototypeOf(FakeNotification, _RealNotification); + window.Notification = FakeNotification; + } +})(); +""".strip() + + +def build_install_script() -> str: + return _INSTALL_TEMPLATE + + +HARVEST_SCRIPT = "return window.__wr_notif_log__ || {permission_requests: [], notifications: []};" + + +# ---------- parsing ---------------------------------------------------- + +def parse_log(payload: Any) -> NotificationsLog: + """Convert the harvested JSON into typed records.""" + if not isinstance(payload, dict): + raise NotificationsAuditError( + f"payload must be dict, got {type(payload).__name__}" + ) + requests: List[PermissionRequest] = [] + for raw in payload.get("permission_requests") or []: + if not isinstance(raw, dict): + continue + try: + result = PermissionResult(str(raw.get("result") or "default")) + except ValueError: + result = PermissionResult.DEFAULT + try: + requests.append(PermissionRequest( + timestamp_ms=float(raw.get("timestamp_ms") or 0), + user_gesture=bool(raw.get("user_gesture", False)), + result=result, + page_age_ms=float(raw.get("page_age_ms") or 0), + )) + except (TypeError, ValueError) as error: + raise NotificationsAuditError( + f"bad permission_request entry {raw!r}: {error}" + ) from error + notifications: List[NotificationShown] = [] + for raw in payload.get("notifications") or []: + if not isinstance(raw, dict): + continue + try: + notifications.append(NotificationShown( + timestamp_ms=float(raw.get("timestamp_ms") or 0), + title=str(raw.get("title") or ""), + body=str(raw.get("body") or ""), + tag=raw.get("tag"), + require_interaction=bool(raw.get("require_interaction", False)), + silent=bool(raw.get("silent", False)), + )) + except (TypeError, ValueError) as error: + raise NotificationsAuditError( + f"bad notification entry {raw!r}: {error}" + ) from error + return NotificationsLog( + permission_requests=requests, + notifications=notifications, + ) + + +# ---------- assertions ------------------------------------------------- + +def assert_no_prompt_without_gesture(log: NotificationsLog) -> None: + """Assert every permission request happened within a user-gesture window.""" + for req in log.permission_requests: + if not req.user_gesture: + raise NotificationsAuditError( + f"Notification.requestPermission called without user gesture " + f"at page age {req.page_age_ms:.0f}ms" + ) + + +def assert_no_prompt_before( + log: NotificationsLog, + *, + min_page_age_ms: float, +) -> None: + """Assert no prompt fires before ``min_page_age_ms`` (avoids auto-prompt on load).""" + if min_page_age_ms < 0: + raise NotificationsAuditError("min_page_age_ms must be >= 0") + for req in log.permission_requests: + if req.page_age_ms < min_page_age_ms: + raise NotificationsAuditError( + f"prompt fired at {req.page_age_ms:.0f}ms, want >= {min_page_age_ms}ms" + ) + + +def assert_no_spam_after_deny(log: NotificationsLog) -> None: + """Assert no further prompts or notifications appear after a 'denied'.""" + deny_time: Optional[float] = None + for req in log.permission_requests: + if req.result == PermissionResult.DENIED: + deny_time = req.timestamp_ms + continue + if deny_time is not None and req.timestamp_ms > deny_time: + raise NotificationsAuditError( + f"re-prompted after denial at {req.timestamp_ms:.0f}ms" + ) + if deny_time is None: + return + for notif in log.notifications: + if notif.timestamp_ms > deny_time: + raise NotificationsAuditError( + f"notification shown after denial: {notif.title!r}" + ) + + +def assert_notification_shown( + log: NotificationsLog, + *, + title_contains: Optional[str] = None, + body_contains: Optional[str] = None, + tag: Optional[str] = None, +) -> NotificationShown: + """Assert at least one notification matches the given filters.""" + if title_contains is None and body_contains is None and tag is None: + raise NotificationsAuditError( + "provide at least one of title_contains / body_contains / tag" + ) + for notif in log.notifications: + if title_contains is not None and title_contains not in notif.title: + continue + if body_contains is not None and body_contains not in notif.body: + continue + if tag is not None and notif.tag != tag: + continue + return notif + raise NotificationsAuditError( + f"no notification matched title_contains={title_contains!r} " + f"body_contains={body_contains!r} tag={tag!r}" + ) + + +def assert_unique_tags(log: NotificationsLog) -> None: + """Assert no tag was reused (would silently replace earlier notification).""" + seen: Dict[str, int] = {} + for notif in log.notifications: + if notif.tag is None: + continue + seen[notif.tag] = seen.get(notif.tag, 0) + 1 + duplicates = [tag for tag, count in seen.items() if count > 1] + if duplicates: + raise NotificationsAuditError( + f"notification tags reused: {sorted(duplicates)}" + ) diff --git a/je_web_runner/utils/ocr_assert/__init__.py b/je_web_runner/utils/ocr_assert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/ocr_assert/ocr.py b/je_web_runner/utils/ocr_assert/ocr.py new file mode 100644 index 0000000..ed55e6a --- /dev/null +++ b/je_web_runner/utils/ocr_assert/ocr.py @@ -0,0 +1,237 @@ +""" +Canvas / WebGL / 圖片內文字 OCR 斷言。補 ``visual_ai`` 只做感知雜湊的缺口 +─ 當你想斷言「圖表標籤是 'Q4 2025'」而不是「兩張圖看起來一樣」時用這個。 + +Thin wrapper around `pytesseract` + Pillow that normalises whitespace, +strips diacritics, and offers a couple of comparison modes (exact, +contains, fuzzy ratio). Tesseract is the only realistic pure-Python +option that runs offline; cloud OCR adapters can be added later via +the same :class:`OcrBackend` protocol. +""" +from __future__ import annotations + +import difflib +import re +import unicodedata +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Callable, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class OcrAssertError(WebRunnerException): + """Raised on missing OCR backend, unreadable image, or failed assertion.""" + + +_WHITESPACE_RE = re.compile(r"\s+") + + +# ---------- normalisation ------------------------------------------------ + +def normalise_text(text: str, *, lowercase: bool = True, strip_accents: bool = True) -> str: + """Collapse whitespace, optionally lowercase + strip combining marks.""" + if not isinstance(text, str): + raise OcrAssertError(f"normalise_text expected str, got {type(text).__name__}") + out = text + if strip_accents: + out = "".join( + ch for ch in unicodedata.normalize("NFKD", out) + if not unicodedata.combining(ch) + ) + if lowercase: + out = out.lower() + out = _WHITESPACE_RE.sub(" ", out).strip() + return out + + +def fuzzy_ratio(a: str, b: str) -> float: + """0..1 similarity ratio after :func:`normalise_text`.""" + return difflib.SequenceMatcher(None, normalise_text(a), normalise_text(b)).ratio() + + +# ---------- OCR backend -------------------------------------------------- + +OcrBackend = Callable[[Any], str] + + +def _require_pytesseract() -> Any: + try: + import pytesseract # type: ignore[import-not-found] + return pytesseract + except ImportError as error: + raise OcrAssertError( + "pytesseract is required for ocr_assert. " + "Install: pip install pytesseract Pillow (and the tesseract binary)." + ) from error + + +def _require_pil() -> Any: + try: + from PIL import Image # type: ignore[import-not-found] + return Image + except ImportError as error: + raise OcrAssertError( + "Pillow is required for ocr_assert. Install: pip install Pillow" + ) from error + + +def _open_image(source: Union[bytes, str, Path, Any]) -> Any: + image_cls = _require_pil() + if isinstance(source, (bytes, bytearray)): + from io import BytesIO + return image_cls.open(BytesIO(source)) + if isinstance(source, (str, Path)): + path = Path(source) + if not path.exists(): + raise OcrAssertError(f"image not found: {path}") + return image_cls.open(path) + if hasattr(source, "convert"): + return source + raise OcrAssertError( + f"ocr source must be bytes/path/PIL.Image, got {type(source).__name__}" + ) + + +def tesseract_backend( + *, + lang: str = "eng", + config: str = "--psm 6", +) -> OcrBackend: + """Return a callable that extracts text from an image using Tesseract.""" + pytesseract = _require_pytesseract() + + def _extract(source: Any) -> str: + image = _open_image(source) + try: + return pytesseract.image_to_string(image, lang=lang, config=config) + except Exception as error: # pytesseract raises a custom error class + raise OcrAssertError(f"tesseract failed: {error!r}") from error + + return _extract + + +def extract_text( + source: Union[bytes, str, Path, Any], + *, + backend: Optional[OcrBackend] = None, +) -> str: + """Run OCR on a screenshot / image and return the raw recognised text.""" + runner = backend or tesseract_backend() + text = runner(source) + if not isinstance(text, str): + raise OcrAssertError( + f"OCR backend returned {type(text).__name__}, expected str" + ) + return text + + +# ---------- assertions -------------------------------------------------- + +@dataclass +class OcrMatchResult: + """Outcome of an OCR assertion.""" + + matched: bool + mode: str + needle: str + haystack: str + score: float = 0.0 + notes: List[str] = field(default_factory=list) + + def raise_if_failed(self) -> None: + if not self.matched: + preview = self.haystack[:200].replace("\n", "\\n") + raise OcrAssertError( + f"OCR assertion failed (mode={self.mode}, score={self.score:.2f}). " + f"needle={self.needle!r} haystack[:200]={preview!r}" + ) + + +def assert_text_contains( + source: Union[bytes, str, Path, Any], + needle: str, + *, + backend: Optional[OcrBackend] = None, + case_sensitive: bool = False, +) -> OcrMatchResult: + """Assert that ``needle`` appears in the OCR output (whitespace-collapsed).""" + if not isinstance(needle, str) or not needle: + raise OcrAssertError("needle must be a non-empty string") + raw = extract_text(source, backend=backend) + if case_sensitive: + haystack_n = _WHITESPACE_RE.sub(" ", raw).strip() + needle_n = _WHITESPACE_RE.sub(" ", needle).strip() + else: + haystack_n = normalise_text(raw) + needle_n = normalise_text(needle) + matched = needle_n in haystack_n + result = OcrMatchResult( + matched=matched, + mode="contains", + needle=needle_n, + haystack=haystack_n, + score=1.0 if matched else 0.0, + ) + if not matched: + web_runner_logger.warning( + f"ocr_assert.contains miss: needle={needle_n!r}" + ) + return result + + +def assert_text_fuzzy( + source: Union[bytes, str, Path, Any], + expected: str, + *, + min_ratio: float = 0.8, + backend: Optional[OcrBackend] = None, +) -> OcrMatchResult: + """Assert that the OCR output is ``min_ratio``-similar to ``expected``.""" + if not 0.0 < min_ratio <= 1.0: + raise OcrAssertError("min_ratio must be in (0, 1]") + raw = extract_text(source, backend=backend) + haystack_n = normalise_text(raw) + expected_n = normalise_text(expected) + score = difflib.SequenceMatcher(None, expected_n, haystack_n).ratio() + matched = score >= min_ratio + return OcrMatchResult( + matched=matched, + mode="fuzzy", + needle=expected_n, + haystack=haystack_n, + score=round(score, 4), + notes=[f"min_ratio={min_ratio}"], + ) + + +def assert_text_any( + source: Union[bytes, str, Path, Any], + candidates: Sequence[str], + *, + backend: Optional[OcrBackend] = None, +) -> OcrMatchResult: + """Assert that at least one ``candidate`` appears in the OCR output.""" + if not candidates: + raise OcrAssertError("candidates must be a non-empty sequence") + raw = extract_text(source, backend=backend) + haystack_n = normalise_text(raw) + for needle in candidates: + if normalise_text(needle) in haystack_n: + return OcrMatchResult( + matched=True, + mode="any", + needle=needle, + haystack=haystack_n, + score=1.0, + notes=[f"matched 1 of {len(candidates)}"], + ) + return OcrMatchResult( + matched=False, + mode="any", + needle=" | ".join(candidates), + haystack=haystack_n, + score=0.0, + notes=[f"none of {len(candidates)} candidates matched"], + ) diff --git a/je_web_runner/utils/open_redirect_detector/__init__.py b/je_web_runner/utils/open_redirect_detector/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/open_redirect_detector/detector.py b/je_web_runner/utils/open_redirect_detector/detector.py new file mode 100644 index 0000000..49029a2 --- /dev/null +++ b/je_web_runner/utils/open_redirect_detector/detector.py @@ -0,0 +1,242 @@ +""" +偵測 ``?redirect=`` / ``?next=`` / ``?return_to=`` 類參數的 open-redirect 弱點。 +The classic phishing primitive: legitimate domain → user trusts the URL +→ click → server 302s to attacker.com. Hard to catch without explicit +testing. + +This module: + +1. Generates a payload set covering same-origin variants, scheme-relative + (`//evil.com`), absolute external URLs, mixed-case bypass, path-only + tricks (`/\\evil.com`), data: / javascript: schemes. +2. Given a callable that performs the request and returns the final + Location header (without auto-following), classifies the result as + ALLOWED (vulnerable), BLOCKED (safe), or AMBIGUOUS. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Sequence +from urllib.parse import urlparse + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class OpenRedirectError(WebRunnerException): + """Raised on malformed inputs or invalid probe result.""" + + +class ProbeOutcome(str, Enum): + BLOCKED = "blocked" # safe — redirect stayed same-origin or rejected + ALLOWED = "allowed" # vulnerable — went off-origin + AMBIGUOUS = "ambiguous" # unclear — neither rejected nor clearly off-origin + + +# ---------- payload generation ----------------------------------------- + +@dataclass(frozen=True) +class RedirectPayload: + """One probe value to inject into the redirect parameter.""" + + label: str + value: str + expected_off_origin_host: Optional[str] = None + + +def default_payloads(attacker_host: str = "evil.example") -> List[RedirectPayload]: + """Return a representative payload set for the given attacker host.""" + if not isinstance(attacker_host, str) or "." not in attacker_host: + raise OpenRedirectError("attacker_host must look like a domain") + return [ + RedirectPayload("absolute_http", + f"http://{attacker_host}/", + attacker_host), + RedirectPayload("absolute_https", + f"https://{attacker_host}/", + attacker_host), + RedirectPayload("scheme_relative", + f"//{attacker_host}/", + attacker_host), + RedirectPayload("backslash_bypass", + f"/\\\\{attacker_host}/", + attacker_host), + RedirectPayload("mixed_case", + f"HtTpS://{attacker_host}/", + attacker_host), + RedirectPayload("at_sign_userinfo", + f"https://trusted.com@{attacker_host}/", + attacker_host), + RedirectPayload("data_uri", + "data:text/html,", + None), + RedirectPayload("javascript_uri", + "javascript:alert(1)", + None), + ] + + +# ---------- classification --------------------------------------------- + +@dataclass +class ProbeResult: + """One payload → one outcome.""" + + payload: RedirectPayload + final_location: Optional[str] + status_code: int + outcome: ProbeOutcome + note: str = "" + + def to_dict(self) -> Dict[str, Any]: + return { + "payload_label": self.payload.label, + "payload_value": self.payload.value, + "final_location": self.final_location, + "status_code": self.status_code, + "outcome": self.outcome.value, + "note": self.note, + } + + +def classify_response( + payload: RedirectPayload, + final_location: Optional[str], + status_code: int, + *, + legitimate_host: str, +) -> ProbeResult: + """Decide if the response indicates an open redirect.""" + if not isinstance(legitimate_host, str) or not legitimate_host: + raise OpenRedirectError("legitimate_host must be non-empty string") + if not isinstance(status_code, int): + raise OpenRedirectError("status_code must be int") + if status_code < 300 or status_code >= 400: + return ProbeResult( + payload=payload, + final_location=final_location, + status_code=status_code, + outcome=ProbeOutcome.BLOCKED, + note=f"non-redirect status {status_code}", + ) + if not final_location: + return ProbeResult( + payload=payload, + final_location=None, + status_code=status_code, + outcome=ProbeOutcome.AMBIGUOUS, + note="redirect status with empty Location", + ) + scheme, host = _parse_target(final_location) + if scheme in ("javascript", "data"): + return ProbeResult( + payload=payload, + final_location=final_location, + status_code=status_code, + outcome=ProbeOutcome.ALLOWED, + note=f"redirected to {scheme}: scheme", + ) + if host and not _is_same_host(host, legitimate_host): + return ProbeResult( + payload=payload, + final_location=final_location, + status_code=status_code, + outcome=ProbeOutcome.ALLOWED, + note=f"redirected to {host}", + ) + return ProbeResult( + payload=payload, + final_location=final_location, + status_code=status_code, + outcome=ProbeOutcome.BLOCKED, + ) + + +def _parse_target(location: str) -> tuple: + if location.startswith("//"): + parsed = urlparse(f"http:{location}") + return "http", (parsed.hostname or "").lower() + try: + parsed = urlparse(location) + except ValueError: + return "", "" + return parsed.scheme.lower(), (parsed.hostname or "").lower() + + +def _is_same_host(actual: str, legitimate: str) -> bool: + actual = actual.lower() + legitimate = legitimate.lower() + if actual == legitimate: + return True + return actual.endswith("." + legitimate) + + +# ---------- probe driver ------------------------------------------------ + +ProbeFn = Callable[[str], "ProbeResponse"] + + +@dataclass(frozen=True) +class ProbeResponse: + """What the probe callable must return.""" + + status_code: int + location: Optional[str] + + +@dataclass +class ProbeReport: + """Aggregate over all payloads.""" + + legitimate_host: str + results: List[ProbeResult] = field(default_factory=list) + + def vulnerable(self) -> List[ProbeResult]: + return [r for r in self.results if r.outcome == ProbeOutcome.ALLOWED] + + def passed(self) -> bool: + return not self.vulnerable() + + +def probe_all( + payloads: Sequence[RedirectPayload], + probe: ProbeFn, + *, + legitimate_host: str, +) -> ProbeReport: + """Drive ``probe`` once per payload, classify, return report.""" + if not payloads: + raise OpenRedirectError("payloads must be non-empty") + if not callable(probe): + raise OpenRedirectError("probe must be callable") + report = ProbeReport(legitimate_host=legitimate_host) + for payload in payloads: + try: + response = probe(payload.value) + except Exception as error: + raise OpenRedirectError( + f"probe failed for {payload.label!r}: {error!r}" + ) from error + if not isinstance(response, ProbeResponse): + raise OpenRedirectError( + f"probe must return ProbeResponse, got {type(response).__name__}" + ) + report.results.append(classify_response( + payload, response.location, response.status_code, + legitimate_host=legitimate_host, + )) + return report + + +# ---------- assertion -------------------------------------------------- + +def assert_safe(report: ProbeReport) -> None: + """Raise if any payload was classified ALLOWED.""" + if not isinstance(report, ProbeReport): + raise OpenRedirectError("assert_safe expects ProbeReport") + if report.passed(): + return + labels = ", ".join(r.payload.label for r in report.vulnerable()) + raise OpenRedirectError( + f"open redirect vulnerable to: {labels}" + ) diff --git a/je_web_runner/utils/openapi_to_e2e/__init__.py b/je_web_runner/utils/openapi_to_e2e/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/openapi_to_e2e/generator.py b/je_web_runner/utils/openapi_to_e2e/generator.py new file mode 100644 index 0000000..2f32fbd --- /dev/null +++ b/je_web_runner/utils/openapi_to_e2e/generator.py @@ -0,0 +1,553 @@ +""" +OpenAPI / Swagger → WebRunner action JSON generator。 +讀 OpenAPI 3.x 或 Swagger 2.0 spec,對每個 endpoint 產出 happy-path ++ 4xx 邊界的 ``WR_http_*`` action JSON。 + +Decisions: + +* No external `openapi-spec-validator` dependency — we tolerate + partially-valid specs the way real-world swagger files are. +* Examples come from (in priority): explicit ``example``/``examples`` + in the schema, then ``default``, then a type-driven faker + (``"string"`` → ``"sample"``, ``"integer"`` → ``1``…). Keeps output + deterministic. +* Auth: if the spec declares Bearer / API-key, we drop a placeholder + header so the generated file is runnable after the user injects a + real token via env-var expansion (``${ENV_VAR}``). + +Output is plain ``WR_http_*`` action lists — runnable by the existing +executor without any extra glue. +""" +from __future__ import annotations + +import copy +import json +import re +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class OpenAPIGeneratorError(WebRunnerException): + """Raised when the spec is unreadable or required fields are missing.""" + + +SUPPORTED_METHODS: Tuple[str, ...] = ( + "get", "post", "put", "patch", "delete", "head", "options", +) + + +@dataclass +class GeneratedTest: + """One generated test scenario.""" + + name: str + method: str + path: str + expected_status: int + actions: List[Any] + scenario: str # "happy" | "missing_body" | "bad_path_param" | etc. + + def to_dict(self) -> Dict[str, Any]: + return { + "name": self.name, + "method": self.method, + "path": self.path, + "expected_status": self.expected_status, + "actions": self.actions, + "scenario": self.scenario, + } + + +@dataclass +class GenerationResult: + """Aggregate result for one spec.""" + + spec_title: str + base_url: str + tests: List[GeneratedTest] = field(default_factory=list) + skipped: List[Dict[str, str]] = field(default_factory=list) + + def to_dict(self) -> Dict[str, Any]: + return { + "spec_title": self.spec_title, + "base_url": self.base_url, + "tests": [t.to_dict() for t in self.tests], + "skipped": list(self.skipped), + } + + +# ---------- spec loading ------------------------------------------------ + +def load_spec(spec_path: Union[str, Path]) -> Dict[str, Any]: + """ + 讀 JSON 或 YAML 格式的 OpenAPI spec。 + YAML support is soft-dependency on ``PyYAML``; JSON specs work without. + """ + path = Path(spec_path) + if not path.is_file(): + raise OpenAPIGeneratorError(f"spec file not found: {path}") + text = path.read_text(encoding="utf-8") + try: + return json.loads(text) + except ValueError: + pass + try: + import yaml # type: ignore[import-not-found] + except ImportError as error: + raise OpenAPIGeneratorError( + f"YAML spec but PyYAML not installed: pip install pyyaml ({path})" + ) from error + try: + loaded = yaml.safe_load(text) + except yaml.YAMLError as error: # type: ignore[attr-defined] + raise OpenAPIGeneratorError(f"cannot parse YAML {path}: {error}") from error + if not isinstance(loaded, dict): + raise OpenAPIGeneratorError(f"top-level YAML must be a mapping: {path}") + return loaded + + +# ---------- $ref resolution -------------------------------------------- + +_REF_RE = re.compile(r"^#/(.+)$") + + +def _resolve_ref(spec: Dict[str, Any], ref: str) -> Any: + match = _REF_RE.match(ref or "") + if not match: + return None + parts = match.group(1).split("/") + node: Any = spec + for part in parts: + if not isinstance(node, dict) or part not in node: + return None + node = node[part] + return node + + +def _maybe_resolve(spec: Dict[str, Any], schema: Any, *, depth: int = 0) -> Any: + if depth > 6 or not isinstance(schema, dict): + return schema + if "$ref" in schema: + resolved = _resolve_ref(spec, schema["$ref"]) + if resolved is None: + return {} + return _maybe_resolve(spec, resolved, depth=depth + 1) + return schema + + +# ---------- example synthesis ------------------------------------------ + +_TYPE_DEFAULTS: Dict[str, Any] = { + "string": "sample", + "integer": 1, + "number": 1.0, + "boolean": True, + "array": [], + "object": {}, +} + + +def synthesize_example( + spec: Dict[str, Any], + schema: Any, + *, + depth: int = 0, +) -> Any: + """ + 從 schema 推一個範例值,依序試 example → default → type 預設。 + Deterministic so repeated runs produce the same output. Recursion + depth is bounded to keep cyclic refs safe. + """ + if depth > 5: + return None + schema = _maybe_resolve(spec, schema, depth=depth) + if not isinstance(schema, dict): + return None + if "example" in schema: + return copy.deepcopy(schema["example"]) + if "examples" in schema: + examples = schema["examples"] + if isinstance(examples, dict): + first = next(iter(examples.values()), None) + if isinstance(first, dict) and "value" in first: + return copy.deepcopy(first["value"]) + if first is not None: + return copy.deepcopy(first) + if isinstance(examples, list) and examples: + return copy.deepcopy(examples[0]) + if "default" in schema: + return copy.deepcopy(schema["default"]) + schema_type = schema.get("type") + if schema_type == "object" or "properties" in schema: + out: Dict[str, Any] = {} + properties = schema.get("properties") or {} + required = set(schema.get("required") or []) + for key, prop in properties.items(): + if required and key not in required: + continue + out[key] = synthesize_example(spec, prop, depth=depth + 1) + return out + if schema_type == "array": + items = schema.get("items") + if items: + return [synthesize_example(spec, items, depth=depth + 1)] + return [] + if isinstance(schema.get("enum"), list) and schema["enum"]: + return schema["enum"][0] + if isinstance(schema_type, str) and schema_type in _TYPE_DEFAULTS: + return copy.deepcopy(_TYPE_DEFAULTS[schema_type]) + return None + + +# ---------- url assembly ------------------------------------------------ + +def _base_url(spec: Dict[str, Any]) -> str: + """Honour OpenAPI 3 ``servers`` first, then Swagger 2 ``host`` + ``basePath``.""" + servers = spec.get("servers") + if isinstance(servers, list) and servers: + first = servers[0] + if isinstance(first, dict) and "url" in first: + return str(first["url"]).rstrip("/") + host = spec.get("host") + if isinstance(host, str) and host: + scheme = "https" + schemes = spec.get("schemes") + if isinstance(schemes, list) and schemes: + scheme = str(schemes[0]) + base = f"{scheme}://{host}" + base_path = spec.get("basePath") or "" + if base_path and not base_path.startswith("/"): + base_path = "/" + base_path + return (base + base_path).rstrip("/") + return "" + + +_PATH_PARAM_RE = re.compile(r"\{([^{}]+)\}") + + +def _expand_path( + template: str, + parameters: List[Dict[str, Any]], + spec: Dict[str, Any], + *, + invalid_param: Optional[str] = None, +) -> Tuple[str, Dict[str, Any]]: + """Returns ``(expanded_path, query_params)``.""" + resolved = template + query: Dict[str, Any] = {} + for raw_param in parameters: + param = _maybe_resolve(spec, raw_param) + if not isinstance(param, dict): + continue + name = param.get("name") + if not isinstance(name, str): + continue + location = param.get("in") + if location == "path": + example = synthesize_example(spec, param.get("schema") or param) or "1" + if invalid_param and invalid_param == name: + example = "" # forces /foo// — server returns 404 or 400 + resolved = resolved.replace( + "{" + name + "}", str(example), + ) + elif location == "query": + example = synthesize_example(spec, param.get("schema") or param) + if example is not None: + query[name] = example + return resolved, query + + +def _action_command(method: str) -> str: + return f"WR_http_{method.lower()}" + + +# ---------- auth heuristics -------------------------------------------- + +def _auth_headers(spec: Dict[str, Any]) -> Dict[str, str]: + """ + 粗略偵測 Bearer / API-key,塞 ``${TOKEN}`` placeholder 讓 env_loader 補。 + """ + components = spec.get("components") or {} + security_schemes = components.get("securitySchemes") or spec.get("securityDefinitions") or {} + headers: Dict[str, str] = {} + if not isinstance(security_schemes, dict): + return headers + for scheme in security_schemes.values(): + if not isinstance(scheme, dict): + continue + kind = (scheme.get("type") or "").lower() + if kind in {"http", "bearer"} and (scheme.get("scheme") or "").lower() == "bearer": + headers["Authorization"] = "Bearer ${API_TOKEN}" + elif kind in {"apikey", "api_key"} and scheme.get("in") == "header": + header_name = str(scheme.get("name") or "X-API-Key") + headers[header_name] = "${API_TOKEN}" + return headers + + +# ---------- per-endpoint generation ------------------------------------ + +def _build_action( + method: str, + path: str, + base_url: str, + *, + body: Any = None, + query: Optional[Dict[str, Any]] = None, + headers: Optional[Dict[str, str]] = None, + timeout: int = 15, +) -> List[Any]: + kwargs: Dict[str, Any] = {"url": f"{base_url}{path}", "timeout": timeout} + if query: + kwargs["params"] = query + if headers: + kwargs["headers"] = headers + if body is not None and method.lower() in {"post", "put", "patch", "delete"}: + kwargs["json_body"] = body + return [_action_command(method), kwargs] + + +def _request_body_example(spec: Dict[str, Any], operation: Dict[str, Any]) -> Any: + body = operation.get("requestBody") + if isinstance(body, dict): + body = _maybe_resolve(spec, body) + content = body.get("content") if isinstance(body, dict) else None + if isinstance(content, dict): + json_payload = content.get("application/json") or next(iter(content.values()), None) + if isinstance(json_payload, dict): + schema = json_payload.get("schema") + if schema is not None: + return synthesize_example(spec, schema) + if "example" in json_payload: + return copy.deepcopy(json_payload["example"]) + # swagger 2 — `parameters` with `in: body` + parameters = operation.get("parameters") or [] + if isinstance(parameters, list): + for raw in parameters: + param = _maybe_resolve(spec, raw) + if isinstance(param, dict) and param.get("in") == "body": + schema = param.get("schema") or param + return synthesize_example(spec, schema) + return None + + +def _success_status(operation: Dict[str, Any]) -> int: + responses = operation.get("responses") or {} + if not isinstance(responses, dict): + return 200 + for code in responses: + if isinstance(code, str) and code.startswith("2"): + try: + return int(code) + except ValueError: + continue + return 200 + + +def _operation_name(method: str, path: str, operation: Dict[str, Any]) -> str: + op_id = operation.get("operationId") + if isinstance(op_id, str) and op_id: + return op_id + sanitised = re.sub(r"[^A-Za-z0-9]+", "_", path).strip("_") or "root" + return f"{method.lower()}_{sanitised}" + + +def _build_happy_test( + spec: Dict[str, Any], + base_url: str, + method: str, + path: str, + operation: Dict[str, Any], + extra_headers: Dict[str, str], +) -> GeneratedTest: + parameters = list(operation.get("parameters") or []) + parameters.extend(operation.get("parameters", []) if False else []) + expanded_path, query = _expand_path(path, parameters, spec) + body = _request_body_example(spec, operation) + status = _success_status(operation) + name = _operation_name(method, path, operation) + return GeneratedTest( + name=f"{name}__happy", + method=method.upper(), + path=expanded_path, + expected_status=status, + scenario="happy", + actions=[ + _build_action( + method, expanded_path, base_url, + body=body, query=query or None, + headers=extra_headers or None, + ), + ["WR_http_assert_status", {"expected": status}], + ], + ) + + +def _build_missing_body_test( + spec: Dict[str, Any], + base_url: str, + method: str, + path: str, + operation: Dict[str, Any], + extra_headers: Dict[str, str], +) -> Optional[GeneratedTest]: + if method.lower() not in {"post", "put", "patch"}: + return None + if not operation.get("requestBody") and not any( + (_maybe_resolve(spec, p) or {}).get("in") == "body" + for p in (operation.get("parameters") or []) + ): + return None + parameters = list(operation.get("parameters") or []) + expanded_path, query = _expand_path(path, parameters, spec) + name = _operation_name(method, path, operation) + return GeneratedTest( + name=f"{name}__missing_body", + method=method.upper(), + path=expanded_path, + expected_status=400, + scenario="missing_body", + actions=[ + _build_action( + method, expanded_path, base_url, + body=None, query=query or None, + headers=extra_headers or None, + ), + ["WR_http_assert_status", {"expected": 400}], + ], + ) + + +def _build_bad_path_param_test( + spec: Dict[str, Any], + base_url: str, + method: str, + path: str, + operation: Dict[str, Any], + extra_headers: Dict[str, str], +) -> Optional[GeneratedTest]: + path_params = _PATH_PARAM_RE.findall(path) + if not path_params: + return None + bad_param = path_params[0] + parameters = list(operation.get("parameters") or []) + expanded_path, query = _expand_path(path, parameters, spec, invalid_param=bad_param) + body = _request_body_example(spec, operation) + name = _operation_name(method, path, operation) + return GeneratedTest( + name=f"{name}__bad_path_param", + method=method.upper(), + path=expanded_path, + expected_status=404, + scenario="bad_path_param", + actions=[ + _build_action( + method, expanded_path, base_url, + body=body, query=query or None, + headers=extra_headers or None, + ), + ["WR_http_assert_status", {"expected": 404}], + ], + ) + + +# ---------- public entry points ---------------------------------------- + +def generate_tests_from_spec( + spec: Dict[str, Any], + *, + include_negative: bool = True, + method_filter: Optional[Iterable[str]] = None, + path_prefix_filter: Optional[str] = None, +) -> GenerationResult: + """ + 從已 load 的 spec 直接產出 GenerationResult。 + ``method_filter`` (e.g. ``{"get", "post"}``) and ``path_prefix_filter`` + let callers narrow the surface during big-spec exploration. + """ + if not isinstance(spec, dict): + raise OpenAPIGeneratorError("spec must be a dict") + paths = spec.get("paths") + if not isinstance(paths, dict): + raise OpenAPIGeneratorError("spec missing 'paths' object") + base_url = _base_url(spec) + title = ((spec.get("info") or {}).get("title")) if isinstance(spec.get("info"), dict) else "" + extra_headers = _auth_headers(spec) + methods_lower = ( + {m.lower() for m in method_filter} if method_filter else set(SUPPORTED_METHODS) + ) + tests: List[GeneratedTest] = [] + skipped: List[Dict[str, str]] = [] + for path, operations in paths.items(): + if not isinstance(path, str) or not isinstance(operations, dict): + continue + if path_prefix_filter and not path.startswith(path_prefix_filter): + continue + for method, operation in operations.items(): + if method.lower() not in SUPPORTED_METHODS: + continue + if method.lower() not in methods_lower: + continue + if not isinstance(operation, dict): + skipped.append({"path": path, "method": method, "reason": "operation not a dict"}) + continue + tests.append(_build_happy_test(spec, base_url, method, path, operation, extra_headers)) + if include_negative: + missing = _build_missing_body_test( + spec, base_url, method, path, operation, extra_headers, + ) + if missing: + tests.append(missing) + bad_path = _build_bad_path_param_test( + spec, base_url, method, path, operation, extra_headers, + ) + if bad_path: + tests.append(bad_path) + web_runner_logger.info( + f"generate_tests_from_spec: title={title!r} produced={len(tests)} " + f"skipped={len(skipped)}" + ) + return GenerationResult( + spec_title=str(title or ""), + base_url=base_url, + tests=tests, + skipped=skipped, + ) + + +def generate_tests_from_file( + spec_path: Union[str, Path], + *, + include_negative: bool = True, + method_filter: Optional[Iterable[str]] = None, + path_prefix_filter: Optional[str] = None, +) -> GenerationResult: + """Convenience: load + generate in one shot.""" + spec = load_spec(spec_path) + return generate_tests_from_spec( + spec, + include_negative=include_negative, + method_filter=method_filter, + path_prefix_filter=path_prefix_filter, + ) + + +def write_tests_to_dir( + result: GenerationResult, + output_dir: Union[str, Path], +) -> List[Path]: + """One JSON file per generated test (slug-named, sorted by name).""" + target = Path(output_dir) + target.mkdir(parents=True, exist_ok=True) + written: List[Path] = [] + for test in result.tests: + slug = re.sub(r"[^A-Za-z0-9_-]+", "_", test.name).strip("_") + path = target / f"{slug}.json" + with open(path, "w", encoding="utf-8") as fp: + json.dump(test.actions, fp, ensure_ascii=False, indent=2) + written.append(path) + web_runner_logger.info(f"write_tests_to_dir: wrote {len(written)} files to {target}") + return written diff --git a/je_web_runner/utils/otel_bridge/__init__.py b/je_web_runner/utils/otel_bridge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/otel_bridge/trace_bridge.py b/je_web_runner/utils/otel_bridge/trace_bridge.py new file mode 100644 index 0000000..412caae --- /dev/null +++ b/je_web_runner/utils/otel_bridge/trace_bridge.py @@ -0,0 +1,289 @@ +""" +把 OpenTelemetry 的 trace context 注入到瀏覽器發出的 HTTP request header, +讓前端 → 後端的 distributed trace 串成一條。 + +Inject W3C ``traceparent`` / ``tracestate`` headers into every browser +request so a frontend action and the backend span it triggers land in +the same trace tree. Supports: + +* **Selenium 4+ Chromium** — via CDP ``Network.setExtraHTTPHeaders``. +* **Playwright** — via ``page.set_extra_http_headers``. + +If ``opentelemetry-api`` isn't installed (it's a soft dep), the helpers +fall back to caller-provided ``trace_id`` / ``span_id`` so callers using +some other tracing library can still bridge. +""" +from __future__ import annotations + +import secrets +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Dict, Iterator, Optional + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class TraceBridgeError(WebRunnerException): + """Raised when header injection or context propagation fails.""" + + +# ---------- W3C traceparent helpers -------------------------------------- + +@dataclass(frozen=True) +class TraceContext: + """W3C traceparent fields decoupled from any specific SDK.""" + + trace_id: str # 32 hex chars + span_id: str # 16 hex chars + sampled: bool = True + version: str = "00" + tracestate: Optional[str] = None + + def to_traceparent(self) -> str: + flags = "01" if self.sampled else "00" + return f"{self.version}-{self.trace_id}-{self.span_id}-{flags}" + + def as_headers(self) -> Dict[str, str]: + headers = {"traceparent": self.to_traceparent()} + if self.tracestate: + headers["tracestate"] = self.tracestate + return headers + + +_HEX = "0123456789abcdef" + + +def _is_hex(value: str, length: int) -> bool: + return ( + isinstance(value, str) + and len(value) == length + and all(ch in _HEX for ch in value) + ) + + +def _validate_context(ctx: TraceContext) -> None: + if not _is_hex(ctx.trace_id, 32) or ctx.trace_id == "0" * 32: + raise TraceBridgeError(f"invalid trace_id: {ctx.trace_id!r}") + if not _is_hex(ctx.span_id, 16) or ctx.span_id == "0" * 16: + raise TraceBridgeError(f"invalid span_id: {ctx.span_id!r}") + if not _is_hex(ctx.version, 2): + raise TraceBridgeError(f"invalid version: {ctx.version!r}") + + +def random_trace_context(sampled: bool = True) -> TraceContext: + """Generate a fresh W3C-compliant trace context (for synthetic traces).""" + return TraceContext( + trace_id=secrets.token_hex(16), + span_id=secrets.token_hex(8), + sampled=sampled, + ) + + +def parse_traceparent(header: str) -> TraceContext: + """Parse a ``traceparent`` header back into a :class:`TraceContext`.""" + if not isinstance(header, str): + raise TraceBridgeError(f"traceparent must be str, got {type(header).__name__}") + parts = header.strip().split("-") + if len(parts) != 4: + raise TraceBridgeError(f"malformed traceparent: {header!r}") + version, trace_id, span_id, flags = parts + sampled = bool(int(flags, 16) & 1) if _is_hex(flags, 2) else False + ctx = TraceContext( + trace_id=trace_id, span_id=span_id, sampled=sampled, version=version, + ) + _validate_context(ctx) + return ctx + + +# ---------- pull context from active OpenTelemetry span ------------------ + +def current_otel_context() -> Optional[TraceContext]: + """ + 若有 active OTel span 就把它包成 TraceContext;沒 OTel 或無 active span 回 None。 + Return the active OpenTelemetry span's context as a :class:`TraceContext` + if OpenTelemetry is installed and a span is currently active. Returns + ``None`` otherwise — callers should fall back to a synthetic context. + """ + try: + from opentelemetry import trace # type: ignore[import-not-found] + except ImportError: + return None + span = trace.get_current_span() + if span is None: + return None + ctx = span.get_span_context() + if not ctx or not getattr(ctx, "is_valid", False): + return None + trace_id_int = ctx.trace_id + span_id_int = ctx.span_id + if not trace_id_int or not span_id_int: + return None + flags = ctx.trace_flags if hasattr(ctx, "trace_flags") else 0 + return TraceContext( + trace_id=format(trace_id_int, "032x"), + span_id=format(span_id_int, "016x"), + sampled=bool(int(flags) & 1), + ) + + +# ---------- header injection --------------------------------------------- + +def inject_headers_selenium(driver: Any, context: TraceContext) -> None: + """ + 透過 CDP 把 traceparent / tracestate 加進 Chrome 每個 request。 + Use ``Network.setExtraHTTPHeaders`` via Selenium's CDP bridge to add + the trace context to every outgoing browser request. Idempotent — + calling again with a new context simply replaces the previous headers. + """ + if driver is None: + raise TraceBridgeError("driver is None") + _validate_context(context) + cdp = getattr(driver, "execute_cdp_cmd", None) + if cdp is None: + raise TraceBridgeError( + "driver does not expose execute_cdp_cmd (need Selenium 4 + Chromium)" + ) + headers = context.as_headers() + try: + cdp("Network.enable", {}) + cdp("Network.setExtraHTTPHeaders", {"headers": headers}) + except Exception as error: # noqa: BLE001 — CDP errors are driver-specific + raise TraceBridgeError(f"CDP header injection failed: {error!r}") from error + web_runner_logger.info( + f"inject_headers_selenium: trace_id={context.trace_id} span_id={context.span_id}" + ) + + +def clear_headers_selenium(driver: Any) -> None: + """Remove the extra headers from the active Chrome session.""" + if driver is None: + return + cdp = getattr(driver, "execute_cdp_cmd", None) + if cdp is None: + return + try: + cdp("Network.setExtraHTTPHeaders", {"headers": {}}) + except Exception as error: # noqa: BLE001 + web_runner_logger.warning(f"clear_headers_selenium failed: {error!r}") + + +def inject_headers_playwright(page: Any, context: TraceContext) -> None: + """ + 對 Playwright page 設 ``set_extra_http_headers``,附 traceparent。 + Equivalent to :func:`inject_headers_selenium` but for Playwright. + """ + if page is None: + raise TraceBridgeError("page is None") + _validate_context(context) + setter = getattr(page, "set_extra_http_headers", None) + if setter is None: + raise TraceBridgeError("page has no set_extra_http_headers method") + try: + setter(context.as_headers()) + except Exception as error: # noqa: BLE001 + raise TraceBridgeError(f"Playwright header injection failed: {error!r}") from error + web_runner_logger.info( + f"inject_headers_playwright: trace_id={context.trace_id}" + ) + + +def clear_headers_playwright(page: Any) -> None: + """Reset extra headers on a Playwright page.""" + if page is None: + return + setter = getattr(page, "set_extra_http_headers", None) + if setter is None: + return + try: + setter({}) + except Exception as error: # noqa: BLE001 + web_runner_logger.warning(f"clear_headers_playwright failed: {error!r}") + + +# ---------- context managers -------------------------------------------- + +@contextmanager +def bridged_span_selenium( + driver: Any, + span_name: str, + *, + fallback_context: Optional[TraceContext] = None, +) -> Iterator[TraceContext]: + """ + 用 OTel span 包住一段 selenium 動作,並把 traceparent 注入瀏覽器。 + Start an OpenTelemetry span (when available) and inject its context + into the active Chrome session for the duration of the ``with`` block. + If OTel isn't installed, ``fallback_context`` is used (or a fresh + synthetic context if both are missing). + """ + span_ctx: Optional[Any] = None + try: + from opentelemetry import trace # type: ignore[import-not-found] + tracer = trace.get_tracer("je_web_runner.otel_bridge") + span_ctx = tracer.start_as_current_span(span_name) + span_ctx.__enter__() + context = current_otel_context() or fallback_context or random_trace_context() + except ImportError: + context = fallback_context or random_trace_context() + inject_headers_selenium(driver, context) + try: + yield context + finally: + clear_headers_selenium(driver) + if span_ctx is not None: + try: + span_ctx.__exit__(None, None, None) + except Exception as error: # noqa: BLE001 + web_runner_logger.warning(f"span exit failed: {error!r}") + + +@contextmanager +def bridged_span_playwright( + page: Any, + span_name: str, + *, + fallback_context: Optional[TraceContext] = None, +) -> Iterator[TraceContext]: + """Playwright twin of :func:`bridged_span_selenium`.""" + span_ctx: Optional[Any] = None + try: + from opentelemetry import trace # type: ignore[import-not-found] + tracer = trace.get_tracer("je_web_runner.otel_bridge") + span_ctx = tracer.start_as_current_span(span_name) + span_ctx.__enter__() + context = current_otel_context() or fallback_context or random_trace_context() + except ImportError: + context = fallback_context or random_trace_context() + inject_headers_playwright(page, context) + try: + yield context + finally: + clear_headers_playwright(page) + if span_ctx is not None: + try: + span_ctx.__exit__(None, None, None) + except Exception as error: # noqa: BLE001 + web_runner_logger.warning(f"span exit failed: {error!r}") + + +# ---------- report helpers ---------------------------------------------- + +def trace_link( + context: TraceContext, + *, + jaeger_base: Optional[str] = None, + tempo_base: Optional[str] = None, +) -> Optional[str]: + """ + 給定 trace context 與後端 base URL,回傳可點擊的 trace 連結。 + Build a direct UI link to the trace in Jaeger / Tempo. Returns the + first base URL that's provided. ``None`` if no base is supplied. + """ + if jaeger_base: + base = jaeger_base.rstrip("/") + return f"{base}/trace/{context.trace_id}" + if tempo_base: + base = tempo_base.rstrip("/") + return f"{base}/explore?orgId=1&left=%7B%22queries%22:%5B%7B%22query%22:%22{context.trace_id}%22%7D%5D%7D" + return None diff --git a/je_web_runner/utils/otp_interceptor/__init__.py b/je_web_runner/utils/otp_interceptor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/otp_interceptor/interceptor.py b/je_web_runner/utils/otp_interceptor/interceptor.py new file mode 100644 index 0000000..90e22d2 --- /dev/null +++ b/je_web_runner/utils/otp_interceptor/interceptor.py @@ -0,0 +1,489 @@ +""" +OTP / Email / SMS 攔截器:整合 MailHog / Mailpit / IMAP / 自建 webhook, +讓 E2E 測試可以等待並抽取一次性驗證碼,不再卡在 2FA 流程。 + +OTP interception backends shared by a single +:func:`wait_for_otp` helper. Built-in providers: + +* :class:`MailHogProvider` — http://mailhog/api/v2 style inbox +* :class:`MailpitProvider` — http://mailpit/api/v1 style inbox +* :class:`ImapProvider` — production-style IMAP fetch +* :class:`WebhookSmsProvider` — local SMS webhook (e.g. Twilio sandbox + forwarder) +* :class:`InMemoryProvider` — for offline tests and dry-runs + +Providers expose a single :meth:`fetch_messages` method that returns a +list of :class:`InterceptedMessage` newest-first. Polling logic and +regex extraction are implemented once in :func:`wait_for_otp`. +""" +from __future__ import annotations + +import json +import re +import ssl +import time +import urllib.parse +import urllib.request +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Pattern, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class OtpInterceptError(WebRunnerException): + """Raised on provider config / fetch / extraction problems.""" + + +# ---------- data ---------------------------------------------------------- + +@dataclass +class InterceptedMessage: + """One email / SMS message normalised across providers.""" + + message_id: str + sender: str + recipient: str + subject: str + body: str + received_at: float + headers: Dict[str, str] = field(default_factory=dict) + + +# ---------- abstract provider -------------------------------------------- + +class OtpProvider(ABC): + """Base class — concrete providers normalise raw inbox data.""" + + @abstractmethod + def fetch_messages( + self, + recipient: Optional[str] = None, + *, + since: Optional[float] = None, + limit: int = 25, + ) -> List[InterceptedMessage]: + """Return messages newest-first, optionally filtered by recipient/since.""" + + +# ---------- HTTP helper -------------------------------------------------- + +def _http_get_json(url: str, timeout: float = 10.0) -> Any: + if not (url.startswith("http://") or url.startswith("https://")): + raise OtpInterceptError(f"refusing non-http URL: {url!r}") + req = urllib.request.Request(url, method="GET") + req.add_header("Accept", "application/json") + context = None if url.startswith("http://") else ssl.create_default_context() + try: + with urllib.request.urlopen( # nosec B310 — scheme allow-listed + req, timeout=timeout, context=context, + ) as response: + body = response.read().decode("utf-8") + except (OSError, ValueError) as error: + raise OtpInterceptError(f"HTTP GET failed for {url}: {error!r}") from error + if not body: + return None + try: + return json.loads(body) + except ValueError as error: + raise OtpInterceptError(f"non-JSON response from {url}: {error}") from error + + +# ---------- MailHog ------------------------------------------------------ + +class MailHogProvider(OtpProvider): + """Talks to a MailHog ``/api/v2/messages`` endpoint.""" + + def __init__(self, base_url: str, *, http_fetcher: Optional[Callable[[str], Any]] = None) -> None: + self.base_url = base_url.rstrip("/") + self._fetch = http_fetcher or _http_get_json + + def fetch_messages( + self, + recipient: Optional[str] = None, + *, + since: Optional[float] = None, + limit: int = 25, + ) -> List[InterceptedMessage]: + url = f"{self.base_url}/api/v2/messages?limit={limit}" + payload = self._fetch(url) + if not isinstance(payload, dict): + raise OtpInterceptError("MailHog payload is not an object") + items = payload.get("items") + if not isinstance(items, list): + return [] + out: List[InterceptedMessage] = [] + for raw in items: + msg = _mailhog_to_message(raw) + if msg is None: + continue + if recipient and msg.recipient.lower() != recipient.lower(): + continue + if since and msg.received_at < since: + continue + out.append(msg) + out.sort(key=lambda m: m.received_at, reverse=True) + return out + + +def _mailhog_to_message(raw: Any) -> Optional[InterceptedMessage]: + if not isinstance(raw, dict): + return None + content = raw.get("Content") or {} + headers = content.get("Headers") or {} + from_list = headers.get("From") or [] + to_list = headers.get("To") or [] + subject_list = headers.get("Subject") or [] + message_id = raw.get("ID") or "" + body = content.get("Body") or "" + received_at = _parse_time(raw.get("Created")) + return InterceptedMessage( + message_id=str(message_id), + sender=from_list[0] if from_list else "", + recipient=to_list[0] if to_list else "", + subject=subject_list[0] if subject_list else "", + body=str(body), + received_at=received_at, + headers={k: ", ".join(v) if isinstance(v, list) else str(v) for k, v in headers.items()}, + ) + + +# ---------- Mailpit ------------------------------------------------------ + +class MailpitProvider(OtpProvider): + """Talks to a Mailpit ``/api/v1/messages`` endpoint.""" + + def __init__(self, base_url: str, *, http_fetcher: Optional[Callable[[str], Any]] = None) -> None: + self.base_url = base_url.rstrip("/") + self._fetch = http_fetcher or _http_get_json + + def fetch_messages( + self, + recipient: Optional[str] = None, + *, + since: Optional[float] = None, + limit: int = 25, + ) -> List[InterceptedMessage]: + url = f"{self.base_url}/api/v1/messages?limit={limit}" + payload = self._fetch(url) + if not isinstance(payload, dict): + raise OtpInterceptError("Mailpit payload is not an object") + items = payload.get("messages") or payload.get("Messages") or [] + if not isinstance(items, list): + return [] + out: List[InterceptedMessage] = [] + for raw in items: + msg = _mailpit_to_message(raw) + if msg is None: + continue + if recipient and msg.recipient.lower() != recipient.lower(): + continue + if since and msg.received_at < since: + continue + out.append(msg) + out.sort(key=lambda m: m.received_at, reverse=True) + return out + + +def _mailpit_to_message(raw: Any) -> Optional[InterceptedMessage]: + if not isinstance(raw, dict): + return None + to_list = raw.get("To") or [] + first_to = to_list[0] if to_list else {} + return InterceptedMessage( + message_id=str(raw.get("ID") or ""), + sender=str((raw.get("From") or {}).get("Address") or ""), + recipient=str(first_to.get("Address") if isinstance(first_to, dict) else first_to), + subject=str(raw.get("Subject") or ""), + body=str(raw.get("Text") or raw.get("Snippet") or ""), + received_at=_parse_time(raw.get("Created")), + ) + + +# ---------- IMAP --------------------------------------------------------- + +class ImapProvider(OtpProvider): + """Real IMAP fetch — used when MailHog/Mailpit isn't available.""" + + def __init__( + self, + host: str, + port: int = 993, + *, + username: str, + password: str, + mailbox: str = "INBOX", + use_ssl: bool = True, + connector: Optional[Callable[..., Any]] = None, + ) -> None: + if not host or not username or not password: + raise OtpInterceptError("IMAP host/username/password are all required") + self.host = host + self.port = port + self.username = username + self.password = password + self.mailbox = mailbox + self.use_ssl = use_ssl + self._connector = connector + + def _connect(self): + if self._connector is not None: + return self._connector(self.host, self.port) + import imaplib # local import — IMAP is rarely needed + return (imaplib.IMAP4_SSL if self.use_ssl else imaplib.IMAP4)(self.host, self.port) + + def fetch_messages( + self, + recipient: Optional[str] = None, + *, + since: Optional[float] = None, + limit: int = 25, + ) -> List[InterceptedMessage]: + conn = self._connect() + try: + conn.login(self.username, self.password) + conn.select(self.mailbox) + criteria = "ALL" if not recipient else f'(TO "{recipient}")' + _typ, ids_data = conn.search(None, criteria) + ids = (ids_data[0].split() if ids_data and ids_data[0] else [])[-limit:] + messages: List[InterceptedMessage] = [] + for raw_id in reversed(ids): + _typ, msg_data = conn.fetch(raw_id, "(RFC822)") + if not msg_data or not msg_data[0]: + continue + payload = msg_data[0] + raw_bytes = payload[1] if isinstance(payload, tuple) else payload + msg = _imap_bytes_to_message(raw_id.decode(), raw_bytes) + if msg is None: + continue + if since and msg.received_at < since: + continue + messages.append(msg) + return messages + finally: + try: + conn.close() + except Exception: # noqa: BLE001 + pass + try: + conn.logout() + except Exception: # noqa: BLE001 + pass + + +def _imap_bytes_to_message(message_id: str, raw_bytes: bytes) -> Optional[InterceptedMessage]: + import email + from email import policy + + try: + msg = email.message_from_bytes(raw_bytes, policy=policy.default) + except Exception: # noqa: BLE001 + return None + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain": + body = part.get_content() + break + else: + try: + body = msg.get_content() + except Exception: # noqa: BLE001 + body = "" + return InterceptedMessage( + message_id=message_id, + sender=str(msg.get("From") or ""), + recipient=str(msg.get("To") or ""), + subject=str(msg.get("Subject") or ""), + body=str(body), + received_at=_parse_time(msg.get("Date")), + headers={k: v for k, v in msg.items()}, + ) + + +# ---------- SMS webhook -------------------------------------------------- + +class WebhookSmsProvider(OtpProvider): + """ + Poll a local webhook that aggregates SMS into a list endpoint + (``GET /messages?to=+15551234567``). Useful with Twilio sandbox or + a self-hosted bridge. + """ + + def __init__( + self, + base_url: str, + *, + endpoint: str = "/messages", + http_fetcher: Optional[Callable[[str], Any]] = None, + ) -> None: + self.base_url = base_url.rstrip("/") + self.endpoint = "/" + endpoint.lstrip("/") + self._fetch = http_fetcher or _http_get_json + + def fetch_messages( + self, + recipient: Optional[str] = None, + *, + since: Optional[float] = None, + limit: int = 25, + ) -> List[InterceptedMessage]: + query = f"limit={limit}" + if recipient: + query += "&to=" + urllib.parse.quote(recipient, safe="") + url = f"{self.base_url}{self.endpoint}?{query}" + payload = self._fetch(url) + if not isinstance(payload, list): + raise OtpInterceptError("SMS webhook payload must be a JSON list") + out: List[InterceptedMessage] = [] + for raw in payload: + if not isinstance(raw, dict): + continue + msg = InterceptedMessage( + message_id=str(raw.get("id") or raw.get("sid") or ""), + sender=str(raw.get("from") or ""), + recipient=str(raw.get("to") or ""), + subject="", + body=str(raw.get("body") or raw.get("text") or ""), + received_at=_parse_time(raw.get("received_at") or raw.get("created_at")), + ) + if since and msg.received_at < since: + continue + out.append(msg) + out.sort(key=lambda m: m.received_at, reverse=True) + return out + + +# ---------- in-memory (for tests) ---------------------------------------- + +class InMemoryProvider(OtpProvider): + """Tests and dry-runs: hand it a list of messages.""" + + def __init__(self) -> None: + self.messages: List[InterceptedMessage] = [] + + def push(self, message: InterceptedMessage) -> None: + self.messages.append(message) + + def clear(self) -> None: + self.messages.clear() + + def fetch_messages( + self, + recipient: Optional[str] = None, + *, + since: Optional[float] = None, + limit: int = 25, + ) -> List[InterceptedMessage]: + out = list(self.messages) + if recipient: + out = [m for m in out if m.recipient.lower() == recipient.lower()] + if since: + out = [m for m in out if m.received_at >= since] + out.sort(key=lambda m: m.received_at, reverse=True) + return out[:limit] + + +# ---------- time parsing ------------------------------------------------- + +def _parse_time(value: Any) -> float: + if value is None: + return time.time() + if isinstance(value, (int, float)): + return float(value) + if isinstance(value, str) and value: + # ISO-8601: 2026-05-24T10:00:00Z or 2026-05-24T10:00:00.000Z + from datetime import datetime, timezone + + text = value + if text.endswith("Z"): + text = text[:-1] + "+00:00" + try: + dt = datetime.fromisoformat(text) + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.timestamp() + except ValueError: + return time.time() + return time.time() + + +# ---------- OTP extraction & polling ------------------------------------- + +_DEFAULT_OTP_REGEX = re.compile(r"\b(\d{4,8})\b") + + +def extract_otp_from_text( + text: str, + pattern: Union[str, Pattern[str], None] = None, +) -> Optional[str]: + """ + 從文字中抽出 OTP code。預設 4–8 位數字。 + Apply ``pattern`` (defaults to 4–8 digits) and return the first match. + Returns ``None`` when no match is found. + """ + if not isinstance(text, str) or not text: + return None + if pattern is None: + regex: Pattern[str] = _DEFAULT_OTP_REGEX + elif isinstance(pattern, str): + regex = re.compile(pattern) + else: + regex = pattern + match = regex.search(text) + if match is None: + return None + if match.groups(): + return match.group(1) + return match.group(0) + + +def wait_for_otp( + provider: OtpProvider, + recipient: str, + *, + pattern: Union[str, Pattern[str], None] = None, + timeout: float = 30.0, + poll_interval: float = 1.0, + since: Optional[float] = None, + subject_contains: Optional[str] = None, + sleep_fn: Callable[[float], None] = time.sleep, + time_fn: Callable[[], float] = time.time, +) -> str: + """ + 輪詢 provider 直到收到含 OTP 的訊息或 timeout。 + Poll the provider every ``poll_interval`` seconds until a message that + matches ``recipient`` (and optionally ``subject_contains``) arrives AND + contains an OTP matching ``pattern``. Returns the extracted OTP string. + + Raises :class:`OtpInterceptError` on timeout. ``since`` defaults to + "now" so messages already in the inbox don't accidentally match. + """ + if not isinstance(provider, OtpProvider): + raise OtpInterceptError(f"provider must be an OtpProvider, got {type(provider).__name__}") + if not recipient: + raise OtpInterceptError("recipient is required") + if timeout <= 0: + raise OtpInterceptError("timeout must be positive") + if poll_interval <= 0: + raise OtpInterceptError("poll_interval must be positive") + start = time_fn() + if since is None: + since = start + while True: + messages = provider.fetch_messages(recipient=recipient, since=since) + for msg in messages: + if subject_contains and subject_contains.lower() not in msg.subject.lower(): + continue + code = extract_otp_from_text(msg.body, pattern) or extract_otp_from_text(msg.subject, pattern) + if code: + web_runner_logger.info( + f"wait_for_otp: matched {recipient} subject={msg.subject!r}" + ) + return code + if time_fn() - start >= timeout: + raise OtpInterceptError( + f"timeout waiting for OTP for {recipient!r} after {timeout}s" + ) + sleep_fn(poll_interval) diff --git a/je_web_runner/utils/pagination_audit/__init__.py b/je_web_runner/utils/pagination_audit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/pagination_audit/audit.py b/je_web_runner/utils/pagination_audit/audit.py new file mode 100644 index 0000000..44e0c7f --- /dev/null +++ b/je_web_runner/utils/pagination_audit/audit.py @@ -0,0 +1,229 @@ +""" +遍歷所有頁,斷言無重複 / 無遺漏 / cursor 穩定。 +Common pagination bugs: + +* Off-by-one: missing 1 row at every page boundary +* Duplicate item across pages (sort key not stable under concurrent + writes) +* Cursor changes meaning when the result set mutates +* "Empty next page" never terminates (infinite loop) + +This module drives a user-supplied :class:`PageFetcher` through every +page until exhaustion (or hits ``max_pages`` safety limit) and reports +counts + violations. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, Hashable, Iterable, List, Optional, Protocol, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class PaginationAuditError(WebRunnerException): + """Raised on bad inputs or detected pagination issues.""" + + +# ---------- model ------------------------------------------------------ + +@dataclass +class Page: + """One fetched page.""" + + items: List[Any] + next_cursor: Optional[Any] = None + + def __post_init__(self) -> None: + if not isinstance(self.items, list): + raise PaginationAuditError("Page.items must be a list") + + +class PageFetcher(Protocol): + """Caller-supplied fetcher.""" + + def __call__(self, cursor: Optional[Any]) -> Page: ... + + +KeyFn = Callable[[Any], Hashable] +"""Function: item → hashable identity (e.g. ``lambda r: r['id']``).""" + + +# ---------- audit ------------------------------------------------------ + +@dataclass +class PaginationFindings: + """Result of :func:`walk_all_pages`.""" + + page_count: int = 0 + total_items: int = 0 + unique_items: int = 0 + duplicates: List[Hashable] = field(default_factory=list) + duplicate_pages: Dict[Hashable, List[int]] = field(default_factory=dict) + empty_pages: List[int] = field(default_factory=list) + cursor_loop: bool = False + hit_max_pages: bool = False + item_keys_by_page: List[List[Hashable]] = field(default_factory=list) + + def passed(self) -> bool: + return not self.duplicates and not self.cursor_loop and not self.hit_max_pages + + +def walk_all_pages( + fetcher: PageFetcher, + key_fn: KeyFn, + *, + max_pages: int = 1_000, + initial_cursor: Optional[Any] = None, +) -> PaginationFindings: + """ + Iterate pages until ``next_cursor`` is None (or ``max_pages`` reached), + accumulating duplicates, empty-page indices, and cursor-loop detection. + """ + if not callable(fetcher): + raise PaginationAuditError("fetcher must be callable") + if not callable(key_fn): + raise PaginationAuditError("key_fn must be callable") + if max_pages <= 0: + raise PaginationAuditError("max_pages must be > 0") + + findings = PaginationFindings() + seen_cursors: set = set() + seen_items: Dict[Hashable, List[int]] = {} + cursor: Optional[Any] = initial_cursor + page_index = 0 + while page_index < max_pages: + try: + page = fetcher(cursor) + except Exception as error: + raise PaginationAuditError( + f"fetcher raised at page {page_index}: {error!r}" + ) from error + if not isinstance(page, Page): + raise PaginationAuditError( + f"fetcher must return Page, got {type(page).__name__}" + ) + findings.page_count = page_index + 1 + page_keys: List[Hashable] = [] + for item in page.items: + try: + key = key_fn(item) + except Exception as error: + raise PaginationAuditError( + f"key_fn failed on page {page_index}: {error!r}" + ) from error + seen_items.setdefault(key, []).append(page_index) + page_keys.append(key) + findings.total_items += len(page.items) + findings.item_keys_by_page.append(page_keys) + if not page.items: + findings.empty_pages.append(page_index) + if page.next_cursor is None: + break + cursor_key = _hashable_cursor(page.next_cursor) + if cursor_key in seen_cursors: + findings.cursor_loop = True + break + seen_cursors.add(cursor_key) + cursor = page.next_cursor + page_index += 1 + else: + findings.hit_max_pages = True + + for key, page_list in seen_items.items(): + if len(page_list) > 1: + findings.duplicates.append(key) + findings.duplicate_pages[key] = page_list + findings.unique_items = len(seen_items) + return findings + + +def _hashable_cursor(cursor: Any) -> Hashable: + if isinstance(cursor, (str, int, float, bool, type(None))): + return cursor + try: + return repr(cursor) + except Exception: + return id(cursor) + + +# ---------- assertions ------------------------------------------------- + +def assert_no_duplicates(findings: PaginationFindings) -> None: + """Raise if any item key appeared on more than one page.""" + if findings.duplicates: + sample = ", ".join(repr(k) for k in findings.duplicates[:5]) + more = ( + "" if len(findings.duplicates) <= 5 + else f" (+{len(findings.duplicates) - 5})" + ) + raise PaginationAuditError(f"duplicate items across pages: {sample}{more}") + + +def assert_no_cursor_loop(findings: PaginationFindings) -> None: + """Raise if a cursor was reused (would loop forever).""" + if findings.cursor_loop: + raise PaginationAuditError("cursor loop detected") + + +def assert_terminated(findings: PaginationFindings) -> None: + """Raise if ``max_pages`` was hit before exhaustion.""" + if findings.hit_max_pages: + raise PaginationAuditError( + f"pagination did not terminate within {findings.page_count} pages" + ) + + +def assert_expected_total( + findings: PaginationFindings, *, expected_total: int, +) -> None: + """Assert ``unique_items`` matches ``expected_total``.""" + if expected_total < 0: + raise PaginationAuditError("expected_total must be >= 0") + if findings.unique_items != expected_total: + raise PaginationAuditError( + f"unique items {findings.unique_items} != expected {expected_total}" + ) + + +def assert_clean(findings: PaginationFindings) -> None: + """All of the above in one go.""" + if not isinstance(findings, PaginationFindings): + raise PaginationAuditError("assert_clean expects PaginationFindings") + assert_no_duplicates(findings) + assert_no_cursor_loop(findings) + assert_terminated(findings) + + +# ---------- ordering check -------------------------------------------- + +def assert_sorted_by( + findings: PaginationFindings, + items_by_page_key: KeyFn, + *, + reverse: bool = False, +) -> None: + """ + Assert each page (and inter-page boundary) is sorted by ``items_by_page_key``. + Different from :func:`assert_no_duplicates` — this catches "page 3 + items come BEFORE page 2 items" bugs that look fine within a page. + """ + if not callable(items_by_page_key): + raise PaginationAuditError("items_by_page_key must be callable") + flattened: List[Hashable] = [ + key for page_keys in findings.item_keys_by_page for key in page_keys + ] + if not flattened: + return + last = flattened[0] + for current in flattened[1:]: + if reverse: + if not (current <= last): + raise PaginationAuditError( + f"order violation: {current!r} > {last!r} but reverse=True" + ) + else: + if not (current >= last): + raise PaginationAuditError( + f"order violation: {current!r} < {last!r}" + ) + last = current diff --git a/je_web_runner/utils/persona_runner/__init__.py b/je_web_runner/utils/persona_runner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/persona_runner/runner.py b/je_web_runner/utils/persona_runner/runner.py new file mode 100644 index 0000000..9d0182c --- /dev/null +++ b/je_web_runner/utils/persona_runner/runner.py @@ -0,0 +1,211 @@ +""" +Persona matrix: 同一份 suite × N 種角色(admin / free / enterprise / guest)。 +Most apps have feature gates by role. Re-running every action JSON file +under every persona catches "free user saw an admin-only button" or +"enterprise user can't see the feature they paid for". + +Inputs: + +* :class:`Persona` — name + auth-state hook + optional flag overrides +* :class:`PersonaRunner.run` — iterate (persona × action_file) and call + the user's runner callable for each pair + +Outputs: + +* :class:`PersonaCaseResult` per pair, and a :class:`MatrixSummary` + helping the reader see "all failures are on persona=guest" at a glance +""" +from __future__ import annotations + +import time +from dataclasses import asdict, dataclass, field +from typing import Any, Callable, Dict, Iterable, List, Optional, Protocol, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class PersonaRunnerError(WebRunnerException): + """Raised on bad persona / action-file inputs.""" + + +# ---------- persona model ----------------------------------------------- + +@dataclass +class Persona: + """One identity under test.""" + + name: str + auth_state: Dict[str, Any] = field(default_factory=dict) + flags: Dict[str, Any] = field(default_factory=dict) + tags: List[str] = field(default_factory=list) + + def __post_init__(self) -> None: + if not self.name or not isinstance(self.name, str): + raise PersonaRunnerError("Persona.name must be non-empty string") + + +# ---------- result model ------------------------------------------------ + +@dataclass +class PersonaCaseResult: + """Outcome of one (persona, action_file) pair.""" + + persona: str + action_file: str + passed: bool + duration_seconds: float = 0.0 + error: Optional[str] = None + notes: List[str] = field(default_factory=list) + + +@dataclass +class MatrixSummary: + """Roll-up across all pairs.""" + + total: int + passed: int + failed: int + by_persona: Dict[str, Dict[str, int]] = field(default_factory=dict) + persona_only_failures: List[str] = field(default_factory=list) + file_only_failures: List[str] = field(default_factory=list) + + +# ---------- runner protocol -------------------------------------------- + +class PersonaCaseRunner(Protocol): + """Implementations execute one action file under one persona.""" + + def __call__(self, persona: Persona, action_file: str) -> None: ... + + +# ---------- the runner -------------------------------------------------- + +@dataclass +class PersonaRunner: + """Drive the persona × file matrix.""" + + personas: Sequence[Persona] + action_files: Sequence[str] + case_runner: PersonaCaseRunner + stop_on_first_failure: bool = False + + def __post_init__(self) -> None: + if not self.personas: + raise PersonaRunnerError("at least one persona is required") + names = [p.name for p in self.personas] + if len(set(names)) != len(names): + raise PersonaRunnerError(f"duplicate persona names: {names}") + if not self.action_files: + raise PersonaRunnerError("at least one action file is required") + if len(set(self.action_files)) != len(self.action_files): + raise PersonaRunnerError("duplicate action_files in matrix") + + def run(self) -> List[PersonaCaseResult]: + results: List[PersonaCaseResult] = [] + for persona in self.personas: + for action_file in self.action_files: + started = time.monotonic() + error: Optional[str] = None + try: + self.case_runner(persona, action_file) + passed = True + except PersonaRunnerError: + raise + except Exception as exc: + passed = False + error = repr(exc) + web_runner_logger.warning( + f"persona={persona.name!r} file={action_file!r} failed: {exc!r}" + ) + duration = round(time.monotonic() - started, 4) + results.append(PersonaCaseResult( + persona=persona.name, + action_file=action_file, + passed=passed, + duration_seconds=duration, + error=error, + )) + if not passed and self.stop_on_first_failure: + return results + return results + + +# ---------- summary ----------------------------------------------------- + +def summarise(results: Iterable[PersonaCaseResult]) -> MatrixSummary: + """Build a :class:`MatrixSummary` from a result iterable.""" + total = 0 + passed_count = 0 + by_persona: Dict[str, Dict[str, int]] = {} + failures_by_persona: Dict[str, List[str]] = {} + failures_by_file: Dict[str, List[str]] = {} + seen_personas: set = set() + seen_files: set = set() + for result in results: + if not isinstance(result, PersonaCaseResult): + raise PersonaRunnerError( + f"summarise expects PersonaCaseResult, got {type(result).__name__}" + ) + total += 1 + seen_personas.add(result.persona) + seen_files.add(result.action_file) + bucket = by_persona.setdefault(result.persona, {"passed": 0, "failed": 0}) + if result.passed: + bucket["passed"] += 1 + passed_count += 1 + else: + bucket["failed"] += 1 + failures_by_persona.setdefault(result.persona, []).append(result.action_file) + failures_by_file.setdefault(result.action_file, []).append(result.persona) + persona_only: List[str] = [] + for persona, failed_files in failures_by_persona.items(): + # A persona "only" fails if every other persona passes the same files + if all( + persona not in failures_by_file.get(file, []) + or len(failures_by_file.get(file, [])) == 1 + for file in failed_files + ): + persona_only.append(persona) + file_only: List[str] = [] + for file, failing_personas in failures_by_file.items(): + if len(set(failing_personas)) >= len(seen_personas): + file_only.append(file) + return MatrixSummary( + total=total, + passed=passed_count, + failed=total - passed_count, + by_persona=by_persona, + persona_only_failures=sorted(persona_only), + file_only_failures=sorted(file_only), + ) + + +# ---------- formatting -------------------------------------------------- + +def summary_markdown(summary: MatrixSummary) -> str: + """Render a small markdown table for PR comments.""" + if summary.total == 0: + return "_No persona matrix results._\n" + lines = [ + f"### Persona matrix: {summary.passed}/{summary.total} passed", + "", + "| Persona | Passed | Failed |", + "|---------|--------|--------|", + ] + for persona in sorted(summary.by_persona): + bucket = summary.by_persona[persona] + lines.append(f"| {persona} | {bucket['passed']} | {bucket['failed']} |") + if summary.persona_only_failures: + lines.append("") + lines.append( + "**Persona-specific regressions:** " + + ", ".join(summary.persona_only_failures) + ) + if summary.file_only_failures: + lines.append("") + lines.append( + "**Files failing on every persona:** " + + ", ".join(summary.file_only_failures) + ) + return "\n".join(lines) + "\n" diff --git a/je_web_runner/utils/pii_in_screenshot/__init__.py b/je_web_runner/utils/pii_in_screenshot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/pii_in_screenshot/scanner.py b/je_web_runner/utils/pii_in_screenshot/scanner.py new file mode 100644 index 0000000..277a602 --- /dev/null +++ b/je_web_runner/utils/pii_in_screenshot/scanner.py @@ -0,0 +1,241 @@ +""" +Screenshot 內個資掃描:OCR → regex,抓出截圖意外洩漏的 email / 信用卡 / 身分證等。 +Many staging environments anonymise the DOM but forget images / charts / +PDF previews / 3rd-party iframes. When ``visual_regression`` snapshots +are uploaded to a shared dashboard or attached to a public bundle, those +unredacted PII pieces leak. + +This module reuses :mod:`ocr_assert` for text extraction and runs a +focused PII regex set against the output. The PII rules are deliberately +narrower than ``pii_scanner`` (which scans repos / structured payloads) +to keep false positives low on free-form OCR. +""" +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Pattern, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.ocr_assert.ocr import OcrBackend, extract_text, normalise_text + + +class PiiInScreenshotError(WebRunnerException): + """Raised on bad inputs or OCR failure during scan.""" + + +# ---------- PII rule catalogue ------------------------------------------ + +@dataclass(frozen=True) +class PiiRule: + """One PII pattern + label.""" + + name: str + pattern: Pattern[str] + severity: str = "high" + # If validator returns False, the match is discarded (e.g. Luhn check). + validator: Optional[Callable[[str], bool]] = None + + +def _luhn(card: str) -> bool: + digits = [int(d) for d in card if d.isdigit()] + if len(digits) < 12 or len(digits) > 19: + return False + checksum = 0 + for index, digit in enumerate(reversed(digits)): + if index % 2 == 1: + doubled = digit * 2 + checksum += doubled - 9 if doubled > 9 else doubled + else: + checksum += digit + return checksum % 10 == 0 + + +_EMAIL = PiiRule( + name="email", + pattern=re.compile(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"), + severity="medium", +) +_PHONE_E164 = PiiRule( + name="phone_e164", + pattern=re.compile(r"\+\d{1,3}[\s\-.]?\(?\d{1,4}\)?[\s\-.]?\d{2,4}[\s\-.]?\d{2,4}"), + severity="medium", +) +_CREDIT_CARD = PiiRule( + name="credit_card", + pattern=re.compile(r"\b(?:\d[ -]*?){13,19}\b"), + severity="critical", + validator=_luhn, +) +_SSN_US = PiiRule( + name="ssn_us", + pattern=re.compile(r"\b\d{3}-\d{2}-\d{4}\b"), + severity="critical", +) +_TWID = PiiRule( # Taiwan ID + name="tw_national_id", + pattern=re.compile(r"\b[A-Z][12]\d{8}\b"), + severity="critical", +) +_IBAN = PiiRule( + name="iban", + pattern=re.compile(r"\b[A-Z]{2}\d{2}[A-Z0-9]{11,30}\b"), + severity="high", +) +_IPV4 = PiiRule( + name="ipv4", + pattern=re.compile( + r"\b(?:(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\.){3}" + r"(?:25[0-5]|2[0-4]\d|1\d\d|[1-9]?\d)\b" + ), + severity="low", +) + + +DEFAULT_RULES: Sequence[PiiRule] = ( + _EMAIL, _PHONE_E164, _CREDIT_CARD, _SSN_US, _TWID, _IBAN, _IPV4, +) + + +# ---------- findings ---------------------------------------------------- + +@dataclass +class PiiFinding: + """One PII occurrence in a screenshot.""" + + rule: str + severity: str + redacted_match: str + image: str = "" + raw_excerpt: str = "" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +def _redact_match(value: str) -> str: + cleaned = value.strip() + if len(cleaned) <= 4: + return "***" + return f"{cleaned[:2]}…{cleaned[-2:]}" + + +# ---------- scan -------------------------------------------------------- + +def scan_image( + source: Union[bytes, str, Path, Any], + *, + backend: Optional[OcrBackend] = None, + rules: Sequence[PiiRule] = DEFAULT_RULES, + image_label: str = "", +) -> List[PiiFinding]: + """OCR the image and return one :class:`PiiFinding` per (rule, match).""" + try: + raw_text = extract_text(source, backend=backend) + except Exception as error: + raise PiiInScreenshotError(f"OCR failed: {error!r}") from error + return _scan_text(raw_text, rules=rules, image_label=image_label) + + +def scan_text_only( + text: str, + *, + rules: Sequence[PiiRule] = DEFAULT_RULES, + image_label: str = "", +) -> List[PiiFinding]: + """Variant for callers that already have OCR'd text in hand.""" + if not isinstance(text, str): + raise PiiInScreenshotError( + f"scan_text_only expects str, got {type(text).__name__}" + ) + return _scan_text(text, rules=rules, image_label=image_label) + + +def _scan_text( + text: str, + *, + rules: Sequence[PiiRule], + image_label: str, +) -> List[PiiFinding]: + findings: List[PiiFinding] = [] + seen: set = set() + normalised = normalise_text(text, lowercase=False, strip_accents=False) + for rule in rules: + for match in rule.pattern.finditer(text): + value = match.group(0) + if rule.validator is not None and not rule.validator(value): + continue + redacted = _redact_match(value) + dedup_key = (rule.name, redacted) + if dedup_key in seen: + continue + seen.add(dedup_key) + findings.append(PiiFinding( + rule=rule.name, + severity=rule.severity, + redacted_match=redacted, + image=image_label, + raw_excerpt=_excerpt_around(normalised, value), + )) + return findings + + +def _excerpt_around(text: str, value: str) -> str: + idx = text.find(value) + if idx == -1: + return "" + start = max(0, idx - 24) + end = min(len(text), idx + len(value) + 24) + excerpt = text[start:end] + return excerpt.replace(value, "<>") + + +# ---------- bulk + assertions ------------------------------------------- + +@dataclass +class ScanReport: + """Aggregate over many screenshots.""" + + scanned: int = 0 + findings: List[PiiFinding] = field(default_factory=list) + by_severity: Dict[str, int] = field(default_factory=dict) + + def passed(self) -> bool: + return not self.findings + + +def scan_screenshots( + sources: Sequence[Union[bytes, str, Path, Any]], + *, + backend: Optional[OcrBackend] = None, + rules: Sequence[PiiRule] = DEFAULT_RULES, +) -> ScanReport: + """Scan a batch of screenshots and return a :class:`ScanReport`.""" + if not sources: + raise PiiInScreenshotError("sources must be non-empty") + report = ScanReport() + for index, source in enumerate(sources): + label = source if isinstance(source, (str, Path)) else f"image_{index}" + report.scanned += 1 + for finding in scan_image( + source, backend=backend, rules=rules, image_label=str(label), + ): + report.findings.append(finding) + report.by_severity[finding.severity] = ( + report.by_severity.get(finding.severity, 0) + 1 + ) + return report + + +def assert_clean(report: ScanReport) -> None: + """Raise unless ``report.passed()``.""" + if not isinstance(report, ScanReport): + raise PiiInScreenshotError("assert_clean expects ScanReport") + if report.passed(): + return + sample = ", ".join( + f"{f.rule}({f.severity})@{f.image}" for f in report.findings[:5] + ) + more = "" if len(report.findings) <= 5 else f" (+{len(report.findings) - 5})" + raise PiiInScreenshotError(f"PII detected in screenshots: {sample}{more}") diff --git a/je_web_runner/utils/pr_risk_score/__init__.py b/je_web_runner/utils/pr_risk_score/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/pr_risk_score/scorer.py b/je_web_runner/utils/pr_risk_score/scorer.py new file mode 100644 index 0000000..1c23354 --- /dev/null +++ b/je_web_runner/utils/pr_risk_score/scorer.py @@ -0,0 +1,250 @@ +""" +融合 flake / impact / locator_health / coverage 訊號,給 PR 一個 0-100 風險分數。 +The pieces this depends on (``flake_detector``, ``impact_analysis``, +``locator_health``, ``coverage_map``) already exist; this module just +combines their per-PR rollups into a single decision-friendly number plus +a human-readable reason list. + +Risk model: weighted sum, each signal clipped to ``[0, 1]``. Default +weights are tuned to roughly match what humans flag as "scary PR" in our +own retrospective sampling, but they're parameterised so teams can tune +them per repo without forking. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class PrRiskScoreError(WebRunnerException): + """Raised on invalid inputs or weight totals that won't normalise.""" + + +# ---------- inputs ------------------------------------------------------ + +@dataclass +class PrSignals: + """ + Per-PR aggregate signals. Each field is optional so a partial signal + set still produces a (less-confident) score. + """ + + # Test stability + flaky_tests_touched: int = 0 + total_tests_touched: int = 0 + avg_flake_score: float = 0.0 + + # Blast radius + impacted_modules: int = 0 + repo_modules: int = 0 + impacted_critical_paths: int = 0 + + # Locator hygiene + fragile_locators_touched: int = 0 + total_locators_touched: int = 0 + + # Coverage + lines_added: int = 0 + lines_covered: int = 0 + + # Other contextual signals + migration_files_changed: int = 0 + security_files_changed: int = 0 + + def __post_init__(self) -> None: + for name in ( + "flaky_tests_touched", "total_tests_touched", + "impacted_modules", "repo_modules", "impacted_critical_paths", + "fragile_locators_touched", "total_locators_touched", + "lines_added", "lines_covered", + "migration_files_changed", "security_files_changed", + ): + if getattr(self, name) < 0: + raise PrRiskScoreError(f"{name} must be >= 0") + if not 0.0 <= self.avg_flake_score <= 1.0: + raise PrRiskScoreError("avg_flake_score must be in [0, 1]") + + +@dataclass(frozen=True) +class RiskWeights: + """Per-signal contribution. Should be non-negative; need not sum to 1.""" + + flake: float = 2.0 + blast_radius: float = 2.0 + critical_path: float = 1.5 + locator_fragility: float = 1.0 + coverage_gap: float = 1.5 + migration: float = 1.0 + security: float = 1.0 + + def total(self) -> float: + return sum(asdict(self).values()) + + +@dataclass +class RiskReport: + """Result of :func:`score_pr`.""" + + score: float # 0..100 + level: str # "low" | "medium" | "high" | "critical" + reasons: List[str] = field(default_factory=list) + contributions: Dict[str, float] = field(default_factory=dict) + + def is_blocking(self, block_at: float = 75.0) -> bool: + return self.score >= block_at + + +# ---------- scoring ----------------------------------------------------- + +def _ratio(numerator: int, denominator: int) -> float: + if denominator <= 0: + return 0.0 + return max(0.0, min(1.0, numerator / denominator)) + + +def _clip(value: float) -> float: + return max(0.0, min(1.0, value)) + + +_SIGNAL_NAMES = ( + "flake", "blast_radius", "critical_path", + "locator_fragility", "coverage_gap", "migration", "security", +) + + +def _signal_components(signals: PrSignals) -> Dict[str, float]: + flake_touched_ratio = _ratio( + signals.flaky_tests_touched, signals.total_tests_touched, + ) + flake = _clip(0.7 * flake_touched_ratio + 0.3 * signals.avg_flake_score) + blast = _ratio(signals.impacted_modules, signals.repo_modules) + # Critical path: each impacted critical path adds 0.25, capped at 1.0 + critical = _clip(0.25 * signals.impacted_critical_paths) + locator = _ratio(signals.fragile_locators_touched, signals.total_locators_touched) + if signals.lines_added <= 0: + coverage_gap = 0.0 + else: + coverage_gap = _clip(1.0 - _ratio(signals.lines_covered, signals.lines_added)) + migration = _clip(0.5 * signals.migration_files_changed) + security = _clip(0.5 * signals.security_files_changed) + return { + "flake": flake, + "blast_radius": blast, + "critical_path": critical, + "locator_fragility": locator, + "coverage_gap": coverage_gap, + "migration": migration, + "security": security, + } + + +def _format_reason(name: str, component: float, weight: float) -> Optional[str]: + if component <= 0: + return None + pct = round(component * 100) + return f"{name.replace('_', ' ')}: {pct}% signal × weight {weight:.1f}" + + +def _level_for(score: float) -> str: + if score >= 75: + return "critical" + if score >= 50: + return "high" + if score >= 25: + return "medium" + return "low" + + +def score_pr( + signals: PrSignals, + weights: Optional[RiskWeights] = None, +) -> RiskReport: + """Combine ``signals`` × ``weights`` into a 0–100 :class:`RiskReport`.""" + if not isinstance(signals, PrSignals): + raise PrRiskScoreError("signals must be a PrSignals instance") + weights = weights or RiskWeights() + weight_total = weights.total() + if weight_total <= 0: + raise PrRiskScoreError("at least one weight must be > 0") + components = _signal_components(signals) + contributions: Dict[str, float] = {} + weighted_sum = 0.0 + for name in _SIGNAL_NAMES: + weight = getattr(weights, name) + component = components[name] + contrib = component * weight + weighted_sum += contrib + contributions[name] = round(contrib, 4) + normalised = weighted_sum / weight_total + score = round(_clip(normalised) * 100.0, 2) + reasons = sorted( + ( + r for r in ( + _format_reason(name, components[name], getattr(weights, name)) + for name in _SIGNAL_NAMES + ) if r is not None + ), + key=lambda s: -float(s.split("%")[0].rsplit(":", 1)[-1].strip()), + ) + return RiskReport( + score=score, + level=_level_for(score), + reasons=reasons, + contributions=contributions, + ) + + +# ---------- formatting -------------------------------------------------- + +def report_markdown(report: RiskReport) -> str: + """Render a :class:`RiskReport` as a small markdown block for PR comments.""" + lines = [ + f"### PR risk: **{report.score:.1f} / 100** ({report.level})", + "", + ] + if report.reasons: + lines.append("Top contributing signals:") + lines.extend(f"- {reason}" for reason in report.reasons) + else: + lines.append("_No risk signals tripped._") + return "\n".join(lines) + "\n" + + +def aggregate_signals(per_file: Sequence[Dict[str, Any]]) -> PrSignals: + """ + Reduce a per-file signal list (from upstream tools) into one + :class:`PrSignals`. Unknown keys are ignored so callers can pass + richer dicts without breaking. + """ + totals: Dict[str, int] = { + name: 0 for name in ( + "flaky_tests_touched", "total_tests_touched", + "impacted_modules", "impacted_critical_paths", + "fragile_locators_touched", "total_locators_touched", + "lines_added", "lines_covered", + "migration_files_changed", "security_files_changed", + ) + } + flake_scores: List[float] = [] + repo_modules = 0 + for entry in per_file: + if not isinstance(entry, dict): + continue + for key in totals: + value = entry.get(key) + if isinstance(value, int) and value >= 0: + totals[key] += value + score = entry.get("avg_flake_score") + if isinstance(score, (int, float)) and 0 <= score <= 1: + flake_scores.append(float(score)) + rm = entry.get("repo_modules") + if isinstance(rm, int) and rm > repo_modules: + repo_modules = rm + avg_flake = sum(flake_scores) / len(flake_scores) if flake_scores else 0.0 + return PrSignals( + avg_flake_score=round(avg_flake, 4), + repo_modules=repo_modules, + **totals, + ) diff --git a/je_web_runner/utils/prompt_drift_monitor/__init__.py b/je_web_runner/utils/prompt_drift_monitor/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/prompt_drift_monitor/monitor.py b/je_web_runner/utils/prompt_drift_monitor/monitor.py new file mode 100644 index 0000000..b73e572 --- /dev/null +++ b/je_web_runner/utils/prompt_drift_monitor/monitor.py @@ -0,0 +1,270 @@ +""" +追蹤 app 內 LLM 輸出隨時間飄移。 +應用情境:你的 app 自己有 LLM 功能(客服 bot、文章摘要、智能搜尋),你要監測它的回答品質是否隨時間下滑。 +Two complementary signals: + +* **Embedding similarity to a frozen baseline** — drift > threshold flags + the run for review. +* **Lexical anchors** — list of phrases that *must* appear (a brand name, + a disclaimer) or *must not* appear (a forbidden competitor name, a + banned topic). Lexical checks complement embeddings: drift below the + similarity threshold still fails if the disclaimer disappeared. + +State is stored as a tiny JSON baseline file so a CI job can compare +today's output against last week's snapshot without a database. +""" +from __future__ import annotations + +import json +import math +import re +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class PromptDriftError(WebRunnerException): + """Raised on malformed baseline / embeddings / config.""" + + +Embedder = Callable[[str], Sequence[float]] +"""Callable: text → embedding vector.""" + + +# ---------- baseline ---------------------------------------------------- + +@dataclass +class BaselineSample: + """One frozen reference answer.""" + + prompt_id: str + prompt: str + answer: str + embedding: List[float] + must_include: List[str] = field(default_factory=list) + must_exclude: List[str] = field(default_factory=list) + captured_at: str = "" + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class Baseline: + """Set of frozen reference samples, persisted as JSON.""" + + samples: List[BaselineSample] = field(default_factory=list) + captured_at: str = "" + + def by_id(self) -> Dict[str, BaselineSample]: + return {s.prompt_id: s for s in self.samples} + + +def capture_baseline( + prompts: Sequence[Dict[str, Any]], + embedder: Embedder, + answerer: Callable[[str], str], +) -> Baseline: + """ + Walk ``prompts`` (each ``{id, prompt, must_include?, must_exclude?}``), + ask ``answerer`` for the current answer, embed it, package as Baseline. + """ + if not prompts: + raise PromptDriftError("prompts must be non-empty") + samples: List[BaselineSample] = [] + now = datetime.now(tz=timezone.utc).isoformat(timespec="seconds") + for raw in prompts: + if not isinstance(raw, dict): + raise PromptDriftError("each prompt must be a dict") + prompt_id = str(raw.get("id") or "").strip() + prompt_text = str(raw.get("prompt") or "").strip() + if not prompt_id or not prompt_text: + raise PromptDriftError("each prompt needs non-empty 'id' and 'prompt'") + try: + answer = str(answerer(prompt_text)) + except Exception as error: + raise PromptDriftError( + f"answerer failed for {prompt_id!r}: {error!r}" + ) from error + vec = _embed_or_raise(embedder, answer, label=prompt_id) + samples.append(BaselineSample( + prompt_id=prompt_id, + prompt=prompt_text, + answer=answer, + embedding=list(vec), + must_include=[str(v) for v in raw.get("must_include") or []], + must_exclude=[str(v) for v in raw.get("must_exclude") or []], + captured_at=now, + )) + return Baseline(samples=samples, captured_at=now) + + +def save_baseline(baseline: Baseline, path: Union[str, Path]) -> Path: + """Persist baseline to JSON.""" + if not isinstance(baseline, Baseline): + raise PromptDriftError("save_baseline expects Baseline") + p = Path(path) + p.parent.mkdir(parents=True, exist_ok=True) + payload = { + "captured_at": baseline.captured_at, + "samples": [s.to_dict() for s in baseline.samples], + } + with open(p, "w", encoding="utf-8") as fp: + json.dump(payload, fp, ensure_ascii=False, indent=2) + return p + + +def load_baseline(path: Union[str, Path]) -> Baseline: + """Read baseline JSON back into a :class:`Baseline`.""" + p = Path(path) + if not p.exists(): + raise PromptDriftError(f"baseline file not found: {p}") + try: + data = json.loads(p.read_text(encoding="utf-8")) + except ValueError as error: + raise PromptDriftError(f"baseline file invalid: {error}") from error + if not isinstance(data, dict) or not isinstance(data.get("samples"), list): + raise PromptDriftError("baseline JSON missing 'samples' list") + samples: List[BaselineSample] = [] + for raw in data["samples"]: + if not isinstance(raw, dict): + continue + try: + samples.append(BaselineSample( + prompt_id=str(raw["prompt_id"]), + prompt=str(raw["prompt"]), + answer=str(raw["answer"]), + embedding=[float(x) for x in raw["embedding"]], + must_include=[str(v) for v in raw.get("must_include") or []], + must_exclude=[str(v) for v in raw.get("must_exclude") or []], + captured_at=str(raw.get("captured_at") or ""), + )) + except (KeyError, TypeError, ValueError) as error: + raise PromptDriftError(f"malformed sample: {error}") from error + return Baseline( + samples=samples, + captured_at=str(data.get("captured_at") or ""), + ) + + +# ---------- monitoring -------------------------------------------------- + +@dataclass +class DriftFinding: + """Per-prompt drift verdict.""" + + prompt_id: str + similarity: float + drifted: bool + missing_required: List[str] = field(default_factory=list) + forbidden_present: List[str] = field(default_factory=list) + current_answer: str = "" + + +@dataclass +class DriftReport: + """Roll-up returned by :func:`check_drift`.""" + + threshold: float + findings: List[DriftFinding] = field(default_factory=list) + + def drifted_findings(self) -> List[DriftFinding]: + return [f for f in self.findings + if f.drifted or f.missing_required or f.forbidden_present] + + def passed(self) -> bool: + return not self.drifted_findings() + + +def check_drift( + baseline: Baseline, + embedder: Embedder, + answerer: Callable[[str], str], + *, + similarity_threshold: float = 0.85, +) -> DriftReport: + """ + For each baseline sample, ask the current model, embed, compare. + Any sample below ``similarity_threshold`` or missing/including a + forbidden anchor is reported as drifted. + """ + if not isinstance(baseline, Baseline): + raise PromptDriftError("check_drift expects Baseline") + if not 0.0 < similarity_threshold <= 1.0: + raise PromptDriftError("similarity_threshold must be in (0, 1]") + report = DriftReport(threshold=similarity_threshold) + for sample in baseline.samples: + try: + current = str(answerer(sample.prompt)) + except Exception as error: + raise PromptDriftError( + f"answerer failed for {sample.prompt_id!r}: {error!r}" + ) from error + vec = _embed_or_raise(embedder, current, label=sample.prompt_id) + similarity = _cosine(sample.embedding, list(vec)) + missing = [phrase for phrase in sample.must_include + if phrase and phrase.lower() not in current.lower()] + forbidden = [phrase for phrase in sample.must_exclude + if phrase and phrase.lower() in current.lower()] + drifted = similarity < similarity_threshold + report.findings.append(DriftFinding( + prompt_id=sample.prompt_id, + similarity=round(similarity, 4), + drifted=drifted, + missing_required=missing, + forbidden_present=forbidden, + current_answer=current, + )) + if drifted or missing or forbidden: + web_runner_logger.warning( + f"prompt_drift: {sample.prompt_id} sim={similarity:.3f} " + f"missing={missing} forbidden={forbidden}" + ) + return report + + +# ---------- helpers ----------------------------------------------------- + +def assert_no_drift(report: DriftReport) -> None: + """Raise unless the report has no drifted findings.""" + if not isinstance(report, DriftReport): + raise PromptDriftError("assert_no_drift expects DriftReport") + if report.passed(): + return + parts = [ + f"{f.prompt_id}(sim={f.similarity:.2f})" + for f in report.drifted_findings()[:5] + ] + more = ( + "" + if len(report.drifted_findings()) <= 5 + else f" (+{len(report.drifted_findings()) - 5})" + ) + raise PromptDriftError(f"prompt drift detected: {', '.join(parts)}{more}") + + +def _embed_or_raise(embedder: Embedder, text: str, *, label: str) -> Sequence[float]: + try: + vec = embedder(text) + except Exception as error: + raise PromptDriftError( + f"embedder failed for {label!r}: {error!r}" + ) from error + if not isinstance(vec, (list, tuple)) or not vec: + raise PromptDriftError(f"embedder returned bad vector for {label!r}: {vec!r}") + return vec + + +def _cosine(a: Sequence[float], b: Sequence[float]) -> float: + if len(a) != len(b) or not a: + raise PromptDriftError("embeddings must be non-empty and equal-length") + dot = sum(float(x) * float(y) for x, y in zip(a, b)) + norm_a = math.sqrt(sum(float(x) * float(x) for x in a)) + norm_b = math.sqrt(sum(float(x) * float(x) for x in b)) + if norm_a == 0 or norm_b == 0: + return 0.0 + return dot / (norm_a * norm_b) diff --git a/je_web_runner/utils/pseudo_localization/__init__.py b/je_web_runner/utils/pseudo_localization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/pseudo_localization/pseudo.py b/je_web_runner/utils/pseudo_localization/pseudo.py new file mode 100644 index 0000000..2ecfe64 --- /dev/null +++ b/je_web_runner/utils/pseudo_localization/pseudo.py @@ -0,0 +1,199 @@ +""" +字串 → pseudo-localised 變體,專抓 hard-coded text、截斷、RTL bug。 +Classic trick: translate "Sign in" → "[!! Šîgñ în ──]". A real engineer +glancing at the page can tell which strings *weren't* translated (still +ASCII) and which UI elements truncate when text grows ~40%. + +Three independent transforms (toggleable): + +* **accent_map** — ASCII letters → look-alike Unicode (still readable). +* **expansion** — pad the string to simulate longer translations. +* **bracket** — wrap with markers to make untranslated leakage obvious. + +Plus a tiny scanner that diffs original vs pseudo and flags strings that +came back unchanged (= probably hard-coded, not from i18n catalogue). +""" +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Dict, Iterable, List, Mapping, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class PseudoLocalizationError(WebRunnerException): + """Raised on invalid config or string input.""" + + +_ACCENT_MAP: Mapping[str, str] = { + "a": "ä", "b": "ƀ", "c": "ç", "d": "ð", "e": "é", "f": "ƒ", "g": "ğ", + "h": "ĥ", "i": "î", "j": "ĵ", "k": "ķ", "l": "ł", "m": "ɱ", "n": "ñ", + "o": "ö", "p": "þ", "q": "ǫ", "r": "ŕ", "s": "š", "t": "ţ", "u": "ü", + "v": "ṽ", "w": "ŵ", "x": "ẋ", "y": "ÿ", "z": "ž", + "A": "Ä", "B": "Ɓ", "C": "Ç", "D": "Ð", "E": "É", "F": "Ƒ", "G": "Ğ", + "H": "Ĥ", "I": "Î", "J": "Ĵ", "K": "Ķ", "L": "Ł", "M": "Ṁ", "N": "Ñ", + "O": "Ö", "P": "Þ", "Q": "Ǫ", "R": "Ŕ", "S": "Š", "T": "Ţ", "U": "Ü", + "V": "Ṽ", "W": "Ŵ", "X": "Ẋ", "Y": "Ÿ", "Z": "Ž", +} + +_PADDING_CHAR = "─" + +_PLACEHOLDER_RE = re.compile( + r"\{[^}]+\}" # {name}, {0} + r"|%(?:\([^)]+\))?[diouxXeEfFgGcrs%]" # printf-style + r"|%[a-zA-Z]" # %d, %s + r"|<[^>]+>" # ,
+) + + +# ---------- transforms -------------------------------------------------- + +@dataclass +class PseudoConfig: + """Knobs for :func:`pseudo_localize`.""" + + accent: bool = True + expansion_ratio: float = 0.4 # 40% padding + bracket: bool = True + left_marker: str = "⟦" + right_marker: str = "⟧" + preserve_placeholders: bool = True + + def __post_init__(self) -> None: + if self.expansion_ratio < 0: + raise PseudoLocalizationError("expansion_ratio must be >= 0") + if not isinstance(self.left_marker, str) or not isinstance(self.right_marker, str): + raise PseudoLocalizationError("markers must be strings") + + +def _accent(text: str) -> str: + return "".join(_ACCENT_MAP.get(ch, ch) for ch in text) + + +def _pad_to_ratio(text: str, ratio: float) -> str: + if ratio <= 0 or not text: + return text + add = max(1, int(round(len(text) * ratio))) + left = add // 2 + right = add - left + return f"{_PADDING_CHAR * left} {text} {_PADDING_CHAR * right}" + + +def _split_around_placeholders(text: str) -> List[tuple]: + """Return list of (segment, is_placeholder) tuples preserving order.""" + parts: List[tuple] = [] + last = 0 + for match in _PLACEHOLDER_RE.finditer(text): + if match.start() > last: + parts.append((text[last:match.start()], False)) + parts.append((match.group(0), True)) + last = match.end() + if last < len(text): + parts.append((text[last:], False)) + if not parts: + parts.append((text, False)) + return parts + + +def pseudo_localize( + text: str, + config: Optional[PseudoConfig] = None, +) -> str: + """Return a pseudo-localised version of ``text``.""" + if not isinstance(text, str): + raise PseudoLocalizationError( + f"pseudo_localize expects str, got {type(text).__name__}" + ) + cfg = config or PseudoConfig() + if not text: + return text + if cfg.preserve_placeholders: + chunks: List[str] = [] + for segment, is_placeholder in _split_around_placeholders(text): + if is_placeholder: + chunks.append(segment) + else: + chunks.append(_accent(segment) if cfg.accent else segment) + accented = "".join(chunks) + else: + accented = _accent(text) if cfg.accent else text + padded = _pad_to_ratio(accented, cfg.expansion_ratio) + if cfg.bracket: + return f"{cfg.left_marker}{padded}{cfg.right_marker}" + return padded + + +# ---------- bulk + JSON dict translation -------------------------------- + +def pseudo_localize_dict( + catalogue: Mapping[str, str], + config: Optional[PseudoConfig] = None, +) -> Dict[str, str]: + """Apply :func:`pseudo_localize` to every value in a {key: string} map.""" + if not isinstance(catalogue, Mapping): + raise PseudoLocalizationError("catalogue must be a mapping") + out: Dict[str, str] = {} + for key, value in catalogue.items(): + if not isinstance(value, str): + raise PseudoLocalizationError( + f"catalogue value for {key!r} must be str, got {type(value).__name__}" + ) + out[key] = pseudo_localize(value, config) + return out + + +# ---------- hard-coded string scanner ----------------------------------- + +@dataclass +class HardcodedHit: + """One string that appeared verbatim in rendered output despite being pseudo'd.""" + + string: str + occurrences: int = 1 + + +@dataclass +class PseudoAuditReport: + """Roll-up of :func:`scan_for_hardcoded`.""" + + rendered_chars: int = 0 + hits: List[HardcodedHit] = field(default_factory=list) + + def passed(self) -> bool: + return not self.hits + + +def scan_for_hardcoded( + rendered_text: str, + *, + catalogue: Mapping[str, str], + min_length: int = 3, +) -> PseudoAuditReport: + """ + Look for any catalogue value that still appears verbatim (i.e. + untranslated) in ``rendered_text``. Strings shorter than + ``min_length`` are ignored to cut noise from single letters / + punctuation. + """ + if not isinstance(rendered_text, str): + raise PseudoLocalizationError("rendered_text must be str") + if min_length < 1: + raise PseudoLocalizationError("min_length must be >= 1") + report = PseudoAuditReport(rendered_chars=len(rendered_text)) + seen: Dict[str, int] = {} + for value in catalogue.values(): + if not isinstance(value, str) or len(value) < min_length: + continue + if not _contains_ascii_letters(value): + continue + count = rendered_text.count(value) + if count > 0: + seen[value] = seen.get(value, 0) + count + for string, occurrences in sorted(seen.items(), key=lambda kv: -kv[1]): + report.hits.append(HardcodedHit(string=string, occurrences=occurrences)) + return report + + +def _contains_ascii_letters(value: str) -> bool: + return any(ch.isascii() and ch.isalpha() for ch in value) diff --git a/je_web_runner/utils/quarantine_age_report/__init__.py b/je_web_runner/utils/quarantine_age_report/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/quarantine_age_report/report.py b/je_web_runner/utils/quarantine_age_report/report.py new file mode 100644 index 0000000..023835e --- /dev/null +++ b/je_web_runner/utils/quarantine_age_report/report.py @@ -0,0 +1,204 @@ +""" +Quarantine 條目加上 age + 自動 escalation tier。 +After ``flake_detector`` puts a test in quarantine the *real* danger is +it sits there forever. This module reads the quarantine registry, +computes how long each entry has been parked, and assigns an escalation +tier so dashboards can highlight "this has been quarantined 90+ days, +delete or fix it". +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from enum import Enum +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class QuarantineAgeReportError(WebRunnerException): + """Raised on malformed registry / inputs.""" + + +class EscalationTier(str, Enum): + """How urgently each quarantined test needs attention.""" + + FRESH = "fresh" # < 7 days + LINGERING = "lingering" # 7..30 days + STALE = "stale" # 30..90 days + ABANDONED = "abandoned" # >= 90 days + + +_TIER_THRESHOLDS = ( + (7, EscalationTier.FRESH), + (30, EscalationTier.LINGERING), + (90, EscalationTier.STALE), +) + + +@dataclass +class AgedEntry: + """One quarantine entry + escalation metadata.""" + + test_id: str + reason: str + flake_score: float + quarantined_at: str + age_days: float + tier: EscalationTier + triage_url: Optional[str] = None + runs_when_added: int = 0 + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "tier": self.tier.value} + + +@dataclass +class AgeReport: + """Roll-up over a whole quarantine registry.""" + + total_entries: int = 0 + by_tier: Dict[str, int] = field(default_factory=dict) + entries: List[AgedEntry] = field(default_factory=list) + abandoned: List[str] = field(default_factory=list) + + +def _parse_iso(value: str) -> datetime: + if not isinstance(value, str) or not value: + raise QuarantineAgeReportError("timestamp must be a non-empty string") + text = value.strip() + if text.endswith("Z"): + text = text[:-1] + "+00:00" + try: + dt = datetime.fromisoformat(text) + except ValueError as error: + raise QuarantineAgeReportError(f"bad timestamp {value!r}: {error}") from error + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + +def _tier_for(age_days: float) -> EscalationTier: + for threshold, tier in _TIER_THRESHOLDS: + if age_days < threshold: + return tier + return EscalationTier.ABANDONED + + +def _load_registry(path: Union[str, Path]) -> List[Dict[str, Any]]: + p = Path(path) + if not p.exists(): + raise QuarantineAgeReportError(f"registry not found: {p}") + try: + data = json.loads(p.read_text(encoding="utf-8")) + except ValueError as error: + raise QuarantineAgeReportError(f"registry not JSON: {error}") from error + if not isinstance(data, dict) or "entries" not in data: + raise QuarantineAgeReportError("registry missing 'entries' key") + entries = data["entries"] + if not isinstance(entries, list): + raise QuarantineAgeReportError("registry 'entries' must be a list") + return entries + + +def age_entries( + entries: Sequence[Dict[str, Any]], + *, + now: Optional[datetime] = None, +) -> List[AgedEntry]: + """Convert raw registry rows into typed entries with age + tier.""" + moment = now if now is not None else datetime.now(tz=timezone.utc) + if moment.tzinfo is None: + raise QuarantineAgeReportError("now must be tz-aware") + out: List[AgedEntry] = [] + for raw in entries: + if not isinstance(raw, dict): + continue + test_id = raw.get("test_id") + if not isinstance(test_id, str) or not test_id: + continue + timestamp = raw.get("quarantined_at") + if not isinstance(timestamp, str) or not timestamp: + continue + added = _parse_iso(timestamp) + age_days = max(0.0, (moment - added).total_seconds() / 86400.0) + out.append(AgedEntry( + test_id=test_id, + reason=str(raw.get("reason") or ""), + flake_score=float(raw.get("flake_score") or 0.0), + quarantined_at=timestamp, + age_days=round(age_days, 2), + tier=_tier_for(age_days), + triage_url=raw.get("triage_url"), + runs_when_added=int(raw.get("runs_when_added") or 0), + )) + return out + + +def build_report(entries: Iterable[AgedEntry]) -> AgeReport: + """Aggregate a list of aged entries into a :class:`AgeReport`.""" + report = AgeReport() + for entry in entries: + if not isinstance(entry, AgedEntry): + raise QuarantineAgeReportError( + f"expects AgedEntry, got {type(entry).__name__}" + ) + report.total_entries += 1 + tier = entry.tier.value + report.by_tier[tier] = report.by_tier.get(tier, 0) + 1 + report.entries.append(entry) + if entry.tier == EscalationTier.ABANDONED: + report.abandoned.append(entry.test_id) + return report + + +def load_and_age( + registry_path: Union[str, Path], + *, + now: Optional[datetime] = None, +) -> AgeReport: + """One-shot: load JSON registry, age every row, build report.""" + return build_report(age_entries(_load_registry(registry_path), now=now)) + + +# ---------- formatting ------------------------------------------------ + +def report_markdown(report: AgeReport, *, top_n: int = 10) -> str: + """Render a small markdown summary suitable for dashboards / PR comments.""" + if not isinstance(report, AgeReport): + raise QuarantineAgeReportError("expects AgeReport") + if top_n < 0: + raise QuarantineAgeReportError("top_n must be >= 0") + lines = [ + f"### Quarantine age report ({report.total_entries} entries)", + "", + ] + if report.by_tier: + lines.append("| Tier | Count |") + lines.append("|------|-------|") + for tier in EscalationTier: + count = report.by_tier.get(tier.value, 0) + lines.append(f"| {tier.value} | {count} |") + if report.abandoned: + lines.append("") + lines.append("**Abandoned (90+ days):**") + for tid in report.abandoned[:top_n]: + lines.append(f"- `{tid}`") + if len(report.abandoned) > top_n: + lines.append(f"- _+{len(report.abandoned) - top_n} more_") + return "\n".join(lines) + "\n" + + +def assert_no_abandoned(report: AgeReport) -> None: + """Raise if any test has been quarantined past the ABANDONED threshold.""" + if not isinstance(report, AgeReport): + raise QuarantineAgeReportError("expects AgeReport") + if not report.abandoned: + return + sample = ", ".join(report.abandoned[:5]) + more = "" if len(report.abandoned) <= 5 else f" (+{len(report.abandoned) - 5})" + raise QuarantineAgeReportError( + f"{len(report.abandoned)} abandoned quarantine entries: {sample}{more}" + ) diff --git a/je_web_runner/utils/repro_minimizer/__init__.py b/je_web_runner/utils/repro_minimizer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/repro_minimizer/minimizer.py b/je_web_runner/utils/repro_minimizer/minimizer.py new file mode 100644 index 0000000..83ef315 --- /dev/null +++ b/je_web_runner/utils/repro_minimizer/minimizer.py @@ -0,0 +1,167 @@ +""" +把失敗的 action list 縮到最小可重現 — delta-debugging (ddmin) 演算法。 +Given a list of N actions that fails, ddmin finds the smallest +subsequence that *still* fails by partitioning + binary-style elimination. +For E2E tests this is enormously useful: a 60-action recorder dump +shrinks to 4 actions that reproduce the bug. + +The runner callable is supplied by the caller (it knows how to execute +WR action JSON). It should return ``True`` when the test *passes* and +``False`` when it fails — minimizer is hunting for the failure-preserving +minimal subset. +""" +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any, Callable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class ReproMinimizerError(WebRunnerException): + """Raised on bad inputs or runner failure.""" + + +# ---------- result model ----------------------------------------------- + +@dataclass +class MinimizationResult: + """Outcome returned by :func:`minimize`.""" + + original_size: int + minimized_actions: List[Any] + minimized_size: int + iterations: int = 0 + eval_count: int = 0 + duration_seconds: float = 0.0 + + @property + def reduction_pct(self) -> float: + if self.original_size <= 0: + return 0.0 + return (1.0 - self.minimized_size / self.original_size) * 100.0 + + +# ---------- runner protocol -------------------------------------------- + +# Runner returns True if the (sub)sequence still *passes* the test +# (i.e. doesn't reproduce the failure), False if it still fails. +ActionRunner = Callable[[List[Any]], bool] + + +# ---------- ddmin ------------------------------------------------------- + +def minimize( + actions: Sequence[Any], + runner: ActionRunner, + *, + max_iterations: int = 200, + verify_failing: bool = True, +) -> MinimizationResult: + """ + Classic ddmin. Returns the smallest contiguous-or-not subsequence of + ``actions`` that still causes ``runner`` to return ``False``. + """ + if not isinstance(actions, (list, tuple)): + raise ReproMinimizerError( + f"actions must be list/tuple, got {type(actions).__name__}" + ) + if not actions: + raise ReproMinimizerError("actions must be non-empty") + if not callable(runner): + raise ReproMinimizerError("runner must be callable") + if max_iterations <= 0: + raise ReproMinimizerError("max_iterations must be > 0") + + counter = {"evals": 0} + + def _evaluate(subset: List[Any]) -> bool: + counter["evals"] += 1 + try: + return bool(runner(subset)) + except Exception as error: + raise ReproMinimizerError( + f"runner raised at size {len(subset)}: {error!r}" + ) from error + + full = list(actions) + if verify_failing: + if _evaluate(full): + raise ReproMinimizerError( + "runner says the original action list PASSES; nothing to minimize" + ) + + started = time.monotonic() + current = full + n = 2 + iterations = 0 + while len(current) >= 2 and iterations < max_iterations: + iterations += 1 + chunk_size = max(1, len(current) // n) + chunks = [current[i:i + chunk_size] + for i in range(0, len(current), chunk_size)] + # Try removing complement of each chunk first (granularity = n). + reduced = False + for index, chunk in enumerate(chunks): + complement = [ + a for j, c in enumerate(chunks) if j != index for a in c + ] + if not complement: + continue + if not _evaluate(complement): + current = complement + n = max(n - 1, 2) + reduced = True + break + if not reduced: + if n >= len(current): + break + n = min(n * 2, len(current)) + duration = round(time.monotonic() - started, 4) + if iterations >= max_iterations: + web_runner_logger.warning( + f"repro_minimizer hit max_iterations={max_iterations}; " + "result may not be locally minimal" + ) + return MinimizationResult( + original_size=len(full), + minimized_actions=current, + minimized_size=len(current), + iterations=iterations, + eval_count=counter["evals"], + duration_seconds=duration, + ) + + +# ---------- helpers ---------------------------------------------------- + +def assert_minimized( + result: MinimizationResult, + *, + max_remaining: int, +) -> None: + """Assert ``minimized_size <= max_remaining``.""" + if not isinstance(result, MinimizationResult): + raise ReproMinimizerError("assert_minimized expects MinimizationResult") + if max_remaining < 0: + raise ReproMinimizerError("max_remaining must be >= 0") + if result.minimized_size > max_remaining: + raise ReproMinimizerError( + f"minimized to {result.minimized_size} actions, " + f"wanted <= {max_remaining}" + ) + + +def report_markdown(result: MinimizationResult) -> str: + """Render a small markdown summary.""" + if not isinstance(result, MinimizationResult): + raise ReproMinimizerError("report_markdown expects MinimizationResult") + return ( + f"### Minimal repro: {result.minimized_size} / {result.original_size} " + f"actions ({result.reduction_pct:.0f}% reduction)\n\n" + f"- iterations: {result.iterations}\n" + f"- runner evaluations: {result.eval_count}\n" + f"- duration: {result.duration_seconds:.2f}s\n" + ) diff --git a/je_web_runner/utils/screen_reader_runner/__init__.py b/je_web_runner/utils/screen_reader_runner/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/screen_reader_runner/reader.py b/je_web_runner/utils/screen_reader_runner/reader.py new file mode 100644 index 0000000..180fbf0 --- /dev/null +++ b/je_web_runner/utils/screen_reader_runner/reader.py @@ -0,0 +1,269 @@ +""" +從 accessibility tree 模擬 NVDA / VoiceOver 朗讀順序與斷句。 +Real-screen-reader testing on CI is fragile (audio capture, OS quirks); +this module skips the audio loop entirely and walks the accessibility +tree to reproduce *what* a screen reader would say and *in what order*. + +Two outputs: + +* **utterances** — the sequence of strings a SR would announce as you + press Tab / arrow-down through the page. +* **violations** — common a11y red flags that surface during the walk + (interactive element with no accessible name, heading skip, image + without alt, link text "click here"). + +Driver-agnostic: feed in a JSON accessibility tree (CDP's +``Accessibility.getFullAXTree``, Playwright's ``page.accessibility.snapshot()``, +or Selenium's WebDriver BiDi `browsingContext.captureAccessibilityTree`). +""" +from __future__ import annotations + +import re +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class ScreenReaderError(WebRunnerException): + """Raised on malformed accessibility tree input.""" + + +# Roles SRs typically announce / focus on. Anything else is treated as +# decorative / structural. +_INTERACTIVE_ROLES = { + "button", "link", "checkbox", "radio", "textbox", "combobox", + "menuitem", "tab", "switch", "spinbutton", "slider", "searchbox", +} + +_GROUPING_ROLES = { + "heading", "list", "listitem", "navigation", "main", "banner", + "region", "complementary", "contentinfo", "form", "article", +} + +# Phrases banned in link names (well-known a11y anti-patterns) +_BANNED_LINK_TEXT = ("click here", "here", "more", "read more", "link") + + +# ---------- enums ------------------------------------------------------- + +class ViolationKind(str, Enum): + """Categories of a11y issues this module surfaces.""" + + UNNAMED_INTERACTIVE = "unnamed_interactive" + HEADING_SKIP = "heading_skip" + MISSING_ALT = "missing_alt" + GENERIC_LINK_TEXT = "generic_link_text" + EMPTY_BUTTON = "empty_button" + + +# ---------- data -------------------------------------------------------- + +@dataclass +class Utterance: + """One thing a screen reader would speak.""" + + text: str + role: str + node_index: int + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + +@dataclass +class Violation: + """One a11y violation discovered during the walk.""" + + kind: ViolationKind + role: str + node_index: int + detail: str = "" + + def to_dict(self) -> Dict[str, Any]: + return {**asdict(self), "kind": self.kind.value} + + +@dataclass +class ScreenReaderTranscript: + """Result of :func:`walk_tree`.""" + + utterances: List[Utterance] = field(default_factory=list) + violations: List[Violation] = field(default_factory=list) + + def speech(self) -> str: + """Joined transcript as a single string.""" + return " ".join(u.text for u in self.utterances if u.text) + + def passed(self) -> bool: + return not self.violations + + +# ---------- walker ------------------------------------------------------ + +def walk_tree( + root: Dict[str, Any], + *, + include_decorative: bool = False, +) -> ScreenReaderTranscript: + """ + Walk an accessibility tree (Playwright snapshot or CDP-shaped) and + return the SR transcript. Schema: nodes have ``role`` and ``name`` + strings plus an optional ``children`` list. + """ + if not isinstance(root, dict): + raise ScreenReaderError( + f"walk_tree expects dict, got {type(root).__name__}" + ) + transcript = ScreenReaderTranscript() + state = {"index": 0, "last_heading_level": 0} + _walk_node(root, transcript, state, include_decorative=include_decorative) + return transcript + + +def _walk_node( + node: Dict[str, Any], + transcript: ScreenReaderTranscript, + state: Dict[str, int], + *, + include_decorative: bool, +) -> None: + role = str(node.get("role") or "").lower() + name = str(node.get("name") or "").strip() + index = state["index"] + state["index"] += 1 + + if role == "heading": + level = int(node.get("level") or 1) + _check_heading_level(level, index, state, transcript) + if name: + transcript.utterances.append(Utterance( + text=f"heading level {level}: {name}", + role=role, node_index=index, + )) + state["last_heading_level"] = level + elif role == "image": + alt = name or str(node.get("description") or "").strip() + if not alt and not _is_decorative(node): + transcript.violations.append(Violation( + kind=ViolationKind.MISSING_ALT, + role=role, node_index=index, + )) + elif alt: + transcript.utterances.append(Utterance( + text=f"image: {alt}", role=role, node_index=index, + )) + elif role in _INTERACTIVE_ROLES: + _emit_interactive(node, role, name, index, transcript) + elif role in _GROUPING_ROLES: + if name: + transcript.utterances.append(Utterance( + text=f"{role}: {name}", role=role, node_index=index, + )) + elif name and (include_decorative or role in {"text", "static_text", "statictext", ""}): + if name: + transcript.utterances.append(Utterance( + text=name, role=role or "text", node_index=index, + )) + + children = node.get("children") or [] + if not isinstance(children, list): + return + for child in children: + if isinstance(child, dict): + _walk_node(child, transcript, state, + include_decorative=include_decorative) + + +def _emit_interactive( + node: Dict[str, Any], + role: str, + name: str, + index: int, + transcript: ScreenReaderTranscript, +) -> None: + if not name: + transcript.violations.append(Violation( + kind=ViolationKind.UNNAMED_INTERACTIVE, + role=role, node_index=index, + detail=f"{role} has no accessible name", + )) + if role == "button": + transcript.violations.append(Violation( + kind=ViolationKind.EMPTY_BUTTON, + role=role, node_index=index, + )) + return + if role == "link" and _is_generic_link(name): + transcript.violations.append(Violation( + kind=ViolationKind.GENERIC_LINK_TEXT, + role=role, node_index=index, + detail=f"link text {name!r} is non-descriptive", + )) + transcript.utterances.append(Utterance( + text=f"{role}: {name}", role=role, node_index=index, + )) + + +def _is_generic_link(name: str) -> bool: + cleaned = re.sub(r"[\W_]+", " ", name).strip().lower() + return cleaned in _BANNED_LINK_TEXT + + +def _is_decorative(node: Dict[str, Any]) -> bool: + if node.get("decorative") is True: + return True + properties = node.get("properties") or {} + if isinstance(properties, dict): + if properties.get("presentational") is True: + return True + if properties.get("role") == "presentation": + return True + return False + + +def _check_heading_level( + level: int, + index: int, + state: Dict[str, int], + transcript: ScreenReaderTranscript, +) -> None: + last = state["last_heading_level"] + if last and level > last + 1: + transcript.violations.append(Violation( + kind=ViolationKind.HEADING_SKIP, + role="heading", + node_index=index, + detail=f"jumped from h{last} to h{level}", + )) + + +# ---------- assertion helpers ------------------------------------------- + +def assert_no_violations(transcript: ScreenReaderTranscript) -> None: + """Raise unless the transcript has zero violations.""" + if not isinstance(transcript, ScreenReaderTranscript): + raise ScreenReaderError("assert_no_violations expects ScreenReaderTranscript") + if transcript.passed(): + return + parts = ", ".join( + f"{v.kind.value}@{v.node_index}" for v in transcript.violations[:5] + ) + more = "" if len(transcript.violations) <= 5 else f" (+{len(transcript.violations) - 5})" + raise ScreenReaderError(f"a11y violations: {parts}{more}") + + +def assert_reads( + transcript: ScreenReaderTranscript, + expected_phrase: str, +) -> Utterance: + """Raise unless ``expected_phrase`` appears in any utterance.""" + if not isinstance(expected_phrase, str) or not expected_phrase: + raise ScreenReaderError("expected_phrase must be a non-empty string") + for utterance in transcript.utterances: + if expected_phrase.lower() in utterance.text.lower(): + return utterance + raise ScreenReaderError( + f"SR transcript never contained {expected_phrase!r}" + ) diff --git a/je_web_runner/utils/session_to_test/__init__.py b/je_web_runner/utils/session_to_test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/session_to_test/converter.py b/je_web_runner/utils/session_to_test/converter.py new file mode 100644 index 0000000..b6b8262 --- /dev/null +++ b/je_web_runner/utils/session_to_test/converter.py @@ -0,0 +1,304 @@ +""" +把 rrweb / 通用 session event 串流轉成 WR action JSON。 +Production session replays (rrweb, Pendo, Hotjar) carry rich event +streams. This converter normalises them into the WebRunner action JSON +that ``executor`` already understands, so a real user flow becomes a +reproducible test. + +Supports two input shapes: + +* **rrweb events** — the public ``[{type, data, timestamp}]`` shape from + ``rrweb.record``. We handle ``IncrementalSnapshot`` mouse / input / scroll + events plus the top-level page-load metadata. +* **Generic events** — provider-agnostic ``{kind, target, value?, url?, + timestamp}`` dicts. This is the format you'd produce when scraping + Pendo/Hotjar/custom telemetry. + +The converter is deliberately conservative: events it cannot map cleanly +become ``WR_comment`` action lines instead of being silently dropped, so +the engineer reviewing the output sees what was skipped. +""" +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException +from je_web_runner.utils.logging.loggin_instance import web_runner_logger + + +class SessionToTestError(WebRunnerException): + """Raised on unreadable input, malformed events, or empty conversions.""" + + +# rrweb event type ids +_RRWEB_FULL_SNAPSHOT = 2 +_RRWEB_INCREMENTAL = 3 +_RRWEB_META = 4 + +# rrweb incremental-source ids +_RRWEB_SRC_MOUSE_INTERACTION = 2 +_RRWEB_SRC_SCROLL = 3 +_RRWEB_SRC_INPUT = 5 + +# rrweb mouse-interaction kinds +_RRWEB_MI_CLICK = 2 +_RRWEB_MI_DOUBLE_CLICK = 4 + +# pragmatic threshold: drop micro-mouse-moves rarer than this many ms apart +_MIN_INTER_EVENT_MS = 50 + + +# ---------- public model ------------------------------------------------ + +@dataclass +class ConversionStats: + """Roll-up returned by :func:`convert_events`.""" + + input_events: int = 0 + actions_emitted: int = 0 + skipped_events: int = 0 + comment_actions: int = 0 + reasons: Dict[str, int] = field(default_factory=dict) + + def note_skip(self, reason: str) -> None: + self.skipped_events += 1 + self.reasons[reason] = self.reasons.get(reason, 0) + 1 + + +@dataclass +class ConversionResult: + """Output of :func:`convert_events`: actions plus stats.""" + + actions: List[Dict[str, Any]] + stats: ConversionStats + + +# ---------- entry points ------------------------------------------------ + +def convert_rrweb_events(events: Sequence[Dict[str, Any]]) -> ConversionResult: + """Convert an rrweb event list into WR action JSON.""" + if not isinstance(events, list): + raise SessionToTestError("rrweb events must be a list") + stats = ConversionStats(input_events=len(events)) + actions: List[Dict[str, Any]] = [] + last_ts: Optional[int] = None + + for event in events: + if not isinstance(event, dict): + stats.note_skip("non-dict event") + continue + kind = event.get("type") + timestamp = event.get("timestamp") + if kind == _RRWEB_META and isinstance(event.get("data"), dict): + url = event["data"].get("href") + if isinstance(url, str) and url: + actions.append({"WR_to_url": [url]}) + stats.actions_emitted += 1 + else: + stats.note_skip("meta without href") + last_ts = timestamp + continue + if kind == _RRWEB_FULL_SNAPSHOT: + stats.note_skip("full snapshot (no action)") + last_ts = timestamp + continue + if kind != _RRWEB_INCREMENTAL: + stats.note_skip(f"unknown rrweb type {kind!r}") + continue + + if last_ts is not None and isinstance(timestamp, (int, float)): + if timestamp - last_ts < _MIN_INTER_EVENT_MS: + # Don't drop, but don't generate per-event waits either. + last_ts = timestamp + else: + last_ts = timestamp + + emitted = _convert_rrweb_incremental(event, stats) + if emitted is not None: + actions.append(emitted) + stats.actions_emitted += 1 + + if not actions: + raise SessionToTestError( + f"no actions produced from {len(events)} rrweb events; " + "input may be unsupported or empty" + ) + return ConversionResult(actions=actions, stats=stats) + + +def convert_generic_events(events: Sequence[Dict[str, Any]]) -> ConversionResult: + """Convert a provider-agnostic event list into WR action JSON.""" + if not isinstance(events, list): + raise SessionToTestError("generic events must be a list") + stats = ConversionStats(input_events=len(events)) + actions: List[Dict[str, Any]] = [] + for event in events: + if not isinstance(event, dict): + stats.note_skip("non-dict event") + continue + emitted = _convert_generic_event(event, stats) + if emitted is not None: + actions.append(emitted) + stats.actions_emitted += 1 + if not actions: + raise SessionToTestError( + f"no actions produced from {len(events)} generic events" + ) + return ConversionResult(actions=actions, stats=stats) + + +def convert_events(payload: Union[str, Path, Sequence[Dict[str, Any]]]) -> ConversionResult: + """ + Sniff the input: file → list / list → list. rrweb vs generic is + detected by the presence of an integer ``type`` field on the events. + """ + events = _load_events(payload) + if not events: + raise SessionToTestError("event list is empty") + if isinstance(events[0], dict) and isinstance(events[0].get("type"), int): + return convert_rrweb_events(events) + return convert_generic_events(events) + + +def _load_events(payload: Union[str, Path, Sequence[Dict[str, Any]]]) -> List[Dict[str, Any]]: + if isinstance(payload, (list, tuple)): + return list(payload) + if isinstance(payload, (str, Path)): + path = Path(payload) + if not path.exists(): + raise SessionToTestError(f"events file not found: {path}") + try: + data = json.loads(path.read_text(encoding="utf-8")) + except ValueError as error: + raise SessionToTestError(f"events file is not JSON: {error}") from error + if isinstance(data, dict) and "events" in data: + data = data["events"] + if not isinstance(data, list): + raise SessionToTestError("events file did not contain a list") + return data + raise SessionToTestError( + f"convert_events expects path or list, got {type(payload).__name__}" + ) + + +# ---------- rrweb mappings ---------------------------------------------- + +def _convert_rrweb_incremental( + event: Dict[str, Any], + stats: ConversionStats, +) -> Optional[Dict[str, Any]]: + data = event.get("data") if isinstance(event.get("data"), dict) else None + if not data: + stats.note_skip("incremental without data") + return None + source = data.get("source") + if source == _RRWEB_SRC_MOUSE_INTERACTION: + kind = data.get("type") + selector = _selector_for_node(data.get("id")) + if selector is None: + stats.note_skip("mouse without node id") + return None + if kind == _RRWEB_MI_CLICK: + return {"WR_click_element": ["css selector", selector]} + if kind == _RRWEB_MI_DOUBLE_CLICK: + return {"WR_double_click_element": ["css selector", selector]} + stats.note_skip(f"mouse type {kind!r}") + return None + if source == _RRWEB_SRC_INPUT: + selector = _selector_for_node(data.get("id")) + value = data.get("text", "") + if selector is None: + stats.note_skip("input without node id") + return None + return {"WR_input_to_element": ["css selector", selector, str(value)]} + if source == _RRWEB_SRC_SCROLL: + x = data.get("x", 0) + y = data.get("y", 0) + return {"WR_comment": [f"scroll to {x},{y}"]} + stats.note_skip(f"incremental source {source!r}") + return None + + +def _selector_for_node(node_id: Any) -> Optional[str]: + """ + rrweb identifies nodes by integer id from its DOM mirror. Without the + full snapshot we can't recover a stable CSS path, so we emit a custom + attribute selector that the test harness can rewrite later. + """ + if not isinstance(node_id, int) or node_id < 0: + return None + return f'[data-rrweb-id="{node_id}"]' + + +# ---------- generic mappings -------------------------------------------- + +def _convert_generic_event( + event: Dict[str, Any], + stats: ConversionStats, +) -> Optional[Dict[str, Any]]: + kind = str(event.get("kind") or "").lower() + target = event.get("target") + locator = _coerce_locator(target) + if kind == "navigate": + url = event.get("url") + if isinstance(url, str) and url: + return {"WR_to_url": [url]} + stats.note_skip("navigate without url") + return None + if kind == "click": + if locator is None: + stats.note_skip("click without target") + return None + return {"WR_click_element": list(locator)} + if kind == "input": + if locator is None: + stats.note_skip("input without target") + return None + value = event.get("value", "") + return {"WR_input_to_element": [*locator, str(value)]} + if kind == "submit": + if locator is None: + return {"WR_comment": ["submit form (no target)"]} + return {"WR_submit_element": list(locator)} + if kind == "wait": + try: + seconds = float(event.get("seconds", 0)) + except (TypeError, ValueError): + stats.note_skip("wait with non-numeric seconds") + return None + return {"WR_implicitly_wait": [seconds]} + stats.note_skip(f"unknown generic kind {kind!r}") + return None + + +def _coerce_locator(target: Any) -> Optional[Tuple[str, str]]: + if isinstance(target, dict): + by = target.get("by") or "css selector" + value = target.get("value") + if isinstance(value, str) and value: + return str(by), value + return None + if isinstance(target, str) and target: + return "css selector", target + return None + + +# ---------- output helpers ---------------------------------------------- + +def write_actions_json( + result: ConversionResult, + output_path: Union[str, Path], +) -> Path: + """Persist the converted actions list to disk in WR action format.""" + path = Path(output_path) + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as fp: + json.dump(result.actions, fp, ensure_ascii=False, indent=2) + web_runner_logger.info( + f"session_to_test: wrote {result.stats.actions_emitted} actions to {path} " + f"(skipped {result.stats.skipped_events})" + ) + return path diff --git a/je_web_runner/utils/sla_tracker/__init__.py b/je_web_runner/utils/sla_tracker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/sla_tracker/tracker.py b/je_web_runner/utils/sla_tracker/tracker.py new file mode 100644 index 0000000..3a69d7e --- /dev/null +++ b/je_web_runner/utils/sla_tracker/tracker.py @@ -0,0 +1,230 @@ +""" +SLA 達成率追蹤:「Y% 的 suite 在 X 分鐘內跑完」加趨勢。 +Engineering teams set targets like "95% of CI runs finish in under 10 +minutes" but rarely have a number to point at. This module reads the +run ledger, groups by ISO week, and computes the rolling SLA-met +percentage so dashboards can show "we're at 91% this week, down from +97% two weeks ago, that's the regression". +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Union + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class SlaTrackerError(WebRunnerException): + """Raised on bad ledger / SLA inputs.""" + + +# ---------- input ------------------------------------------------------ + +@dataclass +class SuiteRun: + """One suite run.""" + + suite: str + started_at: datetime + duration_seconds: float + passed: bool + + def __post_init__(self) -> None: + if not isinstance(self.suite, str) or not self.suite: + raise SlaTrackerError("suite must be non-empty string") + if self.duration_seconds < 0: + raise SlaTrackerError("duration_seconds must be >= 0") + if not isinstance(self.started_at, datetime): + raise SlaTrackerError("started_at must be datetime") + if self.started_at.tzinfo is None: + raise SlaTrackerError("started_at must be tz-aware") + + +def _parse_iso(value: str) -> datetime: + text = value.strip() + if text.endswith("Z"): + text = text[:-1] + "+00:00" + try: + dt = datetime.fromisoformat(text) + except ValueError as error: + raise SlaTrackerError(f"bad timestamp {value!r}: {error}") from error + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt + + +def load_runs(path: Union[str, Path]) -> List[SuiteRun]: + """Read a ledger JSON file. Skips rows missing the fields we need.""" + p = Path(path) + if not p.exists(): + raise SlaTrackerError(f"ledger not found: {p}") + try: + data = json.loads(p.read_text(encoding="utf-8")) + except ValueError as error: + raise SlaTrackerError(f"ledger not JSON: {error}") from error + if not isinstance(data, dict) or "runs" not in data: + raise SlaTrackerError("ledger missing 'runs' key") + out: List[SuiteRun] = [] + for raw in data["runs"]: + if not isinstance(raw, dict): + continue + suite = raw.get("suite") or raw.get("path") + timestamp = raw.get("time") or raw.get("started_at") + duration = raw.get("duration_seconds") + if not isinstance(suite, str) or duration is None or not isinstance(timestamp, str): + continue + try: + run = SuiteRun( + suite=suite, + started_at=_parse_iso(timestamp), + duration_seconds=float(duration), + passed=bool(raw.get("passed", True)), + ) + except SlaTrackerError: + continue + out.append(run) + return out + + +# ---------- SLA model -------------------------------------------------- + +@dataclass(frozen=True) +class SlaTarget: + """Definition of the SLA.""" + + max_duration_seconds: float + target_pass_pct: float + + def __post_init__(self) -> None: + if self.max_duration_seconds <= 0: + raise SlaTrackerError("max_duration_seconds must be > 0") + if not 0 < self.target_pass_pct <= 100: + raise SlaTrackerError("target_pass_pct must be in (0, 100]") + + +@dataclass +class BucketResult: + """One bucket (week or day) of SLA stats.""" + + label: str + runs: int + met: int + pct: float + target_met: bool + + +@dataclass +class SlaReport: + """Outcome of :func:`compute_sla`.""" + + target: SlaTarget + buckets: List[BucketResult] = field(default_factory=list) + overall_pct: float = 0.0 + overall_runs: int = 0 + + def passed(self) -> bool: + return self.overall_pct >= self.target.target_pass_pct + + def to_dict(self) -> Dict[str, Any]: + return { + "target": asdict(self.target), + "buckets": [asdict(b) for b in self.buckets], + "overall_pct": self.overall_pct, + "overall_runs": self.overall_runs, + "passed": self.passed(), + } + + +# ---------- bucketing -------------------------------------------------- + +def _week_label(dt: datetime) -> str: + iso = dt.isocalendar() + return f"{iso[0]:04d}-W{iso[1]:02d}" + + +def _day_label(dt: datetime) -> str: + return dt.strftime("%Y-%m-%d") + + +def compute_sla( + runs: Sequence[SuiteRun], + target: SlaTarget, + *, + bucket: str = "week", + suite: Optional[str] = None, +) -> SlaReport: + """Group runs into buckets, compute met-percentage, aggregate.""" + if bucket not in ("week", "day"): + raise SlaTrackerError("bucket must be 'week' or 'day'") + label_fn = _week_label if bucket == "week" else _day_label + buckets_by_label: Dict[str, List[SuiteRun]] = {} + for run in runs: + if not isinstance(run, SuiteRun): + raise SlaTrackerError( + f"runs entry must be SuiteRun, got {type(run).__name__}" + ) + if suite is not None and run.suite != suite: + continue + buckets_by_label.setdefault(label_fn(run.started_at), []).append(run) + bucket_results: List[BucketResult] = [] + total_runs = 0 + total_met = 0 + for label in sorted(buckets_by_label): + runs_in_bucket = buckets_by_label[label] + met = sum(1 for r in runs_in_bucket + if r.duration_seconds <= target.max_duration_seconds) + pct = (met / len(runs_in_bucket)) * 100.0 + bucket_results.append(BucketResult( + label=label, + runs=len(runs_in_bucket), + met=met, + pct=round(pct, 2), + target_met=pct >= target.target_pass_pct, + )) + total_runs += len(runs_in_bucket) + total_met += met + overall_pct = (total_met / total_runs * 100.0) if total_runs else 0.0 + return SlaReport( + target=target, + buckets=bucket_results, + overall_pct=round(overall_pct, 2), + overall_runs=total_runs, + ) + + +# ---------- formatting ------------------------------------------------- + +def report_markdown(report: SlaReport) -> str: + """Render a small markdown table for dashboards / Slack.""" + if not isinstance(report, SlaReport): + raise SlaTrackerError("expects SlaReport") + lines = [ + f"### SLA: {report.target.target_pass_pct:.0f}% of suites in " + f"<= {report.target.max_duration_seconds:.0f}s", + "", + f"- Overall: **{report.overall_pct:.1f}%** " + f"({report.overall_runs} runs)", + "", + "| Bucket | Runs | Met | % |", + "|--------|------|-----|---|", + ] + for b in report.buckets: + mark = "✓" if b.target_met else "✗" + lines.append(f"| {b.label} | {b.runs} | {b.met} | {b.pct:.1f}% {mark} |") + return "\n".join(lines) + "\n" + + +def assert_meets_sla(report: SlaReport) -> None: + """Raise if the overall percentage is below the target.""" + if not isinstance(report, SlaReport): + raise SlaTrackerError("expects SlaReport") + if report.passed(): + return + raise SlaTrackerError( + f"SLA breach: {report.overall_pct:.1f}% < target " + f"{report.target.target_pass_pct:.1f}% " + f"({report.overall_runs} runs)" + ) diff --git a/je_web_runner/utils/slack_digest/__init__.py b/je_web_runner/utils/slack_digest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/slack_digest/digest.py b/je_web_runner/utils/slack_digest/digest.py new file mode 100644 index 0000000..cd136a9 --- /dev/null +++ b/je_web_runner/utils/slack_digest/digest.py @@ -0,0 +1,222 @@ +""" +週報 / 日報:quarantine 進出、top-risk PR、flake 趨勢、cost 變化,推到 Slack / Teams。 +A digest is just a single Slack Block-Kit payload (or a small Teams card) +that the existing :mod:`notifier` module can post. This module's job is +to *render* the digest from upstream module outputs — pure formatting, +no HTTP. + +Tested by snapshot-style assertions on the produced block list. +""" +from __future__ import annotations + +import json +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, Iterable, List, Optional, Sequence + +from je_web_runner.utils.exception.exceptions import WebRunnerException + + +class SlackDigestError(WebRunnerException): + """Raised on bad input shapes.""" + + +# ---------- inputs ------------------------------------------------------ + +@dataclass +class FlakeStat: + """One quarantine list change in the digest window.""" + + test_id: str + action: str # 'added' | 'released' | 'still_in' + flake_score: float = 0.0 + + +@dataclass +class RiskyPr: + """A high-risk PR from :mod:`pr_risk_score` (or any upstream).""" + + number: int + title: str + score: float + url: str = "" + + +@dataclass +class CostTrend: + """Period-over-period cost (USD).""" + + current_usd: float + previous_usd: float + + def delta_pct(self) -> float: + if self.previous_usd <= 0: + return 0.0 if self.current_usd <= 0 else 100.0 + return ((self.current_usd - self.previous_usd) / self.previous_usd) * 100.0 + + +@dataclass +class DigestInputs: + """Everything a digest can include. Each field is optional.""" + + period_label: str = "last 7 days" + flake_changes: List[FlakeStat] = field(default_factory=list) + risky_prs: List[RiskyPr] = field(default_factory=list) + cost: Optional[CostTrend] = None + suite_pass_rate: Optional[float] = None # 0..1 + suite_pass_rate_previous: Optional[float] = None + extra_lines: List[str] = field(default_factory=list) + + def __post_init__(self) -> None: + if self.suite_pass_rate is not None and not 0.0 <= self.suite_pass_rate <= 1.0: + raise SlackDigestError("suite_pass_rate must be in [0, 1]") + if (self.suite_pass_rate_previous is not None + and not 0.0 <= self.suite_pass_rate_previous <= 1.0): + raise SlackDigestError("suite_pass_rate_previous must be in [0, 1]") + + +# ---------- rendering -------------------------------------------------- + +def _header_block(period_label: str) -> Dict[str, Any]: + today = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d") + return { + "type": "header", + "text": {"type": "plain_text", + "text": f"Test digest — {period_label} (as of {today})"}, + } + + +def _suite_health_block(inputs: DigestInputs) -> Optional[Dict[str, Any]]: + if inputs.suite_pass_rate is None: + return None + pct = inputs.suite_pass_rate * 100 + line = f"*Suite pass rate:* {pct:.1f}%" + if inputs.suite_pass_rate_previous is not None: + delta = (inputs.suite_pass_rate - inputs.suite_pass_rate_previous) * 100 + sign = "▲" if delta >= 0 else "▼" + line += f" ({sign}{abs(delta):.1f} pts vs prev)" + return {"type": "section", "text": {"type": "mrkdwn", "text": line}} + + +def _flake_block(stats: Sequence[FlakeStat]) -> Optional[Dict[str, Any]]: + if not stats: + return None + added = [s for s in stats if s.action == "added"] + released = [s for s in stats if s.action == "released"] + still_in = [s for s in stats if s.action == "still_in"] + parts: List[str] = ["*Quarantine activity:*"] + parts.append(f"• Added: {len(added)}") + parts.append(f"• Released: {len(released)}") + parts.append(f"• Still in quarantine: {len(still_in)}") + for stat in added[:5]: + parts.append(f" • `{stat.test_id}` (score {stat.flake_score:.2f})") + if len(added) > 5: + parts.append(f" • +{len(added) - 5} more added") + return {"type": "section", "text": {"type": "mrkdwn", "text": "\n".join(parts)}} + + +def _risky_pr_block(prs: Sequence[RiskyPr]) -> Optional[Dict[str, Any]]: + if not prs: + return None + lines = ["*High-risk PRs:*"] + for pr in sorted(prs, key=lambda p: -p.score)[:5]: + url = pr.url or f"#{pr.number}" + lines.append(f"• <{url}|#{pr.number}> {pr.title} — risk {pr.score:.1f}") + return {"type": "section", "text": {"type": "mrkdwn", "text": "\n".join(lines)}} + + +def _cost_block(cost: Optional[CostTrend]) -> Optional[Dict[str, Any]]: + if cost is None: + return None + delta = cost.delta_pct() + sign = "▲" if delta >= 0 else "▼" + line = ( + f"*Estimated test cost:* ${cost.current_usd:,.2f} " + f"({sign}{abs(delta):.1f}% vs prev ${cost.previous_usd:,.2f})" + ) + return {"type": "section", "text": {"type": "mrkdwn", "text": line}} + + +def _extra_block(lines: Sequence[str]) -> Optional[Dict[str, Any]]: + if not lines: + return None + text = "\n".join(f"• {line}" for line in lines) + return {"type": "section", "text": {"type": "mrkdwn", "text": text}} + + +def build_slack_blocks(inputs: DigestInputs) -> List[Dict[str, Any]]: + """Render the digest as a Slack Block-Kit ``blocks`` list.""" + if not isinstance(inputs, DigestInputs): + raise SlackDigestError("build_slack_blocks expects DigestInputs") + candidates = [ + _header_block(inputs.period_label), + _suite_health_block(inputs), + _flake_block(inputs.flake_changes), + _risky_pr_block(inputs.risky_prs), + _cost_block(inputs.cost), + _extra_block(inputs.extra_lines), + ] + blocks = [b for b in candidates if b] + if len(blocks) == 1: # only header → nothing to report + blocks.append({ + "type": "section", + "text": {"type": "mrkdwn", + "text": "_Nothing notable to report in this period._"}, + }) + return blocks + + +def build_slack_payload( + inputs: DigestInputs, + *, + channel: Optional[str] = None, +) -> Dict[str, Any]: + """Wrap the blocks in a complete ``chat.postMessage`` payload.""" + payload: Dict[str, Any] = {"blocks": build_slack_blocks(inputs)} + if channel: + if not isinstance(channel, str): + raise SlackDigestError("channel must be a string") + payload["channel"] = channel + return payload + + +# ---------- teams card ------------------------------------------------- + +def build_teams_card(inputs: DigestInputs) -> Dict[str, Any]: + """Render a simple Adaptive Card body for Microsoft Teams webhooks.""" + blocks = build_slack_blocks(inputs) + body: List[Dict[str, Any]] = [] + for block in blocks: + text = "" + block_text = block.get("text") or {} + if isinstance(block_text, dict): + text = str(block_text.get("text") or "") + if not text: + continue + body.append({ + "type": "TextBlock", + "text": text, + "wrap": True, + "weight": "Bolder" if block.get("type") == "header" else "Default", + }) + return { + "type": "AdaptiveCard", + "$schema": "http://adaptivecards.io/schemas/adaptive-card.json", + "version": "1.5", + "body": body, + } + + +# ---------- helpers ---------------------------------------------------- + +def render_plain_text(inputs: DigestInputs) -> str: + """Render a fallback plain-text digest (email / Markdown alike).""" + blocks = build_slack_blocks(inputs) + lines: List[str] = [] + for block in blocks: + block_text = block.get("text") or {} + if isinstance(block_text, dict): + text = str(block_text.get("text") or "") + if text: + lines.append(text) + return "\n\n".join(lines) + "\n" diff --git a/je_web_runner/utils/sri_verify/__init__.py b/je_web_runner/utils/sri_verify/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/je_web_runner/utils/sri_verify/verify.py b/je_web_runner/utils/sri_verify/verify.py new file mode 100644 index 0000000..d67118d --- /dev/null +++ b/je_web_runner/utils/sri_verify/verify.py @@ -0,0 +1,240 @@ +""" +Subresource Integrity (SRI) hash 缺失偵測 + 正確性驗證。 +SRI 是防 CDN 被竄改最便宜的招式。但常見三個錯誤: + +1. 完全沒設(``integrity`` 屬性缺失) +2. 有設但 hash 過時(資源變了 hash 沒更新 → 載入失敗) +3. 設了弱算法(sha1 / md5 — 規範要求 sha256+) + +This module: + +* Parses ``', '
x
'), + [], + ) + + def test_real_divergence_flagged(self): + findings = diff_dom("
server text
", "
client text
") + self.assertEqual(len(findings), 1) + self.assertIn("diverged", findings[0].detail) + + def test_rejects_non_string(self): + with self.assertRaises(HydrationCheckError): + diff_dom(123, "
x
") # type: ignore[arg-type] + + +class TestAudit(unittest.TestCase): + + def test_clean(self): + report = audit( + server_html="
x
", client_html="
x
", + console_messages=["Some unrelated log"], + ) + self.assertTrue(report.passed()) + + def test_dom_only(self): + report = audit(server_html="
a
", client_html="
b
") + self.assertFalse(report.passed()) + self.assertEqual(report.by_kind(), {"dom_diff": 1}) + + def test_console_only(self): + report = audit(console_messages=["Hydration failed"]) + self.assertFalse(report.passed()) + self.assertEqual(report.by_kind(), {"console": 1}) + + def test_both(self): + report = audit( + server_html="
a
", client_html="
b
", + console_messages=["Hydration failed"], + ) + self.assertEqual(sum(report.by_kind().values()), 2) + + def test_empty_audit(self): + self.assertTrue(audit().passed()) + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + assert_no_mismatch(HydrationReport()) + + def test_fail(self): + with self.assertRaises(HydrationCheckError): + assert_no_mismatch(HydrationReport(findings=[ + HydrationFinding(kind="dom_diff", detail="x"), + ])) + + def test_rejects_non_report(self): + with self.assertRaises(HydrationCheckError): + assert_no_mismatch("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_idempotency_check.py b/test/unit_test/test_idempotency_check.py new file mode 100644 index 0000000..0bd94ff --- /dev/null +++ b/test/unit_test/test_idempotency_check.py @@ -0,0 +1,163 @@ +"""Unit tests for je_web_runner.utils.idempotency_check.""" +import unittest + +from je_web_runner.utils.idempotency_check.check import ( + IdemResponse, + IdempotencyCheckError, + IdempotencyReport, + assert_idempotent, + check, + generate_idempotency_key, +) + + +class _StateBox: + def __init__(self): + self.value = 0 + + +def _idempotent_runner(state): + def _run(): + if state.value == 0: + state.value = 1 + return IdemResponse(status_code=200, body={"id": 42, "ok": True}, + side_effect_count=state.value) + return _run + + +def _non_idempotent_runner(state): + def _run(): + state.value += 1 + return IdemResponse(status_code=200, body={"id": state.value}, + side_effect_count=state.value) + return _run + + +class TestCheck(unittest.TestCase): + + def test_idempotent_passes(self): + state = _StateBox() + report = check(_idempotent_runner(state), state_probe=lambda: state.value) + self.assertTrue(report.passed()) + + def test_non_idempotent_caught(self): + state = _StateBox() + report = check(_non_idempotent_runner(state)) + self.assertFalse(report.passed()) + joined = "; ".join(report.violations) + self.assertIn("body differs", joined) + + def test_status_change_caught(self): + calls = [ + IdemResponse(200, {"id": 1}), + IdemResponse(409, {"id": 1}), + ] + def runner(): + return calls.pop(0) + report = check(runner) + self.assertFalse(report.passed()) + + def test_status_change_allowed(self): + calls = [ + IdemResponse(200, {"id": 1}), + IdemResponse(409, {"id": 1}), + ] + def runner(): + return calls.pop(0) + report = check(runner, allow_status_change_to=[409]) + self.assertTrue(report.passed()) + + def test_ignore_body_keys(self): + calls = [ + IdemResponse(200, {"id": 1, "ts": "2026-01-01"}), + IdemResponse(200, {"id": 1, "ts": "2026-01-02"}), + ] + def runner(): + return calls.pop(0) + report = check(runner, ignore_body_keys=["ts"]) + self.assertTrue(report.passed()) + + def test_state_diff_caught(self): + state = _StateBox() + def runner(): + state.value += 1 + return IdemResponse(200, {"id": 1}) + report = check(runner, state_probe=lambda: state.value) + self.assertFalse(report.passed()) + self.assertTrue(any("state changed" in v for v in report.violations)) + + def test_side_effect_count_diff(self): + responses = [ + IdemResponse(200, {"id": 1}, side_effect_count=1), + IdemResponse(200, {"id": 1}, side_effect_count=2), + ] + def runner(): + return responses.pop(0) + report = check(runner) + self.assertFalse(report.passed()) + + def test_runner_must_be_callable(self): + with self.assertRaises(IdempotencyCheckError): + check("not callable") # type: ignore[arg-type] + + def test_state_probe_must_be_callable(self): + with self.assertRaises(IdempotencyCheckError): + check(lambda: IdemResponse(200, {}), state_probe="x") # type: ignore[arg-type] + + def test_runner_must_return_idem_response(self): + with self.assertRaises(IdempotencyCheckError): + check(lambda: "nope") + + def test_runner_exception_wrapped(self): + def boom(): + raise RuntimeError("net") + with self.assertRaises(IdempotencyCheckError): + check(boom) + + +class TestAssertIdempotent(unittest.TestCase): + + def test_pass(self): + report = IdempotencyReport( + first=IdemResponse(200, {}), second=IdemResponse(200, {}), + ) + assert_idempotent(report) + + def test_fail(self): + report = IdempotencyReport( + first=IdemResponse(200, {}), second=IdemResponse(200, {}), + violations=["x"], + ) + with self.assertRaises(IdempotencyCheckError): + assert_idempotent(report) + + def test_rejects_non_report(self): + with self.assertRaises(IdempotencyCheckError): + assert_idempotent("nope") # type: ignore[arg-type] + + +class TestKeyGen(unittest.TestCase): + + def test_stable(self): + self.assertEqual( + generate_idempotency_key("user", 42), + generate_idempotency_key("user", 42), + ) + + def test_changes_with_parts(self): + self.assertNotEqual( + generate_idempotency_key("user", 42), + generate_idempotency_key("user", 43), + ) + + +class TestIdemResponse(unittest.TestCase): + + def test_body_hash_stable(self): + a = IdemResponse(200, {"id": 1, "x": 2}) + b = IdemResponse(200, {"x": 2, "id": 1}) # different dict key order + self.assertEqual(a.body_hash(), b.body_hash()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_indexed_db_explorer.py b/test/unit_test/test_indexed_db_explorer.py new file mode 100644 index 0000000..e531203 --- /dev/null +++ b/test/unit_test/test_indexed_db_explorer.py @@ -0,0 +1,188 @@ +"""Unit tests for je_web_runner.utils.indexed_db_explorer.""" +import unittest + +from je_web_runner.utils.indexed_db_explorer.explorer import ( + IdbSnapshot, + IndexedDbExplorerError, + SnapshotDiff, + StoreSnapshot, + assert_db_exists, + assert_index_present, + assert_key_present, + assert_record_count, + assert_record_matching, + assert_store_present, + build_harvest_script, + diff_snapshots, +) + + +def _snap_dict(stores=None, exists=True, name="myDb", version=1): + return { + "name": name, "exists": exists, "version": version, + "stores": stores or {}, + } + + +def _store(records, keys=None, indexes=None): + return { + "key_path": "id", "auto_increment": False, + "index_names": list(indexes or []), + "records": list(records), + "keys": list(keys or [r.get("id") for r in records]), + } + + +class TestHarvestScript(unittest.TestCase): + + def test_embeds_db_name(self): + js = build_harvest_script("MyApp") + self.assertIn('"MyApp"', js) + self.assertIn("indexedDB.open", js) + + def test_rejects_empty_name(self): + with self.assertRaises(IndexedDbExplorerError): + build_harvest_script("") + + +class TestSnapshotParse(unittest.TestCase): + + def test_basic(self): + snap = IdbSnapshot.from_dict(_snap_dict({ + "users": _store([{"id": 1, "name": "alice"}]), + })) + self.assertTrue(snap.exists) + self.assertEqual(snap.version, 1) + self.assertIn("users", snap.stores) + self.assertEqual(snap.stores["users"].records[0]["name"], "alice") + + def test_missing_db(self): + snap = IdbSnapshot.from_dict({"name": "x", "exists": False}) + self.assertFalse(snap.exists) + self.assertEqual(snap.stores, {}) + + def test_rejects_non_dict(self): + with self.assertRaises(IndexedDbExplorerError): + IdbSnapshot.from_dict("nope") # type: ignore[arg-type] + + def test_rejects_bad_stores(self): + with self.assertRaises(IndexedDbExplorerError): + IdbSnapshot.from_dict({"stores": "not a dict"}) + + def test_ignores_bad_store_entries(self): + snap = IdbSnapshot.from_dict({"stores": {"x": "not a dict"}}) + self.assertEqual(snap.stores, {}) + + def test_round_trip_dict(self): + snap = IdbSnapshot.from_dict(_snap_dict({"u": _store([])})) + data = snap.to_dict() + self.assertEqual(data["name"], "myDb") + self.assertIn("u", data["stores"]) + + +class TestAssertions(unittest.TestCase): + + def _snap(self): + return IdbSnapshot.from_dict(_snap_dict({ + "users": _store( + [{"id": 1, "name": "alice"}, {"id": 2, "name": "bob"}], + indexes=["by_name"], + ), + "todos": _store([{"id": "a", "done": False}]), + })) + + def test_db_exists(self): + assert_db_exists(self._snap()) + + def test_db_not_exists(self): + with self.assertRaises(IndexedDbExplorerError): + assert_db_exists(IdbSnapshot.from_dict({"exists": False})) + + def test_db_exists_rejects_non_snapshot(self): + with self.assertRaises(IndexedDbExplorerError): + assert_db_exists("nope") # type: ignore[arg-type] + + def test_store_present(self): + store = assert_store_present(self._snap(), "users") + self.assertIsInstance(store, StoreSnapshot) + + def test_store_missing(self): + with self.assertRaises(IndexedDbExplorerError): + assert_store_present(self._snap(), "missing") + + def test_store_empty_name(self): + with self.assertRaises(IndexedDbExplorerError): + assert_store_present(self._snap(), "") + + def test_record_count_in_range(self): + self.assertEqual( + assert_record_count(self._snap(), "users", minimum=1, maximum=10), 2, + ) + + def test_record_count_out_of_range(self): + with self.assertRaises(IndexedDbExplorerError): + assert_record_count(self._snap(), "users", minimum=5) + with self.assertRaises(IndexedDbExplorerError): + assert_record_count(self._snap(), "users", maximum=1) + + def test_record_count_bad_bounds(self): + with self.assertRaises(IndexedDbExplorerError): + assert_record_count(self._snap(), "users", minimum=-1) + with self.assertRaises(IndexedDbExplorerError): + assert_record_count(self._snap(), "users", minimum=5, maximum=1) + + def test_key_present(self): + assert_key_present(self._snap(), "users", 1) + + def test_key_missing(self): + with self.assertRaises(IndexedDbExplorerError): + assert_key_present(self._snap(), "users", 999) + + def test_record_matching_pass(self): + record = assert_record_matching( + self._snap(), "users", lambda r: r["name"] == "bob", + ) + self.assertEqual(record["id"], 2) + + def test_record_matching_fail(self): + with self.assertRaises(IndexedDbExplorerError): + assert_record_matching(self._snap(), "users", lambda r: False) + + def test_record_matching_predicate_error_ignored(self): + def bad(_): + raise RuntimeError("oops") + with self.assertRaises(IndexedDbExplorerError): + assert_record_matching(self._snap(), "users", bad) + + def test_index_present(self): + assert_index_present(self._snap(), "users", "by_name") + + def test_index_missing(self): + with self.assertRaises(IndexedDbExplorerError): + assert_index_present(self._snap(), "users", "missing") + + +class TestDiff(unittest.TestCase): + + def test_added_removed_changed(self): + before = IdbSnapshot.from_dict(_snap_dict({ + "a": _store([{"id": 1}]), + "b": _store([]), + })) + after = IdbSnapshot.from_dict(_snap_dict({ + "a": _store([{"id": 1}, {"id": 2}]), + "c": _store([]), + })) + diff = diff_snapshots(before, after) + self.assertEqual(diff.added_stores, ["c"]) + self.assertEqual(diff.removed_stores, ["b"]) + self.assertEqual(diff.record_count_changes, + {"a": {"before": 1, "after": 2}}) + + def test_rejects_non_snapshot(self): + with self.assertRaises(IndexedDbExplorerError): + diff_snapshots("a", "b") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_inp_tracker.py b/test/unit_test/test_inp_tracker.py new file mode 100644 index 0000000..36b37e9 --- /dev/null +++ b/test/unit_test/test_inp_tracker.py @@ -0,0 +1,153 @@ +"""Unit tests for je_web_runner.utils.inp_tracker.""" +import unittest + +from je_web_runner.utils.inp_tracker.tracker import ( + HARVEST_SCRIPT, + InpRating, + InpReport, + InpTrackerError, + InteractionEvent, + assert_inp_under, + assert_no_poor_interactions, + build_install_script, + parse_log, +) + + +def _raw(duration, iid=1, name="click"): + return { + "name": name, "interactionId": iid, "duration_ms": duration, + "startTime": 0, "processingStart": 0, "processingEnd": 0, + "targetTag": "BUTTON", + } + + +class TestScripts(unittest.TestCase): + + def test_install_guard(self): + js = build_install_script() + self.assertIn("__wr_inp_installed__", js) + self.assertIn("PerformanceObserver", js) + + def test_harvest_constant(self): + self.assertIn("__wr_inp_log__", HARVEST_SCRIPT) + + +class TestParseLog(unittest.TestCase): + + def test_basic(self): + events = parse_log([_raw(120), _raw(300, iid=2)]) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].interaction_id, 2) + + def test_skips_non_dict(self): + self.assertEqual(parse_log(["x", None]), []) # type: ignore[list-item] + + def test_skips_bad_duration(self): + out = parse_log([{"duration_ms": "not a number"}]) + self.assertEqual(out, []) + + def test_skips_negative_duration(self): + out = parse_log([_raw(-1)]) + self.assertEqual(out, []) + + def test_rejects_non_list_payload(self): + with self.assertRaises(InpTrackerError): + parse_log("nope") # type: ignore[arg-type] + + +class TestInpReport(unittest.TestCase): + + def test_inp_uses_worst_when_few(self): + report = InpReport(events=parse_log([_raw(120), _raw(300), _raw(450)])) + self.assertEqual(report.inp(), 450) + self.assertEqual(report.rating(), InpRating.NEEDS_WORK) + + def test_inp_uses_p98_when_many(self): + events = parse_log([_raw(50, iid=i) for i in range(1, 51)]) + report = InpReport(events=events) + self.assertEqual(report.inp(), 50) + self.assertEqual(report.rating(), InpRating.GOOD) + + def test_no_events_inp_none(self): + self.assertIsNone(InpReport().inp()) + self.assertEqual(InpReport().rating(), InpRating.GOOD) + + def test_filtered_drops_zero_id(self): + events = [InteractionEvent(name="x", interaction_id=0, duration_ms=10), + InteractionEvent(name="y", interaction_id=1, duration_ms=10)] + self.assertEqual(len(InpReport(events=events).filtered()), 1) + + def test_percentile(self): + report = InpReport(events=parse_log([ + _raw(d, iid=i) for i, d in enumerate([10, 20, 30, 40], start=1) + ])) + self.assertEqual(report.percentile(50), 30) + self.assertEqual(report.percentile(100), 40) + + def test_percentile_bad_input(self): + with self.assertRaises(InpTrackerError): + InpReport().percentile(-1) + with self.assertRaises(InpTrackerError): + InpReport().percentile(150) + + +class TestAssertInpUnder(unittest.TestCase): + + def test_pass(self): + assert_inp_under( + InpReport(events=parse_log([_raw(100), _raw(150)])), + max_ms=200, + ) + + def test_fail(self): + with self.assertRaises(InpTrackerError): + assert_inp_under( + InpReport(events=parse_log([_raw(300)])), + max_ms=200, + ) + + def test_empty_passes(self): + assert_inp_under(InpReport(), max_ms=100) + + def test_bad_budget(self): + with self.assertRaises(InpTrackerError): + assert_inp_under(InpReport(), max_ms=0) + + def test_rejects_non_report(self): + with self.assertRaises(InpTrackerError): + assert_inp_under("nope", max_ms=100) # type: ignore[arg-type] + + +class TestAssertNoPoor(unittest.TestCase): + + def test_pass(self): + assert_no_poor_interactions( + InpReport(events=parse_log([_raw(100), _raw(450)])), + ) + + def test_fail(self): + with self.assertRaises(InpTrackerError): + assert_no_poor_interactions( + InpReport(events=parse_log([_raw(600)])), + ) + + def test_rejects_non_report(self): + with self.assertRaises(InpTrackerError): + assert_no_poor_interactions("nope") # type: ignore[arg-type] + + +class TestEventRating(unittest.TestCase): + + def test_ratings(self): + self.assertEqual(InteractionEvent("x", 1, 100).rating(), InpRating.GOOD) + self.assertEqual(InteractionEvent("x", 1, 250).rating(), InpRating.NEEDS_WORK) + self.assertEqual(InteractionEvent("x", 1, 600).rating(), InpRating.POOR) + + def test_to_dict_includes_rating(self): + d = InteractionEvent("x", 1, 100).to_dict() + self.assertEqual(d["rating"], "good") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_live_dashboard_server.py b/test/unit_test/test_live_dashboard_server.py new file mode 100644 index 0000000..65a7b2f --- /dev/null +++ b/test/unit_test/test_live_dashboard_server.py @@ -0,0 +1,221 @@ +"""Unit tests for je_web_runner.utils.live_dashboard.server (aggregator UI).""" +import json +import tempfile +import unittest +import urllib.error +import urllib.request +from datetime import datetime, timezone +from pathlib import Path + +from je_web_runner.utils.flake_detector.detector import ( + QuarantineEntry, + QuarantineRegistry, +) +from je_web_runner.utils.live_dashboard.server import ( + DashboardConfig, + DashboardServer, + LiveDashboardError, + build_summary, +) + + +def _iso(dt): + return dt.replace(tzinfo=timezone.utc).isoformat(timespec="seconds") + + +def _write_ledger(path: Path, runs): + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(json.dumps({"runs": runs}), encoding="utf-8") + + +def _make_seeded_config(tmpdir: Path) -> DashboardConfig: + ledger = tmpdir / "ledger.json" + now = datetime.now(timezone.utc) + _write_ledger(ledger, [ + {"path": "a.json", "passed": True, "time": _iso(now)}, + {"path": "a.json", "passed": False, "time": _iso(now)}, + {"path": "a.json", "passed": True, "time": _iso(now)}, + {"path": "a.json", "passed": False, "time": _iso(now)}, + {"path": "b.json", "passed": True, "time": _iso(now)}, + ]) + quarantine = tmpdir / "quarantine.json" + reg = QuarantineRegistry(quarantine) + reg.add(QuarantineEntry( + test_id="a.json", reason="auto", flake_score=0.6, + quarantined_at=_iso(now), + )) + locator = tmpdir / "locator.json" + locator.write_text(json.dumps({ + "total": 10, "weak": 3, "strong": 7, + "average_score": 75.5, "threshold": 60, + "weakest": [ + {"file_path": "actions/login.json", "action_index": 1, + "strategy": "XPATH", "value": "//div/div/span", + "score": 30, "reasons": ["deep selector"]}, + ], + }), encoding="utf-8") + return DashboardConfig( + ledger_path=ledger, + quarantine_path=quarantine, + locator_findings_path=locator, + ) + + +class TestBuildSummary(unittest.TestCase): + + def test_aggregates_counts(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_seeded_config(Path(tmpdir)) + summary = build_summary(config) + self.assertEqual(summary["total_runs"], 5) + self.assertEqual(summary["passed"], 3) + self.assertEqual(summary["failed"], 2) + self.assertAlmostEqual(summary["pass_rate"], 0.6, places=2) + self.assertEqual(summary["quarantined_tests"], 1) + self.assertEqual(summary["weak_locators"], 3) + + def test_empty_config_is_safe(self): + summary = build_summary(DashboardConfig()) + self.assertEqual(summary["total_runs"], 0) + self.assertEqual(summary["pass_rate"], 0.0) + self.assertEqual(summary["weak_locators"], 0) + + def test_missing_files_skipped(self): + config = DashboardConfig( + ledger_path=Path("/no/such/ledger.json"), + quarantine_path=Path("/no/such/q.json"), + locator_findings_path=Path("/no/such/l.json"), + ) + summary = build_summary(config) + self.assertEqual(summary["total_runs"], 0) + + +class TestServerLifecycle(unittest.TestCase): + + def test_url_before_start_raises(self): + with self.assertRaises(LiveDashboardError): + _ = DashboardServer().url + + def test_start_stop(self): + with tempfile.TemporaryDirectory() as tmpdir: + config = _make_seeded_config(Path(tmpdir)) + server = DashboardServer(config) + url = server.start() + self.assertTrue(url.startswith("http://127.0.0.1:")) + server.stop() + + def test_double_start_raises(self): + server = DashboardServer() + server.start() + try: + with self.assertRaises(LiveDashboardError): + server.start() + finally: + server.stop() + + def test_context_manager(self): + with DashboardServer() as server: + self.assertTrue(server.url.startswith("http://")) + + def test_stop_when_not_started_is_noop(self): + DashboardServer().stop() + + +def _http_get(url: str) -> bytes: + with urllib.request.urlopen(url, timeout=5) as resp: # nosec B310 — localhost only + return resp.read() + + +def _http_get_json(url: str): + return json.loads(_http_get(url).decode("utf-8")) + + +class TestHttpEndpoints(unittest.TestCase): + + def setUp(self): + self.tmp = tempfile.TemporaryDirectory() + self.addCleanup(self.tmp.cleanup) + self.config = _make_seeded_config(Path(self.tmp.name)) + self.server = DashboardServer(self.config) + self.server.start() + self.addCleanup(self.server.stop) + + def test_overview_returns_html(self): + body = _http_get(self.server.url + "/").decode("utf-8") + self.assertIn("", body) + self.assertIn("WebRunner overview", body) + self.assertIn("Total runs", body) + + def test_runs_page_lists_recent(self): + body = _http_get(self.server.url + "/runs").decode("utf-8") + self.assertIn("a.json", body) + self.assertIn("FAIL", body) + self.assertIn("PASS", body) + + def test_flake_page_lists_flaky_tests(self): + body = _http_get(self.server.url + "/flake").decode("utf-8") + self.assertIn("a.json", body) + + def test_quarantine_page_shows_entries(self): + body = _http_get(self.server.url + "/quarantine").decode("utf-8") + self.assertIn("a.json", body) + + def test_locators_page_shows_weakest(self): + body = _http_get(self.server.url + "/locators").decode("utf-8") + self.assertIn("XPATH", body) + self.assertIn("deep selector", body) + + def test_unknown_route_returns_404(self): + url = self.server.url + "/no-such-route" + with self.assertRaises(urllib.error.HTTPError) as cm: + _http_get(url) + self.assertEqual(cm.exception.code, 404) + + def test_api_summary_returns_json(self): + payload = _http_get_json(self.server.url + "/api/summary") + self.assertEqual(payload["total_runs"], 5) + self.assertEqual(payload["quarantined_tests"], 1) + + def test_api_runs_respects_limit(self): + payload = _http_get_json(self.server.url + "/api/runs?limit=2") + self.assertLessEqual(len(payload), 2) + + def test_api_runs_bad_limit_falls_back(self): + payload = _http_get_json(self.server.url + "/api/runs?limit=junk") + self.assertGreater(len(payload), 0) + + def test_api_flake_payload_shape(self): + payload = _http_get_json(self.server.url + "/api/flake") + self.assertTrue(any(e["path"] == "a.json" for e in payload)) + + def test_api_quarantine_payload(self): + payload = _http_get_json(self.server.url + "/api/quarantine") + self.assertEqual(payload[0]["test_id"], "a.json") + + def test_api_locators_payload(self): + payload = _http_get_json(self.server.url + "/api/locators") + self.assertEqual(payload["total"], 10) + + def test_healthz(self): + body = _http_get(self.server.url + "/healthz") + self.assertEqual(body, b"ok") + + +class TestEmptyServer(unittest.TestCase): + + def test_endpoints_work_with_no_data(self): + with DashboardServer() as server: + overview = _http_get(server.url + "/").decode("utf-8") + self.assertIn("WebRunner overview", overview) + empty_runs = _http_get(server.url + "/runs").decode("utf-8") + self.assertIn("No runs", empty_runs) + empty_flake = _http_get(server.url + "/flake").decode("utf-8") + self.assertIn("No flaky", empty_flake) + empty_q = _http_get(server.url + "/quarantine").decode("utf-8") + self.assertIn("empty", empty_q.lower()) + empty_l = _http_get(server.url + "/locators").decode("utf-8") + self.assertIn("No locator", empty_l) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_locator_hardener.py b/test/unit_test/test_locator_hardener.py new file mode 100644 index 0000000..555b46d --- /dev/null +++ b/test/unit_test/test_locator_hardener.py @@ -0,0 +1,231 @@ +"""Unit tests for je_web_runner.utils.locator_hardener.""" +import json +import unittest + +from je_web_runner.utils.locator_hardener.hardener import ( + FragileLocator, + LocatorHardenerError, + LocatorStrategy, + LocatorSuggestion, + build_prompt, + harden, + parse_suggestions, + score_fragility, +) + + +class StubClient: + def __init__(self, response): + self.response = response + self.last_prompt = None + + def suggest(self, prompt): + self.last_prompt = prompt + if isinstance(self.response, Exception): + raise self.response + return self.response + + +def _good_response(): + return json.dumps([ + {"strategy": "id", "value": "submit-btn", + "rationale": "id is unique and stable"}, + {"strategy": "css selector", "value": "[data-test=submit]", + "rationale": "stable test attribute"}, + ]) + + +class TestFragileLocator(unittest.TestCase): + + def test_rejects_empty(self): + with self.assertRaises(LocatorHardenerError): + FragileLocator(test_id="", strategy=LocatorStrategy.CSS, value="x") + with self.assertRaises(LocatorHardenerError): + FragileLocator(test_id="t", strategy=LocatorStrategy.CSS, value="") + + def test_rejects_negative_history(self): + with self.assertRaises(LocatorHardenerError): + FragileLocator(test_id="t", strategy=LocatorStrategy.CSS, value="x", + failure_history=-1) + + +class TestScoreFragility(unittest.TestCase): + + def test_id_is_low(self): + score = score_fragility(FragileLocator( + test_id="t", strategy=LocatorStrategy.ID, value="submit", + )) + self.assertLess(score.score, 0.5) + + def test_xpath_with_text_high(self): + score = score_fragility(FragileLocator( + test_id="t", strategy=LocatorStrategy.XPATH, + value="//button[text()='Submit']", + )) + self.assertGreater(score.score, 0.4) + self.assertTrue(any("text" in r for r in score.reasons)) + + def test_nth_of_type_high(self): + score = score_fragility(FragileLocator( + test_id="t", strategy=LocatorStrategy.CSS, + value=".table tr:nth-of-type(3) td", + )) + self.assertGreaterEqual(score.score, 0.4) + self.assertTrue(any("nth-of-type" in r for r in score.reasons)) + + def test_hashed_class(self): + score = score_fragility(FragileLocator( + test_id="t", strategy=LocatorStrategy.CSS, + value=".Button_button-_a1b2c3", + )) + self.assertTrue(any("hashed" in r for r in score.reasons)) + + def test_failure_history_boost(self): + score = score_fragility(FragileLocator( + test_id="t", strategy=LocatorStrategy.ID, value="x", + failure_history=5, + )) + self.assertTrue(any("failed" in r for r in score.reasons)) + + def test_class_name_with_spaces(self): + score = score_fragility(FragileLocator( + test_id="t", strategy=LocatorStrategy.CLASS_NAME, + value="btn primary", + )) + self.assertTrue(any("multi-class" in r for r in score.reasons)) + + def test_rejects_non_locator(self): + with self.assertRaises(LocatorHardenerError): + score_fragility("nope") # type: ignore[arg-type] + + +class TestBuildPrompt(unittest.TestCase): + + def test_includes_locator(self): + prompt = build_prompt(FragileLocator( + test_id="login.json", strategy=LocatorStrategy.CSS, + value=".x .y nth-of-type(2)", dom_excerpt="
...
", + )) + self.assertIn("login.json", prompt) + self.assertIn(".x .y nth-of-type(2)", prompt) + self.assertIn("
", prompt) + + def test_rejects_non_locator(self): + with self.assertRaises(LocatorHardenerError): + build_prompt("nope") # type: ignore[arg-type] + + +class TestParseSuggestions(unittest.TestCase): + + def test_parses_clean(self): + suggestions = parse_suggestions(_good_response()) + self.assertEqual(len(suggestions), 2) + self.assertEqual(suggestions[0].strategy, LocatorStrategy.ID) + + def test_drops_unsafe_nth(self): + raw = json.dumps([ + {"strategy": "css selector", + "value": "tr:nth-of-type(2)", "rationale": "x"}, + {"strategy": "id", "value": "good", "rationale": "y"}, + ]) + suggestions = parse_suggestions(raw) + self.assertEqual(len(suggestions), 1) + self.assertEqual(suggestions[0].value, "good") + + def test_drops_unsafe_xpath_text(self): + raw = json.dumps([ + {"strategy": "xpath", + "value": "//button[text()='Save']", "rationale": "x"}, + {"strategy": "id", "value": "save", "rationale": "y"}, + ]) + suggestions = parse_suggestions(raw) + self.assertEqual(suggestions[0].strategy, LocatorStrategy.ID) + + def test_extracts_from_text(self): + wrapped = "Here you go: " + _good_response() + " thanks" + self.assertEqual(len(parse_suggestions(wrapped)), 2) + + def test_skip_unknown_strategy(self): + raw = json.dumps([ + {"strategy": "fancy", "value": "x", "rationale": "y"}, + {"strategy": "id", "value": "good", "rationale": "y"}, + ]) + self.assertEqual(len(parse_suggestions(raw)), 1) + + def test_skip_non_dict(self): + raw = json.dumps([ + "not a dict", + {"strategy": "id", "value": "good", "rationale": "y"}, + ]) + self.assertEqual(len(parse_suggestions(raw)), 1) + + def test_skip_empty_value(self): + raw = json.dumps([ + {"strategy": "id", "value": "", "rationale": "y"}, + {"strategy": "id", "value": "good", "rationale": "y"}, + ]) + self.assertEqual(len(parse_suggestions(raw)), 1) + + def test_no_valid_raises(self): + raw = json.dumps([{"strategy": "fancy", "value": "x", "rationale": ""}]) + with self.assertRaises(LocatorHardenerError): + parse_suggestions(raw) + + def test_empty(self): + with self.assertRaises(LocatorHardenerError): + parse_suggestions("") + + def test_no_array(self): + with self.assertRaises(LocatorHardenerError): + parse_suggestions("just text") + + def test_bad_json(self): + with self.assertRaises(LocatorHardenerError): + parse_suggestions("[not json]") + + +class TestHarden(unittest.TestCase): + + def test_skips_when_below_threshold(self): + locator = FragileLocator( + test_id="t", strategy=LocatorStrategy.ID, value="x", + ) + client = StubClient(_good_response()) + result = harden(locator, client, min_fragility=0.5) + self.assertEqual(result, []) + self.assertIsNone(client.last_prompt) + + def test_calls_when_fragile(self): + locator = FragileLocator( + test_id="t", strategy=LocatorStrategy.CSS, + value=".a .b .c .d:nth-of-type(2)", + ) + result = harden(locator, StubClient(_good_response()), min_fragility=0.3) + self.assertGreaterEqual(len(result), 1) + + def test_client_error_wrapped(self): + locator = FragileLocator( + test_id="t", strategy=LocatorStrategy.CSS, + value=".a:nth-of-type(2)", + ) + with self.assertRaises(LocatorHardenerError): + harden(locator, StubClient(RuntimeError("rate limit")), + min_fragility=0.3) + + def test_bad_threshold(self): + locator = FragileLocator( + test_id="t", strategy=LocatorStrategy.CSS, value=".a", + ) + with self.assertRaises(LocatorHardenerError): + harden(locator, StubClient(_good_response()), min_fragility=2.0) + + +class TestSuggestionDict(unittest.TestCase): + + def test_to_dict(self): + s = LocatorSuggestion(strategy=LocatorStrategy.ID, value="x", rationale="y") + self.assertEqual(s.to_dict()["strategy"], "id") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_locator_health.py b/test/unit_test/test_locator_health.py new file mode 100644 index 0000000..9000339 --- /dev/null +++ b/test/unit_test/test_locator_health.py @@ -0,0 +1,295 @@ +"""Unit tests for je_web_runner.utils.locator_health.""" +import json +import tempfile +import threading +import unittest +from pathlib import Path + +from je_web_runner.utils.locator_health.health_report import ( + FallbackHitTracker, + LocatorFinding, + LocatorHealthError, + LocatorHealthReport, + UpgradeSuggestion, + apply_upgrades, + build_health_report, + fallback_hit_tracker, + render_health_markdown, + save_health_report, + scan_action_file, + scan_project, + suggest_upgrade, + suggest_upgrades, +) + + +def _write_action_file(path: Path, actions): + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "w", encoding="utf-8") as fp: + json.dump(actions, fp) + + +class TestFallbackHitTracker(unittest.TestCase): + + def test_counts_primary_and_fallback(self): + tracker = FallbackHitTracker() + tracker.track_primary("login_btn") + tracker.track_primary("login_btn") + tracker.track_fallback("login_btn") + stats = tracker.stats() + self.assertEqual(stats["login_btn"]["hits"], 3) + self.assertEqual(stats["login_btn"]["fallback_used"], 1) + + def test_thread_safe(self): + tracker = FallbackHitTracker() + N = 200 + + def hammer(): + for _ in range(N): + tracker.track_primary("x") + tracker.track_fallback("x") + + threads = [threading.Thread(target=hammer) for _ in range(4)] + for t in threads: + t.start() + for t in threads: + t.join() + stats = tracker.stats() + self.assertEqual(stats["x"]["hits"], 4 * N * 2) + self.assertEqual(stats["x"]["fallback_used"], 4 * N) + + def test_clear_resets(self): + tracker = FallbackHitTracker() + tracker.track_primary("a") + tracker.clear() + self.assertEqual(tracker.stats(), {}) + + +class TestScanActionFile(unittest.TestCase): + + def test_missing_file_raises(self): + with self.assertRaises(LocatorHealthError): + scan_action_file("/nope/missing.json") + + def test_malformed_json_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "bad.json" + path.write_text("{not json", encoding="utf-8") + with self.assertRaises(LocatorHealthError): + scan_action_file(path) + + def test_scores_each_locator(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "a.json" + _write_action_file(path, [ + ["WR_save_test_object", {"object_type": "ID", "test_object_name": "login_btn"}], + ["WR_save_test_object", {"object_type": "XPATH", "test_object_name": "//div/div/div/div/div/span"}], + ["WR_to_url", {"url": "https://x"}], + ]) + findings = scan_action_file(path) + self.assertEqual(len(findings), 2) + ids = {(f.strategy, f.score) for f in findings} + id_score = next(f for f in findings if f.strategy == "ID").score + xpath_score = next(f for f in findings if f.strategy == "XPATH").score + self.assertGreater(id_score, xpath_score) + + +class TestScanProject(unittest.TestCase): + + def test_walks_directory(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + _write_action_file(root / "a.json", [ + ["WR_save", {"object_type": "ID", "test_object_name": "x"}], + ]) + _write_action_file(root / "sub" / "b.json", [ + ["WR_save", {"object_type": "XPATH", "test_object_name": "//a"}], + ]) + (root / "notes.txt").write_text("hi", encoding="utf-8") + findings = scan_project(root) + self.assertEqual(len(findings), 2) + + def test_missing_root_raises(self): + with self.assertRaises(LocatorHealthError): + scan_project("/no/such/dir") + + def test_skips_unparseable_files(self): + with tempfile.TemporaryDirectory() as tmpdir: + root = Path(tmpdir) + (root / "broken.json").write_text("{{{", encoding="utf-8") + _write_action_file(root / "ok.json", [ + ["WR_save", {"object_type": "ID", "test_object_name": "x"}], + ]) + findings = scan_project(root) + self.assertEqual(len(findings), 1) + + +class TestBuildHealthReport(unittest.TestCase): + + def _findings(self, *triples): + return [ + LocatorFinding( + file_path="x.json", action_index=i, + strategy=s, value=v, score=score, + ) + for i, (s, v, score) in enumerate(triples) + ] + + def test_empty_returns_zeros(self): + report = build_health_report([]) + self.assertEqual(report.total, 0) + self.assertEqual(report.average_score, 0.0) + + def test_aggregates_correctly(self): + report = build_health_report(self._findings( + ("ID", "ok", 90), + ("XPATH", "//a", 30), + ("CSS_SELECTOR", ".btn", 70), + ), threshold=60) + self.assertEqual(report.total, 3) + self.assertEqual(report.weak, 1) + self.assertEqual(report.strong, 2) + self.assertAlmostEqual(report.average_score, (90 + 30 + 70) / 3, places=2) + self.assertEqual(report.weakest[0].score, 30) + + def test_fallback_offenders_sorted_by_rate(self): + findings = [ + LocatorFinding( + file_path="x.json", action_index=0, strategy="ID", value="a", + score=80, hits=10, fallback_used=5, + ), + LocatorFinding( + file_path="x.json", action_index=1, strategy="ID", value="b", + score=80, hits=10, fallback_used=8, + ), + LocatorFinding( + file_path="x.json", action_index=2, strategy="ID", value="c", + score=80, hits=10, fallback_used=1, # below threshold + ), + ] + report = build_health_report(findings, fallback_min_rate=0.4) + self.assertEqual(len(report.fallback_offenders), 2) + self.assertEqual(report.fallback_offenders[0].value, "b") + + +class TestSuggestUpgrade(unittest.TestCase): + + def test_xpath_with_id_attr_suggests_id(self): + finding = LocatorFinding( + file_path="x.json", action_index=0, + strategy="XPATH", value="//div[@id='login']", score=40, + ) + sug = suggest_upgrade(finding) + self.assertIsNotNone(sug) + self.assertEqual(sug.to_strategy, "ID") + self.assertEqual(sug.to_value, "login") + + def test_xpath_with_testid_suggests_css(self): + finding = LocatorFinding( + file_path="x.json", action_index=0, + strategy="XPATH", value="//button[@data-testid='go']", score=40, + ) + sug = suggest_upgrade(finding) + self.assertIsNotNone(sug) + self.assertEqual(sug.to_strategy, "CSS_SELECTOR") + self.assertEqual(sug.to_value, "[data-testid='go']") + + def test_css_single_id_suggests_id(self): + finding = LocatorFinding( + file_path="x.json", action_index=0, + strategy="CSS_SELECTOR", value="#login-btn", score=70, + ) + sug = suggest_upgrade(finding) + self.assertEqual(sug.to_strategy, "ID") + self.assertEqual(sug.to_value, "login-btn") + + def test_id_no_suggestion(self): + finding = LocatorFinding( + file_path="x.json", action_index=0, + strategy="ID", value="ok", score=90, + ) + self.assertIsNone(suggest_upgrade(finding)) + + +class TestSuggestUpgrades(unittest.TestCase): + + def test_filters_below_threshold(self): + findings = [ + LocatorFinding(file_path="x", action_index=0, strategy="XPATH", + value="//div[@id='a']", score=40), + LocatorFinding(file_path="x", action_index=1, strategy="XPATH", + value="//div[@id='b']", score=80), + ] + sugs = suggest_upgrades(findings, only_below=50) + self.assertEqual(len(sugs), 1) + + +class TestApplyUpgrades(unittest.TestCase): + + def test_rewrites_in_place_safely(self): + actions = [ + ["WR_save", {"object_type": "XPATH", "test_object_name": "//div[@id='x']"}], + ["WR_click", {"object_type": "ID", "test_object_name": "x"}], + ] + sugs = [UpgradeSuggestion( + file_path="x", action_index=0, + from_strategy="XPATH", from_value="//div[@id='x']", + to_strategy="ID", to_value="x", + rationale="r", + )] + new_actions = apply_upgrades(actions, sugs) + self.assertEqual(new_actions[0][1]["object_type"], "ID") + self.assertEqual(new_actions[0][1]["test_object_name"], "x") + # Original is untouched. + self.assertEqual(actions[0][1]["object_type"], "XPATH") + # Second action unaffected. + self.assertEqual(new_actions[1][1]["object_type"], "ID") + + def test_out_of_range_index_is_ignored(self): + actions = [["WR_save", {"object_type": "ID", "test_object_name": "x"}]] + sugs = [UpgradeSuggestion( + file_path="x", action_index=99, + from_strategy="X", from_value="x", + to_strategy="Y", to_value="y", rationale="r", + )] + new_actions = apply_upgrades(actions, sugs) + self.assertEqual(new_actions, actions) + + +class TestRendering(unittest.TestCase): + + def test_markdown_contains_required_sections(self): + findings = [ + LocatorFinding(file_path="a.json", action_index=0, + strategy="XPATH", value="//x", score=30, + reasons=["deep"]), + ] + report = build_health_report(findings, threshold=60) + md = render_health_markdown(report) + self.assertIn("Locator health report", md) + self.assertIn("Weakest locators", md) + self.assertIn("`XPATH`", md) + + def test_save_round_trips(self): + report = build_health_report([ + LocatorFinding(file_path="a.json", action_index=0, + strategy="ID", value="x", score=90), + ]) + with tempfile.TemporaryDirectory() as tmpdir: + path = save_health_report(report, Path(tmpdir) / "r.json") + data = json.loads(path.read_text(encoding="utf-8")) + self.assertEqual(data["total"], 1) + self.assertEqual(data["strong"], 1) + + +class TestModuleSingleton(unittest.TestCase): + + def test_module_level_tracker_isolated_between_tests(self): + fallback_hit_tracker.clear() + fallback_hit_tracker.track_primary("foo") + self.assertEqual(fallback_hit_tracker.stats()["foo"]["hits"], 1) + fallback_hit_tracker.clear() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_long_animation_frame.py b/test/unit_test/test_long_animation_frame.py new file mode 100644 index 0000000..7617942 --- /dev/null +++ b/test/unit_test/test_long_animation_frame.py @@ -0,0 +1,140 @@ +"""Unit tests for je_web_runner.utils.long_animation_frame.""" +import unittest + +from je_web_runner.utils.long_animation_frame.frames import ( + HARVEST_SCRIPT, + LoafReport, + LongAnimationFrameError, + LongFrame, + ScriptAttribution, + assert_no_frame_over, + assert_total_blocking_under, + build_install_script, + parse_log, +) + + +def _frame(duration=100, blocking=80, scripts=None): + return { + "duration_ms": duration, + "render_start_ms": 10, + "style_layout_start_ms": 20, + "start_time_ms": 0, + "blocking_duration_ms": blocking, + "scripts": scripts or [], + } + + +def _script(name, duration=50, source_url=""): + return { + "name": name, "invoker": "click", "invoker_type": "event-listener", + "source_url": source_url or f"https://x/{name}.js", + "duration_ms": duration, + "forced_style_layout_duration_ms": 5, + "pause_duration_ms": 0, + } + + +class TestScripts(unittest.TestCase): + + def test_install_guard(self): + js = build_install_script() + self.assertIn("__wr_loaf_installed__", js) + self.assertIn("long-animation-frame", js) + + def test_harvest_constant(self): + self.assertIn("__wr_loaf_log__", HARVEST_SCRIPT) + + +class TestParseLog(unittest.TestCase): + + def test_basic(self): + frames = parse_log([_frame(150, 100, [_script("react", 50)])]) + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].duration_ms, 150) + self.assertEqual(len(frames[0].scripts), 1) + self.assertEqual(frames[0].scripts[0].name, "react") + + def test_skips_non_dict_frame(self): + self.assertEqual(parse_log(["string", None]), []) # type: ignore[list-item] + + def test_skips_non_dict_script(self): + frames = parse_log([_frame(100, 80, ["not dict"])]) + self.assertEqual(len(frames), 1) + self.assertEqual(frames[0].scripts, []) + + def test_rejects_non_list(self): + with self.assertRaises(LongAnimationFrameError): + parse_log({"x": 1}) # type: ignore[arg-type] + + +class TestReport(unittest.TestCase): + + def test_worst_frame(self): + report = LoafReport(frames=parse_log([ + _frame(100), _frame(200), _frame(50), + ])) + self.assertEqual(report.worst_frame_ms(), 200) + + def test_worst_frame_empty(self): + self.assertEqual(LoafReport().worst_frame_ms(), 0.0) + + def test_total_blocking(self): + report = LoafReport(frames=parse_log([ + _frame(100, blocking=60), _frame(100, blocking=40), + ])) + self.assertEqual(report.total_blocking_ms(), 100) + + def test_top_scripts_aggregates_by_url(self): + report = LoafReport(frames=parse_log([ + _frame(scripts=[_script("a", 50, "https://x/a.js")]), + _frame(scripts=[_script("a-dup", 30, "https://x/a.js"), + _script("b", 100, "https://x/b.js")]), + ])) + top = report.top_scripts(n=2) + self.assertEqual(top[0].source_url, "https://x/b.js") + self.assertEqual(top[0].duration_ms, 100) + # Aggregated 'a' script: 50 + 30 = 80 + self.assertEqual(top[1].source_url, "https://x/a.js") + self.assertEqual(top[1].duration_ms, 80) + + +class TestAssertions(unittest.TestCase): + + def test_no_frame_over_pass(self): + report = LoafReport(frames=parse_log([_frame(40)])) + assert_no_frame_over(report, max_ms=50) + + def test_no_frame_over_fail(self): + report = LoafReport(frames=parse_log([_frame(80)])) + with self.assertRaises(LongAnimationFrameError): + assert_no_frame_over(report, max_ms=50) + + def test_no_frame_over_bad_threshold(self): + with self.assertRaises(LongAnimationFrameError): + assert_no_frame_over(LoafReport(), max_ms=0) + + def test_no_frame_over_rejects_non_report(self): + with self.assertRaises(LongAnimationFrameError): + assert_no_frame_over("nope", max_ms=50) # type: ignore[arg-type] + + def test_total_blocking_pass(self): + report = LoafReport(frames=parse_log([_frame(blocking=50)])) + assert_total_blocking_under(report, max_ms=100) + + def test_total_blocking_fail(self): + report = LoafReport(frames=parse_log([_frame(blocking=200)])) + with self.assertRaises(LongAnimationFrameError): + assert_total_blocking_under(report, max_ms=100) + + def test_total_blocking_bad_threshold(self): + with self.assertRaises(LongAnimationFrameError): + assert_total_blocking_under(LoafReport(), max_ms=-1) + + def test_total_blocking_rejects_non_report(self): + with self.assertRaises(LongAnimationFrameError): + assert_total_blocking_under("nope", max_ms=50) # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_mixed_content_audit.py b/test/unit_test/test_mixed_content_audit.py new file mode 100644 index 0000000..4bd8414 --- /dev/null +++ b/test/unit_test/test_mixed_content_audit.py @@ -0,0 +1,175 @@ +"""Unit tests for je_web_runner.utils.mixed_content_audit.""" +import json +import unittest + +from je_web_runner.utils.mixed_content_audit.audit import ( + MixedContentAuditError, + MixedFinding, + Severity, + assert_clean, + assert_no_active, + scan_console_errors, + scan_har, + summary, +) + + +def _har_entry(url, resource_type="script"): + return { + "_resourceType": resource_type, + "request": {"url": url}, + "response": {"content": {}}, + } + + +def _har(*entries): + return {"log": {"entries": list(entries)}} + + +class TestScanHar(unittest.TestCase): + + def test_no_findings_for_clean_https(self): + findings = scan_har(_har(_har_entry("https://x.com/a.js")), + page_url="https://x.com") + self.assertEqual(findings, []) + + def test_active_finding(self): + findings = scan_har( + _har(_har_entry("http://x.com/bad.js", "script")), + page_url="https://x.com", + ) + self.assertEqual(len(findings), 1) + self.assertEqual(findings[0].severity, Severity.ACTIVE) + + def test_passive_finding(self): + findings = scan_har( + _har(_har_entry("http://x.com/img.png", "image")), + page_url="https://x.com", + ) + self.assertEqual(findings[0].severity, Severity.PASSIVE) + + def test_upgrade_for_hsts_domain(self): + findings = scan_har( + _har(_har_entry("http://fonts.googleapis.com/css", "stylesheet")), + page_url="https://x.com", + ) + self.assertEqual(findings[0].severity, Severity.UPGRADE) + + def test_unknown_resource_type_active(self): + findings = scan_har( + _har(_har_entry("http://x.com/a", "weird")), + page_url="https://x.com", + ) + self.assertEqual(findings[0].severity, Severity.ACTIVE) + + def test_http_page_no_risk(self): + findings = scan_har( + _har(_har_entry("http://x.com/a.js")), + page_url="http://x.com", + ) + self.assertEqual(findings, []) + + def test_str_har(self): + findings = scan_har( + json.dumps(_har(_har_entry("http://x.com/a.js"))), + page_url="https://x.com", + ) + self.assertEqual(len(findings), 1) + + def test_bad_har(self): + with self.assertRaises(MixedContentAuditError): + scan_har("not json") + + def test_bad_har_type(self): + with self.assertRaises(MixedContentAuditError): + scan_har(123) # type: ignore[arg-type] + + def test_bad_har_root(self): + with self.assertRaises(MixedContentAuditError): + scan_har("[]") + + def test_page_url_inferred_from_har_pages(self): + har = { + "log": { + "pages": [{"title": "https://example.com/"}], + "entries": [_har_entry("http://example.com/img.png", "image")], + } + } + findings = scan_har(har) + self.assertEqual(len(findings), 1) + self.assertEqual(findings[0].source_url, "https://example.com/") + + def test_no_page_url_assumes_https(self): + # When page_url is empty AND no pages array, still scans + findings = scan_har(_har(_har_entry("http://x.com/a.png", "image"))) + self.assertEqual(len(findings), 1) + + +class TestScanConsole(unittest.TestCase): + + def test_active_message(self): + msgs = ['Mixed Content: The page at https://x.com requested an insecure script http://x.com/bad.js. This request has been blocked.'] + findings = scan_console_errors(msgs, page_url="https://x.com") + self.assertEqual(findings[0].severity, Severity.ACTIVE) + self.assertIn("bad.js", findings[0].url) + + def test_passive_message(self): + msgs = ['Mixed Content: passive image http://x.com/img.png was loaded over HTTP.'] + findings = scan_console_errors(msgs) + self.assertEqual(findings[0].severity, Severity.PASSIVE) + + def test_ignores_unrelated(self): + msgs = ["TypeError: foo is not a function", "Some other log"] + self.assertEqual(scan_console_errors(msgs), []) + + def test_ignores_non_string(self): + self.assertEqual(scan_console_errors([None, 1]), []) # type: ignore[list-item] + + def test_skips_https_url_in_message(self): + msgs = ["Mixed Content message containing https://x.com/foo"] + # https URL in message → don't classify as finding + self.assertEqual(scan_console_errors(msgs), []) + + +class TestAssertions(unittest.TestCase): + + def test_assert_no_active_pass(self): + assert_no_active([ + MixedFinding(url="http://x", resource_type="image", + severity=Severity.PASSIVE), + ]) + + def test_assert_no_active_fail(self): + with self.assertRaises(MixedContentAuditError): + assert_no_active([ + MixedFinding(url="http://x", resource_type="script", + severity=Severity.ACTIVE), + ]) + + def test_assert_clean_pass(self): + assert_clean([]) + + def test_assert_clean_fail(self): + with self.assertRaises(MixedContentAuditError): + assert_clean([ + MixedFinding(url="http://x", resource_type="image", + severity=Severity.PASSIVE), + ]) + + +class TestSummary(unittest.TestCase): + + def test_counts_severities(self): + s = summary([ + MixedFinding(url="a", resource_type="x", severity=Severity.ACTIVE), + MixedFinding(url="b", resource_type="x", severity=Severity.ACTIVE), + MixedFinding(url="c", resource_type="x", severity=Severity.PASSIVE), + ]) + self.assertEqual(s, {"active": 2, "passive": 1}) + + def test_empty(self): + self.assertEqual(summary([]), {}) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_multimodal_qa.py b/test/unit_test/test_multimodal_qa.py new file mode 100644 index 0000000..52d4625 --- /dev/null +++ b/test/unit_test/test_multimodal_qa.py @@ -0,0 +1,189 @@ +"""Unit tests for je_web_runner.utils.multimodal_qa.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.multimodal_qa.qa import ( + MultimodalQaError, + QaRequest, + QaResponse, + Verdict, + VisionClient, + ask, + ask_path, + assert_passes, + build_prompt, + parse_response, +) + + +class StubClient: + def __init__(self, response): + self.response = response + self.last_prompt = None + self.last_image = None + + def ask(self, prompt, image_b64): + self.last_prompt = prompt + self.last_image = image_b64 + if isinstance(self.response, Exception): + raise self.response + return self.response + + +def _good_response(verdict="pass", confidence=0.9, rationale="looks fine", issues=None): + return json.dumps({ + "verdict": verdict, + "confidence": confidence, + "rationale": rationale, + "issues": issues or [], + }) + + +class TestRequest(unittest.TestCase): + + def test_rejects_empty_bytes(self): + with self.assertRaises(MultimodalQaError): + QaRequest(image_bytes=b"", question="Q") + + def test_rejects_non_bytes(self): + with self.assertRaises(MultimodalQaError): + QaRequest(image_bytes="not bytes", question="Q") # type: ignore[arg-type] + + def test_rejects_blank_question(self): + with self.assertRaises(MultimodalQaError): + QaRequest(image_bytes=b"x", question=" ") + + def test_b64_image(self): + req = QaRequest(image_bytes=b"hello", question="Q") + self.assertEqual(req.b64_image(), "aGVsbG8=") + + +class TestBuildPrompt(unittest.TestCase): + + def test_includes_question(self): + prompt = build_prompt(QaRequest(image_bytes=b"x", question="Is it red?")) + self.assertIn("Is it red?", prompt) + self.assertIn("verdict", prompt) + + def test_includes_rubric(self): + prompt = build_prompt(QaRequest( + image_bytes=b"x", question="Q", + rubric=["button visible", "no overlap"], + )) + self.assertIn("Rubric", prompt) + self.assertIn("button visible", prompt) + + +class TestParseResponse(unittest.TestCase): + + def test_parses_clean_pass(self): + response = parse_response(_good_response()) + self.assertEqual(response.verdict, Verdict.PASS) + self.assertEqual(response.confidence, 0.9) + self.assertTrue(response.is_pass()) + + def test_parses_fail_with_issues(self): + raw = _good_response(verdict="fail", issues=["text cropped", "wrong color"]) + response = parse_response(raw) + self.assertEqual(response.verdict, Verdict.FAIL) + self.assertEqual(len(response.issues), 2) + + def test_extracts_from_surrounding_text(self): + raw = "Sure thing! Here is the analysis:\n" + _good_response() + "\nLet me know." + response = parse_response(raw) + self.assertEqual(response.verdict, Verdict.PASS) + + def test_rejects_empty_response(self): + with self.assertRaises(MultimodalQaError): + parse_response("") + + def test_rejects_no_json(self): + with self.assertRaises(MultimodalQaError): + parse_response("no json here") + + def test_rejects_bad_json(self): + with self.assertRaises(MultimodalQaError): + parse_response("{not really json}") + + def test_rejects_unknown_verdict(self): + raw = json.dumps({"verdict": "maybe", "confidence": 0.5, "rationale": "x"}) + with self.assertRaises(MultimodalQaError): + parse_response(raw) + + def test_rejects_missing_confidence(self): + raw = json.dumps({"verdict": "pass", "rationale": "x"}) + with self.assertRaises(MultimodalQaError): + parse_response(raw) + + def test_clamps_confidence(self): + raw = json.dumps({"verdict": "pass", "confidence": 5.0, "rationale": "x"}) + response = parse_response(raw) + self.assertEqual(response.confidence, 1.0) + + def test_rejects_non_list_issues(self): + raw = json.dumps({ + "verdict": "fail", "confidence": 0.5, "rationale": "x", "issues": "oops", + }) + with self.assertRaises(MultimodalQaError): + parse_response(raw) + + +class TestAsk(unittest.TestCase): + + def test_round_trip(self): + client = StubClient(_good_response()) + response = ask(QaRequest(image_bytes=b"hi", question="Q?"), client) + self.assertTrue(response.is_pass()) + self.assertIn("Q?", client.last_prompt) + self.assertEqual(client.last_image, "aGk=") + + def test_client_error_wrapped(self): + client = StubClient(RuntimeError("rate limit")) + with self.assertRaises(MultimodalQaError): + ask(QaRequest(image_bytes=b"x", question="Q"), client) + + +class TestAskPath(unittest.TestCase): + + def test_reads_file(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "shot.png" + p.write_bytes(b"\x89PNG") + client = StubClient(_good_response()) + response = ask_path(p, "ok?", client) + self.assertTrue(response.is_pass()) + + def test_missing_file(self): + with self.assertRaises(MultimodalQaError): + ask_path("/no/such/file.png", "Q", StubClient(_good_response())) + + +class TestAssertPasses(unittest.TestCase): + + def test_pass(self): + assert_passes(parse_response(_good_response())) + + def test_fail(self): + with self.assertRaises(MultimodalQaError): + assert_passes(parse_response(_good_response(verdict="fail"))) + + def test_pass_low_confidence_fails(self): + with self.assertRaises(MultimodalQaError): + assert_passes( + parse_response(_good_response(confidence=0.2)), + min_confidence=0.5, + ) + + def test_bad_min_confidence(self): + with self.assertRaises(MultimodalQaError): + assert_passes(parse_response(_good_response()), min_confidence=2.0) + + def test_rejects_non_response(self): + with self.assertRaises(MultimodalQaError): + assert_passes("not a response") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_mutation_testing.py b/test/unit_test/test_mutation_testing.py new file mode 100644 index 0000000..98bf4d2 --- /dev/null +++ b/test/unit_test/test_mutation_testing.py @@ -0,0 +1,232 @@ +"""Unit tests for je_web_runner.utils.mutation_testing.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.mutation_testing.mutator import ( + Mutation, + MutationScore, + MutationTestingError, + MutationType, + apply_mutation, + assert_min_score, + generate_mutations, + render_mutation_markdown, + run_mutation_testing, + run_mutation_testing_on_file, +) + + +SAMPLE_ACTIONS = [ + ["WR_to_url", {"url": "https://shop/login"}], + ["WR_save_test_object", {"test_object_name": "user", "object_type": "ID"}], + ["WR_element_input", {"test_object_name": "user", "text": "alice", "timeout": 10}], + ["WR_element_click", {"test_object_name": "submit"}], + ["WR_element_assert", {"test_object_name": "welcome", "expected_text": "Hi alice"}], +] + + +class TestGenerateMutations(unittest.TestCase): + + def test_all_types_produce_at_least_one(self): + mutations = generate_mutations(SAMPLE_ACTIONS) + types = {m.type for m in mutations} + self.assertIn(MutationType.LOCATOR_SWAP, types) + self.assertIn(MutationType.TIMEOUT_SHRINK, types) + self.assertIn(MutationType.URL_CHANGE, types) + self.assertIn(MutationType.ASSERTION_FLIP, types) + self.assertIn(MutationType.ACTION_REMOVAL, types) + self.assertIn(MutationType.ADJACENT_REORDER, types) + + def test_max_per_type_caps(self): + mutations = generate_mutations(SAMPLE_ACTIONS, max_per_type=1, seed=42) + per_type: dict = {} + for m in mutations: + per_type[m.type] = per_type.get(m.type, 0) + 1 + for count in per_type.values(): + self.assertLessEqual(count, 1) + + def test_non_list_raises(self): + with self.assertRaises(MutationTestingError): + generate_mutations("not a list") # type: ignore[arg-type] + + def test_url_change_skips_non_url(self): + actions = [["WR_element_click", {"test_object_name": "x"}]] + mutations = generate_mutations(actions, types=[MutationType.URL_CHANGE]) + self.assertEqual(mutations, []) + + def test_action_removal_skips_quit_init(self): + actions = [ + ["WR_init", {}], + ["WR_element_click", {"test_object_name": "x"}], + ["WR_quit_all"], + ] + mutations = generate_mutations(actions, types=[MutationType.ACTION_REMOVAL]) + self.assertEqual(len(mutations), 1) + self.assertEqual(mutations[0].action_index, 1) + + +class TestApplyMutation(unittest.TestCase): + + def test_locator_swap(self): + m = Mutation( + type=MutationType.LOCATOR_SWAP, action_index=1, + description="x", original="user", mutated="__mutated__", + ) + new = apply_mutation(SAMPLE_ACTIONS, m) + self.assertEqual(new[1][1]["test_object_name"], "__mutated__") + self.assertEqual(SAMPLE_ACTIONS[1][1]["test_object_name"], "user") + + def test_timeout_shrink(self): + m = Mutation( + type=MutationType.TIMEOUT_SHRINK, action_index=2, + description="x", original=10, mutated=0.001, + ) + new = apply_mutation(SAMPLE_ACTIONS, m) + self.assertEqual(new[2][1]["timeout"], 0.001) + + def test_url_change(self): + m = Mutation( + type=MutationType.URL_CHANGE, action_index=0, + description="x", original="https://shop/login", + mutated="https://example.invalid/mut", + ) + new = apply_mutation(SAMPLE_ACTIONS, m) + self.assertEqual(new[0][1]["url"], "https://example.invalid/mut") + + def test_assertion_flip_string(self): + m = Mutation( + type=MutationType.ASSERTION_FLIP, action_index=4, + description="x", original="Hi alice", + mutated="Hi alice__MUTATED__", + ) + new = apply_mutation(SAMPLE_ACTIONS, m) + self.assertEqual(new[4][1]["expected_text"], "Hi alice__MUTATED__") + + def test_action_removal_shortens_list(self): + m = Mutation( + type=MutationType.ACTION_REMOVAL, action_index=3, + description="remove click", original=SAMPLE_ACTIONS[3], mutated=None, + ) + new = apply_mutation(SAMPLE_ACTIONS, m) + self.assertEqual(len(new), len(SAMPLE_ACTIONS) - 1) + + def test_adjacent_reorder_swaps(self): + m = Mutation( + type=MutationType.ADJACENT_REORDER, action_index=2, + description="x", original=("WR_element_input", "WR_element_click"), + mutated=("WR_element_click", "WR_element_input"), + ) + new = apply_mutation(SAMPLE_ACTIONS, m) + self.assertEqual(new[2][0], "WR_element_click") + self.assertEqual(new[3][0], "WR_element_input") + + def test_out_of_range_raises(self): + m = Mutation( + type=MutationType.LOCATOR_SWAP, action_index=99, + description="x", original="user", mutated="x", + ) + with self.assertRaises(MutationTestingError): + apply_mutation(SAMPLE_ACTIONS, m) + + def test_reorder_at_last_index_raises(self): + m = Mutation( + type=MutationType.ADJACENT_REORDER, action_index=len(SAMPLE_ACTIONS) - 1, + description="x", original=("a", "b"), mutated=("b", "a"), + ) + with self.assertRaises(MutationTestingError): + apply_mutation(SAMPLE_ACTIONS, m) + + +class TestRunMutationTesting(unittest.TestCase): + + def test_all_killed_when_executor_always_fails(self): + score = run_mutation_testing(SAMPLE_ACTIONS, lambda _a: False) + self.assertEqual(score.killed, score.total) + self.assertEqual(score.survived, 0) + self.assertEqual(score.score, 1.0) + + def test_all_survived_when_executor_always_passes(self): + score = run_mutation_testing(SAMPLE_ACTIONS, lambda _a: True) + self.assertEqual(score.killed, 0) + self.assertEqual(score.survived, score.total) + self.assertEqual(score.score, 0.0) + + def test_executor_exception_counts_as_kill(self): + def boom(_a): + raise RuntimeError("boom") + score = run_mutation_testing( + SAMPLE_ACTIONS, boom, + types=[MutationType.LOCATOR_SWAP], + ) + self.assertTrue(all(r.killed for r in score.results)) + self.assertTrue(all(r.error for r in score.results)) + + def test_stop_on_first_survivor(self): + calls = {"n": 0} + + def survivor(_a): + calls["n"] += 1 + return True # always pass + + score = run_mutation_testing( + SAMPLE_ACTIONS, survivor, + types=[MutationType.LOCATOR_SWAP], + stop_on_first_survivor=True, + ) + self.assertEqual(calls["n"], 1) + self.assertEqual(score.total, 1) + self.assertEqual(score.survived, 1) + + +class TestRunOnFile(unittest.TestCase): + + def test_missing_file_raises(self): + with self.assertRaises(MutationTestingError): + run_mutation_testing_on_file("/no/such.json", lambda _a: False) + + def test_malformed_top_level_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "a.json" + path.write_text('{"not": "a list"}', encoding="utf-8") + with self.assertRaises(MutationTestingError): + run_mutation_testing_on_file(path, lambda _a: False) + + def test_round_trip(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "a.json" + path.write_text(json.dumps(SAMPLE_ACTIONS), encoding="utf-8") + score = run_mutation_testing_on_file(path, lambda _a: False) + self.assertGreater(score.total, 0) + self.assertEqual(score.killed, score.total) + + +class TestRendering(unittest.TestCase): + + def test_markdown_includes_survivors(self): + score = run_mutation_testing(SAMPLE_ACTIONS, lambda _a: True) + md = render_mutation_markdown(score) + self.assertIn("Mutation score", md) + self.assertIn("Surviving mutations", md) + + def test_markdown_omits_survivors_when_none(self): + score = run_mutation_testing(SAMPLE_ACTIONS, lambda _a: False) + md = render_mutation_markdown(score) + self.assertNotIn("Surviving mutations", md) + + +class TestAssertMinScore(unittest.TestCase): + + def test_pass(self): + score = MutationScore(total=10, killed=9, survived=1, score=0.9) + assert_min_score(score, minimum=0.8) + + def test_fail(self): + score = MutationScore(total=10, killed=5, survived=5, score=0.5) + with self.assertRaises(MutationTestingError): + assert_min_score(score, minimum=0.8) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_notifications_audit.py b/test/unit_test/test_notifications_audit.py new file mode 100644 index 0000000..8662731 --- /dev/null +++ b/test/unit_test/test_notifications_audit.py @@ -0,0 +1,204 @@ +"""Unit tests for je_web_runner.utils.notifications_audit.""" +import unittest + +from je_web_runner.utils.notifications_audit.audit import ( + HARVEST_SCRIPT, + NotificationShown, + NotificationsAuditError, + NotificationsLog, + PermissionRequest, + PermissionResult, + assert_no_prompt_before, + assert_no_prompt_without_gesture, + assert_no_spam_after_deny, + assert_notification_shown, + assert_unique_tags, + build_install_script, + parse_log, +) + + +def _payload(requests=None, notifications=None): + return { + "permission_requests": requests or [], + "notifications": notifications or [], + } + + +class TestScripts(unittest.TestCase): + + def test_install_script_install_guard(self): + js = build_install_script() + self.assertIn("__wr_notif_installed__", js) + self.assertIn("requestPermission", js) + + def test_harvest_constant(self): + self.assertIn("__wr_notif_log__", HARVEST_SCRIPT) + + +class TestParseLog(unittest.TestCase): + + def test_basic(self): + log = parse_log(_payload( + requests=[{"timestamp_ms": 100, "user_gesture": True, + "result": "granted", "page_age_ms": 100}], + notifications=[{"timestamp_ms": 200, "title": "hi"}], + )) + self.assertEqual(len(log.permission_requests), 1) + self.assertEqual(log.permission_requests[0].result, PermissionResult.GRANTED) + self.assertEqual(log.notifications[0].title, "hi") + + def test_unknown_result_defaults_to_default(self): + log = parse_log(_payload( + requests=[{"timestamp_ms": 1, "user_gesture": True, "result": "weird"}], + )) + self.assertEqual(log.permission_requests[0].result, PermissionResult.DEFAULT) + + def test_skips_non_dict_entries(self): + log = parse_log(_payload( + requests=["not dict"], notifications=[None], + )) + self.assertEqual(log.permission_requests, []) + self.assertEqual(log.notifications, []) + + def test_rejects_non_dict_payload(self): + with self.assertRaises(NotificationsAuditError): + parse_log("nope") # type: ignore[arg-type] + + +class TestAssertGesture(unittest.TestCase): + + def test_pass(self): + assert_no_prompt_without_gesture(parse_log(_payload( + requests=[{"timestamp_ms": 1, "user_gesture": True, "result": "default"}], + ))) + + def test_fail(self): + with self.assertRaises(NotificationsAuditError): + assert_no_prompt_without_gesture(parse_log(_payload( + requests=[{"timestamp_ms": 1, "user_gesture": False, + "result": "default", "page_age_ms": 50}], + ))) + + +class TestAssertNoPromptBefore(unittest.TestCase): + + def test_pass(self): + assert_no_prompt_before( + parse_log(_payload( + requests=[{"timestamp_ms": 1, "user_gesture": True, + "result": "default", "page_age_ms": 2000}], + )), + min_page_age_ms=1000, + ) + + def test_fail(self): + with self.assertRaises(NotificationsAuditError): + assert_no_prompt_before( + parse_log(_payload( + requests=[{"timestamp_ms": 1, "user_gesture": True, + "result": "default", "page_age_ms": 100}], + )), + min_page_age_ms=1000, + ) + + def test_bad_threshold(self): + with self.assertRaises(NotificationsAuditError): + assert_no_prompt_before(NotificationsLog(), min_page_age_ms=-1) + + +class TestAssertNoSpamAfterDeny(unittest.TestCase): + + def test_no_deny_no_op(self): + assert_no_spam_after_deny(parse_log(_payload( + notifications=[{"timestamp_ms": 1, "title": "x"}], + ))) + + def test_pass(self): + assert_no_spam_after_deny(parse_log(_payload( + requests=[{"timestamp_ms": 100, "user_gesture": True, "result": "denied"}], + ))) + + def test_reprompt_after_deny_fails(self): + with self.assertRaises(NotificationsAuditError): + assert_no_spam_after_deny(parse_log(_payload( + requests=[ + {"timestamp_ms": 100, "user_gesture": True, "result": "denied"}, + {"timestamp_ms": 200, "user_gesture": True, "result": "default"}, + ], + ))) + + def test_notification_after_deny_fails(self): + with self.assertRaises(NotificationsAuditError): + assert_no_spam_after_deny(parse_log(_payload( + requests=[{"timestamp_ms": 100, "user_gesture": True, "result": "denied"}], + notifications=[{"timestamp_ms": 200, "title": "later notif"}], + ))) + + +class TestAssertShown(unittest.TestCase): + + def _log(self): + return parse_log(_payload(notifications=[ + {"timestamp_ms": 1, "title": "Order #1234 shipped", "body": "Track here", "tag": "order"}, + {"timestamp_ms": 2, "title": "New message", "body": "from alice"}, + ])) + + def test_by_title(self): + n = assert_notification_shown(self._log(), title_contains="Order") + self.assertEqual(n.tag, "order") + + def test_by_body(self): + n = assert_notification_shown(self._log(), body_contains="alice") + self.assertEqual(n.title, "New message") + + def test_by_tag(self): + n = assert_notification_shown(self._log(), tag="order") + self.assertIn("Order", n.title) + + def test_combined(self): + n = assert_notification_shown( + self._log(), title_contains="Order", tag="order", + ) + self.assertIsInstance(n, NotificationShown) + + def test_miss(self): + with self.assertRaises(NotificationsAuditError): + assert_notification_shown(self._log(), title_contains="missing") + + def test_no_filter(self): + with self.assertRaises(NotificationsAuditError): + assert_notification_shown(self._log()) + + +class TestUniqueTags(unittest.TestCase): + + def test_pass(self): + log = parse_log(_payload(notifications=[ + {"timestamp_ms": 1, "title": "a", "tag": "x"}, + {"timestamp_ms": 2, "title": "b", "tag": "y"}, + {"timestamp_ms": 3, "title": "c"}, + ])) + assert_unique_tags(log) + + def test_reuse_fails(self): + log = parse_log(_payload(notifications=[ + {"timestamp_ms": 1, "title": "a", "tag": "x"}, + {"timestamp_ms": 2, "title": "b", "tag": "x"}, + ])) + with self.assertRaises(NotificationsAuditError): + assert_unique_tags(log) + + +class TestDictRoundTrip(unittest.TestCase): + + def test_request_to_dict(self): + req = PermissionRequest( + timestamp_ms=1.0, user_gesture=True, + result=PermissionResult.GRANTED, page_age_ms=2.0, + ) + self.assertEqual(req.to_dict()["result"], "granted") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_ocr_assert.py b/test/unit_test/test_ocr_assert.py new file mode 100644 index 0000000..5fa77bf --- /dev/null +++ b/test/unit_test/test_ocr_assert.py @@ -0,0 +1,145 @@ +"""Unit tests for je_web_runner.utils.ocr_assert.""" +import unittest + +from je_web_runner.utils.ocr_assert.ocr import ( + OcrAssertError, + OcrMatchResult, + assert_text_any, + assert_text_contains, + assert_text_fuzzy, + extract_text, + fuzzy_ratio, + normalise_text, +) + + +def _fake_backend(text): + def _b(_source): + return text + return _b + + +class TestNormalisation(unittest.TestCase): + + def test_collapses_whitespace(self): + self.assertEqual(normalise_text("hello world\n\n"), "hello world") + + def test_lowercases_by_default(self): + self.assertEqual(normalise_text("Hello"), "hello") + + def test_keeps_case_when_disabled(self): + self.assertEqual(normalise_text("Hello", lowercase=False), "Hello") + + def test_strips_accents(self): + self.assertEqual(normalise_text("café"), "cafe") + + def test_rejects_non_string(self): + with self.assertRaises(OcrAssertError): + normalise_text(123) + + def test_fuzzy_ratio_identical(self): + self.assertEqual(fuzzy_ratio("hello", "hello"), 1.0) + + def test_fuzzy_ratio_partial(self): + ratio = fuzzy_ratio("hello world", "helo world") + self.assertGreater(ratio, 0.8) + self.assertLess(ratio, 1.0) + + +class TestExtractText(unittest.TestCase): + + def test_uses_custom_backend(self): + text = extract_text(b"", backend=_fake_backend("OCR OK")) + self.assertEqual(text, "OCR OK") + + def test_rejects_non_string_backend_output(self): + with self.assertRaises(OcrAssertError): + extract_text(b"", backend=lambda _: 123) + + +class TestAssertContains(unittest.TestCase): + + def test_match(self): + result = assert_text_contains( + b"", "world", backend=_fake_backend("Hello, World!"), + ) + self.assertTrue(result.matched) + self.assertEqual(result.score, 1.0) + + def test_miss(self): + result = assert_text_contains( + b"", "missing", backend=_fake_backend("Hello, World!"), + ) + self.assertFalse(result.matched) + with self.assertRaises(OcrAssertError): + result.raise_if_failed() + + def test_case_sensitive_miss(self): + result = assert_text_contains( + b"", "WORLD", + backend=_fake_backend("hello world"), + case_sensitive=True, + ) + self.assertFalse(result.matched) + + def test_rejects_empty_needle(self): + with self.assertRaises(OcrAssertError): + assert_text_contains(b"", "", backend=_fake_backend("x")) + + +class TestAssertFuzzy(unittest.TestCase): + + def test_close_match(self): + result = assert_text_fuzzy( + b"", "Quarterly Revenue", + min_ratio=0.7, + backend=_fake_backend("Quartely Revnue"), + ) + self.assertTrue(result.matched) + self.assertGreater(result.score, 0.7) + + def test_far_off_fails(self): + result = assert_text_fuzzy( + b"", "Quarterly Revenue", + min_ratio=0.9, + backend=_fake_backend("totally different text here"), + ) + self.assertFalse(result.matched) + + def test_rejects_bad_ratio(self): + with self.assertRaises(OcrAssertError): + assert_text_fuzzy(b"", "x", min_ratio=0.0, backend=_fake_backend("x")) + with self.assertRaises(OcrAssertError): + assert_text_fuzzy(b"", "x", min_ratio=1.5, backend=_fake_backend("x")) + + +class TestAssertAny(unittest.TestCase): + + def test_finds_first_match(self): + result = assert_text_any( + b"", ["foo", "bar", "baz"], + backend=_fake_backend("There is a bar here"), + ) + self.assertTrue(result.matched) + self.assertEqual(result.needle, "bar") + + def test_no_match(self): + result = assert_text_any( + b"", ["alpha", "beta"], + backend=_fake_backend("nothing relevant"), + ) + self.assertFalse(result.matched) + + def test_rejects_empty(self): + with self.assertRaises(OcrAssertError): + assert_text_any(b"", [], backend=_fake_backend("x")) + + +class TestOcrMatchResult(unittest.TestCase): + + def test_raise_if_failed_noop_when_matched(self): + OcrMatchResult(matched=True, mode="x", needle="n", haystack="h").raise_if_failed() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_open_redirect_detector.py b/test/unit_test/test_open_redirect_detector.py new file mode 100644 index 0000000..5a8e7bc --- /dev/null +++ b/test/unit_test/test_open_redirect_detector.py @@ -0,0 +1,183 @@ +"""Unit tests for je_web_runner.utils.open_redirect_detector.""" +import unittest + +from je_web_runner.utils.open_redirect_detector.detector import ( + OpenRedirectError, + ProbeOutcome, + ProbeReport, + ProbeResponse, + ProbeResult, + RedirectPayload, + assert_safe, + classify_response, + default_payloads, + probe_all, +) + + +class TestPayloads(unittest.TestCase): + + def test_default_set_includes_classics(self): + names = {p.label for p in default_payloads("evil.example")} + for needed in ("absolute_http", "scheme_relative", "javascript_uri", + "data_uri", "at_sign_userinfo"): + self.assertIn(needed, names) + + def test_attacker_host_required(self): + with self.assertRaises(OpenRedirectError): + default_payloads("no_dot") + + +class TestClassify(unittest.TestCase): + + def _p(self, value): + return RedirectPayload(label="x", value=value) + + def test_off_origin_absolute_allowed(self): + r = classify_response( + self._p("https://evil.example/"), + "https://evil.example/", 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.ALLOWED) + + def test_same_origin_blocked(self): + r = classify_response( + self._p("/home"), "https://trusted.com/home", 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.BLOCKED) + + def test_subdomain_same_org(self): + r = classify_response( + self._p("//sub.trusted.com/"), + "https://sub.trusted.com/", 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.BLOCKED) + + def test_scheme_relative_evil(self): + r = classify_response( + self._p("//evil.example/"), + "//evil.example/", 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.ALLOWED) + + def test_javascript_uri(self): + r = classify_response( + self._p("javascript:alert(1)"), + "javascript:alert(1)", 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.ALLOWED) + + def test_data_uri(self): + r = classify_response( + self._p("data:text/html,x"), + "data:text/html,x", 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.ALLOWED) + + def test_non_redirect_status_blocked(self): + r = classify_response( + self._p("x"), "https://evil.example/", 200, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.BLOCKED) + + def test_empty_location_ambiguous(self): + r = classify_response( + self._p("x"), None, 302, + legitimate_host="trusted.com", + ) + self.assertEqual(r.outcome, ProbeOutcome.AMBIGUOUS) + + def test_at_sign_resolves_evil(self): + r = classify_response( + self._p("https://trusted.com@evil.example/"), + "https://trusted.com@evil.example/", 302, + legitimate_host="trusted.com", + ) + # Hostname after @ is evil.example + self.assertEqual(r.outcome, ProbeOutcome.ALLOWED) + + def test_bad_status_type(self): + with self.assertRaises(OpenRedirectError): + classify_response(self._p("x"), None, "302", # type: ignore[arg-type] + legitimate_host="x.com") + + def test_bad_host(self): + with self.assertRaises(OpenRedirectError): + classify_response(self._p("x"), None, 302, legitimate_host="") + + +class TestProbeAll(unittest.TestCase): + + def test_runs_all_payloads(self): + def probe(value): + # Naive vulnerable app: always 302s to the input + return ProbeResponse(status_code=302, location=value) + report = probe_all( + default_payloads("evil.example"), + probe, + legitimate_host="trusted.com", + ) + self.assertEqual(len(report.results), len(default_payloads("evil.example"))) + self.assertFalse(report.passed()) + + def test_safe_app(self): + def probe(_value): + return ProbeResponse(status_code=302, location="https://trusted.com/") + report = probe_all( + default_payloads("evil.example"), + probe, legitimate_host="trusted.com", + ) + self.assertTrue(report.passed()) + + def test_empty_payloads(self): + with self.assertRaises(OpenRedirectError): + probe_all([], lambda _v: ProbeResponse(200, None), legitimate_host="x.com") + + def test_non_callable_probe(self): + with self.assertRaises(OpenRedirectError): + probe_all([RedirectPayload("x", "y")], "not callable", # type: ignore[arg-type] + legitimate_host="x.com") + + def test_probe_exception_wrapped(self): + def boom(_): + raise RuntimeError("net") + with self.assertRaises(OpenRedirectError): + probe_all([RedirectPayload("x", "y")], boom, legitimate_host="x.com") + + def test_bad_probe_return(self): + def bad(_): + return "not a probe response" + with self.assertRaises(OpenRedirectError): + probe_all([RedirectPayload("x", "y")], bad, legitimate_host="x.com") + + +class TestAssertSafe(unittest.TestCase): + + def test_pass(self): + assert_safe(ProbeReport(legitimate_host="x")) + + def test_fail(self): + report = ProbeReport(legitimate_host="x", results=[ + ProbeResult( + payload=RedirectPayload("evil", "//evil/"), + final_location="//evil/", status_code=302, + outcome=ProbeOutcome.ALLOWED, + ), + ]) + with self.assertRaises(OpenRedirectError): + assert_safe(report) + + def test_rejects_non_report(self): + with self.assertRaises(OpenRedirectError): + assert_safe("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_openapi_to_e2e.py b/test/unit_test/test_openapi_to_e2e.py new file mode 100644 index 0000000..bf06871 --- /dev/null +++ b/test/unit_test/test_openapi_to_e2e.py @@ -0,0 +1,260 @@ +"""Unit tests for je_web_runner.utils.openapi_to_e2e.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.openapi_to_e2e.generator import ( + GeneratedTest, + GenerationResult, + OpenAPIGeneratorError, + generate_tests_from_file, + generate_tests_from_spec, + load_spec, + synthesize_example, + write_tests_to_dir, +) + + +_PET_SPEC = { + "openapi": "3.0.0", + "info": {"title": "Pet Store", "version": "1.0"}, + "servers": [{"url": "https://api.pets.example/v1"}], + "components": { + "securitySchemes": { + "bearer": {"type": "http", "scheme": "bearer"}, + }, + "schemas": { + "Pet": { + "type": "object", + "required": ["name"], + "properties": { + "id": {"type": "integer", "example": 7}, + "name": {"type": "string", "example": "Rex"}, + "tags": {"type": "array", + "items": {"type": "string", "example": "good"}}, + }, + }, + }, + }, + "paths": { + "/pets": { + "get": { + "operationId": "listPets", + "responses": {"200": {"description": "ok"}}, + }, + "post": { + "operationId": "createPet", + "requestBody": { + "required": True, + "content": {"application/json": { + "schema": {"$ref": "#/components/schemas/Pet"}, + }}, + }, + "responses": {"201": {"description": "created"}}, + }, + }, + "/pets/{petId}": { + "get": { + "operationId": "getPet", + "parameters": [ + {"name": "petId", "in": "path", + "required": True, "schema": {"type": "integer", "example": 1}}, + ], + "responses": {"200": {"description": "ok"}}, + }, + "delete": { + "operationId": "deletePet", + "parameters": [ + {"name": "petId", "in": "path", + "required": True, "schema": {"type": "integer"}}, + ], + "responses": {"204": {"description": "gone"}}, + }, + }, + }, +} + + +class TestSynthesizeExample(unittest.TestCase): + + def test_explicit_example(self): + result = synthesize_example({}, {"type": "string", "example": "Rex"}) + self.assertEqual(result, "Rex") + + def test_default_fallback(self): + result = synthesize_example({}, {"type": "integer", "default": 42}) + self.assertEqual(result, 42) + + def test_type_only(self): + self.assertEqual(synthesize_example({}, {"type": "string"}), "sample") + self.assertEqual(synthesize_example({}, {"type": "integer"}), 1) + + def test_object_with_required(self): + schema = { + "type": "object", + "required": ["name"], + "properties": { + "name": {"type": "string", "example": "Rex"}, + "age": {"type": "integer", "example": 3}, + }, + } + result = synthesize_example({}, schema) + self.assertEqual(result, {"name": "Rex"}) + + def test_array(self): + result = synthesize_example({}, { + "type": "array", + "items": {"type": "string", "example": "x"}, + }) + self.assertEqual(result, ["x"]) + + def test_enum(self): + self.assertEqual(synthesize_example({}, {"enum": ["a", "b"]}), "a") + + def test_ref_resolution(self): + spec = { + "components": {"schemas": {"X": {"type": "string", "example": "ref"}}}, + } + result = synthesize_example(spec, {"$ref": "#/components/schemas/X"}) + self.assertEqual(result, "ref") + + +class TestLoadSpec(unittest.TestCase): + + def test_loads_json(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "spec.json" + path.write_text(json.dumps(_PET_SPEC), encoding="utf-8") + spec = load_spec(path) + self.assertEqual(spec["info"]["title"], "Pet Store") + + def test_missing_file_raises(self): + with self.assertRaises(OpenAPIGeneratorError): + load_spec("/no/such.json") + + def test_invalid_json_yaml_attempt(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "spec.json" + path.write_text("{not json", encoding="utf-8") + # without pyyaml installed this raises; with pyyaml it may parse + # as a scalar but our top-level check kicks in. + try: + import yaml # noqa: F401 + except ImportError: + with self.assertRaises(OpenAPIGeneratorError): + load_spec(path) + + +class TestGenerate(unittest.TestCase): + + def test_happy_path_generated_for_each_method(self): + result = generate_tests_from_spec(_PET_SPEC, include_negative=False) + names = {t.name for t in result.tests} + self.assertIn("listPets__happy", names) + self.assertIn("createPet__happy", names) + self.assertIn("getPet__happy", names) + self.assertIn("deletePet__happy", names) + self.assertEqual(len(result.tests), 4) + + def test_negative_tests_added(self): + result = generate_tests_from_spec(_PET_SPEC, include_negative=True) + names = {t.name for t in result.tests} + # POST should get a missing-body variant + self.assertIn("createPet__missing_body", names) + # GET /pets/{petId} should get a bad path-param variant + self.assertIn("getPet__bad_path_param", names) + + def test_url_includes_base(self): + result = generate_tests_from_spec(_PET_SPEC) + get_pet = next(t for t in result.tests if t.name == "getPet__happy") + url = get_pet.actions[0][1]["url"] + self.assertTrue(url.startswith("https://api.pets.example/v1/pets/")) + self.assertIn("1", url.rsplit("/", 1)[-1]) + + def test_auth_header_injected(self): + result = generate_tests_from_spec(_PET_SPEC) + first = result.tests[0] + headers = first.actions[0][1].get("headers") or {} + self.assertEqual(headers.get("Authorization"), "Bearer ${API_TOKEN}") + + def test_path_prefix_filter(self): + result = generate_tests_from_spec( + _PET_SPEC, path_prefix_filter="/pets/{", include_negative=False, + ) + names = {t.name for t in result.tests} + self.assertNotIn("listPets__happy", names) + self.assertIn("getPet__happy", names) + + def test_method_filter(self): + result = generate_tests_from_spec( + _PET_SPEC, method_filter={"get"}, include_negative=False, + ) + for t in result.tests: + self.assertEqual(t.method, "GET") + + def test_assert_status_action_present(self): + result = generate_tests_from_spec(_PET_SPEC, include_negative=False) + for t in result.tests: + last = t.actions[-1] + self.assertEqual(last[0], "WR_http_assert_status") + + def test_request_body_synthesised(self): + result = generate_tests_from_spec(_PET_SPEC) + create = next(t for t in result.tests if t.name == "createPet__happy") + body = create.actions[0][1].get("json_body") + self.assertEqual(body, {"name": "Rex"}) + + def test_swagger2_style(self): + spec = { + "swagger": "2.0", + "info": {"title": "Old"}, + "host": "api.old.example", + "basePath": "/v2", + "schemes": ["https"], + "paths": { + "/items": { + "get": {"operationId": "listItems", + "responses": {"200": {"description": "ok"}}}, + }, + }, + } + result = generate_tests_from_spec(spec) + self.assertEqual(result.base_url, "https://api.old.example/v2") + self.assertEqual(len(result.tests), 1) + + def test_invalid_spec_raises(self): + with self.assertRaises(OpenAPIGeneratorError): + generate_tests_from_spec({}) # no paths + with self.assertRaises(OpenAPIGeneratorError): + generate_tests_from_spec("not a dict") # type: ignore[arg-type] + + def test_unsupported_method_skipped(self): + spec = { + "openapi": "3.0.0", + "info": {"title": "x"}, + "servers": [{"url": "https://x"}], + "paths": {"/x": {"trace": {"responses": {"200": {"description": "x"}}}}}, + } + result = generate_tests_from_spec(spec) + self.assertEqual(len(result.tests), 0) + + +class TestFromFileAndWrite(unittest.TestCase): + + def test_round_trip_to_files(self): + with tempfile.TemporaryDirectory() as tmpdir: + spec_path = Path(tmpdir) / "spec.json" + spec_path.write_text(json.dumps(_PET_SPEC), encoding="utf-8") + result = generate_tests_from_file(spec_path, include_negative=False) + out_dir = Path(tmpdir) / "out" + written = write_tests_to_dir(result, out_dir) + self.assertEqual(len(written), 4) + for path in written: + self.assertTrue(path.exists()) + payload = json.loads(path.read_text(encoding="utf-8")) + self.assertIsInstance(payload, list) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_otel_bridge.py b/test/unit_test/test_otel_bridge.py new file mode 100644 index 0000000..cac308b --- /dev/null +++ b/test/unit_test/test_otel_bridge.py @@ -0,0 +1,201 @@ +"""Unit tests for je_web_runner.utils.otel_bridge.""" +import unittest +from unittest.mock import MagicMock, patch + +from je_web_runner.utils.otel_bridge.trace_bridge import ( + TraceBridgeError, + TraceContext, + bridged_span_playwright, + bridged_span_selenium, + clear_headers_playwright, + clear_headers_selenium, + current_otel_context, + inject_headers_playwright, + inject_headers_selenium, + parse_traceparent, + random_trace_context, + trace_link, +) + + +class TestTraceContext(unittest.TestCase): + + def test_traceparent_format(self): + ctx = TraceContext( + trace_id="a" * 32, span_id="b" * 16, sampled=True, + ) + self.assertEqual(ctx.to_traceparent(), f"00-{'a'*32}-{'b'*16}-01") + + def test_traceparent_not_sampled(self): + ctx = TraceContext(trace_id="a" * 32, span_id="b" * 16, sampled=False) + self.assertTrue(ctx.to_traceparent().endswith("-00")) + + def test_as_headers_includes_tracestate(self): + ctx = TraceContext( + trace_id="a" * 32, span_id="b" * 16, tracestate="vendor=x", + ) + headers = ctx.as_headers() + self.assertIn("traceparent", headers) + self.assertEqual(headers["tracestate"], "vendor=x") + + def test_as_headers_without_tracestate(self): + ctx = TraceContext(trace_id="a" * 32, span_id="b" * 16) + self.assertNotIn("tracestate", ctx.as_headers()) + + +class TestRandomContext(unittest.TestCase): + + def test_random_returns_valid_hex(self): + ctx = random_trace_context() + self.assertEqual(len(ctx.trace_id), 32) + self.assertEqual(len(ctx.span_id), 16) + # round-trip through parse_traceparent + parsed = parse_traceparent(ctx.to_traceparent()) + self.assertEqual(parsed.trace_id, ctx.trace_id) + self.assertEqual(parsed.span_id, ctx.span_id) + + +class TestParseTraceparent(unittest.TestCase): + + def test_round_trip(self): + original = TraceContext(trace_id="a" * 32, span_id="b" * 16) + parsed = parse_traceparent(original.to_traceparent()) + self.assertEqual(parsed.trace_id, original.trace_id) + + def test_malformed_raises(self): + with self.assertRaises(TraceBridgeError): + parse_traceparent("garbage") + with self.assertRaises(TraceBridgeError): + parse_traceparent("00-tooshort-tooshort-01") + + +class TestInjectSelenium(unittest.TestCase): + + def test_calls_cdp_with_headers(self): + ctx = TraceContext(trace_id="a" * 32, span_id="b" * 16) + driver = MagicMock() + inject_headers_selenium(driver, ctx) + calls = [c.args for c in driver.execute_cdp_cmd.call_args_list] + # at least one Network.setExtraHTTPHeaders call + set_calls = [c for c in calls if c[0] == "Network.setExtraHTTPHeaders"] + self.assertEqual(len(set_calls), 1) + self.assertIn("traceparent", set_calls[0][1]["headers"]) + + def test_no_driver_raises(self): + ctx = random_trace_context() + with self.assertRaises(TraceBridgeError): + inject_headers_selenium(None, ctx) + + def test_driver_without_cdp_raises(self): + ctx = random_trace_context() + driver = MagicMock(spec=["foo"]) # no execute_cdp_cmd + with self.assertRaises(TraceBridgeError): + inject_headers_selenium(driver, ctx) + + def test_cdp_error_wraps(self): + ctx = random_trace_context() + driver = MagicMock() + driver.execute_cdp_cmd.side_effect = RuntimeError("boom") + with self.assertRaises(TraceBridgeError): + inject_headers_selenium(driver, ctx) + + def test_clear_is_noop_on_missing_cdp(self): + clear_headers_selenium(None) + driver = MagicMock(spec=["foo"]) + clear_headers_selenium(driver) + + +class TestInjectPlaywright(unittest.TestCase): + + def test_set_extra_http_headers(self): + ctx = TraceContext(trace_id="a" * 32, span_id="b" * 16) + page = MagicMock() + inject_headers_playwright(page, ctx) + page.set_extra_http_headers.assert_called_once() + payload = page.set_extra_http_headers.call_args.args[0] + self.assertIn("traceparent", payload) + + def test_no_page_raises(self): + with self.assertRaises(TraceBridgeError): + inject_headers_playwright(None, random_trace_context()) + + def test_page_without_setter_raises(self): + page = MagicMock(spec=[]) + with self.assertRaises(TraceBridgeError): + inject_headers_playwright(page, random_trace_context()) + + def test_clear(self): + page = MagicMock() + clear_headers_playwright(page) + page.set_extra_http_headers.assert_called_once_with({}) + + +class TestCurrentOtelContext(unittest.TestCase): + + def test_returns_none_without_otel(self): + # Pretend OTel isn't importable. + import builtins + original_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "opentelemetry": + raise ImportError("simulated") + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=fake_import): + self.assertIsNone(current_otel_context()) + + +class TestBridgedSpan(unittest.TestCase): + + def test_selenium_context_manager_injects_and_clears(self): + driver = MagicMock() + with bridged_span_selenium(driver, "test_action") as ctx: + self.assertEqual(len(ctx.trace_id), 32) + # Should be called at least twice: enable + set + clear + names = [c.args[0] for c in driver.execute_cdp_cmd.call_args_list] + self.assertIn("Network.setExtraHTTPHeaders", names) + + def test_playwright_context_manager_clears(self): + page = MagicMock() + with bridged_span_playwright(page, "click"): + pass + # at least one set and one reset call + self.assertGreaterEqual(page.set_extra_http_headers.call_count, 2) + last_call = page.set_extra_http_headers.call_args_list[-1].args[0] + self.assertEqual(last_call, {}) + + def test_fallback_context_used_without_otel(self): + import builtins + original_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "opentelemetry": + raise ImportError("simulated") + return original_import(name, *args, **kwargs) + + fallback = TraceContext(trace_id="c" * 32, span_id="d" * 16) + with patch("builtins.__import__", side_effect=fake_import): + page = MagicMock() + with bridged_span_playwright(page, "x", fallback_context=fallback) as ctx: + self.assertEqual(ctx.trace_id, fallback.trace_id) + + +class TestTraceLink(unittest.TestCase): + + def test_jaeger_link(self): + ctx = TraceContext(trace_id="a" * 32, span_id="b" * 16) + link = trace_link(ctx, jaeger_base="https://jaeger.local/") + self.assertEqual(link, f"https://jaeger.local/trace/{'a'*32}") + + def test_tempo_link(self): + ctx = TraceContext(trace_id="a" * 32, span_id="b" * 16) + link = trace_link(ctx, tempo_base="https://tempo.local") + self.assertIn("a" * 32, link) + + def test_no_base_returns_none(self): + self.assertIsNone(trace_link(random_trace_context())) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_otp_interceptor.py b/test/unit_test/test_otp_interceptor.py new file mode 100644 index 0000000..736ea53 --- /dev/null +++ b/test/unit_test/test_otp_interceptor.py @@ -0,0 +1,221 @@ +"""Unit tests for je_web_runner.utils.otp_interceptor.""" +import time +import unittest +from unittest.mock import MagicMock + +from je_web_runner.utils.otp_interceptor.interceptor import ( + InMemoryProvider, + InterceptedMessage, + MailHogProvider, + MailpitProvider, + OtpInterceptError, + WebhookSmsProvider, + extract_otp_from_text, + wait_for_otp, +) + + +def _msg(recipient="a@x", body="Your code is 123456", subject="OTP", + received_at=None): + return InterceptedMessage( + message_id="m1", + sender="bot@x", + recipient=recipient, + subject=subject, + body=body, + received_at=received_at if received_at is not None else time.time(), + ) + + +class TestExtractOtp(unittest.TestCase): + + def test_default_extracts_6_digits(self): + self.assertEqual(extract_otp_from_text("Code: 482910 please"), "482910") + + def test_custom_pattern_with_group(self): + otp = extract_otp_from_text( + "Your one-time code: ABC-9990", + pattern=r"ABC-(\d{4})", + ) + self.assertEqual(otp, "9990") + + def test_no_match_returns_none(self): + self.assertIsNone(extract_otp_from_text("nothing here")) + + def test_empty_input(self): + self.assertIsNone(extract_otp_from_text("")) + self.assertIsNone(extract_otp_from_text(None)) # type: ignore[arg-type] + + +class TestInMemoryProvider(unittest.TestCase): + + def test_filters_by_recipient(self): + p = InMemoryProvider() + p.push(_msg(recipient="alice@x")) + p.push(_msg(recipient="bob@x")) + results = p.fetch_messages(recipient="alice@x") + self.assertEqual(len(results), 1) + self.assertEqual(results[0].recipient, "alice@x") + + def test_filters_by_since(self): + p = InMemoryProvider() + now = time.time() + p.push(_msg(received_at=now - 100)) + p.push(_msg(received_at=now + 100)) + results = p.fetch_messages(since=now) + self.assertEqual(len(results), 1) + + def test_newest_first(self): + p = InMemoryProvider() + p.push(_msg(body="old", received_at=10.0)) + p.push(_msg(body="new", received_at=20.0)) + results = p.fetch_messages() + self.assertEqual(results[0].body, "new") + + +class TestMailHogProvider(unittest.TestCase): + + def test_parses_v2_payload(self): + fake_fetch = MagicMock(return_value={ + "items": [ + { + "ID": "abc", + "Created": "2026-05-24T10:00:00Z", + "Content": { + "Headers": { + "From": ["a@x"], "To": ["b@x"], "Subject": ["Hi"], + }, + "Body": "Code 999111", + }, + } + ] + }) + provider = MailHogProvider("http://mailhog:8025", http_fetcher=fake_fetch) + out = provider.fetch_messages(recipient="b@x") + self.assertEqual(len(out), 1) + self.assertEqual(out[0].subject, "Hi") + self.assertIn("999111", out[0].body) + + def test_non_dict_payload_raises(self): + provider = MailHogProvider("http://x", http_fetcher=lambda _u: []) + with self.assertRaises(OtpInterceptError): + provider.fetch_messages() + + +class TestMailpitProvider(unittest.TestCase): + + def test_parses_messages_key(self): + fake_fetch = MagicMock(return_value={ + "messages": [ + { + "ID": "id1", + "Created": "2026-05-24T10:00:00Z", + "From": {"Address": "a@x"}, + "To": [{"Address": "b@x"}], + "Subject": "verify", + "Text": "Token 224488", + } + ] + }) + provider = MailpitProvider("http://mailpit", http_fetcher=fake_fetch) + out = provider.fetch_messages(recipient="b@x") + self.assertEqual(len(out), 1) + self.assertEqual(out[0].sender, "a@x") + self.assertIn("224488", out[0].body) + + +class TestWebhookSmsProvider(unittest.TestCase): + + def test_parses_list(self): + fake_fetch = MagicMock(return_value=[ + {"id": "s1", "from": "+1000", "to": "+1234", + "body": "Your code 12345", "received_at": "2026-05-24T10:00:00Z"}, + ]) + provider = WebhookSmsProvider("http://sms", http_fetcher=fake_fetch) + out = provider.fetch_messages(recipient="+1234") + self.assertEqual(len(out), 1) + self.assertEqual(out[0].recipient, "+1234") + + def test_non_list_raises(self): + provider = WebhookSmsProvider("http://sms", http_fetcher=lambda _u: {}) + with self.assertRaises(OtpInterceptError): + provider.fetch_messages() + + +class TestWaitForOtp(unittest.TestCase): + + def test_returns_immediately_if_present(self): + provider = InMemoryProvider() + provider.push(_msg(recipient="a@x", body="Code 111222")) + code = wait_for_otp(provider, "a@x", since=0, timeout=2, poll_interval=0.01) + self.assertEqual(code, "111222") + + def test_subject_filter(self): + provider = InMemoryProvider() + provider.push(_msg(recipient="a@x", subject="Welcome", body="Code 333")) + provider.push(_msg(recipient="a@x", subject="OTP", body="Code 444444")) + code = wait_for_otp( + provider, "a@x", + since=0, timeout=2, poll_interval=0.01, + subject_contains="otp", + ) + self.assertEqual(code, "444444") + + def test_subject_extraction(self): + provider = InMemoryProvider() + provider.push(_msg(recipient="a@x", subject="Code 778899 to verify", body="")) + code = wait_for_otp(provider, "a@x", since=0, timeout=2, poll_interval=0.01) + self.assertEqual(code, "778899") + + def test_polls_until_arrives(self): + provider = InMemoryProvider() + clock = {"now": 0.0} + + def fake_time(): + return clock["now"] + + def fake_sleep(seconds): + clock["now"] += seconds + if clock["now"] >= 1.0 and not provider.messages: + provider.push(_msg(recipient="a@x", body="Code 909090")) + + code = wait_for_otp( + provider, "a@x", + since=0, timeout=5, poll_interval=0.5, + sleep_fn=fake_sleep, time_fn=fake_time, + ) + self.assertEqual(code, "909090") + + def test_times_out(self): + provider = InMemoryProvider() + clock = {"now": 0.0} + + def fake_time(): + return clock["now"] + + def fake_sleep(seconds): + clock["now"] += seconds + + with self.assertRaises(OtpInterceptError): + wait_for_otp( + provider, "a@x", + since=0, timeout=2, poll_interval=1.0, + sleep_fn=fake_sleep, time_fn=fake_time, + ) + + def test_skips_messages_before_since(self): + provider = InMemoryProvider() + provider.push(_msg(recipient="a@x", body="Code 999000", received_at=10.0)) + with self.assertRaises(OtpInterceptError): + wait_for_otp( + provider, "a@x", + since=20.0, timeout=0.5, poll_interval=0.1, + ) + + def test_bad_provider_raises(self): + with self.assertRaises(OtpInterceptError): + wait_for_otp("not a provider", "a@x") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pagination_audit.py b/test/unit_test/test_pagination_audit.py new file mode 100644 index 0000000..69931bd --- /dev/null +++ b/test/unit_test/test_pagination_audit.py @@ -0,0 +1,185 @@ +"""Unit tests for je_web_runner.utils.pagination_audit.""" +import unittest + +from je_web_runner.utils.pagination_audit.audit import ( + Page, + PaginationAuditError, + PaginationFindings, + assert_clean, + assert_expected_total, + assert_no_cursor_loop, + assert_no_duplicates, + assert_sorted_by, + assert_terminated, + walk_all_pages, +) + + +def _fetcher(pages): + """pages: list of (items, next_cursor) tuples; first call sees cursor=None.""" + by_cursor = {} + for index, (items, next_cursor) in enumerate(pages): + cursor_in = None if index == 0 else pages[index - 1][1] + by_cursor[cursor_in] = Page(items=list(items), next_cursor=next_cursor) + def _f(cursor): + return by_cursor[cursor] + return _f + + +class TestPage(unittest.TestCase): + + def test_rejects_non_list_items(self): + with self.assertRaises(PaginationAuditError): + Page(items="not list") # type: ignore[arg-type] + + +class TestWalk(unittest.TestCase): + + def test_clean_walk(self): + pages = [ + ([{"id": 1}, {"id": 2}], "c2"), + ([{"id": 3}, {"id": 4}], None), + ] + findings = walk_all_pages(_fetcher(pages), lambda r: r["id"]) + self.assertEqual(findings.page_count, 2) + self.assertEqual(findings.total_items, 4) + self.assertEqual(findings.unique_items, 4) + self.assertTrue(findings.passed()) + + def test_duplicate_caught(self): + pages = [ + ([{"id": 1}, {"id": 2}], "c2"), + ([{"id": 2}, {"id": 3}], None), + ] + findings = walk_all_pages(_fetcher(pages), lambda r: r["id"]) + self.assertIn(2, findings.duplicates) + self.assertFalse(findings.passed()) + + def test_cursor_loop(self): + # c2 → c1 → c2 → ... + responses = iter([ + Page(items=[{"id": 1}], next_cursor="c2"), + Page(items=[{"id": 2}], next_cursor="c1"), + Page(items=[{"id": 1}], next_cursor="c2"), + ]) + def fetcher(_cursor): + return next(responses) + findings = walk_all_pages(fetcher, lambda r: r["id"]) + self.assertTrue(findings.cursor_loop) + + def test_max_pages_hit(self): + # always returns a new cursor → would run forever + def fetcher(cursor): + n = (cursor or 0) + 1 + return Page(items=[{"id": n}], next_cursor=n) + findings = walk_all_pages(fetcher, lambda r: r["id"], max_pages=5) + self.assertTrue(findings.hit_max_pages) + self.assertEqual(findings.page_count, 5) + + def test_empty_pages_recorded(self): + pages = [ + ([], "c2"), + ([{"id": 1}], None), + ] + findings = walk_all_pages(_fetcher(pages), lambda r: r["id"]) + self.assertEqual(findings.empty_pages, [0]) + + def test_fetcher_must_be_callable(self): + with self.assertRaises(PaginationAuditError): + walk_all_pages("not callable", lambda r: r) # type: ignore[arg-type] + + def test_key_fn_must_be_callable(self): + with self.assertRaises(PaginationAuditError): + walk_all_pages(lambda c: Page(items=[]), "not callable") # type: ignore[arg-type] + + def test_bad_max_pages(self): + with self.assertRaises(PaginationAuditError): + walk_all_pages(lambda c: Page(items=[]), lambda r: r, max_pages=0) + + def test_fetcher_must_return_page(self): + with self.assertRaises(PaginationAuditError): + walk_all_pages(lambda c: "not a page", lambda r: r) + + def test_fetcher_exception(self): + def boom(_c): + raise RuntimeError("net") + with self.assertRaises(PaginationAuditError): + walk_all_pages(boom, lambda r: r) + + def test_key_fn_exception(self): + def runner(c): + return Page(items=[{"id": 1}]) + def bad(_item): + raise RuntimeError("nope") + with self.assertRaises(PaginationAuditError): + walk_all_pages(runner, bad) + + +class TestAssertions(unittest.TestCase): + + def test_assert_no_duplicates_pass(self): + assert_no_duplicates(PaginationFindings()) + + def test_assert_no_duplicates_fail(self): + with self.assertRaises(PaginationAuditError): + assert_no_duplicates(PaginationFindings(duplicates=[1, 2])) + + def test_assert_no_cursor_loop_pass(self): + assert_no_cursor_loop(PaginationFindings()) + + def test_assert_no_cursor_loop_fail(self): + with self.assertRaises(PaginationAuditError): + assert_no_cursor_loop(PaginationFindings(cursor_loop=True)) + + def test_assert_terminated(self): + with self.assertRaises(PaginationAuditError): + assert_terminated(PaginationFindings(hit_max_pages=True)) + + def test_assert_expected_total_pass(self): + assert_expected_total( + PaginationFindings(unique_items=5), expected_total=5, + ) + + def test_assert_expected_total_fail(self): + with self.assertRaises(PaginationAuditError): + assert_expected_total( + PaginationFindings(unique_items=4), expected_total=5, + ) + + def test_assert_expected_total_bad_arg(self): + with self.assertRaises(PaginationAuditError): + assert_expected_total(PaginationFindings(), expected_total=-1) + + def test_assert_clean_pass(self): + assert_clean(PaginationFindings()) + + def test_assert_clean_rejects_non_findings(self): + with self.assertRaises(PaginationAuditError): + assert_clean("nope") # type: ignore[arg-type] + + +class TestAssertSortedBy(unittest.TestCase): + + def test_pass_ascending(self): + findings = PaginationFindings(item_keys_by_page=[[1, 2], [3, 4]]) + assert_sorted_by(findings, lambda x: x) + + def test_fail_ascending(self): + findings = PaginationFindings(item_keys_by_page=[[3], [1]]) + with self.assertRaises(PaginationAuditError): + assert_sorted_by(findings, lambda x: x) + + def test_pass_reverse(self): + findings = PaginationFindings(item_keys_by_page=[[5, 4], [3, 2]]) + assert_sorted_by(findings, lambda x: x, reverse=True) + + def test_empty_passes(self): + assert_sorted_by(PaginationFindings(), lambda x: x) + + def test_bad_keyfn(self): + with self.assertRaises(PaginationAuditError): + assert_sorted_by(PaginationFindings(), "nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_persona_runner.py b/test/unit_test/test_persona_runner.py new file mode 100644 index 0000000..0ad276a --- /dev/null +++ b/test/unit_test/test_persona_runner.py @@ -0,0 +1,146 @@ +"""Unit tests for je_web_runner.utils.persona_runner.""" +import unittest + +from je_web_runner.utils.persona_runner.runner import ( + MatrixSummary, + Persona, + PersonaCaseResult, + PersonaRunner, + PersonaRunnerError, + summarise, + summary_markdown, +) + + +def _make_personas(*names): + return [Persona(name=n) for n in names] + + +class TestPersona(unittest.TestCase): + + def test_rejects_empty_name(self): + with self.assertRaises(PersonaRunnerError): + Persona(name="") + + +class TestPersonaRunner(unittest.TestCase): + + def test_rejects_no_personas(self): + with self.assertRaises(PersonaRunnerError): + PersonaRunner(personas=[], action_files=["a"], case_runner=lambda p, f: None) + + def test_rejects_no_files(self): + with self.assertRaises(PersonaRunnerError): + PersonaRunner(personas=_make_personas("a"), action_files=[], + case_runner=lambda p, f: None) + + def test_rejects_duplicate_personas(self): + with self.assertRaises(PersonaRunnerError): + PersonaRunner(personas=_make_personas("a", "a"), action_files=["x"], + case_runner=lambda p, f: None) + + def test_rejects_duplicate_files(self): + with self.assertRaises(PersonaRunnerError): + PersonaRunner(personas=_make_personas("a"), action_files=["x", "x"], + case_runner=lambda p, f: None) + + def test_runs_full_matrix(self): + called = [] + runner = PersonaRunner( + personas=_make_personas("admin", "guest"), + action_files=["a.json", "b.json"], + case_runner=lambda p, f: called.append((p.name, f)), + ) + results = runner.run() + self.assertEqual(len(results), 4) + self.assertTrue(all(r.passed for r in results)) + self.assertEqual(len(called), 4) + + def test_records_failures(self): + def runner(persona, file): + if persona.name == "guest": + raise AssertionError("guest cannot") + results = PersonaRunner( + personas=_make_personas("admin", "guest"), + action_files=["x.json"], + case_runner=runner, + ).run() + self.assertEqual([r.passed for r in results], [True, False]) + self.assertIn("guest cannot", results[1].error or "") + + def test_stop_on_first_failure(self): + called = [] + + def runner(persona, file): + called.append((persona.name, file)) + if persona.name == "admin": + raise RuntimeError("nope") + results = PersonaRunner( + personas=_make_personas("admin", "guest"), + action_files=["a", "b"], + case_runner=runner, + stop_on_first_failure=True, + ).run() + self.assertEqual(len(results), 1) + self.assertEqual(len(called), 1) + + +class TestSummarise(unittest.TestCase): + + def test_counts(self): + results = [ + PersonaCaseResult(persona="admin", action_file="a", passed=True), + PersonaCaseResult(persona="admin", action_file="b", passed=False), + PersonaCaseResult(persona="guest", action_file="a", passed=True), + PersonaCaseResult(persona="guest", action_file="b", passed=True), + ] + s = summarise(results) + self.assertEqual(s.total, 4) + self.assertEqual(s.passed, 3) + self.assertEqual(s.failed, 1) + self.assertEqual(s.by_persona["admin"], {"passed": 1, "failed": 1}) + + def test_persona_only_failure_detected(self): + results = [ + PersonaCaseResult(persona="admin", action_file="dashboard", passed=False), + PersonaCaseResult(persona="guest", action_file="dashboard", passed=True), + ] + s = summarise(results) + self.assertIn("admin", s.persona_only_failures) + + def test_file_only_failure_detected(self): + results = [ + PersonaCaseResult(persona="admin", action_file="broken", passed=False), + PersonaCaseResult(persona="guest", action_file="broken", passed=False), + PersonaCaseResult(persona="admin", action_file="ok", passed=True), + PersonaCaseResult(persona="guest", action_file="ok", passed=True), + ] + s = summarise(results) + self.assertIn("broken", s.file_only_failures) + + def test_rejects_bad_input(self): + with self.assertRaises(PersonaRunnerError): + summarise(["string"]) # type: ignore[list-item] + + +class TestSummaryMarkdown(unittest.TestCase): + + def test_empty(self): + md = summary_markdown(MatrixSummary(total=0, passed=0, failed=0)) + self.assertIn("No persona matrix", md) + + def test_with_failures(self): + s = MatrixSummary( + total=4, passed=3, failed=1, + by_persona={"admin": {"passed": 1, "failed": 1}, + "guest": {"passed": 2, "failed": 0}}, + persona_only_failures=["admin"], + ) + md = summary_markdown(s) + self.assertIn("Persona matrix: 3/4", md) + self.assertIn("admin", md) + self.assertIn("Persona-specific regressions", md) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pii_in_screenshot.py b/test/unit_test/test_pii_in_screenshot.py new file mode 100644 index 0000000..1b55e24 --- /dev/null +++ b/test/unit_test/test_pii_in_screenshot.py @@ -0,0 +1,158 @@ +"""Unit tests for je_web_runner.utils.pii_in_screenshot.""" +import unittest + +from je_web_runner.utils.pii_in_screenshot.scanner import ( + DEFAULT_RULES, + PiiFinding, + PiiInScreenshotError, + ScanReport, + assert_clean, + scan_image, + scan_screenshots, + scan_text_only, +) + + +def _fake_backend(text): + def _b(_source): + return text + return _b + + +class TestScanText(unittest.TestCase): + + def test_email_detected(self): + findings = scan_text_only("Contact me at alice@example.com please") + names = [f.rule for f in findings] + self.assertIn("email", names) + + def test_credit_card_luhn_valid(self): + # 4111-1111-1111-1111 is a well-known Luhn-valid test card + findings = scan_text_only("Card: 4111 1111 1111 1111") + names = [f.rule for f in findings] + self.assertIn("credit_card", names) + + def test_credit_card_luhn_invalid_skipped(self): + # 1234-5678-9012-3456 fails Luhn + findings = scan_text_only("Bogus: 1234 5678 9012 3456") + self.assertNotIn("credit_card", [f.rule for f in findings]) + + def test_ssn_us(self): + findings = scan_text_only("SSN: 123-45-6789") + self.assertIn("ssn_us", [f.rule for f in findings]) + + def test_tw_id(self): + findings = scan_text_only("ID: A123456789") + self.assertIn("tw_national_id", [f.rule for f in findings]) + + def test_phone_e164(self): + findings = scan_text_only("Tel: +1 415-555-1212") + self.assertIn("phone_e164", [f.rule for f in findings]) + + def test_iban(self): + findings = scan_text_only("IBAN GB82WEST12345698765432") + self.assertIn("iban", [f.rule for f in findings]) + + def test_ipv4(self): + findings = scan_text_only("server 192.168.1.50 down") + self.assertIn("ipv4", [f.rule for f in findings]) + + def test_dedup(self): + text = "alice@x.com\nalice@x.com\nalice@x.com" + findings = scan_text_only(text) + emails = [f for f in findings if f.rule == "email"] + self.assertEqual(len(emails), 1) + + def test_redaction_format(self): + findings = scan_text_only("alice@example.com") + self.assertTrue(findings[0].redacted_match.startswith("al")) + self.assertIn("…", findings[0].redacted_match) + + def test_excerpt_marks_pii(self): + findings = scan_text_only("hello alice@x.com world") + self.assertIn("<>", findings[0].raw_excerpt) + + def test_non_string_rejected(self): + with self.assertRaises(PiiInScreenshotError): + scan_text_only(123) # type: ignore[arg-type] + + def test_clean_text(self): + self.assertEqual(scan_text_only("no secrets here"), []) + + +class TestScanImage(unittest.TestCase): + + def test_uses_backend(self): + findings = scan_image(b"", backend=_fake_backend("ssn 123-45-6789")) + self.assertIn("ssn_us", [f.rule for f in findings]) + + def test_ocr_error_wrapped(self): + def boom(_): + raise RuntimeError("bad image") + with self.assertRaises(PiiInScreenshotError): + scan_image(b"", backend=boom) + + +class TestScanScreenshots(unittest.TestCase): + + def test_scans_each_image(self): + backends = [ + _fake_backend("alice@x.com"), + _fake_backend("4111 1111 1111 1111"), + _fake_backend("clean image"), + ] + report = ScanReport() + for i, b in enumerate(backends): + for f in scan_image(b"", backend=b, image_label=f"img_{i}"): + report.findings.append(f) + report.by_severity[f.severity] = report.by_severity.get(f.severity, 0) + 1 + report.scanned += 1 + self.assertEqual(report.scanned, 3) + self.assertEqual(len(report.findings), 2) + + def test_rejects_empty_sources(self): + with self.assertRaises(PiiInScreenshotError): + scan_screenshots([]) + + def test_scan_screenshots_aggregates(self): + report = scan_screenshots( + [b"a", b"b"], + backend=_fake_backend("alice@x.com"), + ) + self.assertEqual(report.scanned, 2) + # Two images, same PII → finding labels differ → 2 records + self.assertEqual(len(report.findings), 2) + + +class TestAssertClean(unittest.TestCase): + + def test_clean_passes(self): + assert_clean(ScanReport()) + + def test_dirty_raises(self): + report = ScanReport( + scanned=1, + findings=[PiiFinding( + rule="email", severity="medium", redacted_match="al…om", + image="x", + )], + ) + with self.assertRaises(PiiInScreenshotError): + assert_clean(report) + + def test_rejects_non_report(self): + with self.assertRaises(PiiInScreenshotError): + assert_clean("not a report") # type: ignore[arg-type] + + +class TestDefaultRules(unittest.TestCase): + + def test_defaults_loaded(self): + names = {r.name for r in DEFAULT_RULES} + self.assertIn("email", names) + self.assertIn("credit_card", names) + self.assertIn("ssn_us", names) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pr_risk_score.py b/test/unit_test/test_pr_risk_score.py new file mode 100644 index 0000000..3561242 --- /dev/null +++ b/test/unit_test/test_pr_risk_score.py @@ -0,0 +1,152 @@ +"""Unit tests for je_web_runner.utils.pr_risk_score.""" +import unittest + +from je_web_runner.utils.pr_risk_score.scorer import ( + PrRiskScoreError, + PrSignals, + RiskReport, + RiskWeights, + aggregate_signals, + report_markdown, + score_pr, +) + + +class TestPrSignals(unittest.TestCase): + + def test_defaults_safe(self): + sig = PrSignals() + report = score_pr(sig) + self.assertEqual(report.score, 0.0) + self.assertEqual(report.level, "low") + + def test_negative_rejected(self): + with self.assertRaises(PrRiskScoreError): + PrSignals(flaky_tests_touched=-1) + + def test_bad_avg_flake_rejected(self): + with self.assertRaises(PrRiskScoreError): + PrSignals(avg_flake_score=1.5) + + +class TestScore(unittest.TestCase): + + def test_clean_pr_low(self): + sig = PrSignals(total_tests_touched=10, lines_added=100, lines_covered=100, + repo_modules=50, total_locators_touched=20) + report = score_pr(sig) + self.assertLess(report.score, 25) + self.assertEqual(report.level, "low") + + def test_flaky_pr_higher(self): + clean = score_pr(PrSignals( + total_tests_touched=10, lines_added=100, lines_covered=100, + )) + flaky = score_pr(PrSignals( + total_tests_touched=10, flaky_tests_touched=8, avg_flake_score=0.7, + lines_added=100, lines_covered=100, + )) + self.assertGreater(flaky.score, clean.score) + self.assertTrue(any("flake" in r for r in flaky.reasons)) + + def test_critical_pr_flagged(self): + sig = PrSignals( + total_tests_touched=5, flaky_tests_touched=5, avg_flake_score=1.0, + impacted_modules=20, repo_modules=20, impacted_critical_paths=4, + fragile_locators_touched=10, total_locators_touched=10, + lines_added=100, lines_covered=0, + migration_files_changed=2, security_files_changed=2, + ) + report = score_pr(sig) + self.assertGreaterEqual(report.score, 75.0) + self.assertEqual(report.level, "critical") + self.assertTrue(report.is_blocking()) + + def test_score_bounded(self): + sig = PrSignals( + total_tests_touched=1, flaky_tests_touched=1, avg_flake_score=1.0, + impacted_modules=10, repo_modules=10, impacted_critical_paths=20, + fragile_locators_touched=5, total_locators_touched=5, + lines_added=10, lines_covered=0, + migration_files_changed=10, security_files_changed=10, + ) + report = score_pr(sig) + self.assertLessEqual(report.score, 100.0) + self.assertGreaterEqual(report.score, 0.0) + + def test_contributions_sum_recorded(self): + sig = PrSignals( + total_tests_touched=2, flaky_tests_touched=1, avg_flake_score=0.5, + lines_added=10, lines_covered=5, + ) + report = score_pr(sig) + self.assertIn("flake", report.contributions) + self.assertIn("coverage_gap", report.contributions) + + def test_zero_weights_rejected(self): + weights = RiskWeights( + flake=0, blast_radius=0, critical_path=0, + locator_fragility=0, coverage_gap=0, migration=0, security=0, + ) + with self.assertRaises(PrRiskScoreError): + score_pr(PrSignals(), weights) + + def test_invalid_signals_type(self): + with self.assertRaises(PrRiskScoreError): + score_pr("not signals") # type: ignore[arg-type] + + def test_coverage_gap_zero_when_no_added(self): + sig = PrSignals(lines_added=0, lines_covered=0) + self.assertEqual(score_pr(sig).contributions["coverage_gap"], 0.0) + + def test_custom_weights_change_score(self): + sig = PrSignals( + lines_added=10, lines_covered=0, + total_tests_touched=10, flaky_tests_touched=10, avg_flake_score=1.0, + ) + low_flake_weight = score_pr(sig, RiskWeights(flake=0.1)) + high_flake_weight = score_pr(sig, RiskWeights(flake=10.0)) + self.assertGreater(high_flake_weight.score, low_flake_weight.score) + + +class TestAggregateSignals(unittest.TestCase): + + def test_sums_per_file(self): + per_file = [ + {"flaky_tests_touched": 1, "total_tests_touched": 3, "avg_flake_score": 0.5, + "repo_modules": 50, "lines_added": 100, "lines_covered": 80}, + {"flaky_tests_touched": 2, "total_tests_touched": 4, "avg_flake_score": 0.7, + "repo_modules": 50, "lines_added": 50, "lines_covered": 40}, + ] + sig = aggregate_signals(per_file) + self.assertEqual(sig.flaky_tests_touched, 3) + self.assertEqual(sig.total_tests_touched, 7) + self.assertEqual(sig.lines_added, 150) + self.assertEqual(sig.repo_modules, 50) + self.assertAlmostEqual(sig.avg_flake_score, 0.6, places=4) + + def test_ignores_non_dict_entries(self): + sig = aggregate_signals(["bad", None, {}]) # type: ignore[list-item] + self.assertEqual(sig.flaky_tests_touched, 0) + + def test_ignores_unknown_keys(self): + sig = aggregate_signals([{"weird_key": 99, "lines_added": 5}]) + self.assertEqual(sig.lines_added, 5) + + +class TestReportMarkdown(unittest.TestCase): + + def test_with_reasons(self): + report = RiskReport(score=80.0, level="critical", reasons=["flake: 50% × 2.0"]) + md = report_markdown(report) + self.assertIn("80.0", md) + self.assertIn("critical", md) + self.assertIn("flake", md) + + def test_no_reasons(self): + md = report_markdown(RiskReport(score=0.0, level="low")) + self.assertIn("No risk signals", md) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_prompt_drift_monitor.py b/test/unit_test/test_prompt_drift_monitor.py new file mode 100644 index 0000000..8a82eb8 --- /dev/null +++ b/test/unit_test/test_prompt_drift_monitor.py @@ -0,0 +1,199 @@ +"""Unit tests for je_web_runner.utils.prompt_drift_monitor.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.prompt_drift_monitor.monitor import ( + Baseline, + BaselineSample, + DriftFinding, + DriftReport, + PromptDriftError, + assert_no_drift, + capture_baseline, + check_drift, + load_baseline, + save_baseline, +) + + +def _fixed_embedder(vector): + return lambda _text: list(vector) + + +def _by_text_embedder(text): + """Embedding that depends on text content (for similarity drift).""" + return [float(text.count("a")), float(text.count("b")), float(text.count("c"))] + + +class TestCaptureBaseline(unittest.TestCase): + + def test_basic_capture(self): + baseline = capture_baseline( + [{"id": "q1", "prompt": "hi"}, {"id": "q2", "prompt": "bye"}], + _fixed_embedder([1.0, 0.0]), + lambda p: f"answer to {p}", + ) + self.assertEqual(len(baseline.samples), 2) + self.assertEqual(baseline.samples[0].embedding, [1.0, 0.0]) + + def test_empty_prompts_rejected(self): + with self.assertRaises(PromptDriftError): + capture_baseline([], _fixed_embedder([1.0]), lambda _: "x") + + def test_missing_id_or_prompt(self): + with self.assertRaises(PromptDriftError): + capture_baseline([{"id": "x"}], _fixed_embedder([1.0]), lambda _: "y") + with self.assertRaises(PromptDriftError): + capture_baseline([{"prompt": "x"}], _fixed_embedder([1.0]), lambda _: "y") + + def test_answerer_failure_wrapped(self): + def boom(_): + raise RuntimeError("rate limit") + with self.assertRaises(PromptDriftError): + capture_baseline([{"id": "x", "prompt": "y"}], _fixed_embedder([1.0]), boom) + + def test_anchors_captured(self): + baseline = capture_baseline( + [{"id": "q", "prompt": "hi", + "must_include": ["disclaimer"], "must_exclude": ["competitor"]}], + _fixed_embedder([1.0]), + lambda _: "x", + ) + self.assertEqual(baseline.samples[0].must_include, ["disclaimer"]) + self.assertEqual(baseline.samples[0].must_exclude, ["competitor"]) + + +class TestPersistence(unittest.TestCase): + + def test_save_and_load_round_trip(self): + baseline = capture_baseline( + [{"id": "q", "prompt": "hi"}], + _fixed_embedder([1.0, 2.0]), + lambda _: "answer", + ) + with tempfile.TemporaryDirectory() as tmp: + path = save_baseline(baseline, Path(tmp) / "b.json") + loaded = load_baseline(path) + self.assertEqual(len(loaded.samples), 1) + self.assertEqual(loaded.samples[0].embedding, [1.0, 2.0]) + + def test_load_missing_file(self): + with self.assertRaises(PromptDriftError): + load_baseline("/no/such/file.json") + + def test_load_invalid_json(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "b.json" + path.write_text("not json", encoding="utf-8") + with self.assertRaises(PromptDriftError): + load_baseline(path) + + def test_load_missing_samples_key(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "b.json" + path.write_text(json.dumps({"captured_at": "x"}), encoding="utf-8") + with self.assertRaises(PromptDriftError): + load_baseline(path) + + def test_load_malformed_sample(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "b.json" + path.write_text(json.dumps( + {"samples": [{"prompt_id": "x"}]} # missing fields + ), encoding="utf-8") + with self.assertRaises(PromptDriftError): + load_baseline(path) + + def test_save_rejects_non_baseline(self): + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(PromptDriftError): + save_baseline("not baseline", Path(tmp) / "x.json") # type: ignore[arg-type] + + +class TestCheckDrift(unittest.TestCase): + + def _baseline_for(self, prompts): + return capture_baseline(prompts, _by_text_embedder, lambda p: f"aaa {p}") + + def test_clean_run(self): + baseline = self._baseline_for([{"id": "q", "prompt": "hi"}]) + report = check_drift( + baseline, _by_text_embedder, + lambda p: f"aaa {p}", # exact baseline reproduction + similarity_threshold=0.99, + ) + self.assertTrue(report.passed()) + + def test_drift_detected(self): + baseline = self._baseline_for([{"id": "q", "prompt": "hi"}]) + report = check_drift( + baseline, _by_text_embedder, + lambda _: "ccc", # totally different vector + similarity_threshold=0.9, + ) + self.assertFalse(report.passed()) + self.assertTrue(report.findings[0].drifted) + + def test_missing_required(self): + baseline = capture_baseline( + [{"id": "q", "prompt": "hi", "must_include": ["disclaimer"]}], + _by_text_embedder, lambda _: "disclaimer aaa", + ) + report = check_drift( + baseline, _by_text_embedder, lambda _: "aaa without it", + ) + self.assertFalse(report.passed()) + self.assertIn("disclaimer", report.findings[0].missing_required) + + def test_forbidden_present(self): + baseline = capture_baseline( + [{"id": "q", "prompt": "hi", "must_exclude": ["competitor"]}], + _by_text_embedder, lambda _: "aaa clean", + ) + report = check_drift( + baseline, _by_text_embedder, lambda _: "aaa with competitor mentioned", + ) + self.assertIn("competitor", report.findings[0].forbidden_present) + + def test_bad_threshold(self): + with self.assertRaises(PromptDriftError): + check_drift(Baseline(), _by_text_embedder, lambda _: "x", + similarity_threshold=0.0) + with self.assertRaises(PromptDriftError): + check_drift(Baseline(), _by_text_embedder, lambda _: "x", + similarity_threshold=2.0) + + def test_answerer_failure(self): + baseline = self._baseline_for([{"id": "q", "prompt": "hi"}]) + def boom(_): + raise RuntimeError("down") + with self.assertRaises(PromptDriftError): + check_drift(baseline, _by_text_embedder, boom) + + def test_rejects_non_baseline(self): + with self.assertRaises(PromptDriftError): + check_drift("nope", _by_text_embedder, lambda _: "x") # type: ignore[arg-type] + + +class TestAssertNoDrift(unittest.TestCase): + + def test_pass(self): + assert_no_drift(DriftReport(threshold=0.9)) + + def test_fail(self): + report = DriftReport( + threshold=0.9, + findings=[DriftFinding(prompt_id="q", similarity=0.2, drifted=True)], + ) + with self.assertRaises(PromptDriftError): + assert_no_drift(report) + + def test_rejects_non_report(self): + with self.assertRaises(PromptDriftError): + assert_no_drift("not a report") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_pseudo_localization.py b/test/unit_test/test_pseudo_localization.py new file mode 100644 index 0000000..0b33571 --- /dev/null +++ b/test/unit_test/test_pseudo_localization.py @@ -0,0 +1,146 @@ +"""Unit tests for je_web_runner.utils.pseudo_localization.""" +import unittest + +from je_web_runner.utils.pseudo_localization.pseudo import ( + PseudoAuditReport, + PseudoConfig, + PseudoLocalizationError, + pseudo_localize, + pseudo_localize_dict, + scan_for_hardcoded, +) + + +class TestPseudoLocalize(unittest.TestCase): + + def test_default_wraps_and_accents(self): + out = pseudo_localize("Hello") + self.assertTrue(out.startswith("⟦")) + self.assertTrue(out.endswith("⟧")) + self.assertIn("é", out.lower()) + self.assertIn("─", out) + + def test_disable_bracket(self): + out = pseudo_localize("hi", PseudoConfig(bracket=False)) + self.assertFalse(out.startswith("⟦")) + + def test_disable_accent(self): + out = pseudo_localize("hello", PseudoConfig(accent=False, bracket=False)) + self.assertIn("hello", out) + + def test_expansion_grows_string(self): + cfg = PseudoConfig(accent=False, expansion_ratio=1.0, bracket=False) + out = pseudo_localize("hi", cfg) + # 1.0 ratio: original 2 chars + ~2 padding + spaces + self.assertGreater(len(out), 4) + + def test_no_expansion(self): + cfg = PseudoConfig(accent=False, expansion_ratio=0.0, bracket=False) + self.assertEqual(pseudo_localize("hi", cfg), "hi") + + def test_preserves_braced_placeholder(self): + out = pseudo_localize("Hello {name}", PseudoConfig( + expansion_ratio=0, bracket=False, + )) + self.assertIn("{name}", out) + + def test_preserves_printf_placeholder(self): + out = pseudo_localize("Got %d items", PseudoConfig( + expansion_ratio=0, bracket=False, + )) + self.assertIn("%d", out) + + def test_preserves_html_tag(self): + out = pseudo_localize("bold", PseudoConfig( + expansion_ratio=0, bracket=False, + )) + self.assertIn("", out) + self.assertIn("", out) + + def test_empty_string_unchanged(self): + self.assertEqual(pseudo_localize(""), "") + + def test_non_string_rejected(self): + with self.assertRaises(PseudoLocalizationError): + pseudo_localize(123) # type: ignore[arg-type] + + def test_negative_expansion_rejected(self): + with self.assertRaises(PseudoLocalizationError): + PseudoConfig(expansion_ratio=-1) + + +class TestPseudoLocalizeDict(unittest.TestCase): + + def test_translates_all_values(self): + out = pseudo_localize_dict({"login": "Sign in", "logout": "Sign out"}) + self.assertEqual(set(out.keys()), {"login", "logout"}) + for value in out.values(): + self.assertTrue(value.startswith("⟦")) + + def test_rejects_non_mapping(self): + with self.assertRaises(PseudoLocalizationError): + pseudo_localize_dict([("a", "b")]) # type: ignore[arg-type] + + def test_rejects_non_string_value(self): + with self.assertRaises(PseudoLocalizationError): + pseudo_localize_dict({"x": 123}) # type: ignore[dict-item] + + +class TestScanForHardcoded(unittest.TestCase): + + def test_finds_hardcoded(self): + catalogue = {"submit": "Submit", "cancel": "Cancel"} + # Submit appears verbatim → hardcoded leak + rendered = "⟦Šubmît⟧ Submit ⟦Çančél⟧" + report = scan_for_hardcoded(rendered, catalogue=catalogue) + self.assertEqual(len(report.hits), 1) + self.assertEqual(report.hits[0].string, "Submit") + self.assertEqual(report.hits[0].occurrences, 1) + self.assertFalse(report.passed()) + + def test_clean(self): + report = scan_for_hardcoded( + "⟦Šubmît⟧ ⟦Çančél⟧", + catalogue={"submit": "Submit", "cancel": "Cancel"}, + ) + self.assertTrue(report.passed()) + + def test_min_length_filter(self): + report = scan_for_hardcoded( + "ok ok ok", + catalogue={"ok": "ok"}, + min_length=3, + ) + # 'ok' is too short → ignored + self.assertTrue(report.passed()) + + def test_skips_non_ascii_only_catalogue(self): + # Catalogue value with no ASCII letters can't be detected as "hard-coded" + report = scan_for_hardcoded( + "你好 你好", + catalogue={"hello": "你好"}, + ) + self.assertTrue(report.passed()) + + def test_counts_multiple_occurrences(self): + report = scan_for_hardcoded( + "Submit Submit Submit", + catalogue={"submit": "Submit"}, + ) + self.assertEqual(report.hits[0].occurrences, 3) + + def test_bad_inputs(self): + with self.assertRaises(PseudoLocalizationError): + scan_for_hardcoded(123, catalogue={}) # type: ignore[arg-type] + with self.assertRaises(PseudoLocalizationError): + scan_for_hardcoded("x", catalogue={}, min_length=0) + + +class TestReport(unittest.TestCase): + + def test_default_passed(self): + self.assertTrue(PseudoAuditReport().passed()) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_quarantine_age_report.py b/test/unit_test/test_quarantine_age_report.py new file mode 100644 index 0000000..02bb98d --- /dev/null +++ b/test/unit_test/test_quarantine_age_report.py @@ -0,0 +1,181 @@ +"""Unit tests for je_web_runner.utils.quarantine_age_report.""" +import json +import tempfile +import unittest +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from je_web_runner.utils.quarantine_age_report.report import ( + AgedEntry, + AgeReport, + EscalationTier, + QuarantineAgeReportError, + age_entries, + assert_no_abandoned, + build_report, + load_and_age, + report_markdown, +) + +_NOW = datetime(2026, 5, 24, tzinfo=timezone.utc) + + +def _entry(test_id, days_ago, score=0.5): + when = _NOW - timedelta(days=days_ago) + return { + "test_id": test_id, + "reason": "flaky", + "flake_score": score, + "quarantined_at": when.isoformat(timespec="seconds"), + "runs_when_added": 10, + } + + +class TestAgeEntries(unittest.TestCase): + + def test_tiers(self): + rows = [ + _entry("fresh", 3), + _entry("lingering", 14), + _entry("stale", 45), + _entry("abandoned", 120), + ] + aged = age_entries(rows, now=_NOW) + tier_by_id = {e.test_id: e.tier for e in aged} + self.assertEqual(tier_by_id["fresh"], EscalationTier.FRESH) + self.assertEqual(tier_by_id["lingering"], EscalationTier.LINGERING) + self.assertEqual(tier_by_id["stale"], EscalationTier.STALE) + self.assertEqual(tier_by_id["abandoned"], EscalationTier.ABANDONED) + + def test_z_timezone(self): + rows = [{ + "test_id": "x", "reason": "", + "flake_score": 0.1, + "quarantined_at": "2026-05-01T00:00:00Z", + }] + aged = age_entries(rows, now=_NOW) + self.assertEqual(len(aged), 1) + self.assertGreater(aged[0].age_days, 20) + + def test_skips_non_dict(self): + aged = age_entries(["not dict", None]) # type: ignore[list-item] + self.assertEqual(aged, []) + + def test_skips_missing_fields(self): + rows = [{"test_id": "x"}, {"quarantined_at": "2026-01-01"}] + self.assertEqual(age_entries(rows, now=_NOW), []) + + def test_naive_now_rejected(self): + with self.assertRaises(QuarantineAgeReportError): + age_entries([_entry("x", 1)], now=datetime(2026, 5, 24)) + + def test_bad_timestamp_rejected(self): + rows = [{"test_id": "x", "reason": "", "flake_score": 0, + "quarantined_at": "garbage"}] + with self.assertRaises(QuarantineAgeReportError): + age_entries(rows, now=_NOW) + + +class TestBuildReport(unittest.TestCase): + + def test_counts(self): + aged = age_entries([_entry("a", 1), _entry("b", 100)], now=_NOW) + report = build_report(aged) + self.assertEqual(report.total_entries, 2) + self.assertEqual(report.by_tier["fresh"], 1) + self.assertEqual(report.by_tier["abandoned"], 1) + self.assertEqual(report.abandoned, ["b"]) + + def test_rejects_non_entry(self): + with self.assertRaises(QuarantineAgeReportError): + build_report(["nope"]) # type: ignore[list-item] + + +class TestLoadAndAge(unittest.TestCase): + + def test_round_trip(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "registry.json" + path.write_text(json.dumps({ + "updated_at": "2026-05-24", + "entries": [_entry("x", 5), _entry("y", 200)], + }), encoding="utf-8") + report = load_and_age(path, now=_NOW) + self.assertEqual(report.total_entries, 2) + + def test_missing_file(self): + with self.assertRaises(QuarantineAgeReportError): + load_and_age("/no/such/file.json") + + def test_bad_json(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text("nope", encoding="utf-8") + with self.assertRaises(QuarantineAgeReportError): + load_and_age(p) + + def test_missing_entries_key(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text(json.dumps({"x": 1}), encoding="utf-8") + with self.assertRaises(QuarantineAgeReportError): + load_and_age(p) + + def test_entries_not_list(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text(json.dumps({"entries": "x"}), encoding="utf-8") + with self.assertRaises(QuarantineAgeReportError): + load_and_age(p) + + +class TestMarkdown(unittest.TestCase): + + def test_renders(self): + report = build_report(age_entries( + [_entry("a", 1), _entry("b", 200)], now=_NOW, + )) + md = report_markdown(report) + self.assertIn("Quarantine age report", md) + self.assertIn("abandoned", md) + self.assertIn("`b`", md) + + def test_caps_top_n(self): + rows = [_entry(f"t{i}", 200) for i in range(20)] + report = build_report(age_entries(rows, now=_NOW)) + md = report_markdown(report, top_n=5) + self.assertIn("+15 more", md) + + def test_bad_top_n(self): + with self.assertRaises(QuarantineAgeReportError): + report_markdown(AgeReport(), top_n=-1) + + def test_rejects_non_report(self): + with self.assertRaises(QuarantineAgeReportError): + report_markdown("nope") # type: ignore[arg-type] + + +class TestAssertNoAbandoned(unittest.TestCase): + + def test_pass(self): + assert_no_abandoned(AgeReport()) + + def test_fail(self): + report = build_report(age_entries([_entry("x", 200)], now=_NOW)) + with self.assertRaises(QuarantineAgeReportError): + assert_no_abandoned(report) + + def test_rejects_non_report(self): + with self.assertRaises(QuarantineAgeReportError): + assert_no_abandoned("nope") # type: ignore[arg-type] + + +class TestAgedEntryDict(unittest.TestCase): + + def test_to_dict_serialises_tier(self): + aged = age_entries([_entry("x", 1)], now=_NOW)[0] + self.assertIn(aged.to_dict()["tier"], ("fresh", "lingering", "stale", "abandoned")) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_repro_minimizer.py b/test/unit_test/test_repro_minimizer.py new file mode 100644 index 0000000..f2b6c20 --- /dev/null +++ b/test/unit_test/test_repro_minimizer.py @@ -0,0 +1,145 @@ +"""Unit tests for je_web_runner.utils.repro_minimizer.""" +import unittest + +from je_web_runner.utils.repro_minimizer.minimizer import ( + MinimizationResult, + ReproMinimizerError, + assert_minimized, + minimize, + report_markdown, +) + + +def _runner_failing_when_action_present(culprit): + """Returns runner that fails (returns False) iff `culprit` is in the subset.""" + def _run(subset): + return culprit not in subset # True = pass, False = fail + return _run + + +class TestMinimize(unittest.TestCase): + + def test_finds_single_culprit(self): + actions = list(range(20)) + runner = _runner_failing_when_action_present(7) + result = minimize(actions, runner) + self.assertIn(7, result.minimized_actions) + self.assertEqual(result.minimized_size, 1) + self.assertEqual(result.original_size, 20) + + def test_two_culprits_together(self): + actions = list(range(30)) + # Fails only when both 5 AND 17 are present + def runner(subset): + return not (5 in subset and 17 in subset) + result = minimize(actions, runner) + self.assertIn(5, result.minimized_actions) + self.assertIn(17, result.minimized_actions) + self.assertLessEqual(result.minimized_size, 5) + + def test_already_minimal(self): + actions = ["only_action"] + # ddmin can't go below 1, so it converges immediately + result = minimize(actions, _runner_failing_when_action_present("only_action")) + self.assertEqual(result.minimized_actions, ["only_action"]) + + def test_passing_input_rejected(self): + with self.assertRaises(ReproMinimizerError): + minimize([1, 2, 3], lambda subset: True) + + def test_skip_verify(self): + # Even if original "passes" by our stub, with verify_failing=False + # the minimizer just runs the procedure + result = minimize([1, 2, 3], lambda subset: True, verify_failing=False) + self.assertGreaterEqual(result.minimized_size, 1) + + def test_runner_must_be_callable(self): + with self.assertRaises(ReproMinimizerError): + minimize([1, 2], "not callable") # type: ignore[arg-type] + + def test_empty_actions(self): + with self.assertRaises(ReproMinimizerError): + minimize([], lambda s: False) + + def test_non_list_rejected(self): + with self.assertRaises(ReproMinimizerError): + minimize("string", lambda s: False) # type: ignore[arg-type] + + def test_max_iterations_bound(self): + with self.assertRaises(ReproMinimizerError): + minimize([1], lambda s: False, max_iterations=0) + + def test_runner_exception(self): + def boom(_subset): + raise RuntimeError("nope") + with self.assertRaises(ReproMinimizerError): + minimize([1, 2, 3], boom) + + def test_eval_count_tracked(self): + result = minimize(list(range(8)), + _runner_failing_when_action_present(3)) + self.assertGreater(result.eval_count, 1) + + def test_reduction_pct(self): + result = MinimizationResult( + original_size=10, minimized_actions=[1], minimized_size=1, + ) + self.assertEqual(result.reduction_pct, 90.0) + + def test_reduction_pct_zero_original(self): + result = MinimizationResult( + original_size=0, minimized_actions=[], minimized_size=0, + ) + self.assertEqual(result.reduction_pct, 0.0) + + +class TestAssertMinimized(unittest.TestCase): + + def test_pass(self): + assert_minimized( + MinimizationResult(original_size=10, minimized_actions=[1, 2], + minimized_size=2), + max_remaining=5, + ) + + def test_fail(self): + with self.assertRaises(ReproMinimizerError): + assert_minimized( + MinimizationResult(original_size=10, minimized_actions=list(range(8)), + minimized_size=8), + max_remaining=5, + ) + + def test_bad_max_remaining(self): + with self.assertRaises(ReproMinimizerError): + assert_minimized( + MinimizationResult(original_size=10, minimized_actions=[], + minimized_size=0), + max_remaining=-1, + ) + + def test_rejects_non_result(self): + with self.assertRaises(ReproMinimizerError): + assert_minimized("nope", max_remaining=1) # type: ignore[arg-type] + + +class TestReport(unittest.TestCase): + + def test_renders(self): + result = MinimizationResult( + original_size=60, minimized_actions=[1, 2, 3, 4], + minimized_size=4, iterations=7, eval_count=42, + duration_seconds=1.23, + ) + md = report_markdown(result) + self.assertIn("4 / 60", md) + self.assertIn("93%", md) + self.assertIn("7", md) + + def test_rejects_non_result(self): + with self.assertRaises(ReproMinimizerError): + report_markdown("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_screen_reader_runner.py b/test/unit_test/test_screen_reader_runner.py new file mode 100644 index 0000000..9a9f857 --- /dev/null +++ b/test/unit_test/test_screen_reader_runner.py @@ -0,0 +1,145 @@ +"""Unit tests for je_web_runner.utils.screen_reader_runner.""" +import unittest + +from je_web_runner.utils.screen_reader_runner.reader import ( + ScreenReaderError, + ScreenReaderTranscript, + Utterance, + ViolationKind, + assert_no_violations, + assert_reads, + walk_tree, +) + + +def _heading(level, name, children=None): + return {"role": "heading", "level": level, "name": name, "children": children or []} + + +def _button(name=""): + return {"role": "button", "name": name} + + +def _link(name): + return {"role": "link", "name": name} + + +def _image(alt=""): + return {"role": "image", "name": alt} + + +def _root(*children): + return {"role": "WebArea", "name": "Page", "children": list(children)} + + +class TestWalkBasics(unittest.TestCase): + + def test_rejects_non_dict(self): + with self.assertRaises(ScreenReaderError): + walk_tree("not a tree") # type: ignore[arg-type] + + def test_heading_spoken_with_level(self): + t = walk_tree(_root(_heading(1, "Welcome"))) + self.assertTrue(any("heading level 1" in u.text for u in t.utterances)) + + def test_button_spoken(self): + t = walk_tree(_root(_button("Save"))) + self.assertTrue(any("button: Save" in u.text for u in t.utterances)) + + def test_link_spoken(self): + t = walk_tree(_root(_link("Documentation"))) + self.assertTrue(any("link: Documentation" in u.text for u in t.utterances)) + + def test_image_with_alt(self): + t = walk_tree(_root(_image("Company logo"))) + self.assertTrue(any("image: Company logo" in u.text for u in t.utterances)) + + +class TestViolations(unittest.TestCase): + + def test_unnamed_button(self): + t = walk_tree(_root(_button(""))) + kinds = [v.kind for v in t.violations] + self.assertIn(ViolationKind.UNNAMED_INTERACTIVE, kinds) + self.assertIn(ViolationKind.EMPTY_BUTTON, kinds) + + def test_generic_link_text(self): + t = walk_tree(_root(_link("click here"))) + self.assertTrue(any(v.kind == ViolationKind.GENERIC_LINK_TEXT for v in t.violations)) + + def test_descriptive_link_passes(self): + t = walk_tree(_root(_link("Open the user guide"))) + self.assertFalse(any(v.kind == ViolationKind.GENERIC_LINK_TEXT for v in t.violations)) + + def test_missing_alt(self): + t = walk_tree(_root(_image(""))) + self.assertTrue(any(v.kind == ViolationKind.MISSING_ALT for v in t.violations)) + + def test_decorative_image_no_violation(self): + node = {"role": "image", "name": "", "decorative": True} + t = walk_tree(_root(node)) + self.assertFalse(any(v.kind == ViolationKind.MISSING_ALT for v in t.violations)) + + def test_heading_skip_detected(self): + t = walk_tree(_root(_heading(1, "A"), _heading(3, "B"))) + self.assertTrue(any(v.kind == ViolationKind.HEADING_SKIP for v in t.violations)) + + def test_heading_no_skip_with_h2_between(self): + t = walk_tree(_root(_heading(1, "A"), _heading(2, "B"), _heading(3, "C"))) + self.assertFalse(any(v.kind == ViolationKind.HEADING_SKIP for v in t.violations)) + + +class TestNested(unittest.TestCase): + + def test_walks_children_in_order(self): + tree = _root( + _heading(1, "First"), + {"role": "navigation", "name": "Main", + "children": [_link("Home"), _link("About")]}, + ) + t = walk_tree(tree) + order = [u.text for u in t.utterances] + self.assertEqual(order[0], "heading level 1: First") + self.assertIn("navigation: Main", order) + self.assertIn("link: Home", order) + self.assertIn("link: About", order) + + def test_static_text(self): + tree = _root({"role": "text", "name": "Welcome to the demo."}) + t = walk_tree(tree) + self.assertTrue(any("Welcome to the demo" in u.text for u in t.utterances)) + + +class TestSpeechAndAssertions(unittest.TestCase): + + def test_speech_joins(self): + t = walk_tree(_root(_heading(1, "A"), _button("Save"))) + self.assertIn("heading level 1: A", t.speech()) + self.assertIn("button: Save", t.speech()) + + def test_assert_no_violations_pass(self): + assert_no_violations(walk_tree(_root(_heading(1, "A"), _button("Save")))) + + def test_assert_no_violations_fail(self): + with self.assertRaises(ScreenReaderError): + assert_no_violations(walk_tree(_root(_button("")))) + + def test_assert_reads_pass(self): + u = assert_reads(walk_tree(_root(_button("Save"))), "Save") + self.assertIsInstance(u, Utterance) + + def test_assert_reads_fail(self): + with self.assertRaises(ScreenReaderError): + assert_reads(walk_tree(_root(_button("Save"))), "Cancel") + + def test_assert_reads_empty_phrase(self): + with self.assertRaises(ScreenReaderError): + assert_reads(ScreenReaderTranscript(), "") + + def test_assert_no_violations_rejects_bad_arg(self): + with self.assertRaises(ScreenReaderError): + assert_no_violations("not a transcript") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_session_to_test.py b/test/unit_test/test_session_to_test.py new file mode 100644 index 0000000..6756e61 --- /dev/null +++ b/test/unit_test/test_session_to_test.py @@ -0,0 +1,214 @@ +"""Unit tests for je_web_runner.utils.session_to_test.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.session_to_test.converter import ( + ConversionResult, + ConversionStats, + SessionToTestError, + convert_events, + convert_generic_events, + convert_rrweb_events, + write_actions_json, +) + + +def _rrweb_meta(href): + return {"type": 4, "timestamp": 0, "data": {"href": href, "width": 1280, "height": 720}} + + +def _rrweb_click(node_id, ts=1000): + return {"type": 3, "timestamp": ts, "data": {"source": 2, "type": 2, "id": node_id}} + + +def _rrweb_input(node_id, text, ts=2000): + return {"type": 3, "timestamp": ts, "data": {"source": 5, "id": node_id, "text": text}} + + +def _rrweb_scroll(x, y, ts=3000): + return {"type": 3, "timestamp": ts, "data": {"source": 3, "x": x, "y": y}} + + +class TestRrweb(unittest.TestCase): + + def test_meta_to_navigate(self): + result = convert_rrweb_events([ + _rrweb_meta("https://example.com"), + _rrweb_click(7), + ]) + self.assertEqual(result.actions[0], {"WR_to_url": ["https://example.com"]}) + self.assertEqual(result.actions[1]["WR_click_element"][0], "css selector") + self.assertIn("data-rrweb-id", result.actions[1]["WR_click_element"][1]) + + def test_input_event(self): + result = convert_rrweb_events([ + _rrweb_meta("https://x"), + _rrweb_input(3, "hello"), + ]) + self.assertEqual( + result.actions[1], + {"WR_input_to_element": ["css selector", '[data-rrweb-id="3"]', "hello"]}, + ) + + def test_scroll_becomes_comment(self): + result = convert_rrweb_events([ + _rrweb_meta("https://x"), + _rrweb_scroll(0, 500), + ]) + self.assertEqual(result.actions[1], {"WR_comment": ["scroll to 0,500"]}) + + def test_full_snapshot_skipped(self): + result = convert_rrweb_events([ + _rrweb_meta("https://x"), + {"type": 2, "timestamp": 0, "data": {}}, + _rrweb_click(1), + ]) + self.assertEqual(result.stats.actions_emitted, 2) + self.assertGreaterEqual(result.stats.skipped_events, 1) + + def test_mouse_without_id_skipped(self): + result = convert_rrweb_events([ + _rrweb_meta("https://x"), + {"type": 3, "timestamp": 1, "data": {"source": 2, "type": 2}}, + _rrweb_click(1), + ]) + # Only meta + valid click → 2 actions + self.assertEqual(result.stats.actions_emitted, 2) + self.assertGreaterEqual(result.stats.skipped_events, 1) + + def test_empty_emits_error(self): + with self.assertRaises(SessionToTestError): + convert_rrweb_events([]) + + def test_unknown_event_type_skipped(self): + result = convert_rrweb_events([ + _rrweb_meta("https://x"), + {"type": 99, "timestamp": 0, "data": {}}, + ]) + self.assertGreaterEqual(result.stats.skipped_events, 1) + + def test_meta_without_href_skipped(self): + with self.assertRaises(SessionToTestError): + convert_rrweb_events([ + {"type": 4, "timestamp": 0, "data": {}}, + ]) + + def test_non_list_rejected(self): + with self.assertRaises(SessionToTestError): + convert_rrweb_events({"not": "a list"}) # type: ignore[arg-type] + + +class TestGeneric(unittest.TestCase): + + def test_navigate(self): + result = convert_generic_events([ + {"kind": "navigate", "url": "https://x", "timestamp": 0}, + ]) + self.assertEqual(result.actions[0], {"WR_to_url": ["https://x"]}) + + def test_click_with_dict_target(self): + result = convert_generic_events([ + {"kind": "click", "target": {"by": "id", "value": "submit"}}, + ]) + self.assertEqual(result.actions[0], {"WR_click_element": ["id", "submit"]}) + + def test_click_with_string_target(self): + result = convert_generic_events([ + {"kind": "click", "target": "#submit"}, + ]) + self.assertEqual(result.actions[0], {"WR_click_element": ["css selector", "#submit"]}) + + def test_input_value(self): + result = convert_generic_events([ + {"kind": "input", "target": "#name", "value": "alice"}, + ]) + self.assertEqual( + result.actions[0], + {"WR_input_to_element": ["css selector", "#name", "alice"]}, + ) + + def test_submit_with_and_without_target(self): + result = convert_generic_events([ + {"kind": "submit", "target": "#form"}, + {"kind": "submit"}, + ]) + self.assertEqual(result.actions[0], {"WR_submit_element": ["css selector", "#form"]}) + self.assertEqual(result.actions[1], {"WR_comment": ["submit form (no target)"]}) + + def test_wait(self): + result = convert_generic_events([{"kind": "wait", "seconds": 1.5}]) + self.assertEqual(result.actions[0], {"WR_implicitly_wait": [1.5]}) + + def test_wait_bad_seconds_skipped(self): + with self.assertRaises(SessionToTestError): + convert_generic_events([{"kind": "wait", "seconds": "soon"}]) + + def test_unknown_kind_skipped(self): + with self.assertRaises(SessionToTestError): + convert_generic_events([{"kind": "wat"}]) + + +class TestAutoDetect(unittest.TestCase): + + def test_rrweb_detected_by_int_type(self): + events = [_rrweb_meta("https://x"), _rrweb_click(1)] + result = convert_events(events) + self.assertEqual(result.actions[0], {"WR_to_url": ["https://x"]}) + + def test_generic_detected(self): + result = convert_events([{"kind": "navigate", "url": "https://x"}]) + self.assertEqual(result.actions[0], {"WR_to_url": ["https://x"]}) + + def test_empty_rejected(self): + with self.assertRaises(SessionToTestError): + convert_events([]) + + +class TestFromFile(unittest.TestCase): + + def test_load_file(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "ev.json" + path.write_text(json.dumps([ + {"kind": "navigate", "url": "https://x"}, + ]), encoding="utf-8") + result = convert_events(path) + self.assertEqual(result.actions[0], {"WR_to_url": ["https://x"]}) + + def test_load_file_with_envelope(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "ev.json" + path.write_text(json.dumps({"events": [ + {"kind": "navigate", "url": "https://x"}, + ]}), encoding="utf-8") + result = convert_events(path) + self.assertEqual(len(result.actions), 1) + + def test_missing_file(self): + with self.assertRaises(SessionToTestError): + convert_events("/no/such.json") + + def test_bad_payload_type(self): + with self.assertRaises(SessionToTestError): + convert_events(123) # type: ignore[arg-type] + + +class TestWriteActions(unittest.TestCase): + + def test_write(self): + result = ConversionResult( + actions=[{"WR_to_url": ["https://x"]}], + stats=ConversionStats(input_events=1, actions_emitted=1), + ) + with tempfile.TemporaryDirectory() as tmp: + out = write_actions_json(result, Path(tmp) / "actions.json") + self.assertEqual( + json.loads(out.read_text(encoding="utf-8")), + [{"WR_to_url": ["https://x"]}], + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_sla_tracker.py b/test/unit_test/test_sla_tracker.py new file mode 100644 index 0000000..638b31f --- /dev/null +++ b/test/unit_test/test_sla_tracker.py @@ -0,0 +1,190 @@ +"""Unit tests for je_web_runner.utils.sla_tracker.""" +import json +import tempfile +import unittest +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from je_web_runner.utils.sla_tracker.tracker import ( + BucketResult, + SlaReport, + SlaTarget, + SlaTrackerError, + SuiteRun, + assert_meets_sla, + compute_sla, + load_runs, + report_markdown, +) + + +_BASE = datetime(2026, 5, 1, 12, 0, tzinfo=timezone.utc) + + +def _run(suite="checkout", days_offset=0, duration=300, passed=True): + return SuiteRun( + suite=suite, + started_at=_BASE + timedelta(days=days_offset), + duration_seconds=duration, + passed=passed, + ) + + +class TestSuiteRun(unittest.TestCase): + + def test_rejects_empty_suite(self): + with self.assertRaises(SlaTrackerError): + SuiteRun(suite="", started_at=_BASE, duration_seconds=0, + passed=True) + + def test_rejects_naive_datetime(self): + with self.assertRaises(SlaTrackerError): + SuiteRun(suite="x", started_at=datetime(2026, 1, 1), + duration_seconds=0, passed=True) + + def test_rejects_negative_duration(self): + with self.assertRaises(SlaTrackerError): + SuiteRun(suite="x", started_at=_BASE, duration_seconds=-1, + passed=True) + + +class TestSlaTarget(unittest.TestCase): + + def test_bad_duration(self): + with self.assertRaises(SlaTrackerError): + SlaTarget(max_duration_seconds=0, target_pass_pct=95) + + def test_bad_pct(self): + with self.assertRaises(SlaTrackerError): + SlaTarget(max_duration_seconds=600, target_pass_pct=0) + with self.assertRaises(SlaTrackerError): + SlaTarget(max_duration_seconds=600, target_pass_pct=150) + + +class TestComputeSla(unittest.TestCase): + + def test_all_met(self): + runs = [_run(duration=300) for _ in range(5)] + target = SlaTarget(max_duration_seconds=600, target_pass_pct=95) + report = compute_sla(runs, target) + self.assertEqual(report.overall_pct, 100.0) + self.assertTrue(report.passed()) + + def test_partial_met(self): + runs = [ + _run(duration=300), + _run(duration=900), # over budget + ] + report = compute_sla(runs, SlaTarget(600, 95)) + self.assertEqual(report.overall_pct, 50.0) + self.assertFalse(report.passed()) + + def test_week_bucketing(self): + runs = [ + _run(days_offset=0), + _run(days_offset=10), # different ISO week + ] + report = compute_sla(runs, SlaTarget(600, 95)) + self.assertEqual(len(report.buckets), 2) + + def test_day_bucketing(self): + runs = [ + _run(days_offset=0), + _run(days_offset=1), + ] + report = compute_sla(runs, SlaTarget(600, 95), bucket="day") + self.assertEqual(len(report.buckets), 2) + + def test_suite_filter(self): + runs = [ + _run(suite="checkout"), + _run(suite="profile"), + ] + report = compute_sla(runs, SlaTarget(600, 95), suite="checkout") + self.assertEqual(report.overall_runs, 1) + + def test_bad_bucket(self): + with self.assertRaises(SlaTrackerError): + compute_sla([], SlaTarget(600, 95), bucket="hour") + + def test_rejects_non_run(self): + with self.assertRaises(SlaTrackerError): + compute_sla(["nope"], SlaTarget(600, 95)) # type: ignore[list-item] + + def test_empty_runs(self): + report = compute_sla([], SlaTarget(600, 95)) + self.assertEqual(report.overall_runs, 0) + self.assertEqual(report.overall_pct, 0.0) + + +class TestLoadRuns(unittest.TestCase): + + def test_load(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "l.json" + path.write_text(json.dumps({"runs": [ + {"suite": "x", "time": "2026-05-01T12:00:00Z", + "duration_seconds": 100, "passed": True}, + {"path": "y", "time": "2026-05-02T12:00:00Z", + "duration_seconds": 200}, + {"suite": "skipme", "time": "bad timestamp", + "duration_seconds": 100}, + {"suite": "skip", "time": "2026-05-01T12:00:00Z"}, # no duration + ]}), encoding="utf-8") + runs = load_runs(path) + self.assertEqual(len(runs), 2) + self.assertEqual(runs[1].suite, "y") + + def test_missing(self): + with self.assertRaises(SlaTrackerError): + load_runs("/no/such/file.json") + + def test_bad_json(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text("nope", encoding="utf-8") + with self.assertRaises(SlaTrackerError): + load_runs(p) + + def test_missing_runs(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text(json.dumps({"x": []}), encoding="utf-8") + with self.assertRaises(SlaTrackerError): + load_runs(p) + + +class TestAssertions(unittest.TestCase): + + def test_meets_pass(self): + runs = [_run(duration=100) for _ in range(10)] + report = compute_sla(runs, SlaTarget(600, 95)) + assert_meets_sla(report) + + def test_meets_fail(self): + runs = [_run(duration=900) for _ in range(10)] + report = compute_sla(runs, SlaTarget(600, 95)) + with self.assertRaises(SlaTrackerError): + assert_meets_sla(report) + + def test_rejects_non_report(self): + with self.assertRaises(SlaTrackerError): + assert_meets_sla("nope") # type: ignore[arg-type] + + +class TestMarkdown(unittest.TestCase): + + def test_renders(self): + runs = [_run(duration=100), _run(duration=900)] + report = compute_sla(runs, SlaTarget(600, 95)) + md = report_markdown(report) + self.assertIn("SLA", md) + self.assertIn("50.0", md) + + def test_rejects_non_report(self): + with self.assertRaises(SlaTrackerError): + report_markdown("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_slack_digest.py b/test/unit_test/test_slack_digest.py new file mode 100644 index 0000000..e0f6df6 --- /dev/null +++ b/test/unit_test/test_slack_digest.py @@ -0,0 +1,146 @@ +"""Unit tests for je_web_runner.utils.slack_digest.""" +import json +import unittest + +from je_web_runner.utils.slack_digest.digest import ( + CostTrend, + DigestInputs, + FlakeStat, + RiskyPr, + SlackDigestError, + build_slack_blocks, + build_slack_payload, + build_teams_card, + render_plain_text, +) + + +def _full_inputs(): + return DigestInputs( + period_label="last 7 days", + flake_changes=[ + FlakeStat(test_id="t1.json", action="added", flake_score=0.55), + FlakeStat(test_id="t2.json", action="released", flake_score=0.05), + FlakeStat(test_id="t3.json", action="still_in", flake_score=0.4), + ], + risky_prs=[ + RiskyPr(number=42, title="Auth rewrite", score=78.0, + url="https://github.com/x/y/pull/42"), + RiskyPr(number=43, title="Tiny tweak", score=20.0), + ], + cost=CostTrend(current_usd=120.0, previous_usd=100.0), + suite_pass_rate=0.94, + suite_pass_rate_previous=0.91, + extra_lines=["3 quarantined tests released back to main"], + ) + + +class TestInputsValidation(unittest.TestCase): + + def test_bad_pass_rate(self): + with self.assertRaises(SlackDigestError): + DigestInputs(suite_pass_rate=1.5) + + def test_bad_previous_pass_rate(self): + with self.assertRaises(SlackDigestError): + DigestInputs(suite_pass_rate_previous=-1) + + +class TestCostTrend(unittest.TestCase): + + def test_delta(self): + self.assertEqual(CostTrend(current_usd=110, previous_usd=100).delta_pct(), 10.0) + + def test_delta_zero_previous(self): + self.assertEqual(CostTrend(current_usd=100, previous_usd=0).delta_pct(), 100.0) + self.assertEqual(CostTrend(current_usd=0, previous_usd=0).delta_pct(), 0.0) + + +class TestSlackBlocks(unittest.TestCase): + + def test_renders_all_sections(self): + blocks = build_slack_blocks(_full_inputs()) + block_types = [b["type"] for b in blocks] + # header + 5 sections + self.assertEqual(block_types[0], "header") + self.assertGreaterEqual(block_types.count("section"), 4) + + def test_minimum_input_renders_nothing_notable(self): + blocks = build_slack_blocks(DigestInputs()) + joined = json.dumps(blocks) + self.assertIn("Nothing notable", joined) + + def test_flake_block_omitted_when_empty(self): + blocks = build_slack_blocks(DigestInputs(suite_pass_rate=0.9)) + joined = json.dumps(blocks) + self.assertNotIn("Quarantine activity", joined) + + def test_high_risk_pr_uses_url(self): + blocks = build_slack_blocks(DigestInputs( + risky_prs=[RiskyPr(number=99, title="Big change", score=80.0, + url="https://gh/x/y/pull/99")], + )) + joined = json.dumps(blocks) + self.assertIn("https://gh/x/y/pull/99", joined) + self.assertIn("Big change", joined) + + def test_pass_rate_block_includes_delta(self): + blocks = build_slack_blocks(DigestInputs( + suite_pass_rate=0.95, suite_pass_rate_previous=0.90, + )) + joined = json.dumps(blocks) + self.assertIn("95.0%", joined) + self.assertIn("pts vs prev", joined) + + def test_rejects_non_inputs(self): + with self.assertRaises(SlackDigestError): + build_slack_blocks("nope") # type: ignore[arg-type] + + def test_extra_lines_rendered(self): + blocks = build_slack_blocks(DigestInputs( + extra_lines=["something interesting"], + )) + joined = json.dumps(blocks) + self.assertIn("something interesting", joined) + + +class TestPayload(unittest.TestCase): + + def test_without_channel(self): + payload = build_slack_payload(_full_inputs()) + self.assertNotIn("channel", payload) + self.assertIn("blocks", payload) + + def test_with_channel(self): + payload = build_slack_payload(_full_inputs(), channel="#qa") + self.assertEqual(payload["channel"], "#qa") + + def test_bad_channel(self): + with self.assertRaises(SlackDigestError): + build_slack_payload(_full_inputs(), channel=123) # type: ignore[arg-type] + + +class TestTeamsCard(unittest.TestCase): + + def test_basic_shape(self): + card = build_teams_card(_full_inputs()) + self.assertEqual(card["type"], "AdaptiveCard") + self.assertGreater(len(card["body"]), 1) + # Header rendered bolder + self.assertEqual(card["body"][0]["weight"], "Bolder") + + +class TestPlainText(unittest.TestCase): + + def test_renders_text(self): + text = render_plain_text(_full_inputs()) + self.assertIn("Test digest", text) + self.assertIn("Quarantine activity", text) + + def test_minimum_input(self): + text = render_plain_text(DigestInputs()) + self.assertIn("Nothing notable", text) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_sri_verify.py b/test/unit_test/test_sri_verify.py new file mode 100644 index 0000000..eb1dce7 --- /dev/null +++ b/test/unit_test/test_sri_verify.py @@ -0,0 +1,190 @@ +"""Unit tests for je_web_runner.utils.sri_verify.""" +import unittest + +from je_web_runner.utils.sri_verify.verify import ( + ResourceTag, + SriFinding, + SriVerifyError, + Verdict, + assert_all_ok, + compute_integrity, + parse_html, + verify_html, + verify_tag, +) + + +_JS = b"console.log('hi');" +_GOOD = compute_integrity(_JS, "sha384") + + +class TestComputeIntegrity(unittest.TestCase): + + def test_sha384(self): + out = compute_integrity(b"hello", "sha384") + self.assertTrue(out.startswith("sha384-")) + + def test_sha256(self): + out = compute_integrity(b"hello", "sha256") + self.assertTrue(out.startswith("sha256-")) + + def test_unknown_alg(self): + with self.assertRaises(SriVerifyError): + compute_integrity(b"x", "blake3") + + def test_payload_must_be_bytes(self): + with self.assertRaises(SriVerifyError): + compute_integrity("text", "sha256") # type: ignore[arg-type] + + +class TestParseHtml(unittest.TestCase): + + def test_script_with_integrity(self): + html = ( + f'' + ) + tags = parse_html(html) + self.assertEqual(len(tags), 1) + self.assertEqual(tags[0].tag, "script") + self.assertEqual(tags[0].integrity, _GOOD) + self.assertEqual(tags[0].crossorigin, "anonymous") + + def test_link_stylesheet(self): + html = '' + tags = parse_html(html) + self.assertEqual(tags[0].tag, "link") + + def test_link_non_stylesheet_skipped(self): + html = '' + self.assertEqual(parse_html(html), []) + + def test_script_without_src_skipped(self): + html = "" + self.assertEqual(parse_html(html), []) + + def test_rejects_non_string(self): + with self.assertRaises(SriVerifyError): + parse_html(123) # type: ignore[arg-type] + + +class TestVerifyTag(unittest.TestCase): + + def test_missing_integrity(self): + tag = ResourceTag(tag="script", url="https://cdn/x.js") + finding = verify_tag(tag) + self.assertEqual(finding.verdict, Verdict.MISSING) + + def test_weak_alg(self): + tag = ResourceTag( + tag="script", url="https://cdn/x.js", + integrity="sha1-abcdef==", crossorigin="anonymous", + ) + self.assertEqual(verify_tag(tag).verdict, Verdict.WEAK_ALG) + + def test_unknown_format(self): + tag = ResourceTag( + tag="script", url="https://cdn/x.js", + integrity="not-an-integrity", + ) + # 'not-an-integrity' parses as alg='not', so it falls into WEAK_ALG + # (not in strong set). Use a value with no dash for UNKNOWN_FORMAT: + tag2 = ResourceTag( + tag="script", url="https://cdn/x.js", integrity="garbage", + ) + self.assertEqual(verify_tag(tag2).verdict, Verdict.UNKNOWN_FORMAT) + + def test_cross_origin_needs_crossorigin(self): + tag = ResourceTag( + tag="script", url="https://cdn/x.js", integrity=_GOOD, + ) + self.assertEqual(verify_tag(tag).verdict, Verdict.NO_CROSSORIGIN) + + def test_same_origin_no_crossorigin_ok(self): + tag = ResourceTag(tag="script", url="/local.js", integrity=_GOOD) + self.assertEqual(verify_tag(tag).verdict, Verdict.OK) + + def test_payload_match(self): + tag = ResourceTag( + tag="script", url="https://cdn/x.js", + integrity=_GOOD, crossorigin="anonymous", + ) + self.assertEqual(verify_tag(tag, payload=_JS).verdict, Verdict.OK) + + def test_payload_mismatch(self): + tag = ResourceTag( + tag="script", url="https://cdn/x.js", + integrity=_GOOD, crossorigin="anonymous", + ) + self.assertEqual( + verify_tag(tag, payload=b"different").verdict, + Verdict.HASH_MISMATCH, + ) + + def test_disable_crossorigin_check(self): + tag = ResourceTag( + tag="script", url="https://cdn/x.js", integrity=_GOOD, + ) + self.assertEqual( + verify_tag(tag, require_crossorigin=False).verdict, Verdict.OK, + ) + + def test_rejects_non_tag(self): + with self.assertRaises(SriVerifyError): + verify_tag("nope") # type: ignore[arg-type] + + +class TestVerifyHtml(unittest.TestCase): + + def test_no_provider(self): + html = ( + f'' + '' + ) + findings = verify_html(html) + verdicts = {f.verdict for f in findings} + self.assertIn(Verdict.OK, verdicts) + self.assertIn(Verdict.MISSING, verdicts) + + def test_with_provider(self): + html = ( + f'' + ) + findings = verify_html(html, payload_provider=lambda _u: _JS) + self.assertEqual(findings[0].verdict, Verdict.OK) + + def test_provider_returning_bad_payload(self): + html = ( + f'' + ) + with self.assertRaises(SriVerifyError): + verify_html(html, payload_provider=lambda _u: "not bytes") + + def test_provider_raising(self): + html = ( + f'' + ) + def boom(_): + raise RuntimeError("network down") + with self.assertRaises(SriVerifyError): + verify_html(html, payload_provider=boom) + + +class TestAssertAllOk(unittest.TestCase): + + def test_pass(self): + assert_all_ok([SriFinding(tag="script", url="/x", verdict=Verdict.OK)]) + + def test_fail(self): + with self.assertRaises(SriVerifyError): + assert_all_ok([ + SriFinding(tag="script", url="/x", verdict=Verdict.MISSING), + ]) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_sse_assert.py b/test/unit_test/test_sse_assert.py new file mode 100644 index 0000000..f59b35c --- /dev/null +++ b/test/unit_test/test_sse_assert.py @@ -0,0 +1,205 @@ +"""Unit tests for je_web_runner.utils.sse_assert.""" +import json +import unittest + +from je_web_runner.utils.sse_assert.stream import ( + SseAssertError, + SseEvent, + SseRecorder, + assert_data_contains, + assert_event_count, + assert_json_shape, + assert_received_event, + assert_strictly_increasing_ids, + parse_sse_stream, + to_json, +) + + +class TestParse(unittest.TestCase): + + def test_basic_event(self): + stream = "data: hello\n\n" + events = parse_sse_stream(stream) + self.assertEqual(len(events), 1) + self.assertEqual(events[0].event, "message") + self.assertEqual(events[0].data, "hello") + + def test_event_type(self): + stream = "event: ping\ndata: 1\n\n" + events = parse_sse_stream(stream) + self.assertEqual(events[0].event, "ping") + + def test_id_and_retry(self): + stream = "id: 42\nretry: 1500\ndata: x\n\n" + events = parse_sse_stream(stream) + self.assertEqual(events[0].id, "42") + self.assertEqual(events[0].retry, 1500) + + def test_multiline_data(self): + stream = "data: line one\ndata: line two\n\n" + events = parse_sse_stream(stream) + self.assertEqual(events[0].data, "line one\nline two") + + def test_comment_ignored(self): + stream = ": keep-alive\ndata: real\n\n" + events = parse_sse_stream(stream) + self.assertEqual(events[0].data, "real") + + def test_crlf_normalised(self): + stream = "data: r\r\n\r\n" + events = parse_sse_stream(stream) + self.assertEqual(events[0].data, "r") + + def test_multiple_events(self): + stream = "data: a\n\ndata: b\n\ndata: c\n\n" + events = parse_sse_stream(stream) + self.assertEqual([e.data for e in events], ["a", "b", "c"]) + + def test_empty_returns_empty(self): + self.assertEqual(parse_sse_stream("\n\n\n"), []) + + def test_rejects_non_string(self): + with self.assertRaises(SseAssertError): + parse_sse_stream(123) # type: ignore[arg-type] + + +class TestRecorder(unittest.TestCase): + + def test_feed_complete_event(self): + rec = SseRecorder() + n = rec.feed("data: hello\n\n") + self.assertEqual(n, 1) + self.assertEqual(len(rec), 1) + + def test_feed_partial_then_complete(self): + rec = SseRecorder() + self.assertEqual(rec.feed("data: hel"), 0) + self.assertEqual(rec.feed("lo\n\n"), 1) + self.assertEqual(rec.events()[0].data, "hello") + + def test_feed_event_helper(self): + rec = SseRecorder() + rec.feed_event(SseEvent(event="msg", data="x")) + self.assertEqual(len(rec), 1) + + def test_clear(self): + rec = SseRecorder() + rec.feed("data: a\n\n") + rec.clear() + self.assertEqual(len(rec), 0) + + def test_filter_by_event_type(self): + rec = SseRecorder() + rec.feed("event: ping\ndata: 1\n\nevent: pong\ndata: 2\n\n") + self.assertEqual(len(rec.events(event_type="ping")), 1) + + def test_rejects_non_string_chunk(self): + with self.assertRaises(SseAssertError): + SseRecorder().feed(123) # type: ignore[arg-type] + + +class TestAssertCount(unittest.TestCase): + + def test_in_range(self): + rec = SseRecorder() + rec.feed("data: a\n\ndata: b\n\n") + self.assertEqual(assert_event_count(rec, minimum=2, maximum=5), 2) + + def test_below_minimum(self): + with self.assertRaises(SseAssertError): + assert_event_count(SseRecorder(), minimum=1) + + def test_above_maximum(self): + rec = SseRecorder() + rec.feed("data: a\n\ndata: b\n\n") + with self.assertRaises(SseAssertError): + assert_event_count(rec, maximum=1) + + def test_filter_by_event_type(self): + rec = SseRecorder() + rec.feed("event: ping\ndata: 1\n\nevent: pong\ndata: 2\n\n") + self.assertEqual(assert_event_count(rec, event_type="ping", minimum=1, maximum=1), 1) + + +class TestAssertReceived(unittest.TestCase): + + def test_match(self): + rec = SseRecorder() + rec.feed("data: success\n\n") + e = assert_received_event(rec, lambda e: "success" in e.data) + self.assertEqual(e.data, "success") + + def test_no_match(self): + rec = SseRecorder() + rec.feed("data: x\n\n") + with self.assertRaises(SseAssertError): + assert_received_event(rec, lambda e: False) + + +class TestAssertDataContains(unittest.TestCase): + + def test_match(self): + rec = SseRecorder() + rec.feed("data: hello world\n\n") + e = assert_data_contains(rec, "world") + self.assertEqual(e.data, "hello world") + + def test_miss(self): + with self.assertRaises(SseAssertError): + assert_data_contains(SseRecorder(), "x") + + def test_empty_needle(self): + with self.assertRaises(SseAssertError): + assert_data_contains(SseRecorder(), "") + + +class TestAssertJsonShape(unittest.TestCase): + + def test_match(self): + rec = SseRecorder() + rec.feed('data: {"id":1,"v":2}\n\n') + e = assert_json_shape(rec, ["id", "v"]) + self.assertEqual(e.as_json()["id"], 1) + + def test_missing_key(self): + rec = SseRecorder() + rec.feed('data: {"id":1}\n\n') + with self.assertRaises(SseAssertError): + assert_json_shape(rec, ["id", "missing"]) + + def test_non_json_skipped(self): + rec = SseRecorder() + rec.feed("data: not json\n\ndata: {\"k\":true}\n\n") + assert_json_shape(rec, ["k"]) + + def test_empty_keys_rejected(self): + with self.assertRaises(SseAssertError): + assert_json_shape(SseRecorder(), []) + + +class TestStrictlyIncreasing(unittest.TestCase): + + def test_pass(self): + rec = SseRecorder() + rec.feed("id: a\ndata: 1\n\nid: b\ndata: 2\n\n") + assert_strictly_increasing_ids(rec) + + def test_fail_on_duplicate(self): + rec = SseRecorder() + rec.feed("id: a\ndata: 1\n\nid: a\ndata: 2\n\n") + with self.assertRaises(SseAssertError): + assert_strictly_increasing_ids(rec) + + +class TestToJson(unittest.TestCase): + + def test_roundtrip(self): + rec = SseRecorder() + rec.feed("data: hi\n\n") + loaded = json.loads(to_json(rec.events())) + self.assertEqual(loaded[0]["data"], "hi") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_story_to_actions.py b/test/unit_test/test_story_to_actions.py new file mode 100644 index 0000000..9ac69f3 --- /dev/null +++ b/test/unit_test/test_story_to_actions.py @@ -0,0 +1,213 @@ +"""Unit tests for je_web_runner.utils.story_to_actions.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.story_to_actions.generator import ( + ALLOWED_ACTIONS, + FigmaHint, + StoryPrompt, + StoryToActionsError, + build_prompt_text, + generate_actions, + validate_actions, + write_actions_json, +) + + +class StubClient: + def __init__(self, response): + self.response = response + self.last_prompt = None + + def generate(self, prompt_text): + self.last_prompt = prompt_text + if isinstance(self.response, Exception): + raise self.response + return self.response + + +class TestStoryPrompt(unittest.TestCase): + + def test_rejects_empty_story(self): + with self.assertRaises(StoryToActionsError): + StoryPrompt(story="") + + def test_rejects_whitespace_story(self): + with self.assertRaises(StoryToActionsError): + StoryPrompt(story=" \n ") + + +class TestBuildPrompt(unittest.TestCase): + + def test_includes_story_and_url(self): + prompt = StoryPrompt(story="Add to cart", start_url="https://shop/") + text = build_prompt_text(prompt) + self.assertIn("Add to cart", text) + self.assertIn("https://shop/", text) + + def test_includes_figma_hints(self): + prompt = StoryPrompt( + story="Click checkout", + figma_hints=[FigmaHint(name="checkout_btn", type="button", + selector_hint="[data-test=checkout]", + text="Checkout")], + ) + text = build_prompt_text(prompt) + self.assertIn("checkout_btn", text) + self.assertIn("[data-test=checkout]", text) + + def test_includes_style_notes(self): + text = build_prompt_text(StoryPrompt(story="x", style_notes=["prefer id locators"])) + self.assertIn("prefer id locators", text) + + +class TestValidate(unittest.TestCase): + + def test_empty_rejected(self): + with self.assertRaises(StoryToActionsError): + validate_actions([]) + + def test_non_list_rejected(self): + with self.assertRaises(StoryToActionsError): + validate_actions({"WR_to_url": ["x"]}) # type: ignore[arg-type] + + def test_unknown_action_name(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_fly_to_moon": []}]) + + def test_multi_key_action_rejected(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_to_url": ["x"], "extra": []}]) + + def test_args_must_be_list(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_to_url": "x"}]) + + def test_to_url_needs_string(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_to_url": []}]) + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_to_url": [""]}]) + + def test_implicitly_wait_needs_number(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_implicitly_wait": ["soon"]}]) + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_implicitly_wait": [-1]}]) + + def test_click_needs_locator(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_click_element": ["#x"]}]) + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_click_element": ["unknown_by", "#x"]}]) + + def test_input_needs_text(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_input_to_element": ["id", "name"]}]) + + def test_assert_text_needs_expected(self): + with self.assertRaises(StoryToActionsError): + validate_actions([{"WR_assert_element_text": ["id", "x"]}]) + + def test_valid_passes(self): + validate_actions([ + {"WR_to_url": ["https://x"]}, + {"WR_click_element": ["id", "submit"]}, + {"WR_input_to_element": ["css selector", "#name", "alice"]}, + {"WR_assert_element_text": ["id", "status", "OK"]}, + {"WR_assert_element_visible": ["id", "ok"]}, + {"WR_implicitly_wait": [1.5]}, + {"WR_comment": ["done"]}, + ]) + + def test_allowed_actions_visible(self): + self.assertIn("WR_to_url", ALLOWED_ACTIONS) + + +class TestGenerateActions(unittest.TestCase): + + def test_happy_path(self): + client = StubClient(json.dumps([ + {"WR_to_url": ["https://x"]}, + {"WR_click_element": ["id", "submit"]}, + ])) + actions = generate_actions(StoryPrompt(story="open the page"), client) + self.assertEqual(len(actions), 2) + self.assertIsNotNone(client.last_prompt) + self.assertIn("open the page", client.last_prompt) + + def test_strips_markdown_fence(self): + client = StubClient('```json\n[{"WR_to_url": ["https://x"]}]\n```') + actions = generate_actions(StoryPrompt(story="x"), client) + self.assertEqual(actions, [{"WR_to_url": ["https://x"]}]) + + def test_prepends_start_url_when_missing(self): + client = StubClient(json.dumps([{"WR_click_element": ["id", "go"]}])) + actions = generate_actions( + StoryPrompt(story="x", start_url="https://shop/"), + client, + ) + self.assertEqual(actions[0], {"WR_to_url": ["https://shop/"]}) + self.assertEqual(len(actions), 2) + + def test_does_not_duplicate_start_url(self): + client = StubClient(json.dumps([ + {"WR_to_url": ["https://shop/"]}, + {"WR_click_element": ["id", "go"]}, + ])) + actions = generate_actions( + StoryPrompt(story="x", start_url="https://shop/"), + client, + ) + self.assertEqual(len(actions), 2) + self.assertEqual(actions[0], {"WR_to_url": ["https://shop/"]}) + + def test_invalid_action_propagates(self): + client = StubClient(json.dumps([{"WR_fake": []}])) + with self.assertRaises(StoryToActionsError): + generate_actions(StoryPrompt(story="x"), client) + + def test_client_error_wrapped(self): + client = StubClient(RuntimeError("network down")) + with self.assertRaises(StoryToActionsError): + generate_actions(StoryPrompt(story="x"), client) + + def test_non_string_response_rejected(self): + class WeirdClient: + def generate(self, _p): + return 42 + with self.assertRaises(StoryToActionsError): + generate_actions(StoryPrompt(story="x"), WeirdClient()) + + def test_bad_json_response(self): + client = StubClient("not json at all") + with self.assertRaises(StoryToActionsError): + generate_actions(StoryPrompt(story="x"), client) + + def test_non_list_response(self): + client = StubClient(json.dumps({"WR_to_url": ["x"]})) + with self.assertRaises(StoryToActionsError): + generate_actions(StoryPrompt(story="x"), client) + + +class TestWriteActions(unittest.TestCase): + + def test_write(self): + actions = [{"WR_to_url": ["https://x"]}] + with tempfile.TemporaryDirectory() as tmp: + out = write_actions_json(actions, Path(tmp) / "actions.json") + self.assertEqual( + json.loads(out.read_text(encoding="utf-8")), + actions, + ) + + def test_write_validates(self): + with tempfile.TemporaryDirectory() as tmp: + with self.assertRaises(StoryToActionsError): + write_actions_json([{"WR_fake": []}], Path(tmp) / "actions.json") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_auto_repair.py b/test/unit_test/test_test_auto_repair.py new file mode 100644 index 0000000..c599672 --- /dev/null +++ b/test/unit_test/test_test_auto_repair.py @@ -0,0 +1,217 @@ +"""Unit tests for je_web_runner.utils.test_auto_repair.""" +import json +import subprocess +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock + +from je_web_runner.utils.ai_assist.llm_assist import set_llm_callable +from je_web_runner.utils.failure_bundle.bundle import FailureBundle +from je_web_runner.utils.failure_triage.triage import TriageSignals +from je_web_runner.utils.test_auto_repair.repair import ( + RepairPlan, + TestAutoRepairError, + apply_repair, + collect_git_diff, + propose_repair, + render_repair_markdown, + repair_from_bundle, +) + + +_VALID_PAYLOAD = { + "summary": "Locator drifted; replaced #old-btn with [data-testid='submit']", + "confidence": 0.8, + "repaired_actions": [ + ["WR_save_test_object", {"object_type": "CSS_SELECTOR", + "test_object_name": "[data-testid='submit']"}], + ["WR_element_click", {"test_object_name": "[data-testid='submit']"}], + ], + "changes": [ + {"index": 0, "kind": "locator", + "before": "#old-btn", "after": "[data-testid='submit']", + "why": "DOM no longer has #old-btn"}, + ], + "risks": ["double-check submit handler still wires on test-id"], +} + + +class TestCollectGitDiff(unittest.TestCase): + + def test_returns_stdout_on_success(self): + fake = MagicMock(return_value=subprocess.CompletedProcess( + args=[], returncode=0, stdout="--- diff text ---", stderr="", + )) + text = collect_git_diff("/some/repo", runner=fake) + self.assertIn("diff text", text) + + def test_returns_empty_on_failure(self): + fake = MagicMock(return_value=subprocess.CompletedProcess( + args=[], returncode=128, stdout="", stderr="not a git repo", + )) + self.assertEqual(collect_git_diff("/x", runner=fake), "") + + def test_returns_empty_on_oserror(self): + def boom(*_a, **_kw): + raise OSError("git missing") + self.assertEqual(collect_git_diff("/x", runner=boom), "") + + def test_truncates_long_diffs(self): + fake = MagicMock(return_value=subprocess.CompletedProcess( + args=[], returncode=0, stdout="x" * 9999, stderr="", + )) + text = collect_git_diff("/x", runner=fake, max_chars=100) + self.assertLessEqual(len(text), 200) # 100 + truncated marker + self.assertIn("truncated", text) + + +class TestProposeRepair(unittest.TestCase): + + def setUp(self): + self.signals = TriageSignals( + test_name="login_test", + error_repr="TimeoutException: #old-btn not found", + error_signature="timeoutexception: #old-btn not found", + last_steps=[["WR_element_click", {"test_object_name": "#old-btn"}]], + ) + self.actions = [ + ["WR_save_test_object", {"object_type": "CSS_SELECTOR", + "test_object_name": "#old-btn"}], + ["WR_element_click", {"test_object_name": "#old-btn"}], + ] + + def tearDown(self): + set_llm_callable(None) + + def test_valid_payload_returns_plan(self): + set_llm_callable(lambda _p: json.dumps(_VALID_PAYLOAD)) + plan = propose_repair(self.actions, self.signals) + self.assertEqual(len(plan.repaired_actions), 2) + self.assertAlmostEqual(plan.confidence, 0.8) + self.assertEqual(plan.changes[0]["kind"], "locator") + + def test_missing_callable_raises(self): + set_llm_callable(None) + with self.assertRaises(TestAutoRepairError): + propose_repair(self.actions, self.signals) + + def test_non_list_actions_raise(self): + with self.assertRaises(TestAutoRepairError): + propose_repair("not a list", self.signals) # type: ignore[arg-type] + + def test_missing_repaired_actions_raises(self): + payload = dict(_VALID_PAYLOAD) + del payload["repaired_actions"] + set_llm_callable(lambda _p: json.dumps(payload)) + with self.assertRaises(TestAutoRepairError): + propose_repair(self.actions, self.signals) + + def test_repaired_actions_not_list_raises(self): + payload = dict(_VALID_PAYLOAD) + payload["repaired_actions"] = "oops" + set_llm_callable(lambda _p: json.dumps(payload)) + with self.assertRaises(TestAutoRepairError): + propose_repair(self.actions, self.signals) + + def test_invalid_json_raises(self): + set_llm_callable(lambda _p: "not json") + with self.assertRaises(TestAutoRepairError): + propose_repair(self.actions, self.signals) + + def test_confidence_clamped(self): + payload = dict(_VALID_PAYLOAD) + payload["confidence"] = 5.0 + set_llm_callable(lambda _p: json.dumps(payload)) + plan = propose_repair(self.actions, self.signals) + self.assertEqual(plan.confidence, 1.0) + + def test_string_risks_coerced(self): + payload = dict(_VALID_PAYLOAD) + payload["risks"] = "single risk note" + set_llm_callable(lambda _p: json.dumps(payload)) + plan = propose_repair(self.actions, self.signals) + self.assertEqual(plan.risks, ["single risk note"]) + + +class TestRepairFromBundle(unittest.TestCase): + + def tearDown(self): + set_llm_callable(None) + + def test_end_to_end(self): + with tempfile.TemporaryDirectory() as tmpdir: + action_path = Path(tmpdir) / "a.json" + action_path.write_text(json.dumps([ + ["WR_element_click", {"test_object_name": "#old-btn"}], + ]), encoding="utf-8") + bundle = FailureBundle(test_name="t", error_repr="boom") + bundle_path = bundle.write(Path(tmpdir) / "b.zip") + set_llm_callable(lambda _p: json.dumps(_VALID_PAYLOAD)) + fake_git = MagicMock(return_value=subprocess.CompletedProcess( + args=[], returncode=0, stdout="diff --git", stderr="", + )) + plan = repair_from_bundle( + action_path, bundle_path, + repo_dir=tmpdir, git_runner=fake_git, + ) + self.assertGreater(len(plan.repaired_actions), 0) + + def test_missing_action_file_raises(self): + with self.assertRaises(TestAutoRepairError): + repair_from_bundle("/no/such.json", "/no/such.zip") + + +class TestApplyRepair(unittest.TestCase): + + def test_writes_side_file_by_default(self): + with tempfile.TemporaryDirectory() as tmpdir: + src = Path(tmpdir) / "a.json" + src.write_text(json.dumps([["WR_x"]]), encoding="utf-8") + plan = RepairPlan( + summary="ok", confidence=0.9, + repaired_actions=[["WR_y"]], + ) + target = apply_repair(src, plan) + self.assertEqual(target.name, "a.json.repaired.json") + self.assertNotEqual(target, src) + payload = json.loads(target.read_text(encoding="utf-8")) + self.assertEqual(payload, [["WR_y"]]) + + def test_low_confidence_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + src = Path(tmpdir) / "a.json" + src.write_text("[]", encoding="utf-8") + plan = RepairPlan(summary="x", confidence=0.2, repaired_actions=[]) + with self.assertRaises(TestAutoRepairError): + apply_repair(src, plan) + + def test_explicit_output_path(self): + with tempfile.TemporaryDirectory() as tmpdir: + src = Path(tmpdir) / "a.json" + src.write_text("[]", encoding="utf-8") + out = Path(tmpdir) / "subdir" / "b.json" + plan = RepairPlan(summary="x", confidence=0.9, repaired_actions=[1]) + applied = apply_repair(src, plan, output_path=out) + self.assertEqual(applied, out) + self.assertTrue(out.exists()) + + +class TestRenderMarkdown(unittest.TestCase): + + def test_includes_changes_and_risks(self): + plan = RepairPlan( + summary="rewired locator", confidence=0.8, + repaired_actions=[], + changes=[{"index": 0, "kind": "locator", "why": "drift"}], + risks=["double check"], + ) + md = render_repair_markdown(plan) + self.assertIn("AI Test Auto-Repair", md) + self.assertIn("80%", md) + self.assertIn("`locator`", md) + self.assertIn("double check", md) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_categorizer.py b/test/unit_test/test_test_categorizer.py new file mode 100644 index 0000000..822c159 --- /dev/null +++ b/test/unit_test/test_test_categorizer.py @@ -0,0 +1,181 @@ +"""Unit tests for je_web_runner.utils.test_categorizer.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.test_categorizer.categorizer import ( + CategoryAssignment, + Rule, + TagDistribution, + TestCategorizerError, + aggregate, + categorize_actions, + categorize_dir, + categorize_file, +) + + +def _smoke(): + return [ + {"WR_to_url": ["https://x"]}, + {"WR_click_element": ["id", "go"]}, + {"WR_assert_element_text": ["id", "status", "OK"]}, + ] + + +def _perf(): + return [ + {"WR_to_url": ["https://x"]}, + {"WR_perf_capture_metrics": []}, + {"WR_lighthouse_run": []}, + ] + + +def _a11y(): + return [ + {"WR_to_url": ["https://x"]}, + {"WR_axe_audit": []}, + ] + + +def _security(): + return [ + {"WR_to_url": ["https://x"]}, + {"WR_csrf_check": []}, + ] + + +def _data_driven(): + return [{"WR_to_url": ["https://x"]}] + [ + {"WR_input_to_element": ["id", f"f{i}", "v"]} for i in range(50) + ] + + +def _visual(): + return [ + {"WR_to_url": ["https://x"]}, + {"WR_snapshot_take": ["#a"]}, + ] + + +def _api(): + return [{"WR_http_request": ["GET", "/api/x"]}] + + +class TestCategorize(unittest.TestCase): + + def test_smoke(self): + self.assertIn("smoke", categorize_actions(_smoke())) + + def test_regression_over_threshold(self): + long_actions = _smoke() + [{"WR_click_element": ["id", "x"]}] * 10 + tags = categorize_actions(long_actions) + self.assertIn("regression", tags) + self.assertNotIn("smoke", tags) + + def test_perf(self): + self.assertIn("perf", categorize_actions(_perf())) + + def test_a11y(self): + self.assertIn("a11y", categorize_actions(_a11y())) + + def test_security(self): + self.assertIn("security", categorize_actions(_security())) + + def test_data_driven(self): + self.assertIn("data_driven", categorize_actions(_data_driven())) + + def test_visual(self): + self.assertIn("visual", categorize_actions(_visual())) + + def test_api(self): + self.assertIn("api", categorize_actions(_api())) + + def test_multiple_tags(self): + actions = _smoke() + [{"WR_axe_audit": []}] + tags = categorize_actions(actions) + self.assertIn("smoke", tags) + self.assertIn("a11y", tags) + + def test_no_tags_empty(self): + self.assertEqual(categorize_actions([]), []) + + def test_rejects_non_list(self): + with self.assertRaises(TestCategorizerError): + categorize_actions("not list") # type: ignore[arg-type] + + def test_rejects_non_rule(self): + with self.assertRaises(TestCategorizerError): + categorize_actions([], rules=["not a rule"]) # type: ignore[list-item] + + def test_matcher_exception(self): + bad_rule = Rule( + tag="bad", + matcher=lambda a: (_ for _ in ()).throw(RuntimeError("oops")), + ) + with self.assertRaises(TestCategorizerError): + categorize_actions([], rules=[bad_rule]) + + +class TestFileAndDir(unittest.TestCase): + + def test_file(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "a.json" + path.write_text(json.dumps(_smoke()), encoding="utf-8") + result = categorize_file(path) + self.assertIn("smoke", result.tags) + self.assertEqual(result.action_count, 3) + + def test_file_missing(self): + with self.assertRaises(TestCategorizerError): + categorize_file("/no/such/file.json") + + def test_file_bad_json(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "x.json" + path.write_text("not json", encoding="utf-8") + with self.assertRaises(TestCategorizerError): + categorize_file(path) + + def test_file_non_list(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "x.json" + path.write_text(json.dumps({"x": 1}), encoding="utf-8") + with self.assertRaises(TestCategorizerError): + categorize_file(path) + + def test_dir(self): + with tempfile.TemporaryDirectory() as tmp: + (Path(tmp) / "a.json").write_text(json.dumps(_smoke()), encoding="utf-8") + (Path(tmp) / "b.json").write_text(json.dumps(_perf()), encoding="utf-8") + results = categorize_dir(tmp) + self.assertEqual(len(results), 2) + + def test_dir_missing(self): + with self.assertRaises(TestCategorizerError): + categorize_dir("/no/such/dir") + + +class TestAggregate(unittest.TestCase): + + def test_counts(self): + assignments = [ + CategoryAssignment(test_id="a", tags=["smoke", "a11y"]), + CategoryAssignment(test_id="b", tags=["perf"]), + CategoryAssignment(test_id="c", tags=[]), + ] + dist = aggregate(assignments) + self.assertEqual(dist.total_tests, 3) + self.assertEqual(dist.untagged_tests, 1) + self.assertEqual(dist.by_tag["smoke"], 1) + self.assertEqual(dist.by_tag["a11y"], 1) + + def test_rejects_non_assignment(self): + with self.assertRaises(TestCategorizerError): + aggregate(["nope"]) # type: ignore[list-item] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_cost_estimator.py b/test/unit_test/test_test_cost_estimator.py new file mode 100644 index 0000000..2a85a14 --- /dev/null +++ b/test/unit_test/test_test_cost_estimator.py @@ -0,0 +1,176 @@ +"""Unit tests for je_web_runner.utils.test_cost_estimator.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.test_cost_estimator.estimator import ( + DEFAULT_RATE_CARDS, + CostEstimate, + RateCard, + RunRow, + TestCostEstimatorError, + estimate_markdown, + estimate_runs, + load_runs, + rate_card_index, +) + + +def _runs(*tuples): + return [RunRow(test_id=tid, runner=runner, duration_seconds=sec) + for tid, runner, sec in tuples] + + +class TestRateCard(unittest.TestCase): + + def test_rejects_negative_rate(self): + with self.assertRaises(TestCostEstimatorError): + RateCard(runner="x", usd_per_minute=-1) + + def test_rejects_negative_co2(self): + with self.assertRaises(TestCostEstimatorError): + RateCard(runner="x", usd_per_minute=1, grams_co2_per_minute=-1) + + def test_rejects_negative_minimum(self): + with self.assertRaises(TestCostEstimatorError): + RateCard(runner="x", usd_per_minute=1, minimum_minutes=-1) + + +class TestRateCardIndex(unittest.TestCase): + + def test_duplicates_rejected(self): + with self.assertRaises(TestCostEstimatorError): + rate_card_index([ + RateCard(runner="a", usd_per_minute=1), + RateCard(runner="a", usd_per_minute=2), + ]) + + def test_returns_dict(self): + idx = rate_card_index([RateCard(runner="a", usd_per_minute=1)]) + self.assertIn("a", idx) + + +class TestRunRow(unittest.TestCase): + + def test_rejects_negative_duration(self): + with self.assertRaises(TestCostEstimatorError): + RunRow(test_id="t", runner="local", duration_seconds=-1) + + +class TestLoadRuns(unittest.TestCase): + + def test_loads_rows(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "l.json" + p.write_text(json.dumps({"runs": [ + {"test_id": "t1", "runner": "saucelabs", "duration_seconds": 30}, + {"path": "t2", "runner": "local", "duration_seconds": 5}, + {"test_id": "skip_me", "runner": "local"}, # no duration → skip + ]}), encoding="utf-8") + runs = load_runs(p) + self.assertEqual(len(runs), 2) + self.assertEqual(runs[1].test_id, "t2") + + def test_missing_file(self): + with self.assertRaises(TestCostEstimatorError): + load_runs("/no/such/file.json") + + def test_bad_json(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text("nope", encoding="utf-8") + with self.assertRaises(TestCostEstimatorError): + load_runs(p) + + def test_missing_runs(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text(json.dumps({"x": []}), encoding="utf-8") + with self.assertRaises(TestCostEstimatorError): + load_runs(p) + + +class TestEstimate(unittest.TestCase): + + def test_uses_defaults(self): + runs = _runs(("t1", "saucelabs", 120), ("t2", "saucelabs", 60)) + est = estimate_runs(runs) + self.assertEqual(est.total_runs, 2) + # 2 + 1 = 3 minutes × $0.18 = $0.54 + self.assertAlmostEqual(est.total_usd, 0.54, places=2) + + def test_minimum_minutes_applied(self): + runs = _runs(("t1", "saucelabs", 10)) # 10s = 0.17m, billed as 1m + est = estimate_runs(runs) + self.assertEqual(est.total_billed_minutes, 1.0) + + def test_unknown_runner_collected(self): + runs = _runs(("t1", "mystery_cloud", 60)) + est = estimate_runs(runs) + self.assertEqual(est.total_runs, 0) + self.assertIn("mystery_cloud", est.unknown_runners) + + def test_per_test_costs(self): + runs = _runs(("t1", "saucelabs", 600), ("t1", "saucelabs", 300)) + est = estimate_runs(runs) + self.assertGreater(est.by_test["t1"], 0) + + def test_by_runner_breakdown(self): + runs = _runs(("a", "saucelabs", 60), ("b", "browserstack", 60)) + est = estimate_runs(runs) + self.assertIn("saucelabs", est.by_runner) + self.assertIn("browserstack", est.by_runner) + + def test_empty_rejected(self): + with self.assertRaises(TestCostEstimatorError): + estimate_runs([]) + + def test_co2_accumulates(self): + runs = _runs(("t1", "saucelabs", 60)) + est = estimate_runs(runs) + self.assertGreater(est.total_grams_co2, 0) + + +class TestEstimateMarkdown(unittest.TestCase): + + def test_renders(self): + est = estimate_runs(_runs(("t1", "saucelabs", 600))) + md = estimate_markdown(est) + self.assertIn("Test cost estimate", md) + self.assertIn("saucelabs", md) + + def test_top_tests(self): + est = estimate_runs(_runs( + ("expensive", "saucelabs", 1200), + ("cheap", "saucelabs", 60), + )) + md = estimate_markdown(est, top_tests=1) + self.assertIn("expensive", md) + + def test_zero_top_tests(self): + est = estimate_runs(_runs(("t1", "saucelabs", 60))) + md = estimate_markdown(est, top_tests=0) + self.assertNotIn("costliest", md) + + def test_bad_top_tests(self): + est = estimate_runs(_runs(("t1", "saucelabs", 60))) + with self.assertRaises(TestCostEstimatorError): + estimate_markdown(est, top_tests=-1) + + def test_rejects_non_estimate(self): + with self.assertRaises(TestCostEstimatorError): + estimate_markdown("not estimate") # type: ignore[arg-type] + + +class TestDefaultCards(unittest.TestCase): + + def test_have_well_known_runners(self): + names = {c.runner for c in DEFAULT_RATE_CARDS} + self.assertIn("local", names) + self.assertIn("saucelabs", names) + self.assertIn("browserstack", names) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_debt_dashboard.py b/test/unit_test/test_test_debt_dashboard.py new file mode 100644 index 0000000..1cad979 --- /dev/null +++ b/test/unit_test/test_test_debt_dashboard.py @@ -0,0 +1,223 @@ +"""Unit tests for je_web_runner.utils.test_debt_dashboard.""" +import json +import os +import tempfile +import unittest +from datetime import datetime, timezone +from pathlib import Path + +from je_web_runner.utils.test_debt_dashboard.debt import ( + CodeownersIndex, + DebtItem, + DebtKind, + DebtReport, + TestDebtDashboardError, + assert_under_age_limit, + parse_codeowners, + report_markdown, + scan_action_json, + scan_directory, + scan_python_file, +) + +_NOW = datetime(2026, 5, 24, tzinfo=timezone.utc) + + +def _write(path: Path, body: str): + path.write_text(body, encoding="utf-8") + + +class TestCodeowners(unittest.TestCase): + + def test_parse(self): + idx = parse_codeowners( + "# comment\n* @team/all\n/test/checkout/ @team/checkout\n" + ) + self.assertEqual(len(idx.rules), 2) + + def test_owner_for(self): + idx = parse_codeowners( + "* @team/all\n/test/checkout/*.py @team/checkout\n" + ) + self.assertEqual(idx.owner_for("test/checkout/test_x.py"), "@team/checkout") + self.assertEqual(idx.owner_for("other/x.py"), "@team/all") + + def test_double_star(self): + idx = parse_codeowners("/test/**/auth_*.py @team/auth\n") + self.assertEqual(idx.owner_for("test/sub/dir/auth_login.py"), "@team/auth") + + +class TestScanPython(unittest.TestCase): + + def test_skip_with_reason(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "test_x.py" + _write(path, '''import pytest +@pytest.mark.skip(reason="known broken") +def test_foo(): + pass +''') + items = scan_python_file(path, now=_NOW) + self.assertEqual(len(items), 1) + self.assertEqual(items[0].kind, DebtKind.SKIP) + self.assertEqual(items[0].reason, "known broken") + self.assertEqual(items[0].test_name, "test_foo") + + def test_xfail(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "test_x.py" + _write(path, '''import pytest +@pytest.mark.xfail(reason="server changed") +def test_bar(): + assert False +''') + items = scan_python_file(path, now=_NOW) + self.assertEqual(items[0].kind, DebtKind.XFAIL) + self.assertEqual(items[0].reason, "server changed") + + def test_todo(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "test_x.py" + _write(path, '''def test_baz(): + # TODO fix the assertion below + assert True +''') + items = scan_python_file(path, now=_NOW) + self.assertEqual(items[0].kind, DebtKind.TODO) + self.assertIn("fix the assertion", items[0].reason) + + def test_owner_assigned(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "test_x.py" + _write(path, '''import pytest +@pytest.mark.skip(reason="x") +def test_y(): pass +''') + owners = parse_codeowners("* @team/qa\n") + items = scan_python_file(path, now=_NOW, owners=owners) + self.assertEqual(items[0].owner, "@team/qa") + + def test_missing(self): + with self.assertRaises(TestDebtDashboardError): + scan_python_file("/no/such/file.py") + + +class TestScanActionJson(unittest.TestCase): + + def test_skip_marker(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "actions.json" + path.write_text(json.dumps([ + {"WR_to_url": ["https://x"]}, + {"_skip": True, "_reason": "until login fixed"}, + ]), encoding="utf-8") + items = scan_action_json(path, now=_NOW) + self.assertEqual(len(items), 1) + self.assertEqual(items[0].reason, "until login fixed") + + def test_non_list_ignored(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "x.json" + path.write_text(json.dumps({"x": 1}), encoding="utf-8") + self.assertEqual(scan_action_json(path, now=_NOW), []) + + def test_bad_json_ignored(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "x.json" + path.write_text("not json", encoding="utf-8") + self.assertEqual(scan_action_json(path, now=_NOW), []) + + def test_missing(self): + with self.assertRaises(TestDebtDashboardError): + scan_action_json("/no/such/file.json") + + +class TestScanDirectory(unittest.TestCase): + + def test_walks_tree(self): + with tempfile.TemporaryDirectory() as tmp: + (Path(tmp) / "test_a.py").write_text('''import pytest +@pytest.mark.skip(reason="x") +def test_y(): pass +''', encoding="utf-8") + sub = Path(tmp) / "sub" + sub.mkdir() + (sub / "test_b.py").write_text('''def test_z(): + # FIXME later + pass +''', encoding="utf-8") + report = scan_directory(tmp, now=_NOW) + self.assertGreaterEqual(len(report.items), 2) + kinds = report.by_kind() + self.assertIn("skip", kinds) + self.assertIn("todo", kinds) + + def test_missing_dir(self): + with self.assertRaises(TestDebtDashboardError): + scan_directory("/no/such/dir") + + +class TestAggregates(unittest.TestCase): + + def test_by_owner(self): + report = DebtReport(items=[ + DebtItem(kind=DebtKind.SKIP, path="a", line=1, + test_name=None, reason="x", age_days=1, + owner="@team/a"), + DebtItem(kind=DebtKind.SKIP, path="b", line=1, + test_name=None, reason="x", age_days=1), + ]) + owners = report.by_owner() + self.assertEqual(owners["@team/a"], 1) + self.assertEqual(owners["(unowned)"], 1) + + def test_older_than(self): + report = DebtReport(items=[ + DebtItem(kind=DebtKind.SKIP, path="a", line=1, + test_name=None, reason="x", age_days=10), + DebtItem(kind=DebtKind.SKIP, path="b", line=1, + test_name=None, reason="x", age_days=100), + ]) + self.assertEqual(len(report.older_than(50)), 1) + + +class TestAssertions(unittest.TestCase): + + def test_assert_under_age_pass(self): + report = DebtReport(items=[DebtItem( + kind=DebtKind.SKIP, path="x", line=1, test_name=None, + reason="", age_days=1, + )]) + assert_under_age_limit(report, max_days=10) + + def test_assert_under_age_fail(self): + report = DebtReport(items=[DebtItem( + kind=DebtKind.SKIP, path="x", line=1, test_name=None, + reason="", age_days=100, + )]) + with self.assertRaises(TestDebtDashboardError): + assert_under_age_limit(report, max_days=10) + + def test_assert_under_age_bad(self): + with self.assertRaises(TestDebtDashboardError): + assert_under_age_limit(DebtReport(), max_days=-1) + + +class TestMarkdown(unittest.TestCase): + + def test_renders(self): + report = DebtReport(items=[DebtItem( + kind=DebtKind.SKIP, path="x", line=1, test_name=None, + reason="", age_days=1, + )]) + md = report_markdown(report) + self.assertIn("Test debt", md) + self.assertIn("skip", md) + + def test_rejects_non_report(self): + with self.assertRaises(TestDebtDashboardError): + report_markdown("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_dedup_ai.py b/test/unit_test/test_test_dedup_ai.py new file mode 100644 index 0000000..eabf505 --- /dev/null +++ b/test/unit_test/test_test_dedup_ai.py @@ -0,0 +1,202 @@ +"""Unit tests for je_web_runner.utils.test_dedup_ai.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.test_dedup_ai.dedup import ( + ActionFile, + DuplicateCluster, + TestDedupError, + clusters_markdown, + load_dir, + semantic_clusters, + stable_fingerprint, + structural_clusters, +) + + +def _login_actions(url="https://x", user="alice"): + return [ + {"WR_to_url": [url]}, + {"WR_input_to_element": ["id", "username", user]}, + {"WR_input_to_element": ["id", "password", "pw"]}, + {"WR_click_element": ["id", "submit"]}, + ] + + +def _checkout_actions(): + return [ + {"WR_to_url": ["https://shop/"]}, + {"WR_click_element": ["id", "cart"]}, + {"WR_click_element": ["id", "checkout"]}, + ] + + +def _file(path, actions): + f = ActionFile(path=path, actions=actions) + f.fingerprint = stable_fingerprint(actions) # use the same internals + return f + + +class TestStructural(unittest.TestCase): + + def test_finds_exact_duplicates(self): + files = [ + ActionFile.load(self._write_file("a.json", _login_actions("https://x", "a"))), + ActionFile.load(self._write_file("b.json", _login_actions("https://y", "b"))), + ActionFile.load(self._write_file("c.json", _checkout_actions())), + ] + clusters = structural_clusters(files) + self.assertEqual(len(clusters), 1) + self.assertEqual(len(clusters[0].members), 2) + + def _write_file(self, name, actions): + path = Path(self._tmpdir) / name + path.write_text(json.dumps(actions), encoding="utf-8") + return path + + def setUp(self): + self._td = tempfile.TemporaryDirectory() + self._tmpdir = self._td.name + + def tearDown(self): + self._td.cleanup() + + def test_singletons_dropped(self): + files = [ + ActionFile.load(self._write_file("a.json", _login_actions())), + ActionFile.load(self._write_file("b.json", _checkout_actions())), + ] + self.assertEqual(structural_clusters(files), []) + + def test_empty_input_rejected(self): + with self.assertRaises(TestDedupError): + structural_clusters([]) + + +class TestLoading(unittest.TestCase): + + def test_load_missing(self): + with self.assertRaises(TestDedupError): + ActionFile.load("/no/such/file.json") + + def test_load_bad_json(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text("not json", encoding="utf-8") + with self.assertRaises(TestDedupError): + ActionFile.load(p) + + def test_load_non_list(self): + with tempfile.TemporaryDirectory() as tmp: + p = Path(tmp) / "x.json" + p.write_text("{\"x\":1}", encoding="utf-8") + with self.assertRaises(TestDedupError): + ActionFile.load(p) + + def test_load_dir(self): + with tempfile.TemporaryDirectory() as tmp: + (Path(tmp) / "a.json").write_text(json.dumps(_login_actions()), encoding="utf-8") + (Path(tmp) / "b.json").write_text(json.dumps(_checkout_actions()), encoding="utf-8") + files = load_dir(tmp) + self.assertEqual(len(files), 2) + + def test_load_dir_missing(self): + with self.assertRaises(TestDedupError): + load_dir("/no/such/dir") + + +class TestSemantic(unittest.TestCase): + + def test_clusters_by_threshold(self): + files = [ + _file("a.json", _login_actions()), + _file("b.json", _login_actions("https://y")), + _file("c.json", _checkout_actions()), + ] + # Stub embedder: identical for login files, different for checkout + def embed(text): + return [1.0, 0.0] if "WR_input" in text else [0.0, 1.0] + + clusters = semantic_clusters(files, embed, similarity_threshold=0.95) + self.assertEqual(len(clusters), 1) + self.assertEqual(len(clusters[0].members), 2) + + def test_no_clusters_when_threshold_too_high(self): + files = [ + _file("a.json", _login_actions()), + _file("b.json", _login_actions()), + ] + # Slight noise → cosine = 0.99 + vectors = [[1.0, 0.0], [0.99, 0.1414]] + ptr = {"i": 0} + + def embed(_): + v = vectors[ptr["i"]] + ptr["i"] += 1 + return v + clusters = semantic_clusters(files, embed, similarity_threshold=0.999) + self.assertEqual(clusters, []) + + def test_bad_threshold(self): + with self.assertRaises(TestDedupError): + semantic_clusters([_file("a", _login_actions())], lambda _: [1.0], + similarity_threshold=0.0) + with self.assertRaises(TestDedupError): + semantic_clusters([_file("a", _login_actions())], lambda _: [1.0], + similarity_threshold=1.5) + + def test_bad_vector(self): + with self.assertRaises(TestDedupError): + semantic_clusters([_file("a", _login_actions())], lambda _: "not vector") + + def test_embedder_exception(self): + def bad(_): + raise RuntimeError("rate limit") + with self.assertRaises(TestDedupError): + semantic_clusters([_file("a", _login_actions())], bad) + + def test_empty_rejected(self): + with self.assertRaises(TestDedupError): + semantic_clusters([], lambda _: [1.0]) + + +class TestFingerprint(unittest.TestCase): + + def test_stable_across_calls(self): + a = stable_fingerprint(_login_actions()) + b = stable_fingerprint(_login_actions("https://different", "different")) + # Same structure, different data → same fingerprint + self.assertEqual(a, b) + + def test_different_for_different_structure(self): + self.assertNotEqual( + stable_fingerprint(_login_actions()), + stable_fingerprint(_checkout_actions()), + ) + + def test_rejects_bad_input(self): + with self.assertRaises(TestDedupError): + stable_fingerprint("not a list") # type: ignore[arg-type] + with self.assertRaises(TestDedupError): + stable_fingerprint([{"a": "b", "extra": 1}]) + + +class TestMarkdown(unittest.TestCase): + + def test_empty(self): + self.assertIn("No duplicate", clusters_markdown([])) + + def test_with_clusters(self): + cluster = DuplicateCluster( + mode="structural", members=["a.json", "b.json"], + representative="a.json", + ) + md = clusters_markdown([cluster]) + self.assertIn("a.json", md) + self.assertIn("structural", md) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_owners_map.py b/test/unit_test/test_test_owners_map.py new file mode 100644 index 0000000..d7df35e --- /dev/null +++ b/test/unit_test/test_test_owners_map.py @@ -0,0 +1,238 @@ +"""Unit tests for je_web_runner.utils.test_owners_map.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.test_owners_map.owners import ( + CodeownersRule, + OwnerAudit, + OwnersFile, + OwnersMap, + TestOwnersMapError, + assert_no_unowned, + audit_markdown, + audit_unowned, + load_codeowners_file, + load_overrides, + parse_codeowners, +) + + +_CODEOWNERS_TEXT = """\ +# Global default +* @team/all +# Test directories +/test/checkout/ @team/checkout +/test/profile/*.json @team/profile +/test/auth/**/*.py @team/security +""" + + +class TestParseCodeowners(unittest.TestCase): + + def test_parses_lines(self): + owners = parse_codeowners(_CODEOWNERS_TEXT) + self.assertEqual(len(owners.rules), 4) + + def test_skips_comments_and_blank(self): + text = "# only comments\n\n \n" + self.assertEqual(parse_codeowners(text).rules, []) + + def test_inline_comments_stripped(self): + owners = parse_codeowners("/foo @team/foo # legacy\n") + self.assertEqual(owners.rules[0].owners, ["@team/foo"]) + + def test_skips_short_lines(self): + # pattern without owner is ignored + self.assertEqual(parse_codeowners("/lonely\n").rules, []) + + def test_rejects_non_string(self): + with self.assertRaises(TestOwnersMapError): + parse_codeowners(123) # type: ignore[arg-type] + + +class TestLookup(unittest.TestCase): + + def setUp(self): + self.owners = parse_codeowners(_CODEOWNERS_TEXT) + + def test_default(self): + self.assertEqual(self.owners.lookup("other/foo.py"), ["@team/all"]) + + def test_dir_match(self): + self.assertEqual( + self.owners.lookup("test/checkout/sub/login.py"), + ["@team/checkout"], + ) + + def test_glob_with_extension(self): + self.assertEqual( + self.owners.lookup("test/profile/edit.json"), + ["@team/profile"], + ) + + def test_double_star(self): + self.assertEqual( + self.owners.lookup("test/auth/sub/login.py"), + ["@team/security"], + ) + + def test_last_match_wins(self): + text = "* @a\n/test/x.py @b\n" + owners = parse_codeowners(text) + self.assertEqual(owners.lookup("test/x.py"), ["@b"]) + + def test_rejects_empty_path(self): + with self.assertRaises(TestOwnersMapError): + self.owners.lookup("") + + +class TestLoadCodeownersFile(unittest.TestCase): + + def test_load(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "CODEOWNERS" + path.write_text(_CODEOWNERS_TEXT, encoding="utf-8") + owners = load_codeowners_file(path) + self.assertEqual(len(owners.rules), 4) + + def test_missing(self): + with self.assertRaises(TestOwnersMapError): + load_codeowners_file("/no/such/file") + + +class TestOverrides(unittest.TestCase): + + def test_load(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "o.json" + path.write_text(json.dumps({ + "test/checkout/login.py": ["@team/auth"], + }), encoding="utf-8") + overrides = load_overrides(path) + self.assertEqual(overrides["test/checkout/login.py"], ["@team/auth"]) + + def test_missing(self): + with self.assertRaises(TestOwnersMapError): + load_overrides("/no/such/file.json") + + def test_bad_json(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "o.json" + path.write_text("not json", encoding="utf-8") + with self.assertRaises(TestOwnersMapError): + load_overrides(path) + + def test_non_object(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "o.json" + path.write_text(json.dumps([1, 2]), encoding="utf-8") + with self.assertRaises(TestOwnersMapError): + load_overrides(path) + + def test_bad_value(self): + with tempfile.TemporaryDirectory() as tmp: + path = Path(tmp) / "o.json" + path.write_text(json.dumps({"x": "not a list"}), encoding="utf-8") + with self.assertRaises(TestOwnersMapError): + load_overrides(path) + + +class TestOwnersMap(unittest.TestCase): + + def _map(self, overrides=None): + return OwnersMap( + codeowners=parse_codeowners(_CODEOWNERS_TEXT), + overrides=overrides or {}, + ) + + def test_codeowners_path(self): + self.assertEqual( + self._map().owners_for("test/checkout/login.py"), + ["@team/checkout"], + ) + + def test_override_wins(self): + m = self._map({"test/checkout/login.py": ["@team/auth-override"]}) + self.assertEqual( + m.owners_for("test/checkout/login.py"), + ["@team/auth-override"], + ) + + def test_no_match_returns_default(self): + self.assertEqual( + self._map().owners_for("anywhere/else.py"), + ["@team/all"], + ) + + def test_empty_test_id_rejected(self): + with self.assertRaises(TestOwnersMapError): + self._map().owners_for("") + + +class TestAudit(unittest.TestCase): + + def test_counts(self): + m = OwnersMap(codeowners=parse_codeowners(_CODEOWNERS_TEXT)) + audit = audit_unowned( + ["test/checkout/login.py", "test/auth/sub/login.py", + "other/foo.py"], + m, + ) + self.assertEqual(audit.total_tests, 3) + self.assertEqual(audit.unowned, []) + self.assertEqual(audit.by_owner["@team/checkout"], 1) + self.assertEqual(audit.by_owner["@team/security"], 1) + + def test_unowned_detected(self): + owners = OwnersFile(rules=[ + CodeownersRule(pattern="/test/owned/", owners=["@team/x"]), + ]) + audit = audit_unowned( + ["test/owned/a.py", "test/orphan/b.py"], + OwnersMap(codeowners=owners), + ) + self.assertEqual(audit.unowned, ["test/orphan/b.py"]) + + def test_rejects_non_map(self): + with self.assertRaises(TestOwnersMapError): + audit_unowned([], "nope") # type: ignore[arg-type] + + +class TestAssertions(unittest.TestCase): + + def test_pass(self): + assert_no_unowned(OwnerAudit(total_tests=1)) + + def test_fail(self): + audit = OwnerAudit(total_tests=2, unowned=["a", "b"]) + with self.assertRaises(TestOwnersMapError): + assert_no_unowned(audit) + + def test_rejects_non_audit(self): + with self.assertRaises(TestOwnersMapError): + assert_no_unowned("nope") # type: ignore[arg-type] + + +class TestMarkdown(unittest.TestCase): + + def test_renders(self): + audit = OwnerAudit( + total_tests=3, unowned=["a"], by_owner={"@team/x": 2, "@team/y": 1}, + ) + md = audit_markdown(audit) + self.assertIn("unowned: **1**", md) + self.assertIn("@team/x", md) + + def test_bad_top_owners(self): + with self.assertRaises(TestOwnersMapError): + audit_markdown(OwnerAudit(total_tests=0), top_owners=-1) + + def test_rejects_non_audit(self): + with self.assertRaises(TestOwnersMapError): + audit_markdown("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_test_scheduler.py b/test/unit_test/test_test_scheduler.py new file mode 100644 index 0000000..b69e43d --- /dev/null +++ b/test/unit_test/test_test_scheduler.py @@ -0,0 +1,217 @@ +"""Unit tests for je_web_runner.utils.test_scheduler.""" +import json +import tempfile +import unittest +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from je_web_runner.utils.test_scheduler.scheduler import ( + Schedule, + TestCandidate, + TestSchedulerError, + build_candidates_from_ledger, + render_schedule_markdown, + schedule_tests, + value_density, + value_of, +) + + +def _iso(dt): + return dt.replace(tzinfo=timezone.utc).isoformat(timespec="seconds") + + +class TestCandidateValidation(unittest.TestCase): + + def test_empty_id_raises(self): + with self.assertRaises(TestSchedulerError): + TestCandidate(test_id="", duration_seconds=1) + + def test_zero_duration_raises(self): + with self.assertRaises(TestSchedulerError): + TestCandidate(test_id="t", duration_seconds=0) + + def test_fail_rate_out_of_range(self): + with self.assertRaises(TestSchedulerError): + TestCandidate(test_id="t", duration_seconds=1, fail_rate=1.5) + + def test_impact_score_out_of_range(self): + with self.assertRaises(TestSchedulerError): + TestCandidate(test_id="t", duration_seconds=1, impact_score=-0.1) + + +class TestValueModel(unittest.TestCase): + + def test_value_components(self): + c = TestCandidate( + test_id="t", duration_seconds=10, + fail_rate=0.5, impact_score=1.0, + last_run_age_hours=48, manual_priority=1.0, + ) + # value = 0.5*1 + 1.0*1.5 + (48/24)*1 + 1*2 = 0.5 + 1.5 + 2 + 2 = 6.0 + self.assertAlmostEqual(value_of(c), 6.0) + self.assertAlmostEqual(value_density(c), 0.6) + + +class TestScheduleTests(unittest.TestCase): + + def test_picks_highest_density_first(self): + candidates = [ + TestCandidate(test_id="slow_low", duration_seconds=100, fail_rate=0.1), + TestCandidate(test_id="fast_high", duration_seconds=5, fail_rate=0.9), + TestCandidate(test_id="medium", duration_seconds=20, fail_rate=0.5), + ] + sched = schedule_tests(candidates, time_budget_seconds=30) + self.assertEqual(sched.selected[0], "fast_high") + self.assertIn("medium", sched.selected) + self.assertNotIn("slow_low", sched.selected) + + def test_respects_time_budget(self): + candidates = [ + TestCandidate(test_id="a", duration_seconds=10, fail_rate=0.5), + TestCandidate(test_id="b", duration_seconds=10, fail_rate=0.5), + TestCandidate(test_id="c", duration_seconds=10, fail_rate=0.5), + ] + sched = schedule_tests(candidates, time_budget_seconds=25) + # 2 fit, 1 doesn't + self.assertEqual(len(sched.selected), 2) + self.assertGreaterEqual(sched.leftover_seconds, 0) + + def test_cloud_quota_limits(self): + candidates = [ + TestCandidate(test_id="cloud_a", duration_seconds=10, + fail_rate=0.9, needs_cloud_session=True), + TestCandidate(test_id="cloud_b", duration_seconds=10, + fail_rate=0.9, needs_cloud_session=True), + TestCandidate(test_id="local", duration_seconds=10, + fail_rate=0.8, needs_cloud_session=False), + ] + sched = schedule_tests( + candidates, time_budget_seconds=100, cloud_slot_budget=1, + ) + cloud_in_selected = [t for t in sched.selected if t.startswith("cloud_")] + self.assertEqual(len(cloud_in_selected), 1) + self.assertIn("local", sched.selected) + self.assertEqual(sched.leftover_cloud_slots, 0) + + def test_pinned_tests_always_included(self): + candidates = [ + TestCandidate(test_id="pinned", duration_seconds=20, fail_rate=0.0), + TestCandidate(test_id="hot", duration_seconds=5, fail_rate=0.9), + ] + sched = schedule_tests( + candidates, time_budget_seconds=30, + pinned_test_ids=["pinned"], + ) + self.assertEqual(sched.selected[0], "pinned") + self.assertIn("hot", sched.selected) + + def test_pinned_overrun_raises(self): + candidates = [ + TestCandidate(test_id="huge", duration_seconds=1000), + ] + with self.assertRaises(TestSchedulerError): + schedule_tests( + candidates, time_budget_seconds=10, + pinned_test_ids=["huge"], + ) + + def test_pinned_unknown_id_raises(self): + candidates = [TestCandidate(test_id="a", duration_seconds=1)] + with self.assertRaises(TestSchedulerError): + schedule_tests( + candidates, time_budget_seconds=100, + pinned_test_ids=["ghost"], + ) + + def test_invalid_budget_raises(self): + with self.assertRaises(TestSchedulerError): + schedule_tests([], time_budget_seconds=0) + with self.assertRaises(TestSchedulerError): + schedule_tests([], time_budget_seconds=100, cloud_slot_budget=-1) + + def test_empty_candidates_returns_empty_schedule(self): + sched = schedule_tests([], time_budget_seconds=100) + self.assertEqual(sched.selected, []) + self.assertEqual(sched.skipped, []) + + def test_cloud_unlimited(self): + candidates = [ + TestCandidate(test_id="x", duration_seconds=1, + fail_rate=0.5, needs_cloud_session=True), + ] + sched = schedule_tests(candidates, time_budget_seconds=10) + # cloud_slot_budget is None → no constraint + self.assertIn("x", sched.selected) + self.assertEqual(sched.leftover_cloud_slots, -1) # sentinel + + +class TestBuildCandidatesFromLedger(unittest.TestCase): + + def test_builds_from_runs(self): + now = datetime.now(timezone.utc) + runs = [ + {"path": "a.json", "passed": True, + "duration_seconds": 12, "time": _iso(now)}, + {"path": "a.json", "passed": False, + "duration_seconds": 14, "time": _iso(now)}, + {"path": "b.json", "passed": True, + "duration_seconds": 30, "time": _iso(now - timedelta(hours=12))}, + ] + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "l.json" + path.write_text(json.dumps({"runs": runs}), encoding="utf-8") + cands = build_candidates_from_ledger(path) + by_id = {c.test_id: c for c in cands} + self.assertEqual(by_id["a.json"].fail_rate, 0.5) + self.assertAlmostEqual(by_id["a.json"].duration_seconds, 13.0) + self.assertGreater(by_id["b.json"].last_run_age_hours, 10) + + def test_missing_ledger_returns_empty(self): + cands = build_candidates_from_ledger("/no/such.json") + self.assertEqual(cands, []) + + def test_malformed_ledger_raises(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "l.json" + path.write_text("{}", encoding="utf-8") + with self.assertRaises(TestSchedulerError): + build_candidates_from_ledger(path) + + def test_default_duration_when_missing(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "l.json" + path.write_text(json.dumps({"runs": [ + {"path": "x.json", "passed": True}, + ]}), encoding="utf-8") + cands = build_candidates_from_ledger(path, default_duration_seconds=45) + self.assertEqual(cands[0].duration_seconds, 45) + + def test_cloud_set_flags_test(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "l.json" + path.write_text(json.dumps({"runs": [ + {"path": "x.json", "passed": True, "duration_seconds": 10}, + ]}), encoding="utf-8") + cands = build_candidates_from_ledger(path, cloud_tests=["x.json"]) + self.assertTrue(cands[0].needs_cloud_session) + + +class TestRendering(unittest.TestCase): + + def test_markdown_lists_selected_and_skipped(self): + sched = Schedule( + selected=["a", "b"], skipped=["c"], + total_seconds=20, total_cloud_slots=1, + leftover_seconds=10, leftover_cloud_slots=0, + value_recovered=3.5, + ) + md = render_schedule_markdown(sched) + self.assertIn("Test schedule", md) + self.assertIn("Selected:** 2", md) + self.assertIn("`a`", md) + self.assertIn("`c`", md) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_third_party_budget.py b/test/unit_test/test_third_party_budget.py new file mode 100644 index 0000000..cdecf8b --- /dev/null +++ b/test/unit_test/test_third_party_budget.py @@ -0,0 +1,215 @@ +"""Unit tests for je_web_runner.utils.third_party_budget.""" +import json +import unittest + +from je_web_runner.utils.third_party_budget.budget import ( + ThirdPartyBudget, + ThirdPartyBudgetError, + ThirdPartyReport, + ThirdPartyRequest, + assert_within_budget, + classify_har, + evaluate, +) + + +def _entry(url, transfer=0, resource_type="script", timings=None): + return { + "_resourceType": resource_type, + "request": {"url": url}, + "response": {"_transferSize": transfer, "content": {"size": transfer}}, + "timings": timings or {"wait": 50, "receive": 10}, + } + + +def _har(*entries): + return {"log": {"entries": list(entries)}} + + +class TestClassify(unittest.TestCase): + + def test_first_party_skipped(self): + reqs = classify_har( + _har(_entry("https://app.com/static/a.js")), + first_party_hostname="app.com", + ) + self.assertEqual(reqs, []) + + def test_first_party_subdomain_skipped(self): + reqs = classify_har( + _har(_entry("https://cdn.app.com/a.js")), + first_party_hostname="app.com", + ) + self.assertEqual(reqs, []) + + def test_known_vendor_tagged(self): + reqs = classify_har( + _har(_entry("https://www.google-analytics.com/ga.js", transfer=10_000)), + first_party_hostname="app.com", + ) + self.assertEqual(reqs[0].vendor, "google_analytics") + self.assertEqual(reqs[0].bytes_transferred, 10_000) + + def test_unknown_third_party_tagged(self): + reqs = classify_har( + _har(_entry("https://random-ad-network.com/track.js")), + first_party_hostname="app.com", + ) + self.assertEqual(reqs[0].vendor, "unknown_third_party") + + def test_extra_vendor(self): + reqs = classify_har( + _har(_entry("https://myco-analytics.com/x.js")), + first_party_hostname="app.com", + extra_vendors={"myco": ("myco-analytics.com",)}, + ) + self.assertEqual(reqs[0].vendor, "myco") + + def test_subdomain_vendor_matches(self): + reqs = classify_har( + _har(_entry("https://collect.www.google-analytics.com/g/x")), + first_party_hostname="app.com", + ) + self.assertEqual(reqs[0].vendor, "google_analytics") + + def test_duration_summed_from_timings(self): + reqs = classify_har( + _har(_entry("https://cdn.segment.com/a.js", + timings={"wait": 100, "receive": 20, "blocked": 5})), + first_party_hostname="app.com", + ) + self.assertEqual(reqs[0].duration_ms, 125) + + def test_blocking_flag(self): + reqs = classify_har( + _har(_entry("https://cdn.segment.com/a.js", resource_type="script"), + _entry("https://cdn.segment.com/img.png", resource_type="image")), + first_party_hostname="app.com", + ) + self.assertTrue(reqs[0].blocking) + self.assertFalse(reqs[1].blocking) + + def test_skips_no_url(self): + reqs = classify_har( + _har({"request": {}, "response": {}}), + first_party_hostname="app.com", + ) + self.assertEqual(reqs, []) + + def test_str_har(self): + reqs = classify_har( + json.dumps(_har(_entry("https://cdn.segment.com/a.js"))), + first_party_hostname="app.com", + ) + self.assertEqual(len(reqs), 1) + + def test_bad_har(self): + with self.assertRaises(ThirdPartyBudgetError): + classify_har("nope", first_party_hostname="app.com") + + def test_bad_first_party(self): + with self.assertRaises(ThirdPartyBudgetError): + classify_har(_har(), first_party_hostname="") + + +class TestBudget(unittest.TestCase): + + def test_negative_rejected(self): + with self.assertRaises(ThirdPartyBudgetError): + ThirdPartyBudget(max_requests=-1) + with self.assertRaises(ThirdPartyBudgetError): + ThirdPartyBudget(max_bytes=-1) + + +class TestEvaluate(unittest.TestCase): + + def _reqs(self, *configs): + out = [] + for url, size, blocking, vendor in configs: + out.append(ThirdPartyRequest( + url=url, vendor=vendor, hostname="x", + bytes_transferred=size, duration_ms=size / 10, + blocking=blocking, + )) + return out + + def test_passes(self): + report = evaluate( + self._reqs( + ("https://x/a", 1000, True, "google_analytics"), + ("https://y/b", 2000, False, "stripe"), + ), + ThirdPartyBudget(max_requests=10, max_bytes=10_000), + ) + self.assertTrue(report.passed()) + self.assertEqual(report.total_bytes, 3000) + + def test_breach_requests(self): + report = evaluate( + self._reqs(*[("https://x/a", 1, True, "ga")] * 5), + ThirdPartyBudget(max_requests=2), + ) + self.assertFalse(report.passed()) + + def test_breach_bytes(self): + report = evaluate( + self._reqs(("https://x/a", 10_000, True, "ga")), + ThirdPartyBudget(max_bytes=1000), + ) + self.assertFalse(report.passed()) + + def test_breach_blocking_ms(self): + report = evaluate( + self._reqs(("https://x/a", 10_000, True, "ga")), + ThirdPartyBudget(max_blocking_ms=100), + ) + # duration_ms = 10_000 / 10 = 1000 > 100 + self.assertFalse(report.passed()) + + def test_breach_vendors(self): + report = evaluate( + self._reqs( + ("https://a", 1, True, "ga"), + ("https://b", 1, True, "stripe"), + ("https://c", 1, True, "hotjar"), + ), + ThirdPartyBudget(max_vendors=2), + ) + self.assertFalse(report.passed()) + + def test_by_vendor_aggregation(self): + report = evaluate( + self._reqs( + ("https://a", 100, True, "ga"), + ("https://b", 200, False, "ga"), + ), + ThirdPartyBudget(), + ) + self.assertEqual(report.by_vendor["ga"]["requests"], 2) + self.assertEqual(report.by_vendor["ga"]["bytes"], 300) + + def test_rejects_non_request(self): + with self.assertRaises(ThirdPartyBudgetError): + evaluate(["not a request"], ThirdPartyBudget()) # type: ignore[list-item] + + def test_rejects_non_budget(self): + with self.assertRaises(ThirdPartyBudgetError): + evaluate([], "not a budget") # type: ignore[arg-type] + + +class TestAssert(unittest.TestCase): + + def test_pass(self): + assert_within_budget(ThirdPartyReport()) + + def test_fail(self): + with self.assertRaises(ThirdPartyBudgetError): + assert_within_budget(ThirdPartyReport(breaches=["x"])) + + def test_rejects_non_report(self): + with self.assertRaises(ThirdPartyBudgetError): + assert_within_budget("nope") # type: ignore[arg-type] + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_time_freezer.py b/test/unit_test/test_time_freezer.py new file mode 100644 index 0000000..87a6692 --- /dev/null +++ b/test/unit_test/test_time_freezer.py @@ -0,0 +1,123 @@ +"""Unit tests for je_web_runner.utils.time_freezer.""" +import unittest +from datetime import datetime, timezone + +from je_web_runner.utils.time_freezer.freezer import ( + FreezeConfig, + TimeFreezerError, + attach_to_cdp, + build_freezer_script, + freeze_at, + slow_motion, + to_epoch_ms, +) + + +class TestToEpochMs(unittest.TestCase): + + def test_iso_string(self): + self.assertEqual( + to_epoch_ms("2026-01-01T00:00:00Z"), + int(datetime(2026, 1, 1, tzinfo=timezone.utc).timestamp() * 1000), + ) + + def test_iso_with_offset(self): + ms = to_epoch_ms("2026-01-01T08:00:00+08:00") + self.assertEqual(ms, int(datetime(2026, 1, 1, 0, 0, tzinfo=timezone.utc).timestamp() * 1000)) + + def test_naive_datetime_treated_as_utc(self): + ms = to_epoch_ms(datetime(2026, 1, 1)) + self.assertEqual(ms, int(datetime(2026, 1, 1, tzinfo=timezone.utc).timestamp() * 1000)) + + def test_int_seconds(self): + # Anything below 1e12 is treated as seconds, then × 1000 + self.assertEqual(to_epoch_ms(1735689600), 1735689600000) + + def test_int_milliseconds(self): + self.assertEqual(to_epoch_ms(1735689600000), 1735689600000) + + def test_bool_rejected(self): + with self.assertRaises(TimeFreezerError): + to_epoch_ms(True) + + def test_bad_string(self): + with self.assertRaises(TimeFreezerError): + to_epoch_ms("not a time") + + def test_unsupported_type(self): + with self.assertRaises(TimeFreezerError): + to_epoch_ms([]) # type: ignore[arg-type] + + +class TestFreezeConfig(unittest.TestCase): + + def test_default(self): + cfg = FreezeConfig(epoch_ms=1) + self.assertEqual(cfg.advance_ms_per_real_second, 0.0) + + def test_negative_epoch_rejected(self): + with self.assertRaises(TimeFreezerError): + FreezeConfig(epoch_ms=-1) + + def test_negative_slope_rejected(self): + with self.assertRaises(TimeFreezerError): + FreezeConfig(epoch_ms=0, advance_ms_per_real_second=-1.0) + + +class TestBuildScript(unittest.TestCase): + + def test_embeds_epoch(self): + cfg = FreezeConfig(epoch_ms=123_456_789) + script = build_freezer_script(cfg) + self.assertIn("123456789", script) + self.assertIn("FakeDate", script) + + def test_disable_date_patch(self): + cfg = FreezeConfig(epoch_ms=1, patch_date_constructor=False) + script = build_freezer_script(cfg) + self.assertIn("Date.now = virtualNow", script) + + def test_disable_performance_patch(self): + cfg = FreezeConfig(epoch_ms=1, patch_performance_now=False) + script = build_freezer_script(cfg) + self.assertIn("__PATCH_PERF__ = false", script) + + def test_rejects_non_config(self): + with self.assertRaises(TimeFreezerError): + build_freezer_script("string") # type: ignore[arg-type] + + +class TestAttach(unittest.TestCase): + + def test_calls_attach_with_script(self): + seen: list = [] + result = attach_to_cdp(seen.append, FreezeConfig(epoch_ms=42)) + self.assertIn("42", seen[0]) + self.assertIsNone(result) + + def test_wraps_attach_failure(self): + def boom(_script): + raise RuntimeError("no cdp") + with self.assertRaises(TimeFreezerError): + attach_to_cdp(boom, FreezeConfig(epoch_ms=1)) + + +class TestConvenience(unittest.TestCase): + + def test_freeze_at_from_iso(self): + cfg = freeze_at("2026-05-24T12:00:00Z") + self.assertEqual(cfg.advance_ms_per_real_second, 0.0) + + def test_slow_motion(self): + cfg = slow_motion("2026-01-01T00:00:00Z", real_seconds_per_virtual_second=10) + self.assertEqual(cfg.advance_ms_per_real_second, 100.0) # 1000ms/10s + + def test_slow_motion_validates(self): + with self.assertRaises(TimeFreezerError): + slow_motion("2026-01-01T00:00:00Z", real_seconds_per_virtual_second=0) + with self.assertRaises(TimeFreezerError): + slow_motion("2026-01-01T00:00:00Z", real_seconds_per_virtual_second=-1) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_token_leak_detector.py b/test/unit_test/test_token_leak_detector.py new file mode 100644 index 0000000..f1343c1 --- /dev/null +++ b/test/unit_test/test_token_leak_detector.py @@ -0,0 +1,200 @@ +"""Unit tests for je_web_runner.utils.token_leak_detector.""" +import base64 +import json +import unittest + +from je_web_runner.utils.token_leak_detector.detector import ( + DEFAULT_PATTERNS, + TokenFinding, + TokenLeakError, + TokenPattern, + assert_no_leaks, + filter_by_severity, + scan_har, + scan_log_lines, + scan_text, +) + + +def _real_jwt(): + header = base64.urlsafe_b64encode(b'{"alg":"HS256","typ":"JWT"}').rstrip(b"=").decode() + body = base64.urlsafe_b64encode(b'{"sub":"x"}').rstrip(b"=").decode() + sig = base64.urlsafe_b64encode(b"signature_signature").rstrip(b"=").decode() + return f"{header}.{body}.{sig}" + + +class TestScanText(unittest.TestCase): + + def test_finds_jwt(self): + text = f"Authorization: Bearer {_real_jwt()}" + findings = scan_text(text) + names = [f.pattern for f in findings] + self.assertIn("jwt", names) + + def test_skips_fake_jwt(self): + text = "eyJfake.eyJfake.eyJfake_padding_padding_padding" + findings = scan_text(text) + self.assertNotIn("jwt", [f.pattern for f in findings]) + + def test_aws_access_key(self): + text = "AKIAIOSFODNN7EXAMPLE in source" + findings = scan_text(text) + self.assertTrue(any(f.pattern == "aws_access_key_id" for f in findings)) + + def test_aws_secret_assignment(self): + text = 'aws_secret_access_key = "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY"' + findings = scan_text(text) + self.assertTrue(any(f.pattern == "aws_secret_access_key_assignment" for f in findings)) + + def test_github_token(self): + text = f"ghp_{'a' * 36}" + findings = scan_text(text) + self.assertTrue(any(f.pattern == "github_token" for f in findings)) + + def test_stripe_live(self): + text = f"sk_live_{'a' * 24}" + findings = scan_text(text) + self.assertTrue(any(f.pattern == "stripe_live_secret" for f in findings)) + + def test_google_api_key(self): + text = "AIza" + "B" * 35 + findings = scan_text(text) + self.assertTrue(any(f.pattern == "google_api_key" for f in findings)) + + def test_session_assignment(self): + text = 'session_id: "abc123def456ghi789jkl"' + findings = scan_text(text) + self.assertTrue(any(f.pattern == "session_token_assignment" for f in findings)) + + def test_redaction(self): + text = f"ghp_{'x' * 36}" + findings = scan_text(text) + self.assertTrue(findings[0].token_suffix.startswith("…")) + + def test_dedup_within_source(self): + token = f"ghp_{'a' * 36}" + text = "\n".join([token] * 5) + findings = scan_text(text, source="x", location="y") + # All have the same suffix → dedup to 1 + self.assertEqual(len(findings), 1) + + def test_dedup_differs_by_location(self): + token = f"ghp_{'a' * 36}" + findings = scan_text(token, source="x", location="loc1") + findings += scan_text(token, source="x", location="loc2") + self.assertEqual(len(findings), 2) + + def test_non_string_rejected(self): + with self.assertRaises(TokenLeakError): + scan_text(123) # type: ignore[arg-type] + + def test_clean_text_returns_empty(self): + self.assertEqual(scan_text("nothing interesting here"), []) + + +class TestScanHar(unittest.TestCase): + + def test_scans_response_body(self): + har = { + "log": { + "entries": [ + { + "request": {"url": "https://api/x"}, + "response": { + "content": {"text": f"token={_real_jwt()}"}, + }, + } + ] + } + } + findings = scan_har(har) + self.assertTrue(any(f.location == "https://api/x" for f in findings)) + + def test_scans_request_body(self): + har = { + "log": { + "entries": [ + { + "request": { + "url": "https://api/x", + "postData": {"text": "AKIAIOSFODNN7EXAMPLE"}, + }, + "response": {"content": {}}, + } + ] + } + } + findings = scan_har(har) + self.assertTrue(any(f.source == "har.request" for f in findings)) + + def test_string_har(self): + har = json.dumps({"log": {"entries": []}}) + self.assertEqual(scan_har(har), []) + + def test_bad_har_string(self): + with self.assertRaises(TokenLeakError): + scan_har("not json") + + def test_bad_har_type(self): + with self.assertRaises(TokenLeakError): + scan_har(123) # type: ignore[arg-type] + + def test_bad_har_json_shape(self): + with self.assertRaises(TokenLeakError): + scan_har(json.dumps([1, 2])) + + +class TestScanLog(unittest.TestCase): + + def test_scans_lines(self): + token = f"ghp_{'a' * 36}" + findings = scan_log_lines(["nothing here", token, "another line"]) + self.assertEqual(len(findings), 1) + self.assertTrue(findings[0].location.startswith("line:")) + + def test_skips_non_strings(self): + findings = scan_log_lines([None, 1, "ok"]) # type: ignore[list-item] + self.assertEqual(findings, []) + + +class TestAssertions(unittest.TestCase): + + def test_assert_no_leaks_pass(self): + assert_no_leaks([]) + + def test_assert_no_leaks_fail(self): + with self.assertRaises(TokenLeakError): + assert_no_leaks([TokenFinding("jwt", "critical", "…abc123", "har")]) + + def test_filter_by_severity(self): + findings = [ + TokenFinding("jwt", "critical", "…1", "x"), + TokenFinding("session", "medium", "…2", "x"), + TokenFinding("api", "low", "…3", "x"), + ] + self.assertEqual(len(filter_by_severity(findings, minimum="medium")), 2) + self.assertEqual(len(filter_by_severity(findings, minimum="critical")), 1) + + def test_bad_severity(self): + with self.assertRaises(TokenLeakError): + filter_by_severity([], minimum="weird") + + +class TestCustomPattern(unittest.TestCase): + + def test_custom_pattern(self): + import re + custom = TokenPattern( + name="my_token", + pattern=re.compile(r"MYAPP-[A-Z0-9]{20}"), + severity="high", + ) + findings = scan_text("MYAPP-ABCDEFGHIJKLMNOPQRST", patterns=[custom]) + self.assertEqual(findings[0].pattern, "my_token") + + def test_default_patterns_loaded(self): + self.assertGreater(len(DEFAULT_PATTERNS), 5) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_view_transitions.py b/test/unit_test/test_view_transitions.py new file mode 100644 index 0000000..e38a670 --- /dev/null +++ b/test/unit_test/test_view_transitions.py @@ -0,0 +1,144 @@ +"""Unit tests for je_web_runner.utils.view_transitions.""" +import unittest + +from je_web_runner.utils.view_transitions.transitions import ( + TransitionRun, + ViewTransitionsError, + assert_all_finished, + assert_cls_under, + assert_group_present, + assert_under_duration, + build_instrumentation_script, + parse_log, +) + + +class TestInstrumentation(unittest.TestCase): + + def test_script_contains_install_guard(self): + js = build_instrumentation_script() + self.assertIn("__wr_vt_installed__", js) + self.assertIn("startViewTransition", js) + + +class TestParseLog(unittest.TestCase): + + def test_parses_basic(self): + runs = parse_log([{ + "id": "vt_1", "startedAt": 100.0, "finishedAt": 350.0, + "durationMs": 250.0, "error": None, + "layoutShifts": 0.02, "maxShiftValue": 0.01, + "groups": ["header"], + }]) + self.assertEqual(len(runs), 1) + self.assertTrue(runs[0].is_finished()) + self.assertEqual(runs[0].groups, ["header"]) + + def test_error_marked_not_finished(self): + runs = parse_log([{ + "id": "vt_1", "startedAt": 0, "finishedAt": 10, + "durationMs": 10, "error": "AbortError", + }]) + self.assertFalse(runs[0].is_finished()) + + def test_ignores_non_dict(self): + runs = parse_log(["not dict", None]) # type: ignore[list-item] + self.assertEqual(runs, []) + + def test_rejects_non_list(self): + with self.assertRaises(ViewTransitionsError): + parse_log("nope") # type: ignore[arg-type] + + +class TestAssertFinished(unittest.TestCase): + + def test_pass(self): + runs = parse_log([{"id": "a", "startedAt": 0, "finishedAt": 1, "durationMs": 1}]) + assert_all_finished(runs) + + def test_empty_rejected(self): + with self.assertRaises(ViewTransitionsError): + assert_all_finished([]) + + def test_fail_with_error(self): + runs = parse_log([{ + "id": "a", "startedAt": 0, "finishedAt": 1, "durationMs": 1, + "error": "boom", + }]) + with self.assertRaises(ViewTransitionsError): + assert_all_finished(runs) + + def test_fail_unfinished(self): + runs = parse_log([{"id": "a", "startedAt": 0}]) + with self.assertRaises(ViewTransitionsError): + assert_all_finished(runs) + + +class TestAssertDuration(unittest.TestCase): + + def test_pass(self): + runs = parse_log([{ + "id": "a", "startedAt": 0, "finishedAt": 100, "durationMs": 100, + }]) + assert_under_duration(runs, max_duration_ms=200) + + def test_fail(self): + runs = parse_log([{ + "id": "a", "startedAt": 0, "finishedAt": 300, "durationMs": 300, + }]) + with self.assertRaises(ViewTransitionsError): + assert_under_duration(runs, max_duration_ms=200) + + def test_bad_threshold(self): + with self.assertRaises(ViewTransitionsError): + assert_under_duration([], max_duration_ms=0) + + +class TestAssertCls(unittest.TestCase): + + def test_pass(self): + runs = parse_log([{ + "id": "a", "startedAt": 0, "finishedAt": 1, "durationMs": 1, + "layoutShifts": 0.05, "maxShiftValue": 0.02, + }]) + assert_cls_under(runs) + + def test_fail_cumulative(self): + runs = parse_log([{ + "id": "a", "startedAt": 0, "finishedAt": 1, "durationMs": 1, + "layoutShifts": 0.5, "maxShiftValue": 0.02, + }]) + with self.assertRaises(ViewTransitionsError): + assert_cls_under(runs) + + def test_fail_single(self): + runs = parse_log([{ + "id": "a", "startedAt": 0, "finishedAt": 1, "durationMs": 1, + "layoutShifts": 0.05, "maxShiftValue": 0.5, + }]) + with self.assertRaises(ViewTransitionsError): + assert_cls_under(runs) + + def test_bad_threshold(self): + with self.assertRaises(ViewTransitionsError): + assert_cls_under([], max_cls=-1) + + +class TestAssertGroup(unittest.TestCase): + + def test_pass(self): + runs = [TransitionRun(id="a", started_at=0, groups=["root", "nav"])] + assert_group_present(runs, "nav") + + def test_fail(self): + runs = [TransitionRun(id="a", started_at=0, groups=["root"])] + with self.assertRaises(ViewTransitionsError): + assert_group_present(runs, "nav") + + def test_empty_name(self): + with self.assertRaises(ViewTransitionsError): + assert_group_present([], "") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_visual_ai.py b/test/unit_test/test_visual_ai.py new file mode 100644 index 0000000..6799c94 --- /dev/null +++ b/test/unit_test/test_visual_ai.py @@ -0,0 +1,281 @@ +"""Unit tests for je_web_runner.utils.visual_ai.""" +import tempfile +import unittest +from io import BytesIO +from pathlib import Path + +from je_web_runner.utils.visual_ai.perceptual import ( + HashResult, + SimilarityResult, + VisualAIError, + assert_visual_similar, + average_hash, + compare_images, + difference_hash, + hamming_distance, + hash_similarity, + perceptual_hash, +) + + +def _require_pillow(): + try: + from PIL import Image # noqa: F401 + return True + except ImportError: + return False + + +def _make_solid(rgb, size=32): + from PIL import Image + img = Image.new("RGB", (size, size), rgb) + buf = BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def _make_checker(square_size=4, total=32): + from PIL import Image + img = Image.new("RGB", (total, total), (0, 0, 0)) + px = img.load() + for y in range(total): + for x in range(total): + if ((x // square_size) + (y // square_size)) % 2 == 0: + px[x, y] = (255, 255, 255) + buf = BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +def _make_gradient(total=32, flipped=False): + from PIL import Image + img = Image.new("RGB", (total, total), (0, 0, 0)) + px = img.load() + for y in range(total): + for x in range(total): + v = int(255 * (x / max(1, total - 1))) + if flipped: + v = 255 - v + px[x, y] = (v, v, v) + buf = BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + +@unittest.skipUnless(_require_pillow(), "Pillow not installed") +class TestHashFunctions(unittest.TestCase): + + def test_average_hash_default_size(self): + img = _make_solid((128, 128, 128)) + h = average_hash(img) + self.assertEqual(h.kind, "aHash") + self.assertEqual(len(h.bits), 64) + + def test_difference_hash_length(self): + img = _make_gradient() + h = difference_hash(img) + self.assertEqual(h.kind, "dHash") + self.assertEqual(len(h.bits), 64) + + def test_perceptual_hash_length(self): + img = _make_gradient() + h = perceptual_hash(img) + self.assertEqual(h.kind, "pHash") + # 8*8 - 1 (DC removed) + self.assertEqual(len(h.bits), 63) + + def test_hex_round_trip_no_crash(self): + img = _make_solid((50, 100, 200)) + h = average_hash(img) + self.assertTrue(len(h.hex()) > 0) + + def test_identical_images_perfect_similarity(self): + img = _make_gradient() + a = perceptual_hash(img) + b = perceptual_hash(img) + self.assertAlmostEqual(hash_similarity(a, b), 1.0) + self.assertEqual(hamming_distance(a, b), 0) + + def test_different_kinds_raise(self): + img = _make_solid((0, 0, 0)) + a = average_hash(img) + d = difference_hash(img) + with self.assertRaises(VisualAIError): + hamming_distance(a, d) + + def test_different_lengths_raise(self): + a = HashResult("aHash", "1010") + b = HashResult("aHash", "101010") + with self.assertRaises(VisualAIError): + hamming_distance(a, b) + + +@unittest.skipUnless(_require_pillow(), "Pillow not installed") +class TestCompareImages(unittest.TestCase): + + def test_identical_images_pass(self): + img = _make_gradient() + result = compare_images(img, img, threshold=0.95) + self.assertTrue(result.passed) + self.assertAlmostEqual(result.composite, 1.0, places=2) + + def test_wildly_different_images_fail(self): + # Use structurally different content (checker vs gradient). Pure- + # colour images degenerate aHash/dHash because every bit is "all + # below mean" on both sides — that's a corner case, not realistic. + a = _make_checker(square_size=4) + b = _make_gradient(flipped=False) + result = compare_images(a, b, threshold=0.85) + self.assertFalse(result.passed) + + def test_similar_charts_pass_above_threshold(self): + # two gradients differing by a tiny perturbation + from PIL import Image + base = _make_gradient() + perturbed_img = Image.open(BytesIO(base)).convert("RGB") + px = perturbed_img.load() + # Add subtle noise to a handful of pixels (much less than before) + for i in range(8): + x = (i * 7) % 32 + y = (i * 5) % 32 + px[x, y] = (min(255, px[x, y][0] + 10), 0, 0) + buf = BytesIO() + perturbed_img.save(buf, format="PNG") + result = compare_images(base, buf.getvalue(), threshold=0.85) + self.assertTrue(result.passed) + + def test_invalid_threshold_raises(self): + img = _make_solid((0, 0, 0)) + with self.assertRaises(VisualAIError): + compare_images(img, img, threshold=1.5) + + def test_weights_must_sum_to_one(self): + img = _make_solid((0, 0, 0)) + with self.assertRaises(VisualAIError): + compare_images(img, img, weights=(0.5, 0.5, 0.5, 0.5)) + + def test_file_path_input(self): + with tempfile.TemporaryDirectory() as tmpdir: + path_a = Path(tmpdir) / "a.png" + path_a.write_bytes(_make_gradient()) + result = compare_images(path_a, path_a) + self.assertTrue(result.passed) + + +@unittest.skipUnless(_require_pillow(), "Pillow not installed") +class TestAssertVisualSimilar(unittest.TestCase): + + def test_pass(self): + img = _make_gradient() + result = assert_visual_similar(img, img, threshold=0.9) + self.assertIsInstance(result, SimilarityResult) + + def test_fail_raises(self): + a = _make_checker(square_size=4) + b = _make_gradient() + with self.assertRaises(VisualAIError): + assert_visual_similar(a, b, threshold=0.85) + + +@unittest.skipUnless(_require_pillow(), "Pillow not installed") +class TestROIAndMask(unittest.TestCase): + + def _half_red_half_blue(self, size=64): + from PIL import Image + img = Image.new("RGB", (size, size), (0, 0, 255)) + px = img.load() + for y in range(size): + for x in range(size // 2): + px[x, y] = (255, 0, 0) + buf = BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + def _half_red_half_green(self, size=64): + from PIL import Image + img = Image.new("RGB", (size, size), (0, 255, 0)) + px = img.load() + for y in range(size): + for x in range(size // 2): + px[x, y] = (255, 0, 0) + buf = BytesIO() + img.save(buf, format="PNG") + return buf.getvalue() + + def test_crop_box_focuses_comparison(self): + a = self._half_red_half_blue() + b = self._half_red_half_green() + # Whole image differs (blue right half vs green right half), + # but the left half (red) is identical. + no_crop = compare_images(a, b, threshold=0.99) + cropped = compare_images( + a, b, threshold=0.99, crop_box=(0, 0, 32, 64), + ) + self.assertGreater(cropped.composite, no_crop.composite) + self.assertAlmostEqual(cropped.composite, 1.0, places=2) + + def test_mask_boxes_hides_dynamic_region(self): + a = self._half_red_half_blue() + b = self._half_red_half_green() + # Mask out the right half so only the identical red side counts. + result = compare_images( + a, b, threshold=0.99, mask_boxes=[(32, 0, 64, 64)], + ) + self.assertTrue(result.passed) + + def test_invalid_box_shape_raises(self): + a = self._half_red_half_blue() + with self.assertRaises(VisualAIError): + compare_images(a, a, crop_box=(0, 0, 10)) # type: ignore[arg-type] + + def test_inverted_box_raises(self): + a = self._half_red_half_blue() + with self.assertRaises(VisualAIError): + compare_images(a, a, crop_box=(10, 10, 5, 20)) + + def test_negative_origin_raises(self): + a = self._half_red_half_blue() + with self.assertRaises(VisualAIError): + compare_images(a, a, crop_box=(-1, 0, 10, 10)) + + def test_box_exceeding_image_raises(self): + a = self._half_red_half_blue(size=32) + with self.assertRaises(VisualAIError): + compare_images(a, a, crop_box=(0, 0, 100, 100)) + + def test_assert_with_mask_passes(self): + a = self._half_red_half_blue() + b = self._half_red_half_green() + result = assert_visual_similar( + a, b, threshold=0.95, mask_boxes=[(32, 0, 64, 64)], + ) + self.assertTrue(result.passed) + + def test_hash_accepts_crop(self): + from je_web_runner.utils.visual_ai.perceptual import perceptual_hash + a = self._half_red_half_blue() + # Just make sure passing a crop_box doesn't error + h = perceptual_hash(a, crop_box=(0, 0, 32, 32)) + self.assertEqual(h.kind, "pHash") + + +class TestInputErrors(unittest.TestCase): + + @unittest.skipUnless(_require_pillow(), "Pillow not installed") + def test_missing_file(self): + with self.assertRaises(VisualAIError): + average_hash("/no/such/file.png") + + @unittest.skipUnless(_require_pillow(), "Pillow not installed") + def test_unsupported_type(self): + with self.assertRaises(VisualAIError): + average_hash(42) # type: ignore[arg-type] + + @unittest.skipUnless(_require_pillow(), "Pillow not installed") + def test_bad_bytes(self): + with self.assertRaises(VisualAIError): + average_hash(b"not an image") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_walkthrough_docs.py b/test/unit_test/test_walkthrough_docs.py new file mode 100644 index 0000000..a4d2bf1 --- /dev/null +++ b/test/unit_test/test_walkthrough_docs.py @@ -0,0 +1,215 @@ +"""Unit tests for je_web_runner.utils.walkthrough_docs.""" +import json +import tempfile +import unittest +from pathlib import Path + +from je_web_runner.utils.ai_assist.llm_assist import set_llm_callable +from je_web_runner.utils.walkthrough_docs.generator import ( + Walkthrough, + WalkthroughError, + WalkthroughStep, + build_walkthrough, + collect_steps, + narrate_steps, + render_confluence, + render_markdown, + save_walkthrough, +) + + +SAMPLE_ACTIONS = [ + ["WR_init", {}], # noise — filtered + ["WR_to_url", {"url": "https://shop.example/cart"}], + ["WR_save_test_object", {"object_type": "ID", # noise + "test_object_name": "checkout"}], + ["WR_element_click", {"test_object_name": "checkout"}], + ["WR_set_timeout", {"timeout": 5}], # noise + ["WR_element_input", {"test_object_name": "promo", "text": "SAVE10"}], + ["WR_element_assert_text", {"test_object_name": "total", + "expected": "$90"}], +] + + +class TestCollectSteps(unittest.TestCase): + + def test_filters_noise(self): + steps = collect_steps(SAMPLE_ACTIONS) + commands = [s.action_command for s in steps] + self.assertIn("WR_to_url", commands) + self.assertIn("WR_element_click", commands) + self.assertNotIn("WR_init", commands) + self.assertNotIn("WR_save_test_object", commands) + self.assertNotIn("WR_set_timeout", commands) + + def test_skip_noise_disabled(self): + steps = collect_steps(SAMPLE_ACTIONS, skip_noise=False) + commands = [s.action_command for s in steps] + self.assertIn("WR_init", commands) + + def test_non_list_raises(self): + with self.assertRaises(WalkthroughError): + collect_steps("not a list") # type: ignore[arg-type] + + def test_screenshot_bytes_attached(self): + png = b"\x89PNG\r\n\x1a\n" + b"x" * 20 + steps = collect_steps(SAMPLE_ACTIONS, screenshots={3: png}) + click_step = next(s for s in steps if s.action_command == "WR_element_click") + self.assertIsNotNone(click_step.screenshot_b64) + + def test_screenshot_file_attached(self): + with tempfile.TemporaryDirectory() as tmpdir: + path = Path(tmpdir) / "shot.png" + path.write_bytes(b"\x89PNG\r\n\x1a\n" + b"x" * 4) + steps = collect_steps(SAMPLE_ACTIONS, screenshots={3: path}) + click_step = next(s for s in steps if s.action_command == "WR_element_click") + self.assertEqual(click_step.screenshot_path, str(path)) + self.assertIsNotNone(click_step.screenshot_b64) + self.assertEqual(click_step.screenshot_mime, "image/png") + + def test_missing_screenshot_file_is_warned_not_raised(self): + steps = collect_steps(SAMPLE_ACTIONS, screenshots={3: "/no/such.png"}) + click_step = next(s for s in steps if s.action_command == "WR_element_click") + self.assertIsNone(click_step.screenshot_path) + + +class TestNarrate(unittest.TestCase): + + def tearDown(self): + set_llm_callable(None) + + def test_assigns_narrations_in_order(self): + wt = Walkthrough( + title="Checkout", + steps=[ + WalkthroughStep(index=0, action_command="WR_to_url"), + WalkthroughStep(index=1, action_command="WR_element_click"), + ], + ) + payload = json.dumps({ + "steps": [ + "Open the cart page.", + "Click the checkout button.", + ] + }) + set_llm_callable(lambda _p: payload) + narrate_steps(wt) + self.assertEqual(wt.steps[0].narration, "Open the cart page.") + self.assertEqual(wt.steps[1].narration, "Click the checkout button.") + + def test_no_steps_returns_early(self): + wt = Walkthrough(title="empty") + # should not call LLM at all — no callable registered, no error + narrate_steps(wt) + self.assertEqual(wt.steps, []) + + def test_missing_steps_key_raises(self): + wt = Walkthrough(title="x", steps=[ + WalkthroughStep(index=0, action_command="WR_x"), + ]) + set_llm_callable(lambda _p: "{}") + with self.assertRaises(WalkthroughError): + narrate_steps(wt) + + def test_no_callable_raises_when_steps_exist(self): + wt = Walkthrough(title="x", steps=[ + WalkthroughStep(index=0, action_command="WR_x"), + ]) + set_llm_callable(None) + with self.assertRaises(WalkthroughError): + narrate_steps(wt) + + +class TestRendering(unittest.TestCase): + + def _wt(self, with_image=False): + steps = [ + WalkthroughStep(index=0, action_command="WR_to_url", + kwargs={"url": "https://x"}, narration="Visit the site."), + WalkthroughStep(index=1, action_command="WR_element_click", + kwargs={"test_object_name": "btn"}, + narration="Click the button."), + ] + if with_image: + steps[1].screenshot_b64 = "deadbeef" + steps[1].screenshot_path = "/tmp/shot.png" + return Walkthrough(title="Sample", description="A demo flow", steps=steps) + + def test_markdown_has_steps(self): + md = render_markdown(self._wt()) + self.assertIn("# Sample", md) + self.assertIn("Step 1. Visit the site.", md) + self.assertIn("Step 2. Click the button.", md) + self.assertIn("`WR_to_url`", md) + + def test_markdown_embeds_data_uri(self): + md = render_markdown(self._wt(with_image=True)) + self.assertIn("data:image/png;base64,deadbeef", md) + + def test_markdown_uses_path_when_no_embed(self): + md = render_markdown(self._wt(with_image=True), embed_images=False) + self.assertIn("/tmp/shot.png", md) + self.assertNotIn("data:image", md) + + def test_confluence_xml_escapes(self): + wt = Walkthrough(title="", steps=[ + WalkthroughStep(index=0, action_command="WR_to_url", + kwargs={"url": "https://x"}, narration="A & B"), + ]) + x = render_confluence(wt) + self.assertIn("<bad>", x) + self.assertIn("A & B", x) + + def test_confluence_uses_attachment_for_image(self): + wt = self._wt(with_image=True) + x = render_confluence(wt) + self.assertIn('ri:filename="shot.png"', x) + + +class TestBuildWalkthrough(unittest.TestCase): + + def tearDown(self): + set_llm_callable(None) + + def test_no_narrate(self): + wt = build_walkthrough("login", SAMPLE_ACTIONS, narrate=False) + self.assertEqual(wt.title, "login") + self.assertTrue(all(s.narration == "" for s in wt.steps)) + + def test_with_narrate(self): + narrations = json.dumps({"steps": ["a"] * 100}) + set_llm_callable(lambda _p: narrations) + wt = build_walkthrough("login", SAMPLE_ACTIONS, narrate=True) + self.assertTrue(any(s.narration for s in wt.steps)) + + +class TestSaveWalkthrough(unittest.TestCase): + + def test_markdown_file(self): + wt = Walkthrough(title="x", steps=[ + WalkthroughStep(index=0, action_command="WR_x", narration="n"), + ]) + with tempfile.TemporaryDirectory() as tmpdir: + path = save_walkthrough(wt, Path(tmpdir) / "out.md") + self.assertTrue(path.exists()) + text = path.read_text(encoding="utf-8") + self.assertIn("# x", text) + + def test_confluence_file(self): + wt = Walkthrough(title="x", steps=[ + WalkthroughStep(index=0, action_command="WR_x", narration="n"), + ]) + with tempfile.TemporaryDirectory() as tmpdir: + path = save_walkthrough(wt, Path(tmpdir) / "out.xml", fmt="confluence") + text = path.read_text(encoding="utf-8") + self.assertIn("

x

", text) + + def test_unknown_fmt_raises(self): + wt = Walkthrough(title="x") + with tempfile.TemporaryDirectory() as tmpdir: + with self.assertRaises(WalkthroughError): + save_walkthrough(wt, Path(tmpdir) / "out.x", fmt="rst") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_webhook_receiver.py b/test/unit_test/test_webhook_receiver.py new file mode 100644 index 0000000..1265efc --- /dev/null +++ b/test/unit_test/test_webhook_receiver.py @@ -0,0 +1,187 @@ +"""Unit tests for je_web_runner.utils.webhook_receiver.""" +import json +import unittest +from urllib.request import Request, urlopen + +from je_web_runner.utils.webhook_receiver.receiver import ( + ReceivedRequest, + WebhookReceiverError, + WebhookServer, + assert_received_json_matching, + assert_received_path, + assert_received_with_header, +) + + +def _post(url, body=b"", headers=None): + req = Request(url, data=body, headers=headers or {}, method="POST") + with urlopen(req, timeout=5) as response: + return response.status, response.read() + + +def _get(url): + with urlopen(url, timeout=5) as response: + return response.status, response.read() + + +class TestWebhookServer(unittest.TestCase): + + def test_starts_and_stops(self): + with WebhookServer() as server: + self.assertTrue(server.base_url.startswith("http://127.0.0.1:")) + status, _ = _get(server.base_url + "/ping") + self.assertEqual(status, 200) + + def test_captures_post_body(self): + with WebhookServer() as server: + _post( + server.base_url + "/hooks/order", + body=json.dumps({"id": 42}).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + requests = server.received() + self.assertEqual(len(requests), 1) + self.assertEqual(requests[0].path, "/hooks/order") + self.assertEqual(requests[0].method, "POST") + self.assertEqual(requests[0].body_json(), {"id": 42}) + + def test_query_parsed(self): + with WebhookServer() as server: + _get(server.base_url + "/x?a=1&a=2&b=3") + request = server.received()[0] + self.assertEqual(request.query["a"], ["1", "2"]) + self.assertEqual(request.query["b"], ["3"]) + + def test_clear(self): + with WebhookServer() as server: + _get(server.base_url + "/x") + server.clear() + self.assertEqual(server.received(), []) + + def test_custom_response(self): + def resp(req): + return {"status": 201, "body": b"created"} + with WebhookServer(response_fn=resp) as server: + status, body = _post(server.base_url + "/x") + self.assertEqual(status, 201) + self.assertEqual(body, b"created") + + def test_wait_for(self): + with WebhookServer() as server: + _post(server.base_url + "/expected") + request = server.wait_for(lambda r: r.path == "/expected", timeout=1.0) + self.assertEqual(request.path, "/expected") + + def test_wait_for_timeout(self): + with WebhookServer() as server: + with self.assertRaises(WebhookReceiverError): + server.wait_for(lambda r: r.path == "/never", timeout=0.2) + + def test_wait_for_bad_args(self): + with WebhookServer() as server: + with self.assertRaises(WebhookReceiverError): + server.wait_for(lambda r: True, timeout=0) + with self.assertRaises(WebhookReceiverError): + server.wait_for(lambda r: True, timeout=1, interval=0) + + def test_double_start_rejected(self): + server = WebhookServer().start() + try: + with self.assertRaises(WebhookReceiverError): + server.start() + finally: + server.stop() + + def test_bad_host(self): + with self.assertRaises(WebhookReceiverError): + WebhookServer(host="") + + def test_bad_port(self): + with self.assertRaises(WebhookReceiverError): + WebhookServer(port=0) + with self.assertRaises(WebhookReceiverError): + WebhookServer(port=99999) + + +class TestReceivedRequest(unittest.TestCase): + + def test_body_text(self): + request = ReceivedRequest(method="POST", path="/", body=b"hi") + self.assertEqual(request.body_text(), "hi") + + def test_body_json_bad(self): + request = ReceivedRequest(method="POST", path="/", body=b"not json") + with self.assertRaises(WebhookReceiverError): + request.body_json() + + def test_to_dict(self): + request = ReceivedRequest(method="POST", path="/", body=b"hi") + self.assertEqual(request.to_dict()["body"], "hi") + + +class TestAssertions(unittest.TestCase): + + def test_assert_received_path(self): + with WebhookServer() as server: + _post(server.base_url + "/x") + _post(server.base_url + "/x") + self.assertEqual(assert_received_path(server, "/x", minimum=2), 2) + + def test_assert_received_path_method_filter(self): + with WebhookServer() as server: + _get(server.base_url + "/x") + self.assertEqual( + assert_received_path(server, "/x", method="GET"), 1, + ) + + def test_assert_received_path_fail(self): + with WebhookServer() as server: + with self.assertRaises(WebhookReceiverError): + assert_received_path(server, "/missing") + + def test_assert_received_path_empty(self): + with WebhookServer() as server: + with self.assertRaises(WebhookReceiverError): + assert_received_path(server, "") + + def test_assert_received_with_header(self): + with WebhookServer() as server: + _post( + server.base_url + "/x", + headers={"X-Token": "abc123"}, + ) + request = assert_received_with_header(server, "X-Token", "abc123") + self.assertEqual(request.path, "/x") + + def test_assert_received_with_header_miss(self): + with WebhookServer() as server: + _post(server.base_url + "/x") + with self.assertRaises(WebhookReceiverError): + assert_received_with_header(server, "X-Missing", "x") + + def test_assert_received_with_header_empty(self): + with WebhookServer() as server: + with self.assertRaises(WebhookReceiverError): + assert_received_with_header(server, "", "x") + + def test_assert_received_json_matching(self): + with WebhookServer() as server: + _post( + server.base_url + "/x", + body=json.dumps({"event": "order.created"}).encode("utf-8"), + headers={"Content-Type": "application/json"}, + ) + request = assert_received_json_matching( + server, lambda p: p.get("event") == "order.created", + ) + self.assertEqual(request.path, "/x") + + def test_assert_received_json_matching_miss(self): + with WebhookServer() as server: + _post(server.base_url + "/x", body=b"not json") + with self.assertRaises(WebhookReceiverError): + assert_received_json_matching(server, lambda p: True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_webrtc_assert.py b/test/unit_test/test_webrtc_assert.py new file mode 100644 index 0000000..b88c007 --- /dev/null +++ b/test/unit_test/test_webrtc_assert.py @@ -0,0 +1,214 @@ +"""Unit tests for je_web_runner.utils.webrtc_assert.""" +import unittest + +from je_web_runner.utils.webrtc_assert.peer import ( + ConnectionState, + IceState, + PeerSnapshot, + RtpStats, + SignalingState, + TrackInfo, + WebRtcAssertError, + aggregate_stats, + assert_connected, + assert_min_bytes_flowed, + assert_no_packet_loss, + assert_sdp_has_codec, + assert_track_present, + export_snapshot, +) + + +def _connected_snapshot(**overrides): + base = dict( + connection_state=ConnectionState.CONNECTED, + ice_connection_state=IceState.CONNECTED, + signaling_state=SignalingState.STABLE, + local_sdp="m=audio 9 UDP/TLS/RTP/SAVPF 111\na=rtpmap:111 opus/48000/2", + remote_sdp="m=video 9 UDP/TLS/RTP/SAVPF 96\na=rtpmap:96 VP8/90000", + remote_tracks=[TrackInfo(kind="audio"), TrackInfo(kind="video")], + local_tracks=[TrackInfo(kind="audio")], + ) + base.update(overrides) + return PeerSnapshot(**base) + + +class TestFromDict(unittest.TestCase): + + def test_minimal(self): + snap = PeerSnapshot.from_dict({ + "connectionState": "connected", + "iceConnectionState": "completed", + }) + self.assertEqual(snap.connection_state, ConnectionState.CONNECTED) + self.assertEqual(snap.ice_connection_state, IceState.COMPLETED) + + def test_with_tracks(self): + snap = PeerSnapshot.from_dict({ + "connectionState": "connected", + "iceConnectionState": "connected", + "remoteTracks": [{"kind": "audio", "readyState": "live"}], + }) + self.assertEqual(len(snap.remote_tracks), 1) + + def test_unknown_state_rejected(self): + with self.assertRaises(WebRtcAssertError): + PeerSnapshot.from_dict({"connectionState": "weird"}) + + def test_rejects_non_dict(self): + with self.assertRaises(WebRtcAssertError): + PeerSnapshot.from_dict("not a dict") # type: ignore[arg-type] + + +class TestAssertConnected(unittest.TestCase): + + def test_passes(self): + assert_connected(_connected_snapshot()) + + def test_fails_when_not_connected(self): + with self.assertRaises(WebRtcAssertError): + assert_connected(_connected_snapshot( + connection_state=ConnectionState.FAILED, + )) + + def test_fails_when_ice_disconnected(self): + with self.assertRaises(WebRtcAssertError): + assert_connected(_connected_snapshot( + ice_connection_state=IceState.DISCONNECTED, + )) + + def test_accepts_ice_completed(self): + assert_connected(_connected_snapshot( + ice_connection_state=IceState.COMPLETED, + )) + + def test_rejects_non_snapshot(self): + with self.assertRaises(WebRtcAssertError): + assert_connected("not snap") # type: ignore[arg-type] + + +class TestTrackPresent(unittest.TestCase): + + def test_remote_audio_present(self): + track = assert_track_present(_connected_snapshot(), "audio") + self.assertEqual(track.kind, "audio") + + def test_local_side(self): + track = assert_track_present(_connected_snapshot(), "audio", side="local") + self.assertEqual(track.kind, "audio") + + def test_missing_kind(self): + with self.assertRaises(WebRtcAssertError): + assert_track_present(_connected_snapshot(), "data") + + def test_ended_track_skipped(self): + snap = _connected_snapshot( + remote_tracks=[TrackInfo(kind="audio", ready_state="ended")], + ) + with self.assertRaises(WebRtcAssertError): + assert_track_present(snap, "audio") + + def test_invalid_side(self): + with self.assertRaises(WebRtcAssertError): + assert_track_present(_connected_snapshot(), "audio", side="weird") + + +class TestSdpCodec(unittest.TestCase): + + def test_local_codec_match(self): + assert_sdp_has_codec(_connected_snapshot(), "opus") + + def test_remote_codec_match(self): + assert_sdp_has_codec(_connected_snapshot(), "VP8", side="remote") + + def test_missing(self): + with self.assertRaises(WebRtcAssertError): + assert_sdp_has_codec(_connected_snapshot(), "h264") + + def test_empty_sdp(self): + with self.assertRaises(WebRtcAssertError): + assert_sdp_has_codec(_connected_snapshot(local_sdp=""), "opus") + + def test_empty_codec_name(self): + with self.assertRaises(WebRtcAssertError): + assert_sdp_has_codec(_connected_snapshot(), "") + + +class TestAggregateStats(unittest.TestCase): + + def test_aggregates_inbound_and_outbound(self): + raw = [ + {"type": "inbound-rtp", "kind": "audio", "packetsReceived": 100, + "packetsLost": 1, "bytesReceived": 10_000, "jitter": 0.01}, + {"type": "outbound-rtp", "kind": "audio", "packetsSent": 100, + "bytesSent": 12_000}, + {"type": "candidate-pair", "kind": "audio"}, # ignored + ] + stats = aggregate_stats(raw) + kinds = {(s.direction, s.kind) for s in stats} + self.assertEqual(kinds, {("inbound", "audio"), ("outbound", "audio")}) + + def test_ignores_non_dict(self): + stats = aggregate_stats(["not dict", None]) # type: ignore[list-item] + self.assertEqual(stats, []) + + def test_rejects_non_list(self): + with self.assertRaises(WebRtcAssertError): + aggregate_stats({"type": "inbound-rtp"}) # type: ignore[arg-type] + + +class TestPacketLoss(unittest.TestCase): + + def test_pass_under_threshold(self): + stats = [RtpStats(direction="inbound", kind="audio", packets=100, packets_lost=1)] + assert_no_packet_loss(stats, max_loss_ratio=0.05) + + def test_fail_over_threshold(self): + stats = [RtpStats(direction="inbound", kind="audio", packets=100, packets_lost=10)] + with self.assertRaises(WebRtcAssertError): + assert_no_packet_loss(stats, max_loss_ratio=0.05) + + def test_filter_by_direction(self): + stats = [ + RtpStats(direction="inbound", kind="audio", packets=100, packets_lost=10), + RtpStats(direction="outbound", kind="audio", packets=100, packets_lost=0), + ] + # Only outbound being checked → passes + assert_no_packet_loss(stats, direction="outbound", max_loss_ratio=0.01) + + def test_bad_ratio(self): + with self.assertRaises(WebRtcAssertError): + assert_no_packet_loss([], max_loss_ratio=2.0) + + +class TestMinBytes(unittest.TestCase): + + def test_pass(self): + stats = [RtpStats(direction="outbound", kind="video", bytes=10_000)] + assert_min_bytes_flowed(stats, direction="outbound", kind="video", minimum=1000) + + def test_fail_too_few(self): + stats = [RtpStats(direction="outbound", kind="video", bytes=10)] + with self.assertRaises(WebRtcAssertError): + assert_min_bytes_flowed(stats, direction="outbound", kind="video", minimum=1000) + + def test_missing_stream(self): + with self.assertRaises(WebRtcAssertError): + assert_min_bytes_flowed([], direction="outbound", kind="video", minimum=1) + + def test_negative_minimum(self): + with self.assertRaises(WebRtcAssertError): + assert_min_bytes_flowed([], direction="x", kind="y", minimum=-1) + + +class TestExport(unittest.TestCase): + + def test_export(self): + data = export_snapshot(_connected_snapshot()) + self.assertEqual(data["connection_state"], "connected") + self.assertEqual(data["ice_connection_state"], "connected") + self.assertIn("remote_tracks", data) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_websocket_assert.py b/test/unit_test/test_websocket_assert.py new file mode 100644 index 0000000..4f26421 --- /dev/null +++ b/test/unit_test/test_websocket_assert.py @@ -0,0 +1,216 @@ +"""Unit tests for je_web_runner.utils.websocket_assert.""" +import json +import re +import unittest + +from je_web_runner.utils.websocket_assert.frames import ( + RECEIVED, + SENT, + WebSocketAssertError, + WsFrame, + WsFrameRecorder, + assert_frame_count, + assert_frame_received, + assert_json_shape, + assert_payload_contains, + assert_pubsub_pattern, + to_json, +) + + +def _recorder_with(*frames): + rec = WsFrameRecorder() + for f in frames: + rec.record(f) + return rec + + +class TestFrame(unittest.TestCase): + + def test_rejects_bad_direction(self): + with self.assertRaises(WebSocketAssertError): + WsFrame(direction="weird", url="ws://x", payload="") + + def test_as_json_decodes(self): + f = WsFrame(direction=SENT, url="ws://x", payload='{"a":1}') + self.assertEqual(f.as_json(), {"a": 1}) + + def test_as_json_raises_on_bad(self): + f = WsFrame(direction=SENT, url="ws://x", payload="not json") + with self.assertRaises(WebSocketAssertError): + f.as_json() + + +class TestRecorder(unittest.TestCase): + + def test_helpers_record_with_direction(self): + rec = WsFrameRecorder() + rec.record_sent("ws://x", "hi") + rec.record_received("ws://x", "bye") + self.assertEqual(len(rec), 2) + self.assertEqual(rec.frames()[0].direction, SENT) + self.assertEqual(rec.frames()[1].direction, RECEIVED) + + def test_filter_by_direction(self): + rec = WsFrameRecorder() + rec.record_sent("ws://x", "a") + rec.record_received("ws://x", "b") + self.assertEqual(len(rec.frames(direction=SENT)), 1) + self.assertEqual(len(rec.frames(direction=RECEIVED)), 1) + + def test_filter_by_url(self): + rec = WsFrameRecorder() + rec.record_sent("ws://api/sub", "a") + rec.record_sent("ws://api/other", "b") + self.assertEqual(len(rec.frames(url_match="sub")), 1) + self.assertEqual(len(rec.frames(url_match=re.compile(r"/api/.+"))), 2) + + def test_clear(self): + rec = WsFrameRecorder() + rec.record_sent("ws://x", "a") + rec.clear() + self.assertEqual(len(rec), 0) + + def test_record_rejects_non_frame(self): + rec = WsFrameRecorder() + with self.assertRaises(WebSocketAssertError): + rec.record("string payload") # type: ignore[arg-type] + + +class TestAssertCount(unittest.TestCase): + + def test_in_range(self): + rec = _recorder_with( + WsFrame(SENT, "ws://x", "a"), + WsFrame(RECEIVED, "ws://x", "b"), + ) + self.assertEqual(assert_frame_count(rec, minimum=2, maximum=5), 2) + + def test_below_minimum(self): + rec = WsFrameRecorder() + with self.assertRaises(WebSocketAssertError): + assert_frame_count(rec, minimum=1) + + def test_above_maximum(self): + rec = _recorder_with(WsFrame(SENT, "ws://x", "a"), WsFrame(SENT, "ws://x", "b")) + with self.assertRaises(WebSocketAssertError): + assert_frame_count(rec, maximum=1) + + def test_negative_minimum_rejected(self): + with self.assertRaises(WebSocketAssertError): + assert_frame_count(WsFrameRecorder(), minimum=-1) + + def test_max_lt_min_rejected(self): + with self.assertRaises(WebSocketAssertError): + assert_frame_count(WsFrameRecorder(), minimum=3, maximum=1) + + +class TestAssertReceived(unittest.TestCase): + + def test_finds_match(self): + rec = _recorder_with( + WsFrame(RECEIVED, "ws://x", '{"type":"ack"}'), + ) + f = assert_frame_received(rec, lambda fr: "ack" in fr.payload, description="ack") + self.assertIn("ack", f.payload) + + def test_no_match(self): + rec = _recorder_with(WsFrame(RECEIVED, "ws://x", "nope")) + with self.assertRaises(WebSocketAssertError): + assert_frame_received(rec, lambda fr: False) + + def test_predicate_exception_treated_as_no_match(self): + rec = _recorder_with(WsFrame(RECEIVED, "ws://x", "a")) + + def bad(_): + raise RuntimeError("oops") + with self.assertRaises(WebSocketAssertError): + assert_frame_received(rec, bad) + + +class TestAssertPayloadContains(unittest.TestCase): + + def test_match(self): + rec = _recorder_with(WsFrame(SENT, "ws://x", "hello world")) + f = assert_payload_contains(rec, "world") + self.assertEqual(f.direction, SENT) + + def test_miss(self): + rec = _recorder_with(WsFrame(SENT, "ws://x", "hello")) + with self.assertRaises(WebSocketAssertError): + assert_payload_contains(rec, "missing") + + def test_empty_needle_rejected(self): + with self.assertRaises(WebSocketAssertError): + assert_payload_contains(WsFrameRecorder(), "") + + +class TestAssertJsonShape(unittest.TestCase): + + def test_found(self): + rec = _recorder_with(WsFrame(RECEIVED, "ws://x", '{"id":1,"v":2}')) + f = assert_json_shape(rec, ["id", "v"]) + self.assertEqual(f.as_json()["id"], 1) + + def test_missing_key(self): + rec = _recorder_with(WsFrame(RECEIVED, "ws://x", '{"id":1}')) + with self.assertRaises(WebSocketAssertError): + assert_json_shape(rec, ["id", "missing"]) + + def test_non_json_frames_skipped(self): + rec = _recorder_with( + WsFrame(RECEIVED, "ws://x", "not json"), + WsFrame(RECEIVED, "ws://x", '{"k":true}'), + ) + assert_json_shape(rec, ["k"]) + + def test_empty_keys_rejected(self): + with self.assertRaises(WebSocketAssertError): + assert_json_shape(WsFrameRecorder(), []) + + +class TestAssertPubsub(unittest.TestCase): + + def test_subscribe_then_publish(self): + rec = _recorder_with( + WsFrame(SENT, "ws://x", '{"op":"subscribe","ch":"prices"}'), + WsFrame(RECEIVED, "ws://x", '{"ch":"prices","data":1}'), + ) + assert_pubsub_pattern( + rec, + subscribe_matcher=lambda f: '"subscribe"' in f.payload, + publish_matcher=lambda f: '"data":1' in f.payload, + ) + + def test_publish_before_subscribe_fails(self): + rec = _recorder_with( + WsFrame(RECEIVED, "ws://x", '{"data":1}'), + WsFrame(SENT, "ws://x", '{"op":"subscribe"}'), + ) + with self.assertRaises(WebSocketAssertError): + assert_pubsub_pattern( + rec, + subscribe_matcher=lambda f: "subscribe" in f.payload, + publish_matcher=lambda f: "data" in f.payload, + ) + + def test_no_pair_at_all(self): + with self.assertRaises(WebSocketAssertError): + assert_pubsub_pattern( + WsFrameRecorder(), + subscribe_matcher=lambda f: True, + publish_matcher=lambda f: True, + ) + + +class TestToJson(unittest.TestCase): + + def test_roundtrip(self): + rec = _recorder_with(WsFrame(SENT, "ws://x", "hi")) + text = to_json(rec.frames()) + loaded = json.loads(text) + self.assertEqual(loaded[0]["payload"], "hi") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/unit_test/test_webtransport_assert.py b/test/unit_test/test_webtransport_assert.py new file mode 100644 index 0000000..4f778ce --- /dev/null +++ b/test/unit_test/test_webtransport_assert.py @@ -0,0 +1,196 @@ +"""Unit tests for je_web_runner.utils.webtransport_assert.""" +import json +import unittest + +from je_web_runner.utils.webtransport_assert.streams import ( + DATAGRAM, + RECEIVED, + SENT, + STREAM, + WebTransportAssertError, + WtFrame, + WtFrameRecorder, + assert_datagram_count, + assert_json_shape, + assert_payload_contains, + assert_stream_complete, + to_json, +) + + +class TestFrame(unittest.TestCase): + + def test_bad_direction(self): + with self.assertRaises(WebTransportAssertError): + WtFrame(direction="weird", channel=DATAGRAM, payload=b"x") + + def test_bad_channel(self): + with self.assertRaises(WebTransportAssertError): + WtFrame(direction=SENT, channel="weird", payload=b"x") + + def test_payload_must_be_bytes(self): + with self.assertRaises(WebTransportAssertError): + WtFrame(direction=SENT, channel=DATAGRAM, payload="text") # type: ignore[arg-type] + + def test_stream_requires_id(self): + with self.assertRaises(WebTransportAssertError): + WtFrame(direction=SENT, channel=STREAM, payload=b"x") + + def test_as_text(self): + f = WtFrame(direction=SENT, channel=DATAGRAM, payload=b"hello") + self.assertEqual(f.as_text(), "hello") + + def test_as_json(self): + f = WtFrame(direction=SENT, channel=DATAGRAM, payload=b'{"a":1}') + self.assertEqual(f.as_json(), {"a": 1}) + + def test_as_json_bad(self): + f = WtFrame(direction=SENT, channel=DATAGRAM, payload=b"not json") + with self.assertRaises(WebTransportAssertError): + f.as_json() + + +class TestRecorder(unittest.TestCase): + + def test_records_datagrams(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"hi") + rec.record_received_datagram(b"bye") + self.assertEqual(len(rec), 2) + + def test_records_stream_chunks(self): + rec = WtFrameRecorder() + rec.record_stream_chunk(RECEIVED, stream_id=1, payload=b"a") + rec.record_stream_chunk(RECEIVED, stream_id=1, payload=b"b", fin=True) + chunks = rec.frames(stream_id=1) + self.assertEqual(len(chunks), 2) + self.assertTrue(chunks[-1].fin) + + def test_clear(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"x") + rec.clear() + self.assertEqual(len(rec), 0) + + def test_filter_combinations(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"a") + rec.record_received_datagram(b"b") + rec.record_stream_chunk(SENT, stream_id=7, payload=b"c") + self.assertEqual(len(rec.frames(channel=DATAGRAM)), 2) + self.assertEqual(len(rec.frames(channel=STREAM)), 1) + self.assertEqual(len(rec.frames(direction=SENT)), 2) + self.assertEqual(rec.stream_ids(), [7]) + + def test_filter_rejects_unknown(self): + rec = WtFrameRecorder() + with self.assertRaises(WebTransportAssertError): + rec.frames(direction="x") + with self.assertRaises(WebTransportAssertError): + rec.frames(channel="x") + + def test_record_rejects_non_frame(self): + with self.assertRaises(WebTransportAssertError): + WtFrameRecorder().record("not a frame") # type: ignore[arg-type] + + +class TestAssertDatagramCount(unittest.TestCase): + + def test_in_range(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"a") + rec.record_received_datagram(b"b") + self.assertEqual(assert_datagram_count(rec, minimum=2), 2) + + def test_below_minimum(self): + with self.assertRaises(WebTransportAssertError): + assert_datagram_count(WtFrameRecorder(), minimum=1) + + def test_above_maximum(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"a") + rec.record_sent_datagram(b"b") + with self.assertRaises(WebTransportAssertError): + assert_datagram_count(rec, maximum=1) + + def test_filter_by_direction(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"a") + rec.record_received_datagram(b"b") + self.assertEqual( + assert_datagram_count(rec, direction=SENT, minimum=1, maximum=1), 1, + ) + + def test_max_lt_min_rejected(self): + with self.assertRaises(WebTransportAssertError): + assert_datagram_count(WtFrameRecorder(), minimum=3, maximum=1) + + +class TestAssertStreamComplete(unittest.TestCase): + + def test_pass(self): + rec = WtFrameRecorder() + rec.record_stream_chunk(RECEIVED, stream_id=1, payload=b"hello ") + rec.record_stream_chunk(RECEIVED, stream_id=1, payload=b"world", fin=True) + self.assertEqual(assert_stream_complete(rec, 1), b"hello world") + + def test_missing_stream(self): + with self.assertRaises(WebTransportAssertError): + assert_stream_complete(WtFrameRecorder(), 1) + + def test_no_fin(self): + rec = WtFrameRecorder() + rec.record_stream_chunk(RECEIVED, stream_id=1, payload=b"x") + with self.assertRaises(WebTransportAssertError): + assert_stream_complete(rec, 1) + + def test_bad_direction(self): + with self.assertRaises(WebTransportAssertError): + assert_stream_complete(WtFrameRecorder(), 1, direction="weird") + + +class TestAssertPayloadContains(unittest.TestCase): + + def test_match(self): + rec = WtFrameRecorder() + rec.record_received_datagram(b"hello world") + self.assertIsNotNone(assert_payload_contains(rec, b"world")) + + def test_miss(self): + with self.assertRaises(WebTransportAssertError): + assert_payload_contains(WtFrameRecorder(), b"x") + + def test_empty_needle(self): + with self.assertRaises(WebTransportAssertError): + assert_payload_contains(WtFrameRecorder(), b"") + + +class TestAssertJsonShape(unittest.TestCase): + + def test_match(self): + rec = WtFrameRecorder() + rec.record_received_datagram(b'{"id":1,"v":2}') + assert_json_shape(rec, ["id", "v"]) + + def test_miss(self): + rec = WtFrameRecorder() + rec.record_received_datagram(b'{"id":1}') + with self.assertRaises(WebTransportAssertError): + assert_json_shape(rec, ["id", "missing"]) + + def test_empty_keys(self): + with self.assertRaises(WebTransportAssertError): + assert_json_shape(WtFrameRecorder(), []) + + +class TestToJson(unittest.TestCase): + + def test_roundtrip(self): + rec = WtFrameRecorder() + rec.record_sent_datagram(b"hi") + loaded = json.loads(to_json(rec.frames())) + self.assertEqual(loaded[0]["payload_b64"], "6869") # 'hi' hex + + +if __name__ == "__main__": + unittest.main() From a522ac3500a1cf70e28f22b01db75def0ba573ff Mon Sep 17 00:00:00 2001 From: JeffreyChen Date: Sun, 24 May 2026 18:25:57 +0800 Subject: [PATCH 2/8] Document the specialized modules in READMEs and Sphinx tree Adds a Specialized Modules section to README.md grouping the 73 new modules by capability area (Web Platform APIs / Security / Perf / Backend Integration / AI Workflow / a11y-i18n-Visual / Governance), mirrors the same in README_zh-TW.md and README_zh-CN.md, and adds a dedicated Sphinx chapter under docs/source/{Eng,Zh}/doc/specialized_modules/ wired into the Quality & Data chapter of both language indices. --- README.md | 224 ++++++++ README/README_zh-CN.md | 87 +++ README/README_zh-TW.md | 87 +++ .../specialized_modules_doc.rst | 530 ++++++++++++++++++ docs/source/Eng/eng_index.rst | 1 + .../specialized_modules_doc.rst | 507 +++++++++++++++++ docs/source/Zh/zh_index.rst | 1 + 7 files changed, 1437 insertions(+) create mode 100644 docs/source/Eng/doc/specialized_modules/specialized_modules_doc.rst create mode 100644 docs/source/Zh/doc/specialized_modules/specialized_modules_doc.rst diff --git a/README.md b/README.md index 4b93ee0..be089e6 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ WebRunner (`je_web_runner`) started as a Selenium wrapper and grew into a full a - [Observability](#observability) - [Test Orchestration](#test-orchestration) - [Quality & Security](#quality--security) +- [Specialized Modules](#specialized-modules) - [Advanced WebDriverWrapper](#advanced-webdriverwrapper) - [Browser Internals](#browser-internals) - [Test Data](#test-data) @@ -845,6 +846,229 @@ Test orchestration: - **Test impact analysis** — `impact_analysis.build_index("./actions")` walks every action JSON file and projects locator names, URLs, template names, and `WR_*` commands into a reverse index; `affected_action_files(index, locators=["primary_cta"])` answers "which tests touch this?" so diff-aware shards can go beyond filename matching. +## Specialized Modules + +A second wave of utility modules, each in its own subpackage under +`je_web_runner/utils/`, organised by capability area. Each module is +fully unit-tested and ships independent of the core executor (import +only what you use). + +### Web Platform APIs + +- **`webtransport_assert`** — HTTP/3 WebTransport datagram + stream + frame recorder with count / payload / JSON-shape / stream-complete + assertions (mirror of `websocket_assert` and `sse_assert`). +- **`indexed_db_explorer`** — Browser-side harvest JS + typed + `IdbSnapshot`; assertions cover store existence, record count, key + presence, index presence, plus per-store diff. +- **`file_system_access`** — JS shim mocking `showOpenFilePicker` / + `showSaveFilePicker` / `showDirectoryPicker`; records every write + performed against the fake handle for later assertion. +- **`notifications_audit`** — Tracks `Notification.requestPermission` + call timing (user-gesture check, min page age) and policy violations + (re-prompt after deny, notification spam after deny, tag reuse). +- **`sse_assert`** — Server-Sent Events stream recorder + chunk-buffer + feed + count / data-contains / JSON-shape / strictly-increasing-id + assertions. +- **`websocket_assert`** — WebSocket frame recorder + count / payload / + pubsub-pattern / JSON-shape assertions. +- **`webrtc_assert`** — `PeerSnapshot.from_dict`, `aggregate_stats` + (getStats), and connected / track-present / SDP-codec / packet-loss / + min-bytes assertions. +- **`view_transitions`** — Instrumentation snippet for the View + Transitions API + duration budget / CLS budget / group-name asserts. + +### Security & Headers + +- **`mixed_content_audit`** — HAR + console-message scan for HTTP + resources on HTTPS pages (active vs passive vs HSTS-upgrade). +- **`clickjacking_audit`** — X-Frame-Options + `frame-ancestors` parser + + iframe-probe page generator; STRICT / SAMEORIGIN / ALLOWED / MISSING + verdict. +- **`open_redirect_detector`** — Eight-payload probe set (`//evil`, + `@userinfo`, `javascript:`, `data:`, mixed-case bypass…) + + classifier (BLOCKED / ALLOWED / AMBIGUOUS). +- **`sri_verify`** — Parse `