From 62f35c4b4502f4ff906ff4576f85b345c10a4556 Mon Sep 17 00:00:00 2001 From: Lloyd Date: Wed, 27 May 2026 14:27:59 +0100 Subject: [PATCH 1/8] fix: update transport key generation to use 16-byte length and add corresponding test --- repeater/data_acquisition/sqlite_handler.py | 11 ++++--- tests/test_sqlite_handler_easy.py | 33 +++++++++++++++++++++ 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/repeater/data_acquisition/sqlite_handler.py b/repeater/data_acquisition/sqlite_handler.py index 4c0a636..6b5d427 100644 --- a/repeater/data_acquisition/sqlite_handler.py +++ b/repeater/data_acquisition/sqlite_handler.py @@ -1641,13 +1641,13 @@ def get_adverts_count_by_contact_type( logger.error(f"Failed to get adverts count for contact_type '{contact_type}': {e}") return 0 - def generate_transport_key(self, name: str, key_length_bytes: int = 32) -> str: + def generate_transport_key(self, name: str, key_length_bytes: int = 16) -> str: """ - Generate a transport key using the proper MeshCore key derivation. + Generate a transport key using MeshCore-compatible key derivation. Args: name: The key name to derive the key from - key_length_bytes: Length of the key in bytes (default: 32 bytes = 256 bits) + key_length_bytes: Fallback random key length in bytes (default: 16) Returns: A base64-encoded transport key derived from the name @@ -1655,7 +1655,6 @@ def generate_transport_key(self, name: str, key_length_bytes: int = 32) -> str: try: from pymc_core.protocol.transport_keys import get_auto_key_for - # Use the proper MeshCore key derivation function key_bytes = get_auto_key_for(name) # Encode to base64 for safe storage and transmission @@ -1668,9 +1667,9 @@ def generate_transport_key(self, name: str, key_length_bytes: int = 32) -> str: except Exception as e: logger.error(f"Failed to generate transport key using get_auto_key_for: {e}") - # Fallback to secure random if MeshCore function fails + # Fallback to a transport-compatible random 16-byte key if derivation fails. try: - random_bytes = secrets.token_bytes(key_length_bytes) + random_bytes = secrets.token_bytes(16) key = base64.b64encode(random_bytes).decode("utf-8") logger.warning(f"Using fallback random key generation for '{name}'") return key diff --git a/tests/test_sqlite_handler_easy.py b/tests/test_sqlite_handler_easy.py index 18bd699..9d2510c 100644 --- a/tests/test_sqlite_handler_easy.py +++ b/tests/test_sqlite_handler_easy.py @@ -1,4 +1,7 @@ +import base64 from pathlib import Path +import sys +import types import pytest @@ -61,6 +64,36 @@ def test_transport_key_crud_cycle(tmp_path): assert h.delete_transport_key(key_id) is False +def test_generate_transport_key_uses_implicit_hashtag_region(tmp_path, monkeypatch): + h = _make_handler(tmp_path) + + captured = {} + fake_transport_keys = types.ModuleType("pymc_core.protocol.transport_keys") + + def _fake_get_auto_key_for(name: str) -> bytes: + captured["name"] = name + return b"0123456789abcdef" + + fake_transport_keys.get_auto_key_for = _fake_get_auto_key_for + + fake_protocol = types.ModuleType("pymc_core.protocol") + fake_protocol.transport_keys = fake_transport_keys + + fake_core = types.ModuleType("pymc_core") + fake_core.protocol = fake_protocol + + monkeypatch.setitem(sys.modules, "pymc_core", fake_core) + monkeypatch.setitem(sys.modules, "pymc_core.protocol", fake_protocol) + monkeypatch.setitem(sys.modules, "pymc_core.protocol.transport_keys", fake_transport_keys) + + generated = h.generate_transport_key("eu") + generated_bytes = base64.b64decode(generated) + + assert captured["name"] == "eu" + assert generated_bytes == b"0123456789abcdef" + assert len(generated_bytes) == 16 + + def test_room_messages_and_sync_flow(tmp_path): h = _make_handler(tmp_path) From 456e97a896ae470ea3d9fc656bbda8bb046bde54 Mon Sep 17 00:00:00 2001 From: Lloyd Date: Wed, 27 May 2026 14:28:11 +0100 Subject: [PATCH 2/8] refactor: update pre-commit configuration and dependencies for improved Python 3.9 support --- .pre-commit-config.yaml | 51 ++++++++++++++++++----------------------- README.md | 7 +++--- debian/control | 2 +- pyproject.toml | 15 +++++------- 4 files changed, 32 insertions(+), 43 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2f2968c..35cb451 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,38 +3,31 @@ # Setup: pre-commit install # Run manually: pre-commit run --all-files +default_language_version: + python: python3 + repos: - # Generic file hygiene checks + # Python-focused safety checks - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v5.0.0 hooks: - - id: trailing-whitespace - - id: end-of-file-fixer - id: check-yaml - - id: check-added-large-files - - # Python formatting (Black) - apply to all Python files - - repo: https://github.com/psf/black - rev: 24.4.2 - hooks: - - id: black - language_version: python3 - args: ["--line-length=100"] - - # Python import sorting (isort) - apply to all Python files - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - args: ["--profile", "black", "--line-length=100"] + - id: check-ast + files: ^.*\.py$ + - id: debug-statements + files: ^.*\.py$ + - id: check-docstring-first + files: ^.*\.py$ + - id: check-builtin-literals + files: ^.*\.py$ - # Python linting (flake8) - strict settings for code quality - - repo: https://github.com/pycqa/flake8 - rev: 6.0.0 + # Modern Python linting + import sorting + formatting + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.13 hooks: - - id: flake8 - # Strict but reasonable settings - args: [ - "--max-line-length=100", - "--extend-ignore=E203,W503" - ] + - id: ruff-check + args: ["--fix"] + files: ^.*\.py$ + - id: ruff-format + args: ["--check"] + files: ^.*\.py$ diff --git a/README.md b/README.md index b048689..d77e60b 100644 --- a/README.md +++ b/README.md @@ -399,7 +399,7 @@ I welcome contributions! To contribute to pyMC_repeater: ### Development Setup ```bash -# Install in development mode with dev tools (black, pytest, isort, mypy, etc) +# Install in development mode with dev tools (ruff, pytest, mypy, etc) pip install -e ".[dev]" # Setup pre-commit hooks for code quality @@ -413,9 +413,8 @@ pre-commit run --all-files **Note:** Hardware support (LoRa radio drivers) is included in the base installation automatically via `pymc_core[hardware]`. Pre-commit hooks will automatically: -- Format code with Black -- Sort imports with isort -- Lint with flake8 +- Lint and auto-fix Python issues with Ruff +- Validate formatting with Ruff formatter - Fix trailing whitespace and other file issues ## Support diff --git a/debian/control b/debian/control index acaafd9..ce2a53a 100644 --- a/debian/control +++ b/debian/control @@ -16,7 +16,7 @@ Build-Depends: debhelper-compat (= 13), git Standards-Version: 4.6.2 Homepage: https://github.com/rightup/pyMC_Repeater -X-Python3-Version: >= 3.8 +X-Python3-Version: >= 3.9 Package: pymc-repeater Architecture: all diff --git a/pyproject.toml b/pyproject.toml index 137a957..9a503a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,14 +11,13 @@ authors = [ description = "PyMC Repeater Daemon" readme = "README.md" license = {text = "MIT"} -requires-python = ">=3.8" +requires-python = ">=3.9" classifiers = [ "Development Status :: 4 - Beta", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: POSIX :: Linux", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", @@ -55,8 +54,7 @@ rrd = [ dev = [ "pytest>=7.4.0", "pytest-asyncio>=0.21.0", - "black>=23.0.0", - "isort>=5.12.0", + "ruff>=0.11.13", "mypy>=1.7.0", ] @@ -78,13 +76,12 @@ repeater = [ "presets/*.yaml", ] -[tool.black] +[tool.ruff] line-length = 100 -target-version = ['py38', 'py39', 'py310', 'py311', 'py312'] +target-version = "py39" -[tool.isort] -profile = "black" -line_length = 100 +[tool.ruff.lint] +extend-ignore = ["E701"] [tool.setuptools_scm] version_scheme = "guess-next-dev" From faa3296a5082277a59582641953624f633c23e16 Mon Sep 17 00:00:00 2001 From: Lloyd Date: Wed, 27 May 2026 14:56:01 +0100 Subject: [PATCH 3/8] refactor: remove unused imports from test files for cleaner code --- tests/test_engine.py | 4 ---- tests/test_flood_loop_dedup.py | 2 -- tests/test_handler_helpers_mesh_cli.py | 1 - tests/test_handler_helpers_room_server.py | 1 - tests/test_http_server_unit.py | 1 - tests/test_identity_manager_and_repeater_cli.py | 2 -- tests/test_keygen_local_cli.py | 3 +-- tests/test_main_py_coverage.py | 1 - tests/test_mqtt_publish_integration.py | 1 - tests/test_path_hash_mode_advert.py | 1 - tests/test_path_hash_protocol.py | 2 -- tests/test_update_endpoints_unit.py | 1 - 12 files changed, 1 insertion(+), 19 deletions(-) diff --git a/tests/test_engine.py b/tests/test_engine.py index d5aa5d3..c5dedea 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -7,11 +7,7 @@ """ import asyncio import base64 -import copy -import math import time -from collections import OrderedDict -from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest diff --git a/tests/test_flood_loop_dedup.py b/tests/test_flood_loop_dedup.py index 76c2b2a..dfc6b0d 100644 --- a/tests/test_flood_loop_dedup.py +++ b/tests/test_flood_loop_dedup.py @@ -13,11 +13,9 @@ """ from unittest.mock import MagicMock, patch -import pytest from pymc_core.protocol import Packet, PathUtils from pymc_core.protocol.constants import ( - MAX_PATH_SIZE, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_FLOOD, ) diff --git a/tests/test_handler_helpers_mesh_cli.py b/tests/test_handler_helpers_mesh_cli.py index 5fcd27c..2e82d9d 100644 --- a/tests/test_handler_helpers_mesh_cli.py +++ b/tests/test_handler_helpers_mesh_cli.py @@ -1,7 +1,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch -import pytest from repeater.handler_helpers.mesh_cli import MeshCLI diff --git a/tests/test_handler_helpers_room_server.py b/tests/test_handler_helpers_room_server.py index 36ff146..1332f0c 100644 --- a/tests/test_handler_helpers_room_server.py +++ b/tests/test_handler_helpers_room_server.py @@ -1,4 +1,3 @@ -import asyncio import time from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch diff --git a/tests/test_http_server_unit.py b/tests/test_http_server_unit.py index 757e54a..6f4d788 100644 --- a/tests/test_http_server_unit.py +++ b/tests/test_http_server_unit.py @@ -2,7 +2,6 @@ import logging from pathlib import Path from types import SimpleNamespace -from unittest.mock import MagicMock import cherrypy import pytest diff --git a/tests/test_identity_manager_and_repeater_cli.py b/tests/test_identity_manager_and_repeater_cli.py index ce149d3..20df1b5 100644 --- a/tests/test_identity_manager_and_repeater_cli.py +++ b/tests/test_identity_manager_and_repeater_cli.py @@ -1,7 +1,5 @@ -from types import SimpleNamespace from unittest.mock import MagicMock, patch -import pytest from repeater.handler_helpers.repeater_cli import MeshCLI, RepeaterCLI from repeater.identity_manager import IdentityManager diff --git a/tests/test_keygen_local_cli.py b/tests/test_keygen_local_cli.py index 7ecc1f6..639c74e 100644 --- a/tests/test_keygen_local_cli.py +++ b/tests/test_keygen_local_cli.py @@ -1,7 +1,6 @@ import hashlib import json -from pathlib import Path -from unittest.mock import MagicMock, patch +from unittest.mock import patch import pytest diff --git a/tests/test_main_py_coverage.py b/tests/test_main_py_coverage.py index 199caeb..e7e832a 100644 --- a/tests/test_main_py_coverage.py +++ b/tests/test_main_py_coverage.py @@ -1,4 +1,3 @@ -import asyncio from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch diff --git a/tests/test_mqtt_publish_integration.py b/tests/test_mqtt_publish_integration.py index d22ccd8..9099922 100644 --- a/tests/test_mqtt_publish_integration.py +++ b/tests/test_mqtt_publish_integration.py @@ -11,7 +11,6 @@ """ import json -import math from unittest.mock import MagicMock from repeater.airtime import AirtimeManager diff --git a/tests/test_path_hash_mode_advert.py b/tests/test_path_hash_mode_advert.py index b8aeb17..fc724a5 100644 --- a/tests/test_path_hash_mode_advert.py +++ b/tests/test_path_hash_mode_advert.py @@ -5,7 +5,6 @@ dispatcher.send_packet() must have path_len encoding set so get_path_hash_size() returns 2 or 3. The dispatcher applies this in send_packet() before transmit. """ -from unittest.mock import AsyncMock, MagicMock import pytest diff --git a/tests/test_path_hash_protocol.py b/tests/test_path_hash_protocol.py index 2e7ead3..2f93778 100644 --- a/tests/test_path_hash_protocol.py +++ b/tests/test_path_hash_protocol.py @@ -12,7 +12,6 @@ - Max-hop boundary enforcement per hash size """ import struct -from collections import OrderedDict from unittest.mock import MagicMock, patch import pytest @@ -23,7 +22,6 @@ PATH_HASH_COUNT_MASK, PATH_HASH_SIZE_SHIFT, PAYLOAD_TYPE_TRACE, - PH_TYPE_SHIFT, ROUTE_TYPE_DIRECT, ROUTE_TYPE_FLOOD, ) diff --git a/tests/test_update_endpoints_unit.py b/tests/test_update_endpoints_unit.py index be222b9..168e8e8 100644 --- a/tests/test_update_endpoints_unit.py +++ b/tests/test_update_endpoints_unit.py @@ -1,6 +1,5 @@ from datetime import datetime, timedelta, timezone from types import SimpleNamespace -from unittest.mock import patch import cherrypy import pytest From 45a44eb47bbdbf98c4d2300fa9c396c89ca607b3 Mon Sep 17 00:00:00 2001 From: Lloyd Date: Wed, 27 May 2026 20:15:10 +0100 Subject: [PATCH 4/8] Refactor test cases and base code for consistency and readability - Updated byte representations in tests to use lowercase hex format for consistency. - Reformatted code for better readability, including line breaks and indentation adjustments. - Consolidated multiple lines into single lines where appropriate to enhance clarity. - Ensured that all test cases maintain consistent formatting and style across the test suite. --- .pre-commit-config.yaml | 23 +- pyproject.toml | 2 +- repeater/airtime.py | 4 +- repeater/companion/bridge.py | 4 +- repeater/companion/frame_server.py | 2 +- repeater/config.py | 52 +- repeater/config_manager.py | 148 +++-- repeater/data_acquisition/__init__.py | 1 + repeater/data_acquisition/glass_handler.py | 83 ++- repeater/data_acquisition/gps_service.py | 27 +- repeater/data_acquisition/hardware_stats.py | 18 +- repeater/data_acquisition/mqtt_handler.py | 211 ++++--- repeater/data_acquisition/rrdtool_handler.py | 4 +- repeater/data_acquisition/sqlite_handler.py | 183 ++++-- .../data_acquisition/storage_collector.py | 18 +- .../data_acquisition/websocket_handler.py | 31 +- repeater/engine.py | 100 +-- repeater/handler_helpers/acl.py | 23 +- repeater/handler_helpers/advert.py | 139 +++-- repeater/handler_helpers/login.py | 44 +- repeater/handler_helpers/mesh_cli.py | 4 +- repeater/handler_helpers/path.py | 2 +- repeater/handler_helpers/protocol_request.py | 58 +- repeater/handler_helpers/repeater_cli.py | 132 ++-- repeater/handler_helpers/room_server.py | 13 +- repeater/handler_helpers/text.py | 26 +- repeater/handler_helpers/trace.py | 7 +- repeater/identity_manager.py | 1 - repeater/keygen.py | 6 +- repeater/local_cli.py | 40 +- repeater/main.py | 33 +- repeater/packet_router.py | 21 +- repeater/sensors/base.py | 28 +- repeater/sensors/ens210.py | 10 +- repeater/sensors/lafvin_ups_3s.py | 69 ++- repeater/sensors/manager.py | 22 +- repeater/sensors/registry.py | 4 +- repeater/sensors/shtc3.py | 11 +- repeater/sensors/waveshare_ups_d.py | 55 +- repeater/sensors/waveshare_ups_e.py | 67 +- repeater/service_utils.py | 36 +- repeater/web/api_endpoints.py | 462 ++++++++------ repeater/web/auth/api_tokens.py | 12 +- repeater/web/auth/cherrypy_tool.py | 25 +- repeater/web/auth/jwt_handler.py | 2 +- repeater/web/auth_endpoints.py | 533 ++++++++-------- repeater/web/cad_calibration_engine.py | 11 +- repeater/web/companion_endpoints.py | 8 +- repeater/web/companion_ws_proxy.py | 13 +- repeater/web/http_server.py | 21 +- repeater/web/update_endpoints.py | 290 +++++---- tests/test_airtime.py | 2 +- tests/test_api_endpoints_core_coverage.py | 451 ++++++++------ tests/test_auth_components.py | 4 +- tests/test_auth_endpoints.py | 42 +- tests/test_companion_bridge_frame_utils.py | 17 +- tests/test_companion_ws_proxy.py | 24 +- tests/test_config_manager.py | 2 +- tests/test_engine.py | 583 +++++++++++------- tests/test_flood_loop_dedup.py | 144 ++--- tests/test_glass_handler.py | 4 +- tests/test_gps_service.py | 13 +- tests/test_handler_helpers_acl_advert.py | 6 +- tests/test_handler_helpers_mesh_cli.py | 14 +- ...test_handler_helpers_path_protocol_text.py | 65 +- tests/test_handler_helpers_room_server.py | 38 +- ...t_handler_helpers_trace_discovery_login.py | 45 +- tests/test_http_server_unit.py | 24 +- .../test_identity_manager_and_repeater_cli.py | 6 +- tests/test_keygen_local_cli.py | 8 +- tests/test_main_py_coverage.py | 18 +- tests/test_main_py_more.py | 6 +- tests/test_mqtt_publish_integration.py | 8 +- tests/test_packet_duration.py | 2 +- tests/test_packet_router.py | 44 +- tests/test_path_hash_protocol.py | 145 ++--- tests/test_radio_config.py | 2 +- tests/test_sensors.py | 19 +- tests/test_service_utils.py | 8 +- tests/test_sqlite_handler_easy.py | 4 +- ...est_storage_collector_ws_stats_throttle.py | 4 +- tests/test_tx_lock.py | 33 +- tests/test_update_endpoints_unit.py | 14 +- 83 files changed, 2795 insertions(+), 2143 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 35cb451..4a1bd84 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: # Modern Python linting + import sorting + formatting - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.13 + rev: v0.15.14 hooks: - id: ruff-check args: ["--fix"] @@ -31,3 +31,24 @@ repos: - id: ruff-format args: ["--check"] files: ^.*\.py$ + + # Security-focused static analysis + - repo: https://github.com/PyCQA/bandit + rev: 1.7.9 + hooks: + - id: bandit + # B104: intentional LAN listeners, B105: setup-required placeholder credentials. + args: ["-q", "-l", "-i", "-s", "B104,B105"] + files: ^.*\.py$ + exclude: ^tests/ + + # Test suite gate + - repo: local + hooks: + - id: pytest + name: pytest + entry: python -m pytest -q + language: system + pass_filenames: false + always_run: true + types: [python] diff --git a/pyproject.toml b/pyproject.toml index 9a503a8..13b4ef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ rrd = [ dev = [ "pytest>=7.4.0", "pytest-asyncio>=0.21.0", - "ruff>=0.11.13", + "ruff>=0.15.14", "mypy>=1.7.0", ] diff --git a/repeater/airtime.py b/repeater/airtime.py index f67e80d..69ee3f3 100644 --- a/repeater/airtime.py +++ b/repeater/airtime.py @@ -54,7 +54,7 @@ def calculate_airtime( Airtime in milliseconds """ sf = spreading_factor or self.spreading_factor - bw_hz = (bandwidth_hz or self.bandwidth) + bw_hz = bandwidth_hz or self.bandwidth cr = coding_rate or self.coding_rate preamble_len = preamble_len or self.preamble_length crc = 1 if crc_enabled else 0 @@ -64,7 +64,7 @@ def calculate_airtime( de = 1 if (sf >= 11 and bw_hz <= 125000) else 0 # Symbol time in milliseconds: T_sym = 2^SF / BW_kHz - t_sym = (2 ** sf) / (bw_hz / 1000) + t_sym = (2**sf) / (bw_hz / 1000) # Preamble time: T_preamble = (n_preamble + 4.25) * T_sym t_preamble = (preamble_len + 4.25) * t_sym diff --git a/repeater/companion/bridge.py b/repeater/companion/bridge.py index dc0787e..ac09b6f 100644 --- a/repeater/companion/bridge.py +++ b/repeater/companion/bridge.py @@ -80,9 +80,7 @@ def _save_prefs(self) -> None: try: prefs_dict = dataclasses.asdict(self.prefs) prefs_safe = _to_json_safe(prefs_dict) - self._sqlite_handler.companion_save_prefs( - str(self._companion_hash), prefs_safe - ) + self._sqlite_handler.companion_save_prefs(str(self._companion_hash), prefs_safe) if self._on_prefs_saved: try: self._on_prefs_saved(self.prefs.node_name) diff --git a/repeater/companion/frame_server.py b/repeater/companion/frame_server.py index 90aa22b..530d520 100644 --- a/repeater/companion/frame_server.py +++ b/repeater/companion/frame_server.py @@ -33,7 +33,7 @@ def __init__( companion_hash: str, port: int = 5000, bind_address: str = "0.0.0.0", - client_idle_timeout_sec: Optional[int] = 8 * 60 * 60, # 8 hours + client_idle_timeout_sec: Optional[int] = 8 * 60 * 60, # 8 hours sqlite_handler=None, local_hash: Optional[int] = None, stats_getter=None, diff --git a/repeater/config.py b/repeater/config.py index 8c7d975..1fc2f74 100644 --- a/repeater/config.py +++ b/repeater/config.py @@ -51,9 +51,7 @@ def resolve_storage_dir( ) -> Path: storage_dir_cfg = ( - config.get("storage", {}).get("storage_dir") - or config.get("storage_dir") - or default + config.get("storage", {}).get("storage_dir") or config.get("storage_dir") or default ) storage_dir = Path(str(storage_dir_cfg)).expanduser() @@ -70,10 +68,10 @@ def resolve_storage_dir( def get_node_info(config: Dict[str, Any]) -> Dict[str, Any]: """ Extract node name, radio configuration, and MQTT settings from config. - + Args: config: Configuration dictionary - + Returns: Dictionary with node_name, radio_config, and MQTT configuration """ @@ -87,10 +85,10 @@ def get_node_info(config: Dict[str, Any]) -> Dict[str, Any]: radio_freq_mhz = radio_freq / 1_000_000 radio_bw_khz = radio_bw / 1_000 radio_config_str = f"{radio_freq_mhz},{radio_bw_khz},{radio_sf},{radio_cr}" - + # Handle getting the config from mqtt brokers, falling back to letsmesh if it doesn't exist mqtt_config = config.get("mqtt_brokers", config.get("letsmesh", {})) - + return { "node_name": node_name, "radio_config": radio_config_str, @@ -143,7 +141,7 @@ def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: "inform_interval_seconds": 30, "request_timeout_seconds": 10, "verify_tls": True, - "api_token": "", + "api_token": None, "cert_store_dir": "/etc/pymc_repeater/glass", } @@ -184,14 +182,14 @@ def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: if "security" not in config["repeater"]: logger.warning( "No 'security' section found under 'repeater' in config. " - "Adding defaults — please review and update passwords." + "Adding secure placeholders — complete setup wizard before login." ) config["repeater"]["security"] = { "max_clients": 1, - "admin_password": "admin123", - "guest_password": "guest123", + "admin_password": None, + "guest_password": None, "allow_read_only": False, - "jwt_secret": "", + "jwt_secret": None, "jwt_expiry_minutes": 60, } @@ -215,17 +213,17 @@ def load_config(config_path: Optional[str] = None) -> Dict[str, Any]: def save_config(config_data: Dict[str, Any], config_path: Optional[str] = None) -> bool: """ Save configuration to YAML file. - + Args: config_data: Configuration dictionary to save config_path: Path to config file (uses default if None) - + Returns: True if successful, False otherwise """ if config_path is None: config_path = os.getenv("PYMC_REPEATER_CONFIG", "/etc/pymc_repeater/config.yaml") - + try: # Create backup of existing config config_file = Path(config_path) @@ -247,7 +245,7 @@ def save_config(config_data: Dict[str, Any], config_path: Optional[str] = None) logger.info(f"Saved configuration to {config_path}") return True - + except Exception as e: logger.error(f"Failed to save configuration: {e}") return False @@ -256,29 +254,29 @@ def save_config(config_data: Dict[str, Any], config_path: Optional[str] = None) def update_unscoped_flood_policy(allow: bool, config_path: Optional[str] = None) -> bool: """ Update the unscoped flood policy in the configuration. - + Args: allow: True to allow unscoped flooding, False to deny config_path: Path to config file (uses default if None) - + Returns: True if successful, False otherwise """ try: # Load current config config = load_config(config_path) - + # Ensure mesh section exists if "mesh" not in config: config["mesh"] = {} - + # Set global flood policy config["mesh"]["global_flood_allow"] = allow config["mesh"]["unscoped_flood_allow"] = allow - + # Save updated config return save_config(config, config_path) - + except Exception as e: logger.error(f"Failed to update unscoped flood policy: {e}") return False @@ -345,7 +343,7 @@ def _parse_int(value, *, default=None): if isinstance(value, int): return value if isinstance(value, str): - return int(value.strip().rstrip(','), 0) + return int(value.strip().rstrip(","), 0) raise ValueError(f"Invalid int value type: {type(value)}") def _parse_int_list(value): @@ -517,9 +515,7 @@ def _parse_int_list(value): host = tcp_cfg.get("host") if not host: - raise ValueError( - "Missing 'host' in 'pymc_tcp' section (modem hostname or LAN IP)" - ) + raise ValueError("Missing 'host' in 'pymc_tcp' section (modem hostname or LAN IP)") radio_cfg = board_config.get("radio") or {} radio = TCPLoRaRadio( @@ -563,9 +559,7 @@ def _parse_int_list(value): port = usb_cfg.get("port") if not port: - raise ValueError( - "Missing 'port' in 'pymc_usb' section (e.g. /dev/ttyACM0)" - ) + raise ValueError("Missing 'port' in 'pymc_usb' section (e.g. /dev/ttyACM0)") radio_cfg = board_config.get("radio") or {} radio = USBLoRaRadio( diff --git a/repeater/config_manager.py b/repeater/config_manager.py index 56cb05e..1a4a3a4 100644 --- a/repeater/config_manager.py +++ b/repeater/config_manager.py @@ -8,11 +8,11 @@ class ConfigManager: """Manages configuration persistence and live updates to the daemon.""" - + def __init__(self, config_path: str, config: dict, daemon_instance=None): """ Initialize ConfigManager. - + Args: config_path: Path to the YAML config file config: Reference to the config dictionary @@ -41,11 +41,7 @@ def _sync_repeater_handler_radio_config(self, radio_cfg: Dict[str, Any]) -> None repeater_handler.radio_config = {} repeater_handler.radio_config.update( - { - key: value - for key, value in radio_cfg.items() - if value not in (None, 0) - } + {key: value for key, value in radio_cfg.items() if value not in (None, 0)} ) def _kiss_transport_restart_required(self) -> bool: @@ -64,7 +60,11 @@ def _kiss_transport_restart_required(self) -> bool: logger.info("KISS port change detected; service restart required") return True - if configured_baudrate and runtime_baudrate and int(configured_baudrate) != int(runtime_baudrate): + if ( + configured_baudrate + and runtime_baudrate + and int(configured_baudrate) != int(runtime_baudrate) + ): logger.info("KISS baud rate change detected; service restart required") return True @@ -144,89 +144,89 @@ def _apply_live_radio_config(self) -> bool: except Exception as e: logger.error(f"Failed to apply live radio config: {e}", exc_info=True) return False - + def save_to_file(self) -> bool: """ Save current config to YAML file. - + Returns: True if successful, False otherwise """ try: os.makedirs(os.path.dirname(self.config_path), exist_ok=True) - with open(self.config_path, 'w') as f: + with open(self.config_path, "w") as f: # Use safe_dump with explicit width to prevent line wrapping # Setting width to a very large number prevents truncation of long strings like identity keys yaml.safe_dump( - self.config, - f, - default_flow_style=False, - indent=2, + self.config, + f, + default_flow_style=False, + indent=2, width=1000000, # Very large width to prevent any line wrapping sort_keys=False, - allow_unicode=True + allow_unicode=True, ) logger.info(f"Configuration saved to {self.config_path}") return True except Exception as e: logger.error(f"Failed to save config to {self.config_path}: {e}", exc_info=True) return False - + def live_update_daemon(self, sections: Optional[List[str]] = None) -> bool: """ Apply configuration changes to the running daemon's in-memory config. - + Args: sections: List of config sections to update (e.g., ['repeater', 'delays']). If None, updates all common sections. - + Returns: True if live update was successful, False otherwise """ - if not self.daemon or not hasattr(self.daemon, 'config'): + if not self.daemon or not hasattr(self.daemon, "config"): logger.warning("Daemon not available for live update") return False - + try: daemon_config = self.daemon.config live_update_ok = True - + # Default sections to update if not specified if sections is None: - sections = ['repeater', 'delays', 'radio', 'acl', 'identities', 'glass'] - + sections = ["repeater", "delays", "radio", "acl", "identities", "glass"] + # Update each section for section in sections: if section in self.config: if section not in daemon_config: daemon_config[section] = {} - + # Deep copy the section to avoid reference issues if isinstance(self.config[section], dict): daemon_config[section].update(self.config[section]) else: daemon_config[section] = self.config[section] - + logger.debug(f"Live updated daemon config section: {section}") - + logger.info(f"Live updated daemon config sections: {', '.join(sections)}") - + # Also reload runtime config in RepeaterHandler if delays or repeater sections changed - if self.daemon and hasattr(self.daemon, 'repeater_handler'): - if any(s in ['delays', 'repeater'] for s in sections): - if hasattr(self.daemon.repeater_handler, 'reload_runtime_config'): + if self.daemon and hasattr(self.daemon, "repeater_handler"): + if any(s in ["delays", "repeater"] for s in sections): + if hasattr(self.daemon.repeater_handler, "reload_runtime_config"): self.daemon.repeater_handler.reload_runtime_config() logger.info("Reloaded RepeaterHandler runtime config") - + # Also reload advert_helper config if repeater section changed - if self.daemon and hasattr(self.daemon, 'advert_helper') and self.daemon.advert_helper: - if 'repeater' in sections: - if hasattr(self.daemon.advert_helper, 'reload_config'): + if self.daemon and hasattr(self.daemon, "advert_helper") and self.daemon.advert_helper: + if "repeater" in sections: + if hasattr(self.daemon.advert_helper, "reload_config"): self.daemon.advert_helper.reload_config() logger.info("Reloaded AdvertHelper config") # Re-apply dispatcher path hash mode when mesh section changed - if 'mesh' in sections and self.daemon and hasattr(self.daemon, 'dispatcher'): + if "mesh" in sections and self.daemon and hasattr(self.daemon, "dispatcher"): mesh_cfg = self.daemon.config.get("mesh", {}) path_hash_mode = mesh_cfg.get("path_hash_mode", 0) if path_hash_mode not in (0, 1, 2): @@ -237,37 +237,39 @@ def live_update_daemon(self, sections: Optional[List[str]] = None) -> bool: self.daemon.dispatcher.set_default_path_hash_mode(path_hash_mode) logger.info(f"Reloaded path hash mode: mesh.path_hash_mode={path_hash_mode}") - if 'radio_type' in sections: + if "radio_type" in sections: logger.info("radio_type change detected; service restart required") live_update_ok = False - if 'kiss' in sections and self._kiss_transport_restart_required(): + if "kiss" in sections and self._kiss_transport_restart_required(): live_update_ok = False - if 'radio' in sections: + if "radio" in sections: live_update_ok = self._apply_live_radio_config() and live_update_ok - + return live_update_ok - + except Exception as e: logger.error(f"Failed to live update daemon config: {e}", exc_info=True) return False - - def update_and_save(self, - updates: Dict[str, Any], - live_update: bool = True, - live_update_sections: Optional[List[str]] = None) -> Dict[str, Any]: + + def update_and_save( + self, + updates: Dict[str, Any], + live_update: bool = True, + live_update_sections: Optional[List[str]] = None, + ) -> Dict[str, Any]: """ Apply updates to config, save to file, and optionally live update daemon. - + This is the main method that should be used by both mesh_cli and api_endpoints. - + Args: updates: Dictionary of config updates in nested format. Example: {"repeater": {"node_name": "NewName"}, "delays": {"tx_delay_factor": 1.5}} live_update: Whether to apply changes to running daemon immediately live_update_sections: Specific sections to live update. If None, auto-detects from updates. - + Returns: Dict with keys: - success: bool - Whether operation succeeded @@ -275,62 +277,58 @@ def update_and_save(self, - live_updated: bool - Whether daemon was live updated - error: str (optional) - Error message if failed """ - result: Dict[str, Any] = { - "success": False, - "saved": False, - "live_updated": False - } - + result: Dict[str, Any] = {"success": False, "saved": False, "live_updated": False} + try: # Apply updates to config for section, values in updates.items(): if section not in self.config: self.config[section] = {} - + if isinstance(values, dict): self.config[section].update(values) else: self.config[section] = values - + # Save to file result["saved"] = self.save_to_file() - + if not result["saved"]: result["error"] = "Failed to save config to file" return result - + # Live update daemon if requested if live_update: # Auto-detect sections if not specified if live_update_sections is None: live_update_sections = list(updates.keys()) - + result["live_updated"] = self.live_update_daemon(live_update_sections) - + result["success"] = result["saved"] return result - + except Exception as e: logger.error(f"Error in update_and_save: {e}", exc_info=True) result["error"] = str(e) return result - + def update_nested(self, path: str, value: Any, live_update: bool = True) -> Dict[str, Any]: """ Update a nested config value using dot notation. - + Convenience method for simple updates like "repeater.node_name" = "NewName" - + Args: path: Dot-separated path to config value (e.g., "repeater.node_name") value: Value to set live_update: Whether to apply changes to running daemon - + Returns: Result dict from update_and_save """ - parts = path.split('.') - + parts = path.split(".") + if len(parts) == 1: # Top-level key updates = {parts[0]: value} @@ -349,26 +347,26 @@ def update_nested(self, path: str, value: Any, live_update: bool = True) -> Dict current[part] = {} current = current[part] current[parts[-1]] = value - + # Determine which section to live update section = parts[0] - + return self.update_and_save( updates=updates, live_update=live_update, - live_update_sections=[section] if live_update else None + live_update_sections=[section] if live_update else None, ) - + def get_status(self) -> Dict[str, Any]: """ Get status information about the ConfigManager. - + Returns: Dict with config file path, existence, daemon availability """ return { "config_path": self.config_path, "config_exists": os.path.exists(self.config_path), - "daemon_available": self.daemon is not None and hasattr(self.daemon, 'config'), - "config_sections": list(self.config.keys()) if self.config else [] + "daemon_available": self.daemon is not None and hasattr(self.daemon, "config"), + "config_sections": list(self.config.keys()) if self.config else [], } diff --git a/repeater/data_acquisition/__init__.py b/repeater/data_acquisition/__init__.py index 0c1f238..c7ad188 100644 --- a/repeater/data_acquisition/__init__.py +++ b/repeater/data_acquisition/__init__.py @@ -3,4 +3,5 @@ from .rrdtool_handler import RRDToolHandler from .sqlite_handler import SQLiteHandler from .storage_collector import StorageCollector + __all__ = ["SQLiteHandler", "RRDToolHandler", "StorageCollector", "GlassHandler", "GPSService"] diff --git a/repeater/data_acquisition/glass_handler.py b/repeater/data_acquisition/glass_handler.py index 61b6269..4a4cffa 100644 --- a/repeater/data_acquisition/glass_handler.py +++ b/repeater/data_acquisition/glass_handler.py @@ -8,10 +8,11 @@ from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, List, Optional, Tuple -from urllib.parse import urlparse from urllib import error, request +from urllib.parse import urlparse import psutil + try: import paho.mqtt.client as mqtt except ImportError: @@ -131,7 +132,8 @@ def _reload_runtime_settings(self) -> None: int(glass_cfg.get("inform_interval_seconds", self.inform_interval_seconds)) ) self.cert_store_dir = str( - glass_cfg.get("cert_store_dir", "/etc/pymc_repeater/glass") or "/etc/pymc_repeater/glass" + glass_cfg.get("cert_store_dir", "/etc/pymc_repeater/glass") + or "/etc/pymc_repeater/glass" ) self.client_cert_path = ( str(glass_cfg.get("client_cert_path")).strip() @@ -144,9 +146,7 @@ def _reload_runtime_settings(self) -> None: else None ) self.ca_cert_path = ( - str(glass_cfg.get("ca_cert_path")).strip() - if glass_cfg.get("ca_cert_path") - else None + str(glass_cfg.get("ca_cert_path")).strip() if glass_cfg.get("ca_cert_path") else None ) managed_cfg = self._load_managed_settings() parsed_base_url = urlparse(self.base_url) @@ -164,7 +164,9 @@ def _reload_runtime_settings(self) -> None: self.mqtt_tls_enabled = bool(managed_cfg.get("mqtt_tls_enabled", False)) username = managed_cfg.get("mqtt_username") password = managed_cfg.get("mqtt_password") - self.mqtt_username = str(username).strip() if isinstance(username, str) and username else None + self.mqtt_username = ( + str(username).strip() if isinstance(username, str) and username else None + ) self.mqtt_password = str(password) if isinstance(password, str) and password else None def _managed_settings_path(self) -> Path: @@ -406,7 +408,9 @@ def _extract_location_from_settings(self, settings: Dict[str, Any]) -> Optional[ def _collect_system_stats(self) -> Dict[str, Any]: temperature_c = None try: - temperatures = psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {} + temperatures = ( + psutil.sensors_temperatures() if hasattr(psutil, "sensors_temperatures") else {} + ) if temperatures: for values in temperatures.values(): if values: @@ -436,6 +440,7 @@ async def _post_inform(self, payload: Dict[str, Any]) -> Dict[str, Any]: def _post_inform_sync(self, payload: Dict[str, Any]) -> Dict[str, Any]: url = f"{self.base_url}/inform" + self._validate_http_url(url) headers = {"Content-Type": "application/json"} if self.api_token: headers["Authorization"] = f"Bearer {self.api_token}" @@ -449,7 +454,7 @@ def _post_inform_sync(self, payload: Dict[str, Any]) -> Dict[str, Any]: req, timeout=self.request_timeout_seconds, context=ssl_context, - ) as response: + ) as response: # nosec B310 response_bytes = response.read() except error.HTTPError as exc: details = "" @@ -484,7 +489,9 @@ def _build_ssl_context(self, url: str) -> Optional[ssl.SSLContext]: else: context = ssl.create_default_context() else: - context = ssl._create_unverified_context() + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE if self.client_cert_path or self.client_key_path: cert_path = self._require_ssl_file(self.client_cert_path, "client_cert_path") @@ -502,6 +509,14 @@ def _require_ssl_file(path_value: Optional[str], field_name: str) -> str: raise RuntimeError(f"Configured {field_name} does not exist: {normalized}") return normalized + @staticmethod + def _validate_http_url(url: str) -> None: + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise RuntimeError(f"Unsupported Glass base_url scheme: {parsed.scheme or ''}") + if not parsed.netloc: + raise RuntimeError("Glass base_url must include a host") + async def _handle_command_response(self, response: Dict[str, Any]) -> None: command_id = str(response.get("command_id", "")).strip() action = str(response.get("action", "")).strip() @@ -583,16 +598,22 @@ async def _execute_command_action( radio_values = params.get("radio", params) if not isinstance(radio_values, dict): return False, "radio settings must be an object", None - success, message = self._apply_config_update({"radio": radio_values}, merge_mode="patch") + success, message = self._apply_config_update( + {"radio": radio_values}, merge_mode="patch" + ) return success, message, None if action == "run_diagnostic": stats = self.daemon_instance.get_stats() if self.daemon_instance else {} - return True, ( - f"rx={int(stats.get('rx_count', 0))}, " - f"tx={int(stats.get('forwarded_count', 0))}, " - f"dropped={int(stats.get('dropped_count', 0))}" - ), None + return ( + True, + ( + f"rx={int(stats.get('rx_count', 0))}, " + f"tx={int(stats.get('forwarded_count', 0))}, " + f"dropped={int(stats.get('dropped_count', 0))}" + ), + None, + ) if action == "export_config": normalized_config = self._normalize_for_hash(self.config) @@ -639,7 +660,9 @@ def _apply_config_update(self, updates: Any, merge_mode: str = "patch") -> Tuple live_updated = self.config_manager.live_update_daemon(sections) return ( bool(saved and live_updated), - "Config replaced" if saved and live_updated else "Failed to persist replace update", + "Config replaced" + if saved and live_updated + else "Failed to persist replace update", ) return True, "Config replaced" @@ -699,7 +722,9 @@ def _apply_cert_renewal(self, response: Dict[str, Any]) -> Tuple[bool, str]: client_key = response.get("client_key") ca_cert = response.get("ca_cert") - if not all(isinstance(item, str) and item.strip() for item in (client_cert, client_key, ca_cert)): + if not all( + isinstance(item, str) and item.strip() for item in (client_cert, client_key, ca_cert) + ): return False, "Missing certificate payload values" cert_dir = Path(self.cert_store_dir) @@ -826,7 +851,11 @@ def _init_mqtt_publisher(self) -> None: if self.mqtt_username: client.username_pw_set(self.mqtt_username, self.mqtt_password) if self.mqtt_tls_enabled: - ca_certs = self._require_ssl_file(self.ca_cert_path, "ca_cert_path") if self.ca_cert_path else None + ca_certs = ( + self._require_ssl_file(self.ca_cert_path, "ca_cert_path") + if self.ca_cert_path + else None + ) certfile = None keyfile = None if self.client_cert_path or self.client_key_path: @@ -890,7 +919,18 @@ def _on_mqtt_disconnect(self, _client, _userdata, reason_code, _properties=None) def _current_mqtt_signature( self, - ) -> Tuple[str, int, str, bool, bool, Optional[str], Optional[str], Optional[str], Optional[str], Optional[str]]: + ) -> Tuple[ + str, + int, + str, + bool, + bool, + Optional[str], + Optional[str], + Optional[str], + Optional[str], + Optional[str], + ]: return ( self.mqtt_broker_host, self.mqtt_broker_port, @@ -923,10 +963,7 @@ def _sync_mqtt_publisher(self) -> None: @staticmethod def _deep_merge(target: Dict[str, Any], source: Dict[str, Any]) -> None: for key, value in source.items(): - if ( - isinstance(value, dict) - and isinstance(target.get(key), dict) - ): + if isinstance(value, dict) and isinstance(target.get(key), dict): GlassHandler._deep_merge(target[key], value) else: target[key] = value diff --git a/repeater/data_acquisition/gps_service.py b/repeater/data_acquisition/gps_service.py index 9976e84..fb4676d 100644 --- a/repeater/data_acquisition/gps_service.py +++ b/repeater/data_acquisition/gps_service.py @@ -597,7 +597,9 @@ def __init__( # Backward-compatible alias: use_gps_for_repeater_location=True means # GPS advertising is enabled for repeater-originated location fields. legacy_use_gps_location = gps_config.get("use_gps_for_repeater_location") - advertise_gps_default = bool(legacy_use_gps_location) if legacy_use_gps_location is not None else False + advertise_gps_default = ( + bool(legacy_use_gps_location) if legacy_use_gps_location is not None else False + ) self.advertise_gps_location = bool( gps_config.get("advertise_gps_location", advertise_gps_default) ) @@ -607,9 +609,7 @@ def __init__( "location_precision_digits", gps_config.get("repeater_location_precision_digits"), ) - self.location_precision_digits = _normalize_precision_digits( - precision_value - ) + self.location_precision_digits = _normalize_precision_digits(precision_value) self.source = str(gps_config.get("source", "serial")).lower() self.device = gps_config.get("device", "/dev/serial0") self.baud_rate = int(gps_config.get("baud_rate", 9600)) @@ -631,7 +631,9 @@ def __init__( # Backward-compatible alias: update_repeater_location_from_fix # predates persist_gps_fix_to_config. legacy_update_from_fix = gps_config.get("update_repeater_location_from_fix") - persist_fix_default = bool(legacy_update_from_fix) if legacy_update_from_fix is not None else False + persist_fix_default = ( + bool(legacy_update_from_fix) if legacy_update_from_fix is not None else False + ) self.persist_gps_fix_enabled = bool( gps_config.get("persist_gps_fix_to_config", persist_fix_default) ) @@ -639,9 +641,7 @@ def __init__( "persist_gps_fix_interval_seconds", gps_config.get("location_update_interval_seconds", 600.0), ) - self.persist_gps_fix_interval_seconds = max( - 1.0, float(persist_interval_value) - ) + self.persist_gps_fix_interval_seconds = max(1.0, float(persist_interval_value)) self._location_update_callback = location_update_callback self._location_update_lock = threading.RLock() self._last_location_update_monotonic: Optional[float] = None @@ -740,7 +740,9 @@ def _apply_precision(self, value: Optional[float]) -> Optional[float]: return value return round(value, self.location_precision_digits) - def _resolve_repeater_location(self, snapshot: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _resolve_repeater_location( + self, snapshot: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: fallback_lat = _to_float(self.repeater_config.get("latitude")) fallback_lon = _to_float(self.repeater_config.get("longitude")) fallback_location = { @@ -980,9 +982,7 @@ def _apply_effective_position(self, snapshot: Dict[str, Any]): snapshot["position_meta"] = { "source": position_source, "source_label": position_source_label, - "policy": "fallback_to_config" - if self.api_fallback_to_config_location - else "gps_only", + "policy": "fallback_to_config" if self.api_fallback_to_config_location else "gps_only", "manual_config_available": manual_position is not None, "gps_fix_valid": gps_fix_valid, } @@ -1052,8 +1052,7 @@ def _maybe_sync_system_time(self): with self._time_sync_lock: if ( self._last_time_sync_monotonic is not None - and now_monotonic - self._last_time_sync_monotonic - < self.time_sync_interval_seconds + and now_monotonic - self._last_time_sync_monotonic < self.time_sync_interval_seconds ): return diff --git a/repeater/data_acquisition/hardware_stats.py b/repeater/data_acquisition/hardware_stats.py index 465478e..8ab18f0 100644 --- a/repeater/data_acquisition/hardware_stats.py +++ b/repeater/data_acquisition/hardware_stats.py @@ -18,11 +18,10 @@ class HardwareStatsCollector: - def __init__(self): self.start_time = time.time() - + def get_stats(self): if not PSUTIL_AVAILABLE: @@ -32,13 +31,12 @@ def get_stats(self): try: # Get current timestamp now = time.time() - uptime = now - self.start_time - + # CPU stats cpu_percent = psutil.cpu_percent(interval=0.1) cpu_count = psutil.cpu_count() cpu_freq = psutil.cpu_freq() - + # Memory stats memory = psutil.virtual_memory() @@ -47,7 +45,7 @@ def get_stats(self): # Network stats (total across all interfaces) net_io = psutil.net_io_counters() - + # Load average (Unix only) load_avg = None try: @@ -55,11 +53,11 @@ def get_stats(self): except (AttributeError, OSError): # Not available on all systems - use zeros load_avg = (0.0, 0.0, 0.0) - + # System boot time boot_time = psutil.boot_time() system_uptime = now - boot_time - + # Temperature (if available) temperatures = {} try: @@ -71,7 +69,7 @@ def get_stats(self): except (AttributeError, OSError): # Temperature sensors not available pass - + # Format data structure to match Vue component expectations stats = { "cpu": { @@ -104,7 +102,7 @@ def get_stats(self): # Add temperatures if available if temperatures: stats["temperatures"] = temperatures - + return stats except Exception as e: diff --git a/repeater/data_acquisition/mqtt_handler.py b/repeater/data_acquisition/mqtt_handler.py index 0d94484..fda3c08 100644 --- a/repeater/data_acquisition/mqtt_handler.py +++ b/repeater/data_acquisition/mqtt_handler.py @@ -2,15 +2,13 @@ import binascii import json import logging -import string import threading from datetime import datetime, timedelta, timezone from typing import Any, Callable, Dict, List, Optional import paho.mqtt.client as mqtt -from nacl.signing import SigningKey -from repeater import __version__, config +from repeater import __version__ from repeater.presets import get_preset # Try to import paho-mqtt error code mappings @@ -176,7 +174,7 @@ def __init__( broker_index: int, node_name: str, on_connect_callback: Optional[Callable] = None, - on_disconnect_callback: Optional[Callable] = None + on_disconnect_callback: Optional[Callable] = None, ): self.broker = broker self.local_identity = local_identity @@ -194,22 +192,27 @@ def __init__( self._reconnect_attempts = 0 self._reconnect_timer = None self._max_reconnect_delay = 300 # 5 minutes max - self._keepalive = broker.get("keepalive", 30) # default tighter than paho's 60s to beat NAT/proxy timeouts + self._keepalive = broker.get( + "keepalive", 30 + ) # default tighter than paho's 60s to beat NAT/proxy timeouts self._jwt_refresh_timer = None self._shutdown_requested = False self._last_jwt_claims = None - self.transport = broker.get('transport', 'websockets') - - self.use_jwt_auth = broker.get('use_jwt_auth', False) - self.username = broker.get('username', None) - self.password = broker.get('password', None) - - self.format=broker.get("format", "letsmesh") - self.tls=broker.get("tls", { - "enabled": False, - "insecure": False, - }) - + self.transport = broker.get("transport", "websockets") + + self.use_jwt_auth = broker.get("use_jwt_auth", False) + self.username = broker.get("username", None) + self.password = broker.get("password", None) + + self.format = broker.get("format", "letsmesh") + self.tls = broker.get( + "tls", + { + "enabled": False, + "insecure": False, + }, + ) + client_id = f"meshcore_{self.public_key}_{broker['host']}_{self.format}" client_kwargs = { "client_id": client_id, @@ -232,7 +235,7 @@ def __init__( self.client.on_disconnect = self._on_disconnect # If None, will be use defaults depending on the format value - self.base_topic=broker.get("base_topic", None) + self.base_topic = broker.get("base_topic", None) self.enabled = broker.get("enabled", False) self.retain_status = broker.get("retain_status", False) @@ -247,17 +250,20 @@ def __init__( # MC2MQTT family: canonical meshcoretomqtt topic structure. self.base_topic = f"meshcore/{self.iata_code}/{self.public_key}" else: - logger.warning(f"Unknown broker format '{self.format}' for {self.broker['name']}, defaulting to MC2MQTT topic structure") + logger.warning( + f"Unknown broker format '{self.format}' for {self.broker['name']}, defaulting to MC2MQTT topic structure" + ) self.base_topic = f"meshcore/{self.iata_code}/{self.public_key}" - + from pymc_core.protocol.utils import PAYLOAD_TYPES - + disallowed_types = broker.get("disallowed_packet_types", []) type_name_map = {name: code for code, name in PAYLOAD_TYPES.items()} - + self.disallowed_types = [type_name_map.get(name.upper(), None) for name in disallowed_types] - self.disallowed_types = [val for val in self.disallowed_types if val is not None] # Filter out invalid names - + self.disallowed_types = [ + val for val in self.disallowed_types if val is not None + ] # Filter out invalid names def _generate_jwt(self) -> str: """Generate MeshCore-style Ed25519 JWT token""" @@ -276,7 +282,12 @@ def _generate_jwt(self) -> str: payload["aud"] = self.broker["audience"] # Only include email/owner for verified TLS connections - if self.tls and self.tls.get("enabled", False) and self._tls_verified and (self.email or self.owner): + if ( + self.tls + and self.tls.get("enabled", False) + and self._tls_verified + and (self.email or self.owner) + ): payload["email"] = self.email payload["owner"] = self.owner else: @@ -359,7 +370,9 @@ def _on_disconnect(self, client, userdata, rc, *extra): if rc_value != 0: # Unexpected disconnect error_msg = get_mqtt_error_message(rc_value, is_disconnect=True) if was_running: - logger.warning(f"Disconnected from {self.broker['name']} (rc={rc_value}): {error_msg}") + logger.warning( + f"Disconnected from {self.broker['name']} (rc={rc_value}): {error_msg}" + ) else: logger.debug( f"Duplicate disconnect callback from {self.broker['name']} while already disconnected " @@ -410,8 +423,10 @@ def _attempt_reconnect(self, reason: str = "connection lost"): # Stop the loop if it's still running (websocket mode requires clean restart) try: self.client.loop_stop() - except: - pass + except Exception as e: + logger.debug( + f"loop_stop during reconnect was ignored for {self.broker['name']}: {e}" + ) self._set_credentials() @@ -436,13 +451,17 @@ def _set_credentials(self): f"user=v1_{self.public_key[:8]}...{self.public_key[-8:]}" ) elif self.username and self.password: - logger.info(f"Using provided credentials for {self.broker['name']} (username: {self.username})") + logger.info( + f"Using provided credentials for {self.broker['name']} (username: {self.username})" + ) self.client.username_pw_set(username=self.username, password=self.password) else: - logger.info(f"No credentials set for {self.broker['name']} (JWT auth disabled and no username/password provided)") - + logger.info( + f"No credentials set for {self.broker['name']} (JWT auth disabled and no username/password provided)" + ) + self._connect_time = datetime.now(timezone.utc) - + except Exception as e: logger.error(f"Failed to set JWT credentials for {self.broker['name']}: {e}") raise @@ -452,26 +471,27 @@ def connect(self): self._shutdown_requested = False # Conditional TLS setup - if self.enabled == False: + if not self.enabled: logger.info(f"Connection to {self.broker['name']} is disabled in configuration") return if self.transport == "websockets": - protocol = "ws" + protocol = "ws" elif self.transport == "tcp": protocol = "mqtt" else: raise ValueError(f"Invalid transport '{self.transport}' for {self.broker['name']}") - + # Setup TLS independent of transport - MQTT over TLS can be used with both websockets and raw TCP if self.tls and self.tls.get("enabled", False): import ssl + self.client.tls_set(cert_reqs=ssl.CERT_REQUIRED, tls_version=ssl.PROTOCOL_TLS_CLIENT) self.client.tls_insecure_set(self.tls.get("insecure", False)) self._tls_verified = True # Ensure to update the protocol is we're running TLS on websockets - if( self.transport == "websockets" ): + if self.transport == "websockets": protocol = "wss" # Set JWT credentials before CONNECT handshake @@ -506,12 +526,14 @@ def disconnect(self): def publish(self, subtopic: str, payload: str, retain: bool = False, qos: int = 0): """Publish message to broker""" - + # Legacy MQTT config uses singular "packet" topic, while LetsMesh uses "packets". Handle this for compatibility. if self.format == "mqtt" and subtopic == "packets": subtopic = "packet" - - if(subtopic == "status"): # Override the status topic retain and qos settings based on broker configuration + + if ( + subtopic == "status" + ): # Override the status topic retain and qos settings based on broker configuration retain = self.retain_status qos = 1 if self.retain_status else 0 @@ -565,7 +587,7 @@ def _schedule_jwt_refresh(self): _trace( f"JWT refresh scheduled for {self.broker['name']} in {refresh_delay:.0f}s " - f"({refresh_threshold*100:.0f}% of {self.jwt_expiry_minutes}min token lifetime)" + f"({refresh_threshold * 100:.0f}% of {self.jwt_expiry_minutes}min token lifetime)" ) self._jwt_refresh_timer = threading.Timer(refresh_delay, self.reconnect_for_token_expiry) self._jwt_refresh_timer.daemon = True @@ -588,7 +610,6 @@ def reconnect_for_token_expiry(self): # MeshCore → MQTT Publisher # ==================================================================== class MeshCoreToMqttPusher: - def __init__( self, local_identity, @@ -618,11 +639,11 @@ def __init__( self.stats_provider = stats_provider self._status_task = None self._running = False - self._shutdown_requested = False + self._shutdown_requested = False self._lock = threading.Lock() self._connect_timers: List[threading.Timer] = [] - # Initialize brokers list + # Initialize brokers list mqtt_brokers_config = config.get("mqtt_brokers", {}) letsmesh_config = config.get("letsmesh", {}) mqtt_config = config.get("mqtt", {}) @@ -633,7 +654,9 @@ def __init__( brokers.extend(mqtt_brokers_config.get("brokers", [])) if letsmesh_config or mqtt_config: - logger.warning("Multiple MQTT broker configurations found (mqtt_brokers, letsmesh, mqtt). Only mqtt_brokers will be used") + logger.warning( + "Multiple MQTT broker configurations found (mqtt_brokers, letsmesh, mqtt). Only mqtt_brokers will be used" + ) else: if mqtt_config: @@ -675,12 +698,12 @@ def __init__( ) broker_config = {**broker_config, "format": "letsmesh"} self.brokers.append(broker_config) - logger.info(f"Added broker: {broker_config['name']} (format={broker_config.get('format', 'unknown')})") + logger.info( + f"Added broker: {broker_config['name']} (format={broker_config.get('format', 'unknown')})" + ) else: logger.warning(f"Skipping invalid broker config: {broker_config}") - - # Create broker connections self.connections: List[_BrokerConnection] = [] for idx, broker in enumerate(self.brokers): @@ -709,7 +732,7 @@ def __init__( "status_interval": self.status_interval, "owner": self.owner, "email": self.email, - "brokers": brokers + "brokers": brokers, } # Update the configuration with the new configuration @@ -724,7 +747,7 @@ def convert_mqtt_to_broker_config(self, mqtt_cfg: dict) -> dict: "name": mqtt_cfg["broker"], "host": mqtt_cfg["broker"], "port": mqtt_cfg["port"], - "use_jwt_auth": False, # The legacy MQTT config does not support JWT auth, so we set this to False + "use_jwt_auth": False, # The legacy MQTT config does not support JWT auth, so we set this to False "username": mqtt_cfg.get("username", None), "password": mqtt_cfg.get("password", None), "transport": transport, @@ -768,23 +791,27 @@ def convert_letsmesh_to_broker_config(self, letsmesh_cfg: dict) -> List[dict]: # Append any user-defined additional brokers as full entries. for add_broker in letsmesh_cfg.get("additional_brokers", []): - logger.info(f"Imported additional LetsMesh broker from 'letsmesh' config: {add_broker.get('name')}") - entries.append({ - "enabled": enabled, - "name": add_broker["name"], - "host": add_broker["host"], - "port": add_broker["port"], - "audience": add_broker["audience"], - "use_jwt_auth": add_broker.get("use_jwt_auth", True), - "transport": add_broker.get("transport", "websockets"), - "format": "letsmesh", - "base_topic": None, - "retain_status": False, - "tls": { - "enabled": add_broker.get("tls", {}).get("enabled", True), - "insecure": add_broker.get("tls", {}).get("insecure", False), - }, - }) + logger.info( + f"Imported additional LetsMesh broker from 'letsmesh' config: {add_broker.get('name')}" + ) + entries.append( + { + "enabled": enabled, + "name": add_broker["name"], + "host": add_broker["host"], + "port": add_broker["port"], + "audience": add_broker["audience"], + "use_jwt_auth": add_broker.get("use_jwt_auth", True), + "transport": add_broker.get("transport", "websockets"), + "format": "letsmesh", + "base_topic": None, + "retain_status": False, + "tls": { + "enabled": add_broker.get("tls", {}).get("enabled", True), + "insecure": add_broker.get("tls", {}).get("insecure", False), + }, + } + ) return entries @@ -833,7 +860,7 @@ def connect(self): timer = threading.Timer(delay, lambda c=conn: self._delayed_connect(c)) timer.daemon = True timer.start() - self._connect_timers.append(timer) + self._connect_timers.append(timer) except Exception as e: logger.error(f"Failed to connect to {conn.broker['name']}: {e}") @@ -855,8 +882,8 @@ def disconnect(self): for timer in self._connect_timers: try: timer.cancel() - except Exception: - pass + except Exception as exc: + logger.debug(f"Error cancelling MQTT connect timer: {exc}") self._connect_timers = [] # Stop the heartbeat loop @@ -864,9 +891,11 @@ def disconnect(self): # Publish offline status before disconnecting try: - self.publish_status(state="offline", origin=self.node_name, radio_config=self.radio_config) - except Exception: - pass + self.publish_status( + state="offline", origin=self.node_name, radio_config=self.radio_config + ) + except Exception as exc: + logger.debug(f"Failed to publish MQTT offline status during disconnect: {exc}") # Disconnect all brokers for conn in self.connections: @@ -874,7 +903,7 @@ def disconnect(self): conn.disconnect() except Exception as e: logger.error(f"Error disconnecting from {conn.broker['name']}: {e}") - + self._status_task = None logger.info("Disconnected from all brokers") @@ -899,7 +928,11 @@ def _status_heartbeat_loop(self): # Packet helpers # ---------------------------------------------------------------- def _process_packet(self, pkt: dict) -> dict: - return {"timestamp": datetime.now(timezone.utc).isoformat(), "origin_id": self.public_key, **pkt} + return { + "timestamp": datetime.now(timezone.utc).isoformat(), + "origin_id": self.public_key, + **pkt, + } def publish_packet(self, pkt: dict, subtopic="packets", retain=False): return self.publish(subtopic, self._process_packet(pkt), retain) @@ -954,7 +987,9 @@ def publish(self, subtopic: str, payload: dict, retain: bool = False, qos: int = message = json.dumps(payload) # _BrokerConnection now handles topic prefixing, so we only log the subtopic here - logger.debug(f"Publishing topic='{subtopic}', {_summarize_payload_for_log(payload, message)}") + logger.debug( + f"Publishing topic='{subtopic}', {_summarize_payload_for_log(payload, message)}" + ) packet_type = payload.get("type") @@ -968,14 +1003,13 @@ def publish(self, subtopic: str, payload: dict, retain: bool = False, qos: int = result = conn.publish(subtopic, message, retain=retain, qos=qos) results.append((conn.broker["name"], result)) _trace(f"Published to {conn.broker['name']} -- {subtopic}") - elif conn.enabled == False: + elif not conn.enabled: results.append((conn.broker["name"], "Skipped due to being disabled")) if not results: logger.warning(f"No active broker connections for publishing to {subtopic}") return results - def publish_mqtt(self, payload: dict, subtopic: str, retain: bool = False, qos: int = 0): """Publish message to brokers using the legacy custom-MQTT format only. @@ -988,7 +1022,9 @@ def publish_mqtt(self, payload: dict, subtopic: str, retain: bool = False, qos: message = json.dumps(payload) # _BrokerConnection now handles topic prefixing, so we only log the subtopic here - logger.debug(f"Publishing topic='{subtopic}', {_summarize_payload_for_log(payload, message)}") + logger.debug( + f"Publishing topic='{subtopic}', {_summarize_payload_for_log(payload, message)}" + ) results = [] with self._lock: @@ -1004,8 +1040,10 @@ def publish_mqtt(self, payload: dict, subtopic: str, retain: bool = False, qos: continue result = conn.publish(subtopic, message, retain=retain, qos=qos) results.append((conn.broker["name"], result)) - _trace(f"Published to {conn.broker['name']} (format={conn.format}) -- {subtopic}") - elif conn.enabled == False: + _trace( + f"Published to {conn.broker['name']} (format={conn.format}) -- {subtopic}" + ) + elif not conn.enabled: results.append((conn.broker["name"], "Skipped due to being disabled")) if not results: @@ -1073,7 +1111,6 @@ def get_mqtt_error_message(rc: int, is_disconnect: bool = False) -> str: 16: "The connection was lost.", 17: "Client timeout", # MQTT v5 codes - 4: "Disconnect with Will message", 128: "Unspecified error", 129: "Malformed packet", 130: "Protocol error", @@ -1105,12 +1142,12 @@ def get_mqtt_error_message(rc: int, is_disconnect: bool = False) -> str: if HAS_REASON_CODES and ReasonCode is not None: try: - - reason = ReasonCode(mqtt.CONNACK if not is_disconnect else mqtt.DISCONNECT, identifier=rc) - name = reason.getName() if hasattr(reason, 'getName') else str(reason) + reason = ReasonCode( + mqtt.CONNACK if not is_disconnect else mqtt.DISCONNECT, identifier=rc + ) + name = reason.getName() if hasattr(reason, "getName") else str(reason) return f"{name} (code {rc})" except Exception as e: - _fallback = (disconnect_errors if is_disconnect else connect_errors).get(rc) if _fallback is None: logger.debug(f"Could not decode reason code {rc}: {e}") @@ -1127,7 +1164,7 @@ def get_mqtt_error_message(rc: int, is_disconnect: bool = False) -> str: paho_error = mqtt.error_string(rc) if paho_error and paho_error != "Unknown error.": return paho_error - except Exception: - pass + except Exception as exc: + logger.debug(f"Failed to map paho MQTT error string for code {rc}: {exc}") return error_dict.get(rc, f"Unknown error code {rc}") diff --git a/repeater/data_acquisition/rrdtool_handler.py b/repeater/data_acquisition/rrdtool_handler.py index ca075aa..54930cb 100644 --- a/repeater/data_acquisition/rrdtool_handler.py +++ b/repeater/data_acquisition/rrdtool_handler.py @@ -1,7 +1,7 @@ import logging import time from pathlib import Path -from typing import Any, Dict, Optional +from typing import Optional try: import rrdtool @@ -231,7 +231,7 @@ def get_packet_type_stats(self, hours: int = 24) -> Optional[dict]: rrd_data = self.get_data(start_time, end_time) if not rrd_data or "packet_types" not in rrd_data: - logger.warning(f"No RRD data available") + logger.warning("No RRD data available") return None type_totals = {} diff --git a/repeater/data_acquisition/sqlite_handler.py b/repeater/data_acquisition/sqlite_handler.py index 6b5d427..a41e0da 100644 --- a/repeater/data_acquisition/sqlite_handler.py +++ b/repeater/data_acquisition/sqlite_handler.py @@ -457,9 +457,7 @@ def _run_migrations(self): if not existing: # Replace the non-unique index with a UNIQUE one - conn.execute( - "DROP INDEX IF EXISTS idx_companion_contacts_pubkey" - ) + conn.execute("DROP INDEX IF EXISTS idx_companion_contacts_pubkey") conn.execute( "CREATE UNIQUE INDEX IF NOT EXISTS idx_companion_contacts_hash_pubkey " "ON companion_contacts (companion_hash, pubkey)" @@ -478,11 +476,18 @@ def _run_migrations(self): ).fetchone() if not existing: - for table in ("companion_contacts", "companion_channels", "companion_messages"): - conn.execute( - f"UPDATE {table} SET companion_hash = '0x' || companion_hash " - "WHERE companion_hash NOT LIKE '0x%'" - ) + conn.execute( + "UPDATE companion_contacts SET companion_hash = '0x' || companion_hash " + "WHERE companion_hash NOT LIKE '0x%'" + ) + conn.execute( + "UPDATE companion_channels SET companion_hash = '0x' || companion_hash " + "WHERE companion_hash NOT LIKE '0x%'" + ) + conn.execute( + "UPDATE companion_messages SET companion_hash = '0x' || companion_hash " + "WHERE companion_hash NOT LIKE '0x%'" + ) conn.execute( "INSERT INTO migrations (migration_name, applied_at) VALUES (?, ?)", (migration_name, time.time()), @@ -811,13 +816,13 @@ def store_crc_errors(self, record: dict): """Store a CRC error batch (delta count since last poll).""" try: with self._connect() as conn: - conn.execute(""" + conn.execute( + """ INSERT INTO crc_errors (timestamp, count) VALUES (?, ?) - """, ( - record.get("timestamp", time.time()), - record.get("count", 1) - )) + """, + (record.get("timestamp", time.time()), record.get("count", 1)), + ) except Exception as e: logger.error(f"Failed to store CRC errors in SQLite: {e}") @@ -827,8 +832,7 @@ def get_crc_error_count(self, hours: int = 24) -> int: cutoff = time.time() - (hours * 3600) with self._connect() as conn: row = conn.execute( - "SELECT COALESCE(SUM(count), 0) FROM crc_errors WHERE timestamp > ?", - (cutoff,) + "SELECT COALESCE(SUM(count), 0) FROM crc_errors WHERE timestamp > ?", (cutoff,) ).fetchone() return row[0] if row else 0 except Exception as e: @@ -1088,7 +1092,13 @@ def _airtime_ms(length_bytes: int) -> float: bucket_ts = int(row["timestamp"] / bucket_seconds) * bucket_seconds ms = _airtime_ms(row["length"]) if bucket_ts not in buckets: - buckets[bucket_ts] = {"timestamp": bucket_ts, "rx_ms": 0.0, "tx_ms": 0.0, "rx_count": 0, "tx_count": 0} + buckets[bucket_ts] = { + "timestamp": bucket_ts, + "rx_ms": 0.0, + "tx_ms": 0.0, + "rx_count": 0, + "tx_count": 0, + } if row["transmitted"]: buckets[bucket_ts]["tx_ms"] += ms buckets[bucket_ts]["tx_count"] += 1 @@ -1140,6 +1150,7 @@ def get_packet_type_stats(self, hours: int = 24) -> dict: # Align with pyMC_core feat/newRadios PAYLOAD_TYPES (0x0B = CONTROL) try: from pymc_core.protocol.utils import PAYLOAD_TYPES as _PT + _human = { "REQ": "Request", "RESPONSE": "Response", @@ -1388,13 +1399,21 @@ def get_table_stats(self) -> dict: try: db_size = self.sqlite_path.stat().st_size if self.sqlite_path.exists() else 0 - tables_with_timestamp = { - "packets": "timestamp", - "adverts": "timestamp", - "noise_floor": "timestamp", - "crc_errors": "timestamp", - "room_messages": "created_at", - "companion_messages": "created_at", + tables_with_timestamp = [ + "packets", + "adverts", + "noise_floor", + "crc_errors", + "room_messages", + "companion_messages", + ] + stats_queries = { + "packets": "SELECT COUNT(*), MIN(timestamp), MAX(timestamp) FROM packets", + "adverts": "SELECT COUNT(*), MIN(timestamp), MAX(timestamp) FROM adverts", + "noise_floor": "SELECT COUNT(*), MIN(timestamp), MAX(timestamp) FROM noise_floor", + "crc_errors": "SELECT COUNT(*), MIN(timestamp), MAX(timestamp) FROM crc_errors", + "room_messages": "SELECT COUNT(*), MIN(created_at), MAX(created_at) FROM room_messages", + "companion_messages": "SELECT COUNT(*), MIN(created_at), MAX(created_at) FROM companion_messages", } tables_without_timestamp = [ "transport_keys", @@ -1405,6 +1424,15 @@ def get_table_stats(self) -> dict: "companion_prefs", "migrations", ] + count_queries = { + "transport_keys": "SELECT COUNT(*) FROM transport_keys", + "api_tokens": "SELECT COUNT(*) FROM api_tokens", + "room_client_sync": "SELECT COUNT(*) FROM room_client_sync", + "companion_contacts": "SELECT COUNT(*) FROM companion_contacts", + "companion_channels": "SELECT COUNT(*) FROM companion_channels", + "companion_prefs": "SELECT COUNT(*) FROM companion_prefs", + "migrations": "SELECT COUNT(*) FROM migrations", + } table_info = [] with self._connect() as conn: @@ -1416,12 +1444,10 @@ def get_table_stats(self) -> dict: ).fetchall() } - for table, ts_col in tables_with_timestamp.items(): + for table in tables_with_timestamp: if table not in existing: continue - row = conn.execute( - f"SELECT COUNT(*), MIN({ts_col}), MAX({ts_col}) FROM {table}" # noqa: S608 - ).fetchone() + row = conn.execute(stats_queries[table]).fetchone() count, oldest, newest = row[0], row[1], row[2] table_info.append( { @@ -1436,7 +1462,7 @@ def get_table_stats(self) -> dict: for table in tables_without_timestamp: if table not in existing: continue - count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0] # noqa: S608 + count = conn.execute(count_queries[table]).fetchone()[0] table_info.append( { "name": table, @@ -1469,9 +1495,22 @@ def purge_table(self, table_name: str) -> int: if table_name not in PURGEABLE: raise ValueError(f"Table '{table_name}' cannot be purged") + purge_queries = { + "packets": "DELETE FROM packets", + "adverts": "DELETE FROM adverts", + "noise_floor": "DELETE FROM noise_floor", + "crc_errors": "DELETE FROM crc_errors", + "room_messages": "DELETE FROM room_messages", + "room_client_sync": "DELETE FROM room_client_sync", + "companion_contacts": "DELETE FROM companion_contacts", + "companion_channels": "DELETE FROM companion_channels", + "companion_messages": "DELETE FROM companion_messages", + "companion_prefs": "DELETE FROM companion_prefs", + } + try: with self._connect() as conn: - result = conn.execute(f"DELETE FROM {table_name}") # noqa: S608 + result = conn.execute(purge_queries[table_name]) conn.commit() logger.info(f"Purged {result.rowcount} rows from {table_name}") return result.rowcount @@ -1508,7 +1547,12 @@ def cleanup_old_data(self, days: int = 7): conn.commit() - if packets_deleted > 0 or adverts_deleted > 0 or noise_deleted > 0 or crc_deleted > 0: + if ( + packets_deleted > 0 + or adverts_deleted > 0 + or noise_deleted > 0 + or crc_deleted > 0 + ): logger.info( f"Cleaned up {packets_deleted} old packets, {adverts_deleted} old adverts, {noise_deleted} old noise measurements, {crc_deleted} old CRC error records" ) @@ -1557,7 +1601,11 @@ def get_cumulative_counts(self) -> dict: return {"rx_total": 0, "tx_total": 0, "drop_total": 0, "type_counts": {}} def get_adverts_by_contact_type( - self, contact_type: str, limit: Optional[int] = None, offset: Optional[int] = None, hours: Optional[int] = None + self, + contact_type: str, + limit: Optional[int] = None, + offset: Optional[int] = None, + hours: Optional[int] = None, ) -> List[dict]: try: @@ -1779,38 +1827,51 @@ def update_transport_key( last_used: Optional[float] = None, ) -> bool: try: - updates = [] - params = [] - - if name is not None: - updates.append("name = ?") - params.append(name) - if flood_policy is not None: - updates.append("flood_policy = ?") - params.append(flood_policy) - if transport_key is not None: - updates.append("transport_key = ?") - params.append(transport_key) - if parent_id is not None: - updates.append("parent_id = ?") - params.append(parent_id) - if last_used is not None: - updates.append("last_used = ?") - params.append(last_used) - - if not updates: + has_name = name is not None + has_flood_policy = flood_policy is not None + has_transport_key = transport_key is not None + has_parent_id = parent_id is not None + has_last_used = last_used is not None + + if not any( + [ + has_name, + has_flood_policy, + has_transport_key, + has_parent_id, + has_last_used, + ] + ): return False - updates.append("updated_at = ?") - params.append(time.time()) - params.append(key_id) + params = ( + int(has_name), + name, + int(has_flood_policy), + flood_policy, + int(has_transport_key), + transport_key, + int(has_parent_id), + parent_id, + int(has_last_used), + last_used, + time.time(), + key_id, + ) with self._connect() as conn: cursor = conn.execute( - f""" - UPDATE transport_keys SET {', '.join(updates)} + """ + UPDATE transport_keys + SET + name = CASE WHEN ? THEN ? ELSE name END, + flood_policy = CASE WHEN ? THEN ? ELSE flood_policy END, + transport_key = CASE WHEN ? THEN ? ELSE transport_key END, + parent_id = CASE WHEN ? THEN ? ELSE parent_id END, + last_used = CASE WHEN ? THEN ? ELSE last_used END, + updated_at = ? WHERE id = ? - """, + """, params, ) return cursor.rowcount > 0 @@ -1911,9 +1972,7 @@ def sync_transport_keys(self, entries: List[Dict[str, Any]]) -> Dict[str, int]: transport_key = self.generate_transport_key(node["name"]) generated_keys += 1 parent_id = ( - db_ids.get(node["parent_node_id"]) - if node.get("parent_node_id") - else None + db_ids.get(node["parent_node_id"]) if node.get("parent_node_id") else None ) cursor = conn.execute( """ @@ -2035,8 +2094,8 @@ def upsert_client_sync(self, room_hash: str, client_pubkey: str, **kwargs) -> bo # Use INSERT OR REPLACE for single atomic upsert conn.execute( f""" - INSERT OR REPLACE INTO room_client_sync ({', '.join(columns)}) - VALUES ({', '.join(placeholders)}) + INSERT OR REPLACE INTO room_client_sync ({", ".join(columns)}) + VALUES ({", ".join(placeholders)}) """, values, ) diff --git a/repeater/data_acquisition/storage_collector.py b/repeater/data_acquisition/storage_collector.py index bc0a5b2..fb363d7 100644 --- a/repeater/data_acquisition/storage_collector.py +++ b/repeater/data_acquisition/storage_collector.py @@ -1,10 +1,7 @@ import asyncio -import json import logging import time -from datetime import datetime -from pathlib import Path -from typing import Any, Dict, Optional +from typing import Optional from repeater.config import resolve_storage_dir @@ -31,7 +28,9 @@ def __init__(self, config: dict, local_identity=None, repeater_handler=None): # Initialize MQTT handler if configured self.mqtt_handler = None - if (config.get("mqtt_brokers", {}) or config.get("letsmesh", {}) or config.get("mqtt", {})) and local_identity: + if ( + config.get("mqtt_brokers", {}) or config.get("letsmesh", {}) or config.get("mqtt", {}) + ) and local_identity: try: # Pass local_identity directly (supports both standard and firmware keys) self.mqtt_handler = MeshCoreToMqttPusher( @@ -42,9 +41,7 @@ def __init__(self, config: dict, local_identity=None, repeater_handler=None): self.mqtt_handler.connect() public_key_hex = local_identity.get_public_key().hex() - logger.info( - f"MQTT handler initialized with public key: {public_key_hex[:16]}..." - ) + logger.info(f"MQTT handler initialized with public key: {public_key_hex[:16]}...") except Exception as e: logger.error(f"Failed to initialize MQTT handler: {e}") self.mqtt_handler = None @@ -196,7 +193,9 @@ def _publish_packet_sync(self, packet_record: dict, skip_mqtt: bool): self._last_ws_stats_broadcast = now_mono packet_stats_24h = self.sqlite_handler.get_packet_stats(hours=24) uptime_seconds = ( - time.time() - self.repeater_handler.start_time if self.repeater_handler else 0 + time.time() - self.repeater_handler.start_time + if self.repeater_handler + else 0 ) self.websocket_broadcast_stats( { @@ -207,7 +206,6 @@ def _publish_packet_sync(self, packet_record: dict, skip_mqtt: bool): except Exception as e: logger.debug(f"WebSocket broadcast failed: {e}") - self._publish_packet_to_mqtt(packet_record) def _publish_packet_to_mqtt(self, packet_record: dict): diff --git a/repeater/data_acquisition/websocket_handler.py b/repeater/data_acquisition/websocket_handler.py index bf88b66..ed772e9 100644 --- a/repeater/data_acquisition/websocket_handler.py +++ b/repeater/data_acquisition/websocket_handler.py @@ -27,17 +27,16 @@ class PacketWebSocket(WebSocket): - def opened(self): """Called when a WebSocket connection is established""" # Authenticate using JWT provided as query parameter (token=) jwt_handler = cherrypy.config.get("jwt_handler") - + # Get query string from environ qs = "" if hasattr(self, "environ"): qs = self.environ.get("QUERY_STRING", "") - + params = parse_qs(qs) token = params.get("token", [None])[0] client_id = params.get("client_id", [None])[0] @@ -46,7 +45,7 @@ def opened(self): logger.warning("WebSocket connection rejected: no JWT handler configured") self.close(code=1011, reason="server configuration error") return - + if not token: logger.warning("WebSocket connection rejected: missing token") self.close(code=1008, reason="unauthorized") @@ -92,17 +91,17 @@ def received_message(self, message): elif data.get("type") == "pong": # Client responded to our ping pass - except Exception: - pass + except Exception as exc: + logger.debug(f"Ignoring malformed WebSocket message: {exc}") def broadcast_packet(packet_data: dict): if not _connected_clients: return - + message = json.dumps({"type": "packet", "data": packet_data}) - + for client in list(_connected_clients): try: client.send(message) @@ -115,9 +114,9 @@ def broadcast_stats(stats_data: dict): if not _connected_clients: return - + message = json.dumps({"type": "stats", "data": stats_data}) - + for client in list(_connected_clients): try: client.send(message) @@ -134,15 +133,15 @@ def has_connected_clients() -> bool: def _heartbeat_loop(): """Background thread to send periodic pings to all connected clients""" global _heartbeat_running - + while _heartbeat_running: time.sleep(PING_INTERVAL) - + if not _connected_clients: continue - + ping_message = json.dumps({"type": "ping"}) - + for client in list(_connected_clients): try: client.send(ping_message) @@ -154,10 +153,10 @@ def _heartbeat_loop(): def init_websocket(): """Initialize WebSocket plugin and start heartbeat""" global _heartbeat_thread, _heartbeat_running - + WebSocketPlugin(cherrypy.engine).subscribe() cherrypy.tools.websocket = WebSocketTool() - + # Start heartbeat thread if not _heartbeat_running: _heartbeat_running = True diff --git a/repeater/engine.py b/repeater/engine.py index 0ca401f..cfcd489 100644 --- a/repeater/engine.py +++ b/repeater/engine.py @@ -1,8 +1,7 @@ import asyncio import copy import logging -import random -import struct +import secrets import time from collections import OrderedDict, deque from typing import Optional, Tuple @@ -15,14 +14,12 @@ PAYLOAD_TYPE_ANON_REQ, PAYLOAD_TYPE_TRACE, PH_ROUTE_MASK, - PH_TYPE_MASK, - PH_TYPE_SHIFT, ROUTE_TYPE_DIRECT, ROUTE_TYPE_FLOOD, ROUTE_TYPE_TRANSPORT_DIRECT, ROUTE_TYPE_TRANSPORT_FLOOD, ) -from pymc_core.protocol.packet_utils import PacketHeaderUtils, PacketTimingUtils, PathUtils +from pymc_core.protocol.packet_utils import PacketHeaderUtils, PathUtils from repeater.airtime import AirtimeManager from repeater.data_acquisition import StorageCollector @@ -47,13 +44,20 @@ class RepeaterHandler(BaseHandler): - @staticmethod def payload_type() -> int: return 0xFF # Special marker (not a real payload type) - def __init__(self, config: dict, dispatcher, local_hash: int, *, local_hash_bytes=None, send_advert_func=None): + def __init__( + self, + config: dict, + dispatcher, + local_hash: int, + *, + local_hash_bytes=None, + send_advert_func=None, + ): self.config = config self.dispatcher = dispatcher @@ -115,7 +119,6 @@ def __init__(self, config: dict, dispatcher, local_hash: int, *, local_hash_byte # Storage collector for persistent packet logging try: - local_identity = dispatcher.local_identity if dispatcher else None self.storage = StorageCollector(config, local_identity, repeater_handler=self) logger.info("StorageCollector initialized successfully") @@ -131,7 +134,7 @@ def __init__(self, config: dict, dispatcher, local_hash: int, *, local_hash_byte self._background_task = None self._cached_noise_floor = None self._last_crc_error_count = 0 # Track radio counter for delta persistence - + # Cache transport keys for efficient lookup self._transport_keys_cache = None self._transport_keys_cache_time = 0 @@ -175,8 +178,8 @@ async def __call__( try: rx_airtime_ms = self.airtime_mgr.calculate_airtime(packet.get_raw_length()) self.airtime_mgr.record_rx(rx_airtime_ms) - except Exception: - pass + except Exception as exc: + logger.debug(f"Failed to record RX airtime: {exc}") route_type = packet.header & PH_ROUTE_MASK pkt_hash_full = packet.calculate_packet_hash().hex().upper() @@ -272,9 +275,7 @@ async def __call__( tx_metadata = getattr(fwd_pkt, "_tx_metadata", None) if tx_metadata: lbt_attempts = tx_metadata.get("lbt_attempts", 0) - lbt_backoff_delays_ms = tx_metadata.get( - "lbt_backoff_delays_ms", [] - ) + lbt_backoff_delays_ms = tx_metadata.get("lbt_backoff_delays_ms", []) lbt_channel_busy = tx_metadata.get("lbt_channel_busy", False) if lbt_attempts > 0: total_lbt_delay = sum(lbt_backoff_delays_ms) @@ -501,9 +502,16 @@ def record_duplicate(self, packet: Packet, rssi: int = 0, snr: float = 0.0) -> N src_hash, dst_hash = self._packet_record_src_dst(packet, payload_type) packet_record = self._build_packet_record( - packet, payload_type, route_type_parsed, rssi, snr, - original_path_hashes, path_hash_size, path_hash, - src_hash, dst_hash, + packet, + payload_type, + route_type_parsed, + rssi, + snr, + original_path_hashes, + path_hash_size, + path_hash, + src_hash, + dst_hash, transmitted=False, drop_reason="Duplicate", is_duplicate=True, @@ -667,7 +675,9 @@ def _get_drop_reason(self, packet: Packet, packet_hash: Optional[str] = None) -> if route_type == ROUTE_TYPE_FLOOD: # Check if unscoped flood policy blocked it - unscoped_flood_allow = self.config.get("mesh", {}).get("unscoped_flood_allow", self.config.get("mesh", {}).get("global_flood_allow", True)) + unscoped_flood_allow = self.config.get("mesh", {}).get( + "unscoped_flood_allow", self.config.get("mesh", {}).get("global_flood_allow", True) + ) if not unscoped_flood_allow: return "Unscoped flood policy disabled" @@ -747,8 +757,9 @@ def _is_flood_looped(self, packet: Packet, mode: Optional[str] = None) -> bool: path = packet.path or bytearray() local_hash = self.local_hash_bytes[:hash_size] local_count = sum( - 1 for i in range(hop_count) - if bytes(path[i * hash_size:(i + 1) * hash_size]) == local_hash + 1 + for i in range(hop_count) + if bytes(path[i * hash_size : (i + 1) * hash_size]) == local_hash ) return local_count >= max_counter @@ -757,7 +768,7 @@ def _check_transport_codes(self, packet: Packet) -> Tuple[bool, str]: if not self.storage: logger.warning("Transport code check failed: no storage available") return False, "No storage available for transport key validation" - + try: from pymc_core.protocol.transport_keys import calc_transport_code @@ -770,31 +781,24 @@ def _check_transport_codes(self, packet: Packet) -> Tuple[bool, str]: # Refresh cache self._transport_keys_cache = self.storage.get_transport_keys() self._transport_keys_cache_time = current_time - + transport_keys = self._transport_keys_cache - + if not transport_keys: return False, "No transport keys configured" - + # Check if packet has transport codes if not packet.has_transport_codes(): return False, "No transport codes present" transport_code_0 = packet.transport_codes[0] # First transport code - payload = packet.get_payload() - payload_type = ( - packet.get_payload_type() - if hasattr(packet, "get_payload_type") - else ((packet.header & 0x3C) >> 2) - ) - # Check packet against each transport key for key_record in transport_keys: transport_key_encoded = key_record.get("transport_key") key_name = key_record.get("name", "unknown") flood_policy = key_record.get("flood_policy", "deny") - + if not transport_key_encoded: continue @@ -865,17 +869,19 @@ def flood_forward(self, packet: Packet, packet_hash: Optional[str] = None) -> Op if not packet.drop_reason: packet.drop_reason = "Marked do not retransmit" return None - + # Check unscoped flood policy - unscoped_flood_allow = self.config.get("mesh", {}).get("unscoped_flood_allow", self.config.get("mesh", {}).get("global_flood_allow", True)) + unscoped_flood_allow = self.config.get("mesh", {}).get( + "unscoped_flood_allow", self.config.get("mesh", {}).get("global_flood_allow", True) + ) route_type = packet.header & PH_ROUTE_MASK if route_type == ROUTE_TYPE_FLOOD: if not unscoped_flood_allow: packet.drop_reason = "Unscoped flood policy disabled" return None - #Check transport scopes flood policy - if route_type == ROUTE_TYPE_TRANSPORT_FLOOD: + # Check transport scopes flood policy + if route_type == ROUTE_TYPE_TRANSPORT_FLOOD: allowed, check_reason = self._check_transport_codes(packet) if not allowed: packet.drop_reason = "Transport code not allowed to flood" @@ -1006,7 +1012,7 @@ def _calculate_tx_delay(self, packet: Packet, snr: float = 0.0) -> float: # Flood packets: random(0-5) * (airtime * 52/50 / 2) * tx_delay_factor # This creates collision avoidance with tunable delay base_delay_ms = (airtime_ms * 52 / 50) / 2.0 # From C++ implementation - random_mult = random.uniform(0, 5) # Random multiplier for collision avoidance + random_mult = secrets.randbelow(5001) / 1000.0 delay_ms = base_delay_ms * random_mult * self.tx_delay_factor delay_s = delay_ms / 1000.0 else: # DIRECT @@ -1119,8 +1125,8 @@ async def delayed_send(): can_tx_now, _ = self.airtime_mgr.can_transmit(airtime_ms) if not can_tx_now: logger.warning( - "Packet dropped at TX time: duty-cycle exceeded " - "(airtime=%.1fms)", airtime_ms, + "Packet dropped at TX time: duty-cycle exceeded (airtime=%.1fms)", + airtime_ms, ) return @@ -1254,7 +1260,10 @@ def get_stats(self) -> dict: "web": self.config.get("web", {}), # Include web configuration "mesh": { "loop_detect": self.config.get("mesh", {}).get("loop_detect", "off"), - "unscoped_flood_allow": self.config.get("mesh", {}).get("unscoped_flood_allow", self.config.get("mesh", {}).get("global_flood_allow", True)), + "unscoped_flood_allow": self.config.get("mesh", {}).get( + "unscoped_flood_allow", + self.config.get("mesh", {}).get("global_flood_allow", True), + ), "path_hash_mode": self.config.get("mesh", {}).get("path_hash_mode", 0), }, "mqtt_brokers": self.config.get("mqtt_brokers", {}), @@ -1298,8 +1307,7 @@ async def _background_timer_loop(self): if self.storage: try: retention_days = ( - self.config - .get("storage", {}) + self.config.get("storage", {}) .get("retention", {}) .get("sqlite_cleanup_days", 31) ) @@ -1390,10 +1398,10 @@ def reload_runtime_config(self): self.loop_detect_mode = self._normalize_loop_detect_mode( self.config.get("mesh", {}).get("loop_detect", LOOP_DETECT_OFF) ) - + # Note: Radio config changes require restart as they affect hardware # Note: Airtime manager has its own config reference that gets updated - + logger.info("Runtime configuration reloaded successfully") except Exception as e: logger.error(f"Error reloading runtime config: {e}") @@ -1413,5 +1421,5 @@ def cleanup(self): def __del__(self): try: self.cleanup() - except Exception: - pass + except Exception as exc: + logger.debug(f"Engine cleanup during __del__ failed: {exc}") diff --git a/repeater/handler_helpers/acl.py b/repeater/handler_helpers/acl.py index 3351999..0d5d659 100644 --- a/repeater/handler_helpers/acl.py +++ b/repeater/handler_helpers/acl.py @@ -35,12 +35,11 @@ def is_guest(self) -> bool: class ACL: - def __init__( self, max_clients: int = 50, - admin_password: str = "admin123", - guest_password: str = "guest123", + admin_password: Optional[str] = None, + guest_password: Optional[str] = None, allow_read_only: bool = True, ): self.max_clients = max_clients @@ -50,10 +49,10 @@ def __init__( self.clients: Dict[bytes, ClientInfo] = {} def authenticate_client( - self, - client_identity: Identity, - shared_secret: bytes, - password: str, + self, + client_identity: Identity, + shared_secret: bytes, + password: str, timestamp: int, sync_since: int = None, target_identity_hash: int = None, @@ -62,18 +61,18 @@ def authenticate_client( ) -> tuple[bool, int]: target_identity_config = target_identity_config or {} - + # Check for identity-specific passwords (required for room servers) identity_settings = target_identity_config.get("settings", {}) - + # Determine if this is a room server by checking the type field identity_type = target_identity_config.get("type", "") is_room_server = identity_type == "room_server" - + # Log sync_since if provided (room server format) if sync_since is not None: logger.debug(f"Client sync_since timestamp: {sync_since}") - + if is_room_server: # Room servers use passwords from their settings section only # Empty strings are treated as "not set" @@ -153,7 +152,7 @@ def authenticate_client( client.permissions &= ~PERM_ACL_ROLE_MASK client.permissions |= permissions client.shared_secret = shared_secret - + # Store sync_since for room server clients if sync_since is not None: client.sync_since = sync_since diff --git a/repeater/handler_helpers/advert.py b/repeater/handler_helpers/advert.py index e95f5e0..a91e9e5 100644 --- a/repeater/handler_helpers/advert.py +++ b/repeater/handler_helpers/advert.py @@ -20,6 +20,7 @@ class MeshActivityTier(Enum): """Mesh activity levels for adaptive rate limiting.""" + QUIET = "quiet" NORMAL = "normal" BUSY = "busy" @@ -28,10 +29,10 @@ class MeshActivityTier(Enum): # Tier multipliers for rate limit scaling TIER_MULTIPLIERS = { - MeshActivityTier.QUIET: 0.0, # No rate limiting - MeshActivityTier.NORMAL: 0.5, # Light limiting - MeshActivityTier.BUSY: 1.0, # Standard limiting - MeshActivityTier.CONGESTED: 2.0, # Aggressive limiting + MeshActivityTier.QUIET: 0.0, # No rate limiting + MeshActivityTier.NORMAL: 0.5, # Light limiting + MeshActivityTier.BUSY: 1.0, # Standard limiting + MeshActivityTier.CONGESTED: 2.0, # Aggressive limiting } @@ -50,10 +51,10 @@ def __init__(self, local_identity, storage, config=None, log_fn=None): self.local_identity = local_identity self.storage = storage self.config = config or {} - + # Create AdvertHandler internally as a parsing utility self.advert_handler = AdvertHandler(log_fn=log_fn or logger.info) - + # Cache for tracking known neighbors (avoid repeated database queries) self._known_neighbors = set() @@ -63,8 +64,10 @@ def __init__(self, local_identity, storage, config=None, log_fn=None): adaptive_cfg = repeater_cfg.get("advert_adaptive", {}) self._adaptive_enabled = bool(adaptive_cfg.get("enabled", True)) self._ewma_alpha = max(0.01, min(1.0, float(adaptive_cfg.get("ewma_alpha", 0.1)))) - self._tier_hysteresis_seconds = max(0.0, float(adaptive_cfg.get("hysteresis_seconds", 300.0))) - + self._tier_hysteresis_seconds = max( + 0.0, float(adaptive_cfg.get("hysteresis_seconds", 300.0)) + ) + # Tier thresholds (packets per minute) thresholds = adaptive_cfg.get("thresholds", {}) self._threshold_normal = float(thresholds.get("normal", 1.0)) @@ -76,15 +79,21 @@ def __init__(self, local_identity, storage, config=None, log_fn=None): self._rate_limit_enabled = bool(rate_cfg.get("enabled", True)) self._base_bucket_capacity = max(1.0, float(rate_cfg.get("bucket_capacity", 2))) self._base_refill_tokens = max(0.1, float(rate_cfg.get("refill_tokens", 1.0))) - self._base_refill_interval = max(1.0, float(rate_cfg.get("refill_interval_seconds", 36000.0))) + self._base_refill_interval = max( + 1.0, float(rate_cfg.get("refill_interval_seconds", 36000.0)) + ) self._base_min_interval = max(0.0, float(rate_cfg.get("min_interval_seconds", 3600.0))) # --- Penalty box config --- penalty_cfg = repeater_cfg.get("advert_penalty_box", {}) self._penalty_enabled = bool(penalty_cfg.get("enabled", True)) self._penalty_violation_threshold = max(1, int(penalty_cfg.get("violation_threshold", 2))) - self._penalty_decay_seconds = max(1.0, float(penalty_cfg.get("violation_decay_seconds", 43200.0))) - self._penalty_base_seconds = max(1.0, float(penalty_cfg.get("base_penalty_seconds", 21600.0))) + self._penalty_decay_seconds = max( + 1.0, float(penalty_cfg.get("violation_decay_seconds", 43200.0)) + ) + self._penalty_base_seconds = max( + 1.0, float(penalty_cfg.get("base_penalty_seconds", 21600.0)) + ) self._penalty_multiplier = max(1.0, float(penalty_cfg.get("penalty_multiplier", 2.0))) self._penalty_max_seconds = max( self._penalty_base_seconds, @@ -123,11 +132,11 @@ def __init__(self, local_identity, storage, config=None, log_fn=None): self._stats_adverts_dropped = 0 self._stats_advert_duplicates = 0 self._stats_tier_changes = 0 - + # Recent drops tracking — bounded deque so append is O(1) and the # oldest entry is evicted automatically (no pop(0) O(n) shift needed). self._recent_drops: deque = deque(maxlen=20) - + # Memory management self._last_cleanup = time.time() self._cleanup_interval_seconds = 3600.0 # Clean up every hour @@ -157,33 +166,31 @@ def _cleanup_old_state(self, now: float) -> None: while len(self._recent_advert_hashes) > self._advert_dedupe_max_hashes: self._recent_advert_hashes.popitem(last=False) - expired_penalties = [pk for pk, until in self._penalty_until.items() if until < now] for pk in expired_penalties: del self._penalty_until[pk] - inactive_pubkeys = [ - pk for pk, state in self._bucket_state.items() + pk + for pk, state in self._bucket_state.items() if now - state.get("last_seen", 0) > self._bucket_state_retention_seconds ] for pk in inactive_pubkeys: del self._bucket_state[pk] if pk in self._violation_state: del self._violation_state[pk] - + # 3. Decay old violations based on decay time for pk, vstate in list(self._violation_state.items()): last_violation = vstate.get("last_violation", 0) if now - last_violation > self._penalty_decay_seconds: # Reset violation count after decay period vstate["count"] = 0 - + if len(self._bucket_state) > self._max_tracked_pubkeys: # Sort by last_seen and remove oldest 10% sorted_pubkeys = sorted( - self._bucket_state.items(), - key=lambda x: x[1].get("last_seen", 0) + self._bucket_state.items(), key=lambda x: x[1].get("last_seen", 0) ) to_remove = int(len(sorted_pubkeys) * 0.1) for pk, _ in sorted_pubkeys[:to_remove]: @@ -192,12 +199,12 @@ def _cleanup_old_state(self, now: float) -> None: del self._violation_state[pk] if pk in self._penalty_until: del self._penalty_until[pk] - + # 5. Limit known neighbors set to prevent unbounded growth if len(self._known_neighbors) > 1000: # itertools.islice avoids materialising the full list first (O(n) → O(k)) self._known_neighbors = set(itertools.islice(self._known_neighbors, 500)) - + if expired_penalties or inactive_pubkeys: logger.debug( f"Cleaned up {len(expired_penalties)} expired penalties, " @@ -235,7 +242,9 @@ def _dedupe_advert_packet_hash(self, packet, now: float) -> bool: # Adaptive tier calculation # ------------------------------------------------------------------------- - def _update_metrics_window(self, now: float, is_advert: bool = True, is_duplicate: bool = False) -> None: + def _update_metrics_window( + self, now: float, is_advert: bool = True, is_duplicate: bool = False + ) -> None: """Update rolling metrics window and EWMA.""" elapsed = now - self._last_metrics_update @@ -243,9 +252,7 @@ def _update_metrics_window(self, now: float, is_advert: bool = True, is_duplicat # Calculate rates for window adverts_per_min = (self._adverts_in_window / elapsed) * 60.0 packets_per_min = (self._packets_in_window / elapsed) * 60.0 - dup_ratio = ( - self._duplicates_in_window / max(1, self._packets_in_window) - ) + dup_ratio = self._duplicates_in_window / max(1, self._packets_in_window) # Update EWMA alpha = self._ewma_alpha @@ -258,7 +265,7 @@ def _update_metrics_window(self, now: float, is_advert: bool = True, is_duplicat self._packets_in_window = 0 self._duplicates_in_window = 0 self._last_metrics_update = now - + # Periodic cleanup if now - self._last_cleanup >= self._cleanup_interval_seconds: self._cleanup_old_state(now) @@ -343,7 +350,7 @@ def _get_effective_limits(self) -> Tuple[float, float, float, float]: def _refill_tokens_if_needed(self, pubkey: str, now: float) -> dict: """Refill token bucket using effective (tier-scaled) limits.""" bucket_cap, refill_tokens, refill_interval, _ = self._get_effective_limits() - + state = self._bucket_state.get(pubkey) if state is None: state = { @@ -405,7 +412,7 @@ def _allow_advert(self, pubkey: str, now: float) -> Tuple[bool, str]: # Update metrics and tier self._update_metrics_window(now, is_advert=True) self._update_tier(now) - + if not self._rate_limit_enabled: self._stats_adverts_allowed += 1 return True, "" @@ -454,22 +461,24 @@ def get_rate_limit_stats(self) -> dict: """Get comprehensive rate limiting and adaptive tier statistics.""" now = time.time() bucket_cap, refill_tokens, refill_interval, min_interval = self._get_effective_limits() - + # Active penalties active_penalties = { pk[:16]: round(until - now, 1) for pk, until in self._penalty_until.items() if until > now } - + # Per-pubkey bucket states bucket_summary = {} for pk, state in self._bucket_state.items(): bucket_summary[pk[:16]] = { "tokens": round(state["tokens"], 2), - "last_seen_ago": round(now - state["last_seen"], 1) if state["last_seen"] > 0 else None, + "last_seen_ago": round(now - state["last_seen"], 1) + if state["last_seen"] > 0 + else None, } - + return { "adaptive": { "enabled": self._adaptive_enabled, @@ -494,7 +503,8 @@ def get_rate_limit_stats(self) -> dict: "adverts_dropped": self._stats_adverts_dropped, "adverts_duplicate_reheard": self._stats_advert_duplicates, "drop_rate": round( - self._stats_adverts_dropped / max(1, self._stats_adverts_allowed + self._stats_adverts_dropped), + self._stats_adverts_dropped + / max(1, self._stats_adverts_allowed + self._stats_adverts_dropped), 3, ), }, @@ -512,7 +522,7 @@ def get_rate_limit_stats(self) -> dict: "pubkey": drop["pubkey"], "name": drop["name"], "reason": drop["reason"], - "seconds_ago": round(now - drop["timestamp"], 1) + "seconds_ago": round(now - drop["timestamp"], 1), } for drop in reversed(self._recent_drops) # Most recent first ], @@ -534,16 +544,16 @@ async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: # Set signal metrics on packet for handler to use packet._snr = snr packet._rssi = rssi - + # Use AdvertHandler to parse the packet - it now returns parsed data advert_data = await self.advert_handler(packet) - + if not advert_data or not advert_data.get("valid"): logger.warning("Invalid advert packet received, dropping.") packet.mark_do_not_retransmit() packet.drop_reason = "Invalid advert packet" return - + # Extract data from parsed advert pubkey = advert_data["public_key"] node_name = advert_data["name"] @@ -568,10 +578,10 @@ async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: logger.warning(f"Dropping advert from '{node_name}' ({pubkey[:16]}...): {reason}") packet.mark_do_not_retransmit() packet.drop_reason = reason - + # Track recent drop (deduplicate by pubkey) pubkey_short = pubkey[:16] - + # Remove any existing entry for this pubkey, then append the # updated record. Rebuilding as a deque preserves maxlen so # the oldest entry is evicted automatically — no pop(0) needed. @@ -579,15 +589,12 @@ async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: (d for d in self._recent_drops if d["pubkey"] != pubkey_short), maxlen=20, ) - self._recent_drops.append({ - "pubkey": pubkey_short, - "name": node_name, - "reason": reason, - "timestamp": now - }) - + self._recent_drops.append( + {"pubkey": pubkey_short, "name": node_name, "reason": reason, "timestamp": now} + ) + return - + # Skip our own adverts if self.local_identity: local_pubkey = self.local_identity.get_public_key().hex() @@ -605,24 +612,22 @@ async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: if pubkey not in self._known_neighbors: # Only check database if not in cache if self.storage: - current_neighbors = await asyncio.to_thread( - self.storage.get_neighbors - ) + current_neighbors = await asyncio.to_thread(self.storage.get_neighbors) else: current_neighbors = {} is_new_neighbor = pubkey not in current_neighbors - + if is_new_neighbor: self._known_neighbors.add(pubkey) logger.info(f"Discovered new neighbor: {node_name} ({pubkey[:16]}...)") else: is_new_neighbor = False - + # Determine zero-hop: direct routes are always zero-hop, # flood routes are zero-hop if path_len <= 1 (received directly) path_len = len(packet.path) if packet.path else 0 zero_hop = path_len == 0 - + # Build advert record advert_record = { "timestamp": current_time, @@ -638,7 +643,7 @@ async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: "is_new_neighbor": is_new_neighbor, "zero_hop": zero_hop, } - + # Store to database (run in thread so event loop stays responsive; # blocking here can cause companion TCP clients to disconnect) if self.storage: @@ -649,7 +654,7 @@ async def process_advert_packet(self, packet, rssi: int, snr: float) -> None: ) except Exception as e: logger.error(f"Failed to store advert record: {e}") - + except Exception as e: logger.error(f"Error processing advert packet: {e}", exc_info=True) @@ -662,8 +667,10 @@ def reload_config(self) -> None: adaptive_cfg = repeater_cfg.get("advert_adaptive", {}) self._adaptive_enabled = bool(adaptive_cfg.get("enabled", True)) self._ewma_alpha = max(0.01, min(1.0, float(adaptive_cfg.get("ewma_alpha", 0.1)))) - self._tier_hysteresis_seconds = max(0.0, float(adaptive_cfg.get("hysteresis_seconds", 300.0))) - + self._tier_hysteresis_seconds = max( + 0.0, float(adaptive_cfg.get("hysteresis_seconds", 300.0)) + ) + thresholds = adaptive_cfg.get("thresholds", {}) self._threshold_normal = float(thresholds.get("normal", 1.0)) self._threshold_busy = float(thresholds.get("busy", 5.0)) @@ -674,15 +681,23 @@ def reload_config(self) -> None: self._rate_limit_enabled = bool(rate_cfg.get("enabled", True)) self._base_bucket_capacity = max(1.0, float(rate_cfg.get("bucket_capacity", 2))) self._base_refill_tokens = max(0.1, float(rate_cfg.get("refill_tokens", 1.0))) - self._base_refill_interval = max(1.0, float(rate_cfg.get("refill_interval_seconds", 36000.0))) + self._base_refill_interval = max( + 1.0, float(rate_cfg.get("refill_interval_seconds", 36000.0)) + ) self._base_min_interval = max(0.0, float(rate_cfg.get("min_interval_seconds", 3600.0))) # Penalty box config penalty_cfg = repeater_cfg.get("advert_penalty_box", {}) self._penalty_enabled = bool(penalty_cfg.get("enabled", True)) - self._penalty_violation_threshold = max(1, int(penalty_cfg.get("violation_threshold", 2))) - self._penalty_decay_seconds = max(1.0, float(penalty_cfg.get("violation_decay_seconds", 43200.0))) - self._penalty_base_seconds = max(1.0, float(penalty_cfg.get("base_penalty_seconds", 21600.0))) + self._penalty_violation_threshold = max( + 1, int(penalty_cfg.get("violation_threshold", 2)) + ) + self._penalty_decay_seconds = max( + 1.0, float(penalty_cfg.get("violation_decay_seconds", 43200.0)) + ) + self._penalty_base_seconds = max( + 1.0, float(penalty_cfg.get("base_penalty_seconds", 21600.0)) + ) self._penalty_multiplier = max(1.0, float(penalty_cfg.get("penalty_multiplier", 2.0))) self._penalty_max_seconds = max( self._penalty_base_seconds, diff --git a/repeater/handler_helpers/login.py b/repeater/handler_helpers/login.py index 1561dec..6c05a4e 100644 --- a/repeater/handler_helpers/login.py +++ b/repeater/handler_helpers/login.py @@ -19,7 +19,7 @@ def __init__(self, identity_manager, packet_injector=None, log_fn=None): self.identity_manager = identity_manager self.packet_injector = packet_injector self.log_fn = log_fn or logger.info - + self.handlers = {} self.acls = {} # Per-identity ACLs keyed by hash_byte self._pending_tasks = set() @@ -44,19 +44,19 @@ def register_identity( config = config or {} hash_byte = identity.get_public_key()[0] - + # Create ACL for this identity from repeater.handler_helpers.acl import ACL - + # Get security config for this identity if identity_type == "room_server": # Room servers use passwords from their settings section only settings = config.get("settings", {}) - + # Empty strings ('') are treated as "not set" by using 'or None' admin_password = settings.get("admin_password") or None guest_password = settings.get("guest_password") or None - + # Validate room servers have passwords configured if not admin_password and not guest_password: logger.error( @@ -64,7 +64,7 @@ def register_identity( f"Add them to 'settings' section. Skipping registration." ) return - + # Use configured passwords from settings final_security = { "max_clients": settings.get("max_clients", 50), @@ -75,18 +75,24 @@ def register_identity( else: # Repeater uses security from repeater.security in config security = config.get("repeater", {}).get("security", {}) + admin_password = security.get("admin_password") or None + guest_password = security.get("guest_password") or None final_security = { "max_clients": security.get("max_clients", 10), - "admin_password": security.get("admin_password", "admin123"), - "guest_password": security.get("guest_password", "guest123"), - "allow_read_only": security.get("allow_read_only", True), + "admin_password": admin_password, + "guest_password": guest_password, + "allow_read_only": security.get("allow_read_only", False), } + if not admin_password and not guest_password: + logger.warning( + f"Repeater '{name}' has no admin/guest password configured; setup is required before login." + ) logger.debug( f"Repeater security config: admin_pw={'SET' if final_security['admin_password'] else 'NONE'}, " f"guest_pw={'SET' if final_security['guest_password'] else 'NONE'}, " f"max_clients={final_security['max_clients']}" ) - + # Create ACL for this identity identity_acl = ACL( max_clients=final_security["max_clients"], @@ -94,7 +100,7 @@ def register_identity( guest_password=final_security["guest_password"], allow_read_only=final_security["allow_read_only"], ) - + self.acls[hash_byte] = identity_acl logger.info(f"Created ACL for {identity_type} '{name}': hash=0x{hash_byte:02X}") @@ -119,9 +125,9 @@ def auth_callback_with_context( authenticate_callback=auth_callback_with_context, is_room_server=(identity_type == "room_server"), ) - + handler.set_send_packet_callback(self._send_packet_with_delay) - + self.handlers[hash_byte] = handler logger.info(f"Registered {identity_type} '{name}' login handler: hash=0x{hash_byte:02X}") @@ -131,9 +137,9 @@ async def process_login_packet(self, packet): try: if len(packet.payload) < 1: return False - + dest_hash = packet.payload[0] - + handler = self.handlers.get(dest_hash) if handler: logger.debug(f"Routing login to identity: hash=0x{dest_hash:02X}") @@ -154,7 +160,7 @@ async def process_login_packet(self, packet): return False def _send_packet_with_delay(self, packet, delay_ms: int): - + if self.packet_injector: task = asyncio.create_task(self._delayed_send(packet, delay_ms)) self._track_task(task) @@ -162,7 +168,7 @@ def _send_packet_with_delay(self, packet, delay_ms: int): logger.error("No packet injector configured, cannot send login response") async def _delayed_send(self, packet, delay_ms: int): - + await asyncio.sleep(delay_ms / 1000.0) try: await self.packet_injector(packet, wait_for_ack=False) @@ -173,7 +179,7 @@ async def _delayed_send(self, packet, delay_ms: int): def get_acl_dict(self): """Return dictionary of ACLs keyed by identity hash.""" return self.acls - + def get_acl_for_identity(self, hash_byte: int): """Get ACL for a specific identity.""" return self.acls.get(hash_byte) @@ -183,7 +189,7 @@ def list_authenticated_clients(self, hash_byte: int = None): if hash_byte is not None: acl = self.acls.get(hash_byte) return acl.get_all_clients() if acl else [] - + # Return clients from all ACLs all_clients = [] for acl in self.acls.values(): diff --git a/repeater/handler_helpers/mesh_cli.py b/repeater/handler_helpers/mesh_cli.py index f56f1f6..aa379f7 100644 --- a/repeater/handler_helpers/mesh_cli.py +++ b/repeater/handler_helpers/mesh_cli.py @@ -1,15 +1,12 @@ import logging -import time from pathlib import Path from typing import Any, Callable, Dict, Optional -import yaml logger = logging.getLogger(__name__) class MeshCLI: - def __init__( self, config_path: str, @@ -33,6 +30,7 @@ def __init__( # Store event loop reference for thread-safe scheduling import asyncio + try: self._event_loop = asyncio.get_running_loop() except RuntimeError: diff --git a/repeater/handler_helpers/path.py b/repeater/handler_helpers/path.py index d482118..f9055c2 100644 --- a/repeater/handler_helpers/path.py +++ b/repeater/handler_helpers/path.py @@ -62,7 +62,7 @@ async def process_path_packet(self, packet): # Parse decrypted PATH data # Format: path_len(1) + path[path_len] + extra_type(1) + extra[...] if len(decrypted) < 1: - logger.debug(f"Decrypted PATH data too short") + logger.debug("Decrypted PATH data too short") return False path_len = decrypted[0] diff --git a/repeater/handler_helpers/protocol_request.py b/repeater/handler_helpers/protocol_request.py index ca8d4ae..74770a4 100644 --- a/repeater/handler_helpers/protocol_request.py +++ b/repeater/handler_helpers/protocol_request.py @@ -14,7 +14,6 @@ REQ_TYPE_GET_NEIGHBOURS, REQ_TYPE_GET_OWNER_INFO, REQ_TYPE_GET_STATUS, - REQ_TYPE_GET_TELEMETRY_DATA, SERVER_RESPONSE_DELAY_MS, ProtocolRequestHandler, ) @@ -43,23 +42,23 @@ def __init__( self.engine = engine self.neighbor_tracker = neighbor_tracker self.config = config or {} - + # Dictionary of core handlers keyed by dest_hash self.handlers = {} - + def register_identity(self, name: str, identity, identity_type: str = "repeater"): hash_byte = identity.get_public_key()[0] - + # Get ACL for this identity identity_acl = self.acl_dict.get(hash_byte) if not identity_acl: logger.warning(f"Cannot register identity '{name}': no ACL for hash 0x{hash_byte:02X}") return - + # Create ACL contacts wrapper acl_contacts = self._create_acl_contacts_wrapper(identity_acl) - + # Build request handlers dict request_handlers = { REQ_TYPE_GET_STATUS: self._handle_get_status, @@ -67,7 +66,7 @@ def register_identity(self, name: str, identity, identity_type: str = "repeater" REQ_TYPE_GET_NEIGHBOURS: self._handle_get_neighbours, REQ_TYPE_GET_OWNER_INFO: self._handle_get_owner_info, } - + # Create core handler handler = ProtocolRequestHandler( local_identity=identity, @@ -76,14 +75,14 @@ def register_identity(self, name: str, identity, identity_type: str = "repeater" request_handlers=request_handlers, log_fn=logger.info, ) - + self.handlers[hash_byte] = { "handler": handler, "identity": identity, "name": name, "type": identity_type, } - + logger.info(f"Registered protocol request handler for '{name}': hash=0x{hash_byte:02X}") def _create_acl_contacts_wrapper(self, acl): @@ -92,47 +91,47 @@ def _create_acl_contacts_wrapper(self, acl): class ACLContactsWrapper: def __init__(self, identity_acl): self._acl = identity_acl - + @property def contacts(self): return self._acl.get_all_clients() - + return ACLContactsWrapper(acl) - + def _get_client_from_acl(self, acl, src_hash: int): """Get client from ACL by source hash.""" for client_info in acl.get_all_clients(): if client_info.id.get_public_key()[0] == src_hash: return client_info return None - + async def process_request_packet(self, packet): try: if len(packet.payload) < 2: return False - + dest_hash = packet.payload[0] - + handler_info = self.handlers.get(dest_hash) if not handler_info: return False - + # Let core handler build response response_packet = await handler_info["handler"](packet) - + # Send response after delay if response_packet and self.packet_injector: await asyncio.sleep(SERVER_RESPONSE_DELAY_MS / 1000.0) await self.packet_injector(response_packet, wait_for_ack=False) - + packet.mark_do_not_retransmit() return True - + except Exception as e: logger.error(f"Error processing protocol request: {e}", exc_info=True) return False - + def _handle_get_status(self, client, timestamp: int, req_data: bytes): """Build 56-byte RepeaterStats (firmware layout from MeshCore simple_repeater/MyMesh.h).""" # RepeaterStats: uint16 batt, uint16 curr_tx_queue_len, int16 noise_floor, int16 last_rssi, @@ -192,11 +191,7 @@ def _handle_get_status(self, client, timestamp: int, req_data: bytes): n_recv_direct = getattr(self.engine, "recv_direct_count", 0) if self.engine else 0 n_direct_dups = getattr(self.engine, "direct_dup_count", 0) if self.engine else 0 n_flood_dups = getattr(self.engine, "flood_dup_count", 0) if self.engine else 0 - n_recv_errors = ( - int(getattr(self.radio, "crc_error_count", 0) or 0) - if self.radio - else 0 - ) + n_recv_errors = int(getattr(self.radio, "crc_error_count", 0) or 0) if self.radio else 0 # Pack 56-byte RepeaterStats (layout matches firmware) stats = struct.pack( @@ -235,8 +230,10 @@ def _handle_get_status(self, client, timestamp: int, req_data: bytes): def _make_handle_get_access_list(self, identity_acl): """Create a closure for GET_ACCESS_LIST bound to a specific identity ACL.""" + def _handler(client, timestamp: int, req_data: bytes): return self._handle_get_access_list(client, timestamp, req_data, identity_acl) + return _handler def _handle_get_access_list(self, client, timestamp: int, req_data: bytes, identity_acl): @@ -313,13 +310,13 @@ def _handle_get_neighbours(self, client, timestamp: int, req_data: bytes): # Sort (matches C++ order_by values) if order_by == 0: - entries.sort(key=lambda e: e[1]) # newest first (smallest heard_ago) + entries.sort(key=lambda e: e[1]) # newest first (smallest heard_ago) elif order_by == 1: entries.sort(key=lambda e: e[1], reverse=True) # oldest first elif order_by == 2: entries.sort(key=lambda e: e[2], reverse=True) # strongest SNR first elif order_by == 3: - entries.sort(key=lambda e: e[2]) # weakest SNR first + entries.sort(key=lambda e: e[2]) # weakest SNR first total_count = len(entries) @@ -350,7 +347,11 @@ def _handle_get_neighbours(self, client, timestamp: int, req_data: bytes): logger.debug( "GET_NEIGHBOURS: total=%d, returned=%d, offset=%d, order=%d, requested=%d", - total_count, results_count, offset, order_by, count, + total_count, + results_count, + offset, + order_by, + count, ) return header + bytes(results) @@ -366,6 +367,7 @@ def _handle_get_owner_info(self, client, timestamp: int, req_data: bytes): # Version: use package version if available, fallback to "pyMC" try: from importlib.metadata import version as pkg_version + fw_version = pkg_version("pymc-repeater") except Exception: fw_version = "pyMC" diff --git a/repeater/handler_helpers/repeater_cli.py b/repeater/handler_helpers/repeater_cli.py index e2f993b..99a675d 100644 --- a/repeater/handler_helpers/repeater_cli.py +++ b/repeater/handler_helpers/repeater_cli.py @@ -5,11 +5,9 @@ """ import logging -import time from pathlib import Path -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict -import yaml logger = logging.getLogger(__name__) @@ -22,8 +20,8 @@ class MeshCLI: """ def __init__( - self, - config_path: str, + self, + config_path: str, config: Dict[str, Any], save_config_callback: Callable, identity_type: str = "repeater", @@ -31,7 +29,7 @@ def __init__( ): """ Initialize the CLI handler. - + Args: config_path: Path to the config.yaml file config: Current configuration dictionary @@ -51,19 +49,19 @@ def __init__( def handle_command(self, sender_pubkey: bytes, command: str, is_admin: bool) -> str: """ Handle an incoming command from a client. - + Args: sender_pubkey: Public key of sender command: Command string (may include XX| prefix) is_admin: Whether sender has admin permissions - + Returns: Reply string to send back to sender """ # Check admin permission first if not is_admin: return "Error: Admin permission required" - + logger.debug(f"handle_command received: '{command}' (len={len(command)})") # Extract optional sequence prefix (XX|) @@ -72,26 +70,26 @@ def handle_command(self, sender_pubkey: bytes, command: str, is_admin: bool) -> prefix = command[:3] command = command[3:] logger.debug(f"Extracted prefix: '{prefix}', remaining command: '{command}'") - + # Strip leading/trailing whitespace command = command.strip() logger.debug(f"After strip: '{command}'") - + # Route to appropriate handler reply = self._route_command(command) - + # Add prefix back to reply if present if prefix: return prefix + reply return reply - + def _route_command(self, command: str) -> str: """Route command to appropriate handler method.""" - + # Help if command == "help" or command.startswith("help "): return self._cmd_help(command) - + # System commands elif command == "reboot": return self._cmd_reboot() @@ -109,57 +107,57 @@ def _route_command(self, command: str) -> str: return self._cmd_clear_stats() elif command == "ver": return self._cmd_version() - + # Get commands elif command.startswith("get "): return self._cmd_get(command[4:]) - + # Set commands elif command.startswith("set "): return self._cmd_set(command[4:]) - + # ACL commands elif command.startswith("setperm "): return self._cmd_setperm(command) elif command == "get acl": return "Error: Use 'get acl' via serial console only" - + # Region commands (repeaters only) elif command.startswith("region"): if self.enable_regions: return self._cmd_region(command) else: return "Error: Region commands not available for room servers" - + # Neighbor commands elif command == "neighbors": return self._cmd_neighbors() elif command.startswith("neighbor.remove "): return self._cmd_neighbor_remove(command) - + # Temporary radio params elif command.startswith("tempradio "): return self._cmd_tempradio(command) - + # Sensor commands elif command.startswith("sensor "): return "Error: Sensor commands not implemented in Python repeater" - + # GPS commands elif command.startswith("gps"): return "Error: GPS commands not implemented in Python repeater" - + # Logging commands elif command.startswith("log "): return self._cmd_log(command) - + # Statistics commands elif command.startswith("stats-"): return "Error: Stats commands not fully implemented yet" - + else: return "Unknown command" - + # ==================== Help Command ==================== def _cmd_help(self, command: str) -> str: @@ -167,7 +165,7 @@ def _cmd_help(self, command: str) -> str: parts = command.split(None, 1) if len(parts) == 2: return self._help_detail(parts[1]) - + lines = [ "=== pyMC CLI Commands ===", "", @@ -260,25 +258,25 @@ def _help_detail(self, topic: str) -> str: return details.get(topic, f"No detailed help for '{topic}'. Type 'help' for command list.") # ==================== System Commands == - + def _cmd_reboot(self) -> str: """Reboot the repeater process.""" from repeater.service_utils import restart_service - + logger.warning("Reboot command received via repeater CLI") success, message = restart_service() - + if success: return f"OK - {message}" else: return f"Error: {message}" - + def _cmd_advert(self) -> str: """Send self advertisement.""" logger.info("Advert command received") # TODO: Trigger advertisement through packet handler return "Error: Not yet implemented" - + def _cmd_clock(self, command: str) -> str: """Handle clock commands.""" if command == "clock": @@ -292,15 +290,15 @@ def _cmd_clock(self, command: str) -> str: return "OK - clock sync not needed (system time used)" else: return "Unknown clock command" - + def _cmd_time(self, command: str) -> str: """Set time - not supported in Python (use system time).""" return "Error: Time setting not supported (system time is used)" - + def _cmd_password(self, command: str) -> str: """Change admin password.""" new_password = command[9:].strip() - + if not new_password: return "Error: Password cannot be empty" @@ -317,12 +315,12 @@ def _cmd_password(self, command: str) -> str: except Exception as e: logger.error(f"Failed to save password: {e}") return "Error: Failed to save password" - + def _cmd_clear_stats(self) -> str: """Clear statistics.""" # TODO: Implement stats clearing return "Error: Not yet implemented" - + def _cmd_version(self) -> str: """Get version information.""" role = "room_server" if self.identity_type == "room_server" else "repeater" @@ -330,7 +328,7 @@ def _cmd_version(self) -> str: return f"pyMC_{role} v{version}" # ==================== Get Commands ==================== - + def _cmd_get(self, param: str) -> str: """Handle get commands.""" param = param.strip() @@ -379,7 +377,7 @@ def _cmd_get(self, param: str) -> str: elif param == "public.key": # TODO: Get from identity return "Error: Not yet implemented" - + elif param == "role": role = "room_server" if self.identity_type == "room_server" else "repeater" return f"> {role}" @@ -430,15 +428,15 @@ def _cmd_get(self, param: str) -> str: else: return f"??: {param}" - + # ==================== Set Commands ==================== - + def _cmd_set(self, param: str) -> str: """Handle set commands.""" parts = param.split(None, 1) if len(parts) < 2: return "Error: Missing value" - + key, value = parts[0], parts[1] try: @@ -579,43 +577,43 @@ def _cmd_set(self, param: str) -> str: else: return f"unknown config: {key}" - + except ValueError as e: return f"Error: invalid value - {e}" except Exception as e: logger.error(f"Set command error: {e}") return f"Error: {e}" - + # ==================== ACL Commands ==================== - + def _cmd_setperm(self, command: str) -> str: """Set permissions for a public key.""" # Format: setperm {pubkey-hex} {permissions-int} parts = command[8:].split() if len(parts) < 2: return "Err - bad params" - + pubkey_hex = parts[0] try: permissions = int(parts[1]) except ValueError: return "Err - invalid permissions" - + # TODO: Apply permissions via ACL logger.info(f"setperm command: {pubkey_hex} -> {permissions}") return "Error: Not yet implemented - use config file" - + # ==================== Region Commands ==================== - + def _cmd_region(self, command: str) -> str: """Handle region commands.""" parts = command.split() - + if len(parts) == 1: return "Error: Region commands not implemented in Python repeater" - + subcommand = parts[1] - + if subcommand == "load": return "Error: Region commands not implemented" elif subcommand == "save": @@ -624,42 +622,42 @@ def _cmd_region(self, command: str) -> str: return "Error: Region commands not implemented" else: return "Err - ??" - + # ==================== Neighbor Commands ==================== - + def _cmd_neighbors(self) -> str: """List neighbors.""" # TODO: Get neighbors from routing table return "Error: Not yet implemented" - + def _cmd_neighbor_remove(self, command: str) -> str: """Remove a neighbor.""" pubkey_hex = command[16:].strip() - + if not pubkey_hex: return "ERR: Missing pubkey" - + # TODO: Remove neighbor from routing table logger.info(f"neighbor.remove: {pubkey_hex}") return "Error: Not yet implemented" - + # ==================== Temporary Radio Commands ==================== - + def _cmd_tempradio(self, command: str) -> str: """Apply temporary radio parameters.""" # Format: tempradio {freq} {bw} {sf} {cr} {timeout_mins} parts = command[10:].split() - + if len(parts) < 5: return "Error: Expected freq bw sf cr timeout_mins" - + try: freq = float(parts[0]) bw = float(parts[1]) sf = int(parts[2]) cr = int(parts[3]) timeout_mins = int(parts[4]) - + # Validate if not (300.0 <= freq <= 2500.0): return "Error: invalid frequency" @@ -671,16 +669,16 @@ def _cmd_tempradio(self, command: str) -> str: return "Error: invalid coding rate" if timeout_mins <= 0: return "Error: invalid timeout" - + # TODO: Apply temporary radio parameters logger.info(f"tempradio: {freq}MHz {bw}kHz SF{sf} CR4/{cr} for {timeout_mins}min") return "Error: Not yet implemented" - + except ValueError: return "Error, invalid params" - + # ==================== Logging Commands ==================== - + def _cmd_log(self, command: str) -> str: """Handle log commands.""" if command == "log start": diff --git a/repeater/handler_helpers/room_server.py b/repeater/handler_helpers/room_server.py index f65b0c6..632e82c 100644 --- a/repeater/handler_helpers/room_server.py +++ b/repeater/handler_helpers/room_server.py @@ -1,7 +1,8 @@ import asyncio import logging +import secrets import time -from typing import Dict, Optional +from typing import Dict from pymc_core.protocol import CryptoUtils, PacketBuilder from pymc_core.protocol.constants import PAYLOAD_TYPE_TXT_MSG @@ -46,7 +47,6 @@ class GlobalRateLimiter: - def __init__(self, min_gap_seconds: float = 0.1): self.min_gap = min_gap_seconds # Minimum gap between consecutive messages self.lock = asyncio.Lock() # Only one transmission at a time @@ -60,7 +60,7 @@ async def acquire(self): time_since_last = now - self.last_release_time if time_since_last < self.min_gap: wait_time = self.min_gap - time_since_last - logger.debug(f"Global rate limiter: waiting {wait_time*1000:.0f}ms") + logger.debug(f"Global rate limiter: waiting {wait_time * 1000:.0f}ms") await asyncio.sleep(wait_time) # Lock is now held - caller can transmit # Will be released when context exits @@ -70,7 +70,6 @@ def release(self): class RoomServer: - def __init__( self, room_hash: int, @@ -307,7 +306,7 @@ async def add_post( return True else: - logger.error(f"Failed to store message to database") + logger.error("Failed to store message to database") return False except Exception as e: @@ -537,9 +536,7 @@ async def _evict_failed_clients(self): async def _sync_loop(self): # SAFETY: Stagger room startup to prevent thundering herd - import random - - startup_delay = random.uniform(0, 5) # 0-5 second random delay + startup_delay = secrets.randbelow(5001) / 1000.0 # 0-5 second random delay await asyncio.sleep(startup_delay) logger.info(f"Room '{self.room_name}' sync loop starting (delayed {startup_delay:.1f}s)") diff --git a/repeater/handler_helpers/text.py b/repeater/handler_helpers/text.py index e6f07a5..e4f1e91 100644 --- a/repeater/handler_helpers/text.py +++ b/repeater/handler_helpers/text.py @@ -8,7 +8,6 @@ import asyncio import logging -import struct import time from pymc_core.node.handlers.text import TextMessageHandler @@ -24,7 +23,6 @@ class TextHelper: - def __init__( self, identity_manager, @@ -175,21 +173,19 @@ def register_identity( except RuntimeError: # No running event loop in this thread if self._loop and self._loop.is_running(): - future = asyncio.run_coroutine_threadsafe( - room_server.start(), self._loop - ) + future = asyncio.run_coroutine_threadsafe(room_server.start(), self._loop) future.add_done_callback( - lambda f: logger.error( - f"Room server '{name}' failed: {f.exception()}", - exc_info=f.exception(), + lambda f: ( + logger.error( + f"Room server '{name}' failed: {f.exception()}", + exc_info=f.exception(), + ) + if not f.cancelled() and f.exception() + else None ) - if not f.cancelled() and f.exception() - else None ) else: - logger.error( - f"Cannot start room server '{name}': no event loop available" - ) + logger.error(f"Cannot start room server '{name}': no event loop available") logger.info( f"Registered room server '{name}': hash=0x{hash_byte:02X}, " @@ -271,7 +267,7 @@ async def _on_message_received( # Placeholder - can be overridden or callback can be added logger.debug( - f"Message received for {identity_type} '{identity_name}' " f"from 0x{src_hash:02X}" + f"Message received for {identity_type} '{identity_name}' from 0x{src_hash:02X}" ) # Extract decrypted message if available @@ -511,7 +507,7 @@ async def _send_cli_reply(self, original_packet, reply_text: str, handler_info: """ import time - from pymc_core.protocol import Identity, PacketBuilder + from pymc_core.protocol import PacketBuilder from pymc_core.protocol.constants import PAYLOAD_TYPE_TXT_MSG try: diff --git a/repeater/handler_helpers/trace.py b/repeater/handler_helpers/trace.py index 448f6b0..80ce715 100644 --- a/repeater/handler_helpers/trace.py +++ b/repeater/handler_helpers/trace.py @@ -53,9 +53,7 @@ def __init__( self.packet_injector = packet_injector # Function to inject packets into router # Ping callback system - track pending ping requests by tag - self.pending_pings = ( - {} - ) # {tag: {'event': asyncio.Event(), 'result': dict, 'target': int, 'sent_at': float}} + self.pending_pings = {} # {tag: {'event': asyncio.Event(), 'result': dict, 'target': int, 'sent_at': float}} # Optional: when trace reaches final node, call this (packet, parsed_data) to push 0x89 to companions self.on_trace_complete = None # async (packet, parsed_data) -> None @@ -103,8 +101,7 @@ async def process_trace_packet(self, packet) -> None: rssi_val = getattr(packet, "rssi", 0) if rssi_val == 0: logger.warning( - f"Ignoring trace response for tag {trace_tag} " - "with RSSI=0 (no signal data)" + f"Ignoring trace response for tag {trace_tag} with RSSI=0 (no signal data)" ) return # wait for a valid response or let timeout handle it ping_info = self.pending_pings[trace_tag] diff --git a/repeater/identity_manager.py b/repeater/identity_manager.py index 7de98b7..9953ca7 100644 --- a/repeater/identity_manager.py +++ b/repeater/identity_manager.py @@ -5,7 +5,6 @@ class IdentityManager: - def __init__(self, config: dict): self.config = config self.identities: Dict[int, Tuple[Any, dict, str]] = {} diff --git a/repeater/keygen.py b/repeater/keygen.py index 65032be..4a92b08 100644 --- a/repeater/keygen.py +++ b/repeater/keygen.py @@ -29,9 +29,9 @@ def generate_meshcore_keypair() -> Tuple[bytes, bytes]: # 3. Ed25519 scalar clamping on first 32 bytes clamped = bytearray(digest[:32]) - clamped[0] &= 248 # Clear bottom 3 bits - clamped[31] &= 63 # Clear top 2 bits - clamped[31] |= 64 # Set bit 6 + clamped[0] &= 248 # Clear bottom 3 bits + clamped[31] &= 63 # Clear top 2 bits + clamped[31] |= 64 # Set bit 6 # 4. Derive public key public_key = crypto_scalarmult_ed25519_base_noclamp(bytes(clamped)) diff --git a/repeater/local_cli.py b/repeater/local_cli.py index b376740..9ee79d7 100644 --- a/repeater/local_cli.py +++ b/repeater/local_cli.py @@ -5,6 +5,8 @@ """ import sys +from typing import Optional +from urllib.parse import urlparse CONFIG_PATHS = [ @@ -13,6 +15,14 @@ ] +def _validate_http_url(url: str) -> None: + parsed = urlparse(url) + if parsed.scheme not in {"http", "https"}: + raise ValueError(f"Unsupported URL scheme: {parsed.scheme or ''}") + if not parsed.hostname: + raise ValueError("URL must include a host") + + def _load_config(config_path=None): """Load repeater config.yaml, trying common paths.""" import yaml @@ -27,7 +37,7 @@ def _load_config(config_path=None): return {} -def run_client_cli(host: str = "127.0.0.1", port: int = 8000, password: str = ""): +def run_client_cli(host: str = "127.0.0.1", port: int = 8000, password: Optional[str] = None): """ Standalone CLI client that connects to a running repeater's HTTP API. """ @@ -41,18 +51,21 @@ def run_client_cli(host: str = "127.0.0.1", port: int = 8000, password: str = "" token = None if password: try: - auth_data = json.dumps({ - "username": "admin", - "password": password, - "client_id": "pymc-cli", - }).encode() + auth_data = json.dumps( + { + "username": "admin", + "password": password, + "client_id": "pymc-cli", + } + ).encode() req = urllib.request.Request( f"{base_url}/auth/login", data=auth_data, headers={"Content-Type": "application/json"}, method="POST", ) - with urllib.request.urlopen(req, timeout=5) as resp: + _validate_http_url(req.full_url) + with urllib.request.urlopen(req, timeout=5) as resp: # nosec B310 result = json.loads(resp.read()) token = result.get("token") or result.get("data", {}).get("token") except urllib.error.URLError as e: @@ -92,7 +105,8 @@ def run_client_cli(host: str = "127.0.0.1", port: int = 8000, password: str = "" }, method="POST", ) - with urllib.request.urlopen(req, timeout=10) as resp: + _validate_http_url(req.full_url) + with urllib.request.urlopen(req, timeout=10) as resp: # nosec B310 result = json.loads(resp.read()) if result.get("success"): print(result["data"]["reply"]) @@ -112,15 +126,19 @@ def main(): description="Connect to a running pyMC Repeater and issue CLI commands" ) parser.add_argument( - "--config", default=None, + "--config", + default=None, help="Path to config.yaml (auto-detected if not set)", ) parser.add_argument( - "--host", default=None, + "--host", + default=None, help="Repeater HTTP host (default: 127.0.0.1)", ) parser.add_argument( - "--port", type=int, default=None, + "--port", + type=int, + default=None, help="Repeater HTTP port (default: from config or 8000)", ) args = parser.parse_args() diff --git a/repeater/main.py b/repeater/main.py index 5c98e3c..157a69f 100644 --- a/repeater/main.py +++ b/repeater/main.py @@ -3,11 +3,11 @@ import logging import os import signal -import sys import socket +import sys import time -from repeater.companion.utils import validate_companion_node_name, normalize_companion_identity_key +from repeater.companion.utils import normalize_companion_identity_key, validate_companion_node_name from repeater.config import NullRadio, get_radio_for_board, load_config, save_config from repeater.config_manager import ConfigManager from repeater.data_acquisition.glass_handler import GlassHandler @@ -31,7 +31,6 @@ class RepeaterDaemon: - def __init__(self, config: dict, radio=None): self.config = config @@ -76,14 +75,14 @@ async def initialize(self): logger.info(f"Initializing repeater: {self.config['repeater']['node_name']}") - #----------------------------------------------- - # Get the actual Network IP Address + # ----------------------------------------------- + # Get the actual Network IP Address try: # This looks for the IP assigned to the default hostname host_name = socket.gethostname() # We try to get the IP associated with the hostname self.network_ip = socket.gethostbyname(host_name) - + # If that still gives 127.0.x.x, let's try a different internal method if self.network_ip.startswith("127."): s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) @@ -96,7 +95,7 @@ async def initialize(self): self.network_ip = "Unknown" logger.info(f"System Network IP: {self.network_ip}") - #----------------------------------------------- + # ----------------------------------------------- if self.radio is None: radio_type_raw = self.config.get("radio_type") @@ -202,7 +201,9 @@ async def initialize(self): self.dispatcher._is_own_packet = lambda pkt: False self.repeater_handler = RepeaterHandler( - self.config, self.dispatcher, self.local_hash, + self.config, + self.dispatcher, + self.local_hash, local_hash_bytes=self.local_hash_bytes, send_advert_func=self.send_advert, ) @@ -407,7 +408,9 @@ async def initialize(self): and self.repeater_handler.storage and hasattr(self.repeater_handler.storage, "set_glass_publisher") ): - self.repeater_handler.storage.set_glass_publisher(self.glass_handler.publish_telemetry) + self.repeater_handler.storage.set_glass_publisher( + self.glass_handler.publish_telemetry + ) except Exception as e: logger.error(f"Failed to initialize dispatcher: {e}") @@ -426,7 +429,7 @@ async def _load_additional_identities(self): identity_key = room_config.get("identity_key") if not name or not identity_key: - logger.warning(f"Skipping room server config: missing name or identity_key") + logger.warning("Skipping room server config: missing name or identity_key") continue # Convert identity_key to bytes if it's a hex string @@ -477,7 +480,7 @@ async def _load_additional_identities(self): async def _load_companion_identities(self) -> None: """Load companion identities from config and create CompanionBridge + frame server for each.""" from pymc_core import LocalIdentity - from pymc_core.companion.models import Channel, Contact + from pymc_core.companion.models import Channel from repeater.companion import CompanionFrameServer, RepeaterCompanionBridge @@ -511,7 +514,9 @@ async def _load_companion_identities(self) -> None: if isinstance(identity_key, str): try: - identity_key_bytes = bytes.fromhex(normalize_companion_identity_key(identity_key)) + identity_key_bytes = bytes.fromhex( + normalize_companion_identity_key(identity_key) + ) except ValueError as e: logger.error(f"Companion '{name}' identity_key invalid hex: {e}") continue @@ -535,11 +540,12 @@ async def _load_companion_identities(self) -> None: node_name = settings.get("node_name", name) tcp_port = settings.get("tcp_port", 5000) bind_address = settings.get("bind_address", "0.0.0.0") - tcp_timeout_raw = settings.get("tcp_timeout", 8 * 60 * 60) # 8 hours + tcp_timeout_raw = settings.get("tcp_timeout", 8 * 60 * 60) # 8 hours client_idle_timeout_sec = None if tcp_timeout_raw == 0 else int(tcp_timeout_raw) def _make_sync_node_name_to_config(companion_name: str): """Return a callback that syncs node_name to config for this companion (binds name at creation).""" + def _sync(new_node_name: str) -> None: try: validated = validate_companion_node_name(new_node_name) @@ -555,6 +561,7 @@ def _sync(new_node_name: str) -> None: if config_path: save_config(self.config, config_path) break + return _sync bridge = RepeaterCompanionBridge( diff --git a/repeater/packet_router.py b/repeater/packet_router.py index 0b5024e..8731782 100644 --- a/repeater/packet_router.py +++ b/repeater/packet_router.py @@ -44,7 +44,6 @@ def _is_direct_final_hop(packet) -> bool: class PacketRouter: - def __init__(self, daemon_instance): self.daemon = daemon_instance self.queue = asyncio.Queue(maxsize=500) @@ -75,7 +74,7 @@ async def start(self): self.running = True self.router_task = asyncio.create_task(self._process_queue()) logger.info("Packet router started") - + async def stop(self): self.running = False if self.router_task: @@ -121,7 +120,7 @@ def _on_route_done(self, task: asyncio.Task) -> None: exc = task.exception() if exc is not None: logger.error("_route_packet raised: %s", exc, exc_info=exc) - + def _should_deliver_path_to_companions(self, packet) -> bool: """Return True if this PATH/protocol-response should be delivered to companions (first of duplicates).""" key = _companion_dedup_key(packet) @@ -173,9 +172,7 @@ async def inject_packet(self, packet, wait_for_ack: bool = False): # (avoids duty-cycle or dispatcher races where a later packet goes out first) async with self._inject_lock: # Use local_transmission=True to bypass forwarding logic - await self.daemon.repeater_handler( - packet, metadata, local_transmission=True - ) + await self.daemon.repeater_handler(packet, metadata, local_transmission=True) # Mark so when this packet is dequeued we don't pass to engine again (avoid double-send / double-count) packet._injected_for_tx = True @@ -189,7 +186,11 @@ async def inject_packet(self, packet, wait_for_ack: bool = False): ) # Log protocol REQ (e.g. status/telemetry) so we can confirm target node ptype = getattr(packet, "get_payload_type", lambda: None)() - if ptype == ProtocolRequestHandler.payload_type() and packet.payload and packet_len >= 1: + if ( + ptype == ProtocolRequestHandler.payload_type() + and packet.payload + and packet_len >= 1 + ): logger.info( "Injected protocol REQ: dest=0x%02x, payload=%d bytes", packet.payload[0], @@ -200,7 +201,7 @@ async def inject_packet(self, packet, wait_for_ack: bool = False): except Exception as e: logger.error(f"Error injecting packet through engine: {e}") return False - + async def _process_queue(self): while self.running: try: @@ -213,7 +214,9 @@ async def _process_queue(self): logger.warning( "In-flight task cap reached (%d/%d), dropping packet " "(session total dropped: %d)", - self._in_flight, self._max_in_flight, self._cap_drop_count, + self._in_flight, + self._max_in_flight, + self._cap_drop_count, ) continue self._in_flight += 1 diff --git a/repeater/sensors/base.py b/repeater/sensors/base.py index 1aaefb0..b14748d 100644 --- a/repeater/sensors/base.py +++ b/repeater/sensors/base.py @@ -2,23 +2,36 @@ import importlib.util import logging -import subprocess +import re + +# Required for optional dependency installation in controlled, validated path. +import subprocess # nosec B404 import sys from abc import ABC, abstractmethod from datetime import datetime, timezone from typing import Any, Dict, Iterable, Optional, Tuple +_PIP_PACKAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") + + class SensorBase(ABC): """Base class for lightweight sensor plug-ins.""" sensor_type = "sensor" - def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log: Optional[logging.Logger] = None): + def __init__( + self, + name: str, + config: Optional[Dict[str, Any]] = None, + log: Optional[logging.Logger] = None, + ): self.name = name self.config = config or {} self.settings = self.config.get("settings", {}) if isinstance(self.config, dict) else {} - self.enabled = bool(self.config.get("enabled", True)) if isinstance(self.config, dict) else True + self.enabled = ( + bool(self.config.get("enabled", True)) if isinstance(self.config, dict) else True + ) self.log = log or logging.getLogger(self.__class__.__name__) @abstractmethod @@ -90,13 +103,20 @@ def ensure_python_modules(self, modules: Iterable[Tuple[str, str]]) -> bool: return False for import_name, package_name in missing: + if not _PIP_PACKAGE_RE.fullmatch(package_name): + self.log.warning( + "Refusing to install dependency with unsupported package name for %s: %r", + self.name, + package_name, + ) + return False self.log.info("Installing missing dependency for %s: %s", self.name, package_name) result = subprocess.run( [sys.executable, "-m", "pip", "install", package_name], capture_output=True, text=True, check=False, - ) + ) # nosec B603 if result.returncode != 0: self.log.warning( "Failed installing %s for %s: %s", diff --git a/repeater/sensors/ens210.py b/repeater/sensors/ens210.py index 4ab6f29..549d517 100644 --- a/repeater/sensors/ens210.py +++ b/repeater/sensors/ens210.py @@ -23,10 +23,10 @@ from .registry import SensorRegistry # ENS210 register addresses -_REG_SENS_RUN = 0x21 +_REG_SENS_RUN = 0x21 _REG_SENS_START = 0x22 -_REG_T_VAL = 0x30 -_REG_H_VAL = 0x33 +_REG_T_VAL = 0x30 +_REG_H_VAL = 0x33 @SensorRegistry.register("ens210") @@ -39,7 +39,9 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) self.i2c_address = int(self.settings.get("i2c_address", 0x43)) self.bus_number = int(self.settings.get("bus_number", 1)) self._poll_interval = 0.05 # 50 ms between validity checks - self._poll_attempts = max(1, int(float(self.settings.get("read_timeout_seconds", 1.0)) / self._poll_interval)) + self._poll_attempts = max( + 1, int(float(self.settings.get("read_timeout_seconds", 1.0)) / self._poll_interval) + ) self.available = False if not self.ensure_python_modules( diff --git a/repeater/sensors/lafvin_ups_3s.py b/repeater/sensors/lafvin_ups_3s.py index 5885ef1..aa14cb9 100644 --- a/repeater/sensors/lafvin_ups_3s.py +++ b/repeater/sensors/lafvin_ups_3s.py @@ -30,30 +30,36 @@ from .registry import SensorRegistry # INA219 register addresses -_REG_CONFIG = 0x00 -_REG_SHUNT = 0x01 -_REG_BUS = 0x02 -_REG_POWER = 0x03 -_REG_CURRENT = 0x04 +_REG_CONFIG = 0x00 +_REG_SHUNT = 0x01 +_REG_BUS = 0x02 +_REG_POWER = 0x03 +_REG_CURRENT = 0x04 _REG_CALIBRATION = 0x05 # 32V range, ±320mV shunt gain, 12-bit ADC, continuous shunt+bus _CONFIG_VALUE = 0x399F # 3S LiPo/Li-ion pack voltage thresholds (3 cells in series) -_V_MAX = 12.6 # 4.20 V/cell × 3 — fully charged -_V_MIN = 9.0 # 3.00 V/cell × 3 — cutoff +_V_MAX = 12.6 # 4.20 V/cell × 3 — fully charged +_V_MIN = 9.0 # 3.00 V/cell × 3 — cutoff def _pack_voltage_to_percent(v: float) -> int: """Piecewise linear SoC estimate for a 3S Li-ion/LiPo pack (9.0–12.6 V).""" cell = v / 3.0 - if cell >= 4.20: return 100 - if cell >= 4.00: return int(85 + (cell - 4.00) / 0.20 * 15) - if cell >= 3.80: return int(60 + (cell - 3.80) / 0.20 * 25) - if cell >= 3.70: return int(40 + (cell - 3.70) / 0.10 * 20) - if cell >= 3.50: return int(15 + (cell - 3.50) / 0.20 * 25) - if cell >= 3.00: return int( (cell - 3.00) / 0.50 * 15) + if cell >= 4.20: + return 100 + if cell >= 4.00: + return int(85 + (cell - 4.00) / 0.20 * 15) + if cell >= 3.80: + return int(60 + (cell - 3.80) / 0.20 * 25) + if cell >= 3.70: + return int(40 + (cell - 3.70) / 0.10 * 20) + if cell >= 3.50: + return int(15 + (cell - 3.50) / 0.20 * 25) + if cell >= 3.00: + return int((cell - 3.00) / 0.50 * 15) return 0 @@ -65,16 +71,16 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) super().__init__(name=name, config=config, log=log) addr = self.settings.get("i2c_address", 0x41) - self.i2c_address = int(addr, 0) if isinstance(addr, str) else int(addr) - self.bus_number = int(self.settings.get("bus_number", 1)) - self.shunt_ohms = float(self.settings.get("shunt_ohms", 0.1)) - self.max_amps = float(self.settings.get("max_amps", 5.0)) + self.i2c_address = int(addr, 0) if isinstance(addr, str) else int(addr) + self.bus_number = int(self.settings.get("bus_number", 1)) + self.shunt_ohms = float(self.settings.get("shunt_ohms", 0.1)) + self.max_amps = float(self.settings.get("max_amps", 5.0)) # INA219 calibration per datasheet - self.current_lsb = self.max_amps / 32768.0 - cal = int(0.04096 / (self.current_lsb * self.shunt_ohms)) - self.calibration = max(1, min(cal, 0xFFFF)) - self.power_lsb = self.current_lsb * 20.0 + self.current_lsb = self.max_amps / 32768.0 + cal = int(0.04096 / (self.current_lsb * self.shunt_ohms)) + self.calibration = max(1, min(cal, 0xFFFF)) + self.power_lsb = self.current_lsb * 20.0 self.available = False @@ -83,6 +89,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) try: import smbus2 # type: ignore[import-not-found] + self._smbus2 = smbus2 bus = smbus2.SMBus(self.bus_number) @@ -109,9 +116,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) ) def _write(self, bus, reg: int, val: int) -> None: - bus.write_i2c_block_data( - self.i2c_address, reg, [(val >> 8) & 0xFF, val & 0xFF] - ) + bus.write_i2c_block_data(self.i2c_address, reg, [(val >> 8) & 0xFF, val & 0xFF]) def _read_u(self, bus, reg: int) -> int: d = bus.read_i2c_block_data(self.i2c_address, reg, 2) @@ -131,10 +136,10 @@ def _read(self) -> Dict[str, Any]: try: self._write(bus, _REG_CALIBRATION, self.calibration) - bus_v = (self._read_u(bus, _REG_BUS) >> 3) * 4 / 1000.0 - shunt_mv = self._read_s(bus, _REG_SHUNT) * 0.01 + bus_v = (self._read_u(bus, _REG_BUS) >> 3) * 4 / 1000.0 + shunt_mv = self._read_s(bus, _REG_SHUNT) * 0.01 current_ma = self._read_s(bus, _REG_CURRENT) * self.current_lsb * 1000.0 - power_mw = self._read_u(bus, _REG_POWER) * self.power_lsb * 1000.0 + power_mw = self._read_u(bus, _REG_POWER) * self.power_lsb * 1000.0 finally: bus.close() @@ -148,12 +153,12 @@ def _read(self) -> Dict[str, Any]: state = "idle" return { - "bus_voltage_v": round(bus_v, 3), + "bus_voltage_v": round(bus_v, 3), "shunt_voltage_mv": round(shunt_mv, 2), - "current_ma": round(current_ma, 1), - "power_mw": round(power_mw, 1), - "battery_percent": pct, - "charge_state": state, + "current_ma": round(current_ma, 1), + "power_mw": round(power_mw, 1), + "battery_percent": pct, + "charge_state": state, } except Exception as exc: raise RuntimeError(f"LAFVIN UPS 3S read failed: {exc}") from exc diff --git a/repeater/sensors/manager.py b/repeater/sensors/manager.py index 5edf39e..367756b 100644 --- a/repeater/sensors/manager.py +++ b/repeater/sensors/manager.py @@ -22,14 +22,14 @@ def __init__( self.log = log or logging.getLogger(self.__class__.__name__) self.registry = registry self.sensors = [] - + # Background polling self._poll_thread: Optional[threading.Thread] = None self._stop_event = threading.Event() self._latest_readings: List[Dict[str, Any]] = [] self._readings_lock = threading.RLock() self._running = False - + self.reload() def _get_sensor_definitions(self) -> List[Dict[str, Any]]: @@ -81,17 +81,17 @@ def start(self) -> None: if self._running: return self.reload() - + # Start background polling thread if enabled and sensors exist section = self.config.get("sensors", {}) if not isinstance(section, dict) or not section.get("enabled", False): self.log.debug("Sensor manager disabled in config") return - + if not self.sensors: self.log.debug("No sensors loaded; skipping background polling") return - + self._stop_event.clear() self._poll_thread = threading.Thread( target=self._poll_loop, name="sensor-manager", daemon=True @@ -137,9 +137,9 @@ def _poll_loop(self) -> None: poll_interval = float(section.get("poll_interval_seconds", 30.0)) except (TypeError, ValueError): pass - + self.log.debug("Sensor polling loop started (interval=%.1f sec)", poll_interval) - + while not self._stop_event.is_set(): try: readings = self.read_all() @@ -147,10 +147,10 @@ def _poll_loop(self) -> None: self._latest_readings = readings except Exception as exc: self.log.warning("Sensor poll cycle failed: %s", exc) - + # Wait for next poll or stop signal self._stop_event.wait(poll_interval) - + self.log.debug("Sensor polling loop stopped") def get_summary(self) -> Dict[str, Any]: @@ -161,11 +161,11 @@ def get_summary(self) -> Dict[str, Any]: poll_interval = float(section.get("poll_interval_seconds", 30.0)) except (TypeError, ValueError): pass - + # Get cached readings (or empty list if not running) with self._readings_lock: readings = list(self._latest_readings) if self._latest_readings else [] - + return { "enabled": bool(isinstance(section, dict) and section.get("enabled", False)), "poll_interval_seconds": poll_interval, diff --git a/repeater/sensors/registry.py b/repeater/sensors/registry.py index 5fa615c..094e01a 100644 --- a/repeater/sensors/registry.py +++ b/repeater/sensors/registry.py @@ -23,7 +23,9 @@ def _decorator(target): return _decorator @classmethod - def create(cls, sensor_type: str, name: str, config: Optional[Dict[str, Any]] = None, **kwargs) -> SensorBase: + def create( + cls, sensor_type: str, name: str, config: Optional[Dict[str, Any]] = None, **kwargs + ) -> SensorBase: key = str(sensor_type).strip().lower() factory = cls._factories.get(key) if factory is None: diff --git a/repeater/sensors/shtc3.py b/repeater/sensors/shtc3.py index 7ec3099..92174e8 100644 --- a/repeater/sensors/shtc3.py +++ b/repeater/sensors/shtc3.py @@ -25,8 +25,8 @@ from .registry import SensorRegistry # SHTC3 two-byte command words -_CMD_WAKE = [0x35, 0x17] -_CMD_MEAS = [0x7C, 0xA2] # T-first, normal power mode +_CMD_WAKE = [0x35, 0x17] +_CMD_MEAS = [0x7C, 0xA2] # T-first, normal power mode _CMD_SLEEP = [0xB0, 0x98] @@ -47,6 +47,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) try: import smbus2 # type: ignore[import-not-found] + self._smbus2 = smbus2 # Verify sensor is reachable: wake then immediately sleep @@ -93,16 +94,16 @@ def _read(self) -> Dict[str, Any]: bus.i2c_rdwr(smbus2.i2c_msg.write(self.i2c_address, _CMD_SLEEP)) # Bytes: T_MSB, T_LSB, T_CRC, RH_MSB, RH_LSB, RH_CRC - t_raw = (data[0] << 8) | data[1] + t_raw = (data[0] << 8) | data[1] rh_raw = (data[3] << 8) | data[4] temp_c = round(-45.0 + 175.0 * t_raw / 65536.0, 2) temp_f = round(temp_c * 9.0 / 5.0 + 32.0, 2) - rh = round(100.0 * rh_raw / 65536.0, 2) + rh = round(100.0 * rh_raw / 65536.0, 2) return { "temperature_c": temp_c, "temperature_f": temp_f, - "humidity_pct": rh, + "humidity_pct": rh, } except Exception as exc: raise RuntimeError(f"SHTC3 read failed: {exc}") from exc diff --git a/repeater/sensors/waveshare_ups_d.py b/repeater/sensors/waveshare_ups_d.py index 5b62232..709451c 100644 --- a/repeater/sensors/waveshare_ups_d.py +++ b/repeater/sensors/waveshare_ups_d.py @@ -26,30 +26,36 @@ # INA219 register addresses _REG_CONFIG = 0x00 -_REG_SHUNT = 0x01 -_REG_BUS = 0x02 -_REG_POWER = 0x03 +_REG_SHUNT = 0x01 +_REG_BUS = 0x02 +_REG_POWER = 0x03 _REG_CURRENT = 0x04 -_REG_CAL = 0x05 +_REG_CAL = 0x05 # 32V range, ±320mV gain, 128-sample averaging, continuous shunt+bus _CONFIG_VALUE = 0x3FFF # Waveshare UPS HAT (D) calibration — 0.01Ω shunt (per Waveshare sample code) # current_lsb ≈ 0.1524 mA/LSB, power_lsb = current_lsb × 20 -_CAL_VALUE = 26868 -_CURRENT_LSB = 0.0001524 # A per LSB -_POWER_LSB = _CURRENT_LSB * 20.0 # W per LSB +_CAL_VALUE = 26868 +_CURRENT_LSB = 0.0001524 # A per LSB +_POWER_LSB = _CURRENT_LSB * 20.0 # W per LSB def _voltage_to_percent(v: float) -> int: """Piecewise linear SoC estimate for a single 21700 Li-ion cell (3.0–4.2 V).""" - if v >= 4.20: return 100 - if v >= 4.00: return int(85 + (v - 4.00) / 0.20 * 15) - if v >= 3.80: return int(60 + (v - 3.80) / 0.20 * 25) - if v >= 3.70: return int(40 + (v - 3.70) / 0.10 * 20) - if v >= 3.50: return int(15 + (v - 3.50) / 0.20 * 25) - if v >= 3.00: return int( (v - 3.00) / 0.50 * 15) + if v >= 4.20: + return 100 + if v >= 4.00: + return int(85 + (v - 4.00) / 0.20 * 15) + if v >= 3.80: + return int(60 + (v - 3.80) / 0.20 * 25) + if v >= 3.70: + return int(40 + (v - 3.70) / 0.10 * 20) + if v >= 3.50: + return int(15 + (v - 3.50) / 0.20 * 25) + if v >= 3.00: + return int((v - 3.00) / 0.50 * 15) return 0 @@ -61,7 +67,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) super().__init__(name=name, config=config, log=log) self.i2c_address = int(self.settings.get("i2c_address", 0x43)) - self.bus_number = int(self.settings.get("bus_number", 1)) + self.bus_number = int(self.settings.get("bus_number", 1)) self.available = False @@ -70,6 +76,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) try: import smbus2 # type: ignore[import-not-found] + self._smbus2 = smbus2 bus = smbus2.SMBus(self.bus_number) @@ -96,9 +103,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) ) def _write(self, bus, reg: int, val: int) -> None: - bus.write_i2c_block_data( - self.i2c_address, reg, [(val >> 8) & 0xFF, val & 0xFF] - ) + bus.write_i2c_block_data(self.i2c_address, reg, [(val >> 8) & 0xFF, val & 0xFF]) def _read_u(self, bus, reg: int) -> int: d = bus.read_i2c_block_data(self.i2c_address, reg, 2) @@ -119,10 +124,10 @@ def _read(self) -> Dict[str, Any]: # Re-apply calibration in case of external reset self._write(bus, _REG_CAL, _CAL_VALUE) - bus_v = (self._read_u(bus, _REG_BUS) >> 3) * 4 / 1000.0 - shunt_mv = self._read_s(bus, _REG_SHUNT) * 0.01 + bus_v = (self._read_u(bus, _REG_BUS) >> 3) * 4 / 1000.0 + shunt_mv = self._read_s(bus, _REG_SHUNT) * 0.01 current_ma = self._read_s(bus, _REG_CURRENT) * _CURRENT_LSB * 1000.0 - power_mw = self._read_u(bus, _REG_POWER) * _POWER_LSB * 1000.0 + power_mw = self._read_u(bus, _REG_POWER) * _POWER_LSB * 1000.0 finally: bus.close() @@ -137,12 +142,12 @@ def _read(self) -> Dict[str, Any]: state = "idle" return { - "bus_voltage_v": round(bus_v, 3), + "bus_voltage_v": round(bus_v, 3), "shunt_voltage_mv": round(shunt_mv, 2), - "current_ma": round(current_ma, 1), - "power_mw": round(power_mw, 1), - "battery_percent": pct, - "charge_state": state, + "current_ma": round(current_ma, 1), + "power_mw": round(power_mw, 1), + "battery_percent": pct, + "charge_state": state, } except Exception as exc: raise RuntimeError(f"Waveshare UPS HAT (D) read failed: {exc}") from exc diff --git a/repeater/sensors/waveshare_ups_e.py b/repeater/sensors/waveshare_ups_e.py index 0ec3e5c..225f0a7 100644 --- a/repeater/sensors/waveshare_ups_e.py +++ b/repeater/sensors/waveshare_ups_e.py @@ -26,14 +26,14 @@ from .registry import SensorRegistry # Register map -_REG_STATUS = 0x02 # 1 byte — charge state flags -_REG_VBUS = 0x10 # 6 bytes — input (VBUS) voltage, current, power -_REG_BATT = 0x20 # 12 bytes — pack voltage, current, percent, mAh, time -_REG_CELLS = 0x30 # 8 bytes — four cell voltages (LE uint16 each) +_REG_STATUS = 0x02 # 1 byte — charge state flags +_REG_VBUS = 0x10 # 6 bytes — input (VBUS) voltage, current, power +_REG_BATT = 0x20 # 12 bytes — pack voltage, current, percent, mAh, time +_REG_CELLS = 0x30 # 8 bytes — four cell voltages (LE uint16 each) # Charge state flag bits _FLAG_FAST_CHARGE = 0x40 -_FLAG_CHARGING = 0x80 +_FLAG_CHARGING = 0x80 _FLAG_DISCHARGING = 0x20 @@ -50,7 +50,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) addr = self.settings.get("i2c_address", 0x2D) self.i2c_address = int(addr, 0) if isinstance(addr, str) else int(addr) - self.bus_number = int(self.settings.get("bus_number", 1)) + self.bus_number = int(self.settings.get("bus_number", 1)) self.low_cell_mv = int(self.settings.get("low_cell_mv", 3150)) self.available = False @@ -60,6 +60,7 @@ def __init__(self, name: str, config: Optional[Dict[str, Any]] = None, log=None) try: import smbus2 # type: ignore[import-not-found] + self._smbus2 = smbus2 bus = smbus2.SMBus(self.bus_number) @@ -91,50 +92,56 @@ def _read(self) -> Dict[str, Any]: bus = self._smbus2.SMBus(self.bus_number) try: status = bus.read_i2c_block_data(self.i2c_address, _REG_STATUS, 1)[0] - vb = bus.read_i2c_block_data(self.i2c_address, _REG_VBUS, 6) - bd = bus.read_i2c_block_data(self.i2c_address, _REG_BATT, 12) - cd = bus.read_i2c_block_data(self.i2c_address, _REG_CELLS, 8) + vb = bus.read_i2c_block_data(self.i2c_address, _REG_VBUS, 6) + bd = bus.read_i2c_block_data(self.i2c_address, _REG_BATT, 12) + cd = bus.read_i2c_block_data(self.i2c_address, _REG_CELLS, 8) finally: bus.close() # Charge state - if status & _FLAG_FAST_CHARGE: charge_state = "fast_charging" - elif status & _FLAG_CHARGING: charge_state = "charging" - elif status & _FLAG_DISCHARGING: charge_state = "discharging" - else: charge_state = "idle" + if status & _FLAG_FAST_CHARGE: + charge_state = "fast_charging" + elif status & _FLAG_CHARGING: + charge_state = "charging" + elif status & _FLAG_DISCHARGING: + charge_state = "discharging" + else: + charge_state = "idle" # VBUS (input power from charger) vbus_voltage_mv = _u16le(vb, 0) vbus_current_ma = _u16le(vb, 2) - vbus_power_mw = _u16le(vb, 4) + vbus_power_mw = _u16le(vb, 4) # Battery pack batt_voltage_mv = _u16le(bd, 0) batt_current_ma = _u16le(bd, 2) - if batt_current_ma > 0x7FFF: # signed 16-bit + if batt_current_ma > 0x7FFF: # signed 16-bit batt_current_ma -= 0xFFFF - batt_percent = _u16le(bd, 4) - remaining_mah = _u16le(bd, 6) - time_remaining = _u16le(bd, 8) - time_to_full = _u16le(bd, 10) + batt_percent = _u16le(bd, 4) + remaining_mah = _u16le(bd, 6) + time_remaining = _u16le(bd, 8) + time_to_full = _u16le(bd, 10) # Per-cell voltages (4 cells) cells_mv = [ - _u16le(cd, 0), _u16le(cd, 2), - _u16le(cd, 4), _u16le(cd, 6), + _u16le(cd, 0), + _u16le(cd, 2), + _u16le(cd, 4), + _u16le(cd, 6), ] result: Dict[str, Any] = { - "charge_state": charge_state, - "battery_voltage_mv": batt_voltage_mv, - "battery_current_ma": batt_current_ma, - "battery_percent": batt_percent, + "charge_state": charge_state, + "battery_voltage_mv": batt_voltage_mv, + "battery_current_ma": batt_current_ma, + "battery_percent": batt_percent, "remaining_capacity_mah": remaining_mah, - "vbus_voltage_mv": vbus_voltage_mv, - "vbus_current_ma": vbus_current_ma, - "vbus_power_mw": vbus_power_mw, - "cell_voltages_mv": cells_mv, - "low_cell_warning": any(0 < v < self.low_cell_mv for v in cells_mv), + "vbus_voltage_mv": vbus_voltage_mv, + "vbus_current_ma": vbus_current_ma, + "vbus_power_mw": vbus_power_mw, + "cell_voltages_mv": cells_mv, + "low_cell_warning": any(0 < v < self.low_cell_mv for v in cells_mv), } # Only include whichever time estimate is relevant diff --git a/repeater/service_utils.py b/repeater/service_utils.py index c52902d..e6c14ef 100644 --- a/repeater/service_utils.py +++ b/repeater/service_utils.py @@ -5,7 +5,7 @@ import logging import os -import subprocess +import subprocess # nosec B404 import threading import time from typing import Dict, Optional, Tuple @@ -14,6 +14,9 @@ INIT_SCRIPT = "/etc/init.d/S80pymc-repeater" BUILDROOT_METADATA_PATH = "/etc/pymc-image-build-id" _CONTAINER_RESTART_DELAY_SECONDS = 1.0 +_SH_BIN = "/bin/sh" +_SYSTEMCTL_BIN = "/bin/systemctl" +_SUDO_BIN = "/usr/bin/sudo" def is_buildroot() -> bool: @@ -95,12 +98,12 @@ def get_container_restart_message() -> str: def restart_service() -> Tuple[bool, str]: """ Restart the pymc-repeater service. - + On Buildroot/Luckfox, use the shipped init script directly. On systemd hosts, try polkit-based restart first (plain systemctl), then fall back to sudo-based restart (requires sudoers.d rule installed by manage.sh). - + Returns: Tuple[bool, str]: (success, message) """ @@ -116,12 +119,12 @@ def restart_service() -> Tuple[bool, str]: try: subprocess.Popen( - ["/bin/sh", "-c", f"sleep 1; exec {INIT_SCRIPT} restart >/dev/null 2>&1"], + [_SH_BIN, "-c", f"sleep 1; exec {INIT_SCRIPT} restart >/dev/null 2>&1"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, stdin=subprocess.DEVNULL, start_new_session=True, - ) + ) # nosec B603 logger.info("Service restart scheduled via Buildroot init script") return True, "Service restart initiated" except Exception as exc: @@ -131,20 +134,23 @@ def restart_service() -> Tuple[bool, str]: # Try polkit-based restart first (works on bare metal / VMs with polkit running) try: result = subprocess.run( - ["systemctl", "restart", "pymc-repeater"], capture_output=True, text=True, timeout=5 - ) + [_SYSTEMCTL_BIN, "restart", "pymc-repeater"], + capture_output=True, + text=True, + timeout=5, + ) # nosec B603 if result.returncode == 0: logger.info("Service restart via polkit succeeded") return True, "Service restart initiated" - + stderr = result.stderr or "" if "Access denied" in stderr or "authorization" in stderr.lower(): logger.info("Polkit denied restart, trying sudo fallback...") else: # Some other error, still try sudo logger.warning(f"systemctl restart failed ({result.returncode}): {stderr.strip()}") - + except subprocess.TimeoutExpired: # Timeout likely means it's restarting - that's success logger.warning("Service restart command timed out (service may be restarting)") @@ -154,16 +160,16 @@ def restart_service() -> Tuple[bool, str]: return False, "systemctl not available" except Exception as e: logger.warning(f"Polkit restart attempt failed: {e}") - + # Fallback: use sudo (requires /etc/sudoers.d/pymc-repeater rule) try: result = subprocess.run( - ['sudo', '--non-interactive', 'systemctl', 'restart', 'pymc-repeater'], + [_SUDO_BIN, "--non-interactive", _SYSTEMCTL_BIN, "restart", "pymc-repeater"], capture_output=True, text=True, - timeout=5 - ) - + timeout=5, + ) # nosec B603 + if result.returncode == 0: logger.info("Service restart via sudo succeeded") return True, "Service restart initiated" @@ -171,7 +177,7 @@ def restart_service() -> Tuple[bool, str]: error_msg = result.stderr or "Unknown error" logger.error(f"Service restart via sudo failed: {error_msg}") return False, f"Restart failed: {error_msg}" - + except subprocess.TimeoutExpired: logger.warning("Sudo restart timed out (service likely restarting)") return True, "Service restart initiated (timeout - likely restarting)" diff --git a/repeater/web/api_endpoints.py b/repeater/web/api_endpoints.py index 56cf8c2..a88ab7c 100644 --- a/repeater/web/api_endpoints.py +++ b/repeater/web/api_endpoints.py @@ -1,9 +1,9 @@ import json import logging import os +import secrets import time from datetime import datetime, timezone -from pathlib import Path from typing import Callable, Optional import cherrypy @@ -15,7 +15,7 @@ find_companion_index, heal_companion_empty_names, ) -from repeater.config import resolve_storage_dir, update_unscoped_flood_policy +from repeater.config import resolve_storage_dir from repeater.service_utils import get_buildroot_image_info from .auth.middleware import require_auth @@ -164,7 +164,6 @@ class APIEndpoints: - def __init__( self, stats_getter: Optional[Callable] = None, @@ -332,12 +331,14 @@ def needs_setup(self): import yaml config = self.config + config_path = getattr(self, "_config_path", None) try: - with open(self._config_path, "r") as f: - config = yaml.safe_load(f) or {} - except Exception: + if config_path: + with open(config_path, "r") as f: + config = yaml.safe_load(f) or {} + except Exception as exc: # Fall back to in-memory config if file cannot be read. - pass + logger.debug(f"needs_setup could not read persisted config {config_path}: {exc}") needs_setup, reasons = self._setup_status_from_config(config) @@ -455,7 +456,12 @@ def serial_ports(self): # Fallback for environments where pyserial is unavailable. import glob - for pattern in ("/dev/ttyACM*", "/dev/ttyUSB*", "/dev/ttyS*", "/dev/serial/by-id/*"): + for pattern in ( + "/dev/ttyACM*", + "/dev/ttyUSB*", + "/dev/ttyS*", + "/dev/serial/by-id/*", + ): for dev in glob.glob(pattern): devices.append({"device": str(dev), "description": str(dev)}) @@ -535,7 +541,10 @@ def setup_wizard(self): hardware_configs = hardware_data.get("hardware", {}) hw_config = hardware_configs.get(hardware_key, {}) if not hw_config: - return {"success": False, "error": f"Hardware configuration not found: {hardware_key}"} + return { + "success": False, + "error": f"Hardware configuration not found: {hardware_key}", + } else: hw_config = {} @@ -581,7 +590,9 @@ def setup_wizard(self): kiss_port = (data.get("kiss_port") or "").strip() or "/dev/ttyUSB0" kiss_baud = int(data.get("kiss_baud_rate", data.get("kiss_baud", 115200))) config_yaml["kiss"] = {"port": kiss_port, "baud_rate": kiss_baud} - config_yaml["radio"]["tx_power"] = tx_power_preset if tx_power_preset is not None else 14 + config_yaml["radio"]["tx_power"] = ( + tx_power_preset if tx_power_preset is not None else 14 + ) if "preamble_length" not in config_yaml["radio"]: config_yaml["radio"]["preamble_length"] = 17 elif hardware_key == "pymc_usb": @@ -638,7 +649,9 @@ def setup_wizard(self): else: config_yaml["radio_type"] = "sx1262" - ch341_cfg = hw_config.get("ch341") if isinstance(hw_config.get("ch341"), dict) else None + ch341_cfg = ( + hw_config.get("ch341") if isinstance(hw_config.get("ch341"), dict) else None + ) vid = (ch341_cfg or {}).get("vid", hw_config.get("vid")) pid = (ch341_cfg or {}).get("pid", hw_config.get("pid")) if vid is not None or pid is not None: @@ -685,7 +698,9 @@ def setup_wizard(self): if "use_dio3_tcxo" in hw_config: config_yaml["sx1262"]["use_dio3_tcxo"] = hw_config.get("use_dio3_tcxo", False) if "dio3_tcxo_voltage" in hw_config: - config_yaml["sx1262"]["dio3_tcxo_voltage"] = hw_config.get("dio3_tcxo_voltage", 1.8) + config_yaml["sx1262"]["dio3_tcxo_voltage"] = hw_config.get( + "dio3_tcxo_voltage", 1.8 + ) if "use_dio2_rf" in hw_config: config_yaml["sx1262"]["use_dio2_rf"] = hw_config.get("use_dio2_rf", False) if "is_waveshare" in hw_config: @@ -699,7 +714,6 @@ def setup_wizard(self): ) # Trigger service restart after setup - import subprocess import threading def delayed_restart(): @@ -708,6 +722,7 @@ def delayed_restart(): time.sleep(2) # Give time for response to be sent try: from repeater.service_utils import restart_service + restart_service() except Exception as e: logger.error(f"Failed to restart service: {e}") @@ -889,8 +904,7 @@ def generate(): if snapshot_json != last_snapshot_json: yield ( - f"data: {json.dumps({'type': 'snapshot', 'data': snapshot})}" - "\n\n" + f"data: {json.dumps({'type': 'snapshot', 'data': snapshot})}\n\n" ) last_snapshot_json = snapshot_json last_keepalive = time.time() @@ -1067,7 +1081,7 @@ def update_duty_cycle_config(self): @cherrypy.tools.json_in() def update_advert_rate_limit_config(self): """Update advert rate limiting configuration using ConfigManager. - + POST /api/update_advert_rate_limit_config Body: { "rate_limit_enabled": true, @@ -1090,16 +1104,16 @@ def update_advert_rate_limit_config(self): } """ self._set_cors_headers() - + if cherrypy.request.method == "OPTIONS": return "" - + try: self._require_post() data = cherrypy.request.json or {} - + applied = [] - + # Ensure config sections exist if "repeater" not in self.config: self.config["repeater"] = {} @@ -1109,117 +1123,117 @@ def update_advert_rate_limit_config(self): self.config["repeater"]["advert_penalty_box"] = {} if "advert_adaptive" not in self.config["repeater"]: self.config["repeater"]["advert_adaptive"] = {"thresholds": {}} - + rate_cfg = self.config["repeater"]["advert_rate_limit"] penalty_cfg = self.config["repeater"]["advert_penalty_box"] adaptive_cfg = self.config["repeater"]["advert_adaptive"] - + # Rate limit settings if "rate_limit_enabled" in data: rate_cfg["enabled"] = bool(data["rate_limit_enabled"]) applied.append(f"rate_limit={'enabled' if rate_cfg['enabled'] else 'disabled'}") - + if "bucket_capacity" in data: cap = max(1, int(data["bucket_capacity"])) rate_cfg["bucket_capacity"] = cap applied.append(f"bucket_capacity={cap}") - + if "refill_tokens" in data: tokens = max(1, int(data["refill_tokens"])) rate_cfg["refill_tokens"] = tokens applied.append(f"refill_tokens={tokens}") - + if "refill_interval_seconds" in data: interval = max(60, int(data["refill_interval_seconds"])) rate_cfg["refill_interval_seconds"] = interval applied.append(f"refill_interval={interval}s") - + if "min_interval_seconds" in data: min_int = max(0, int(data["min_interval_seconds"])) rate_cfg["min_interval_seconds"] = min_int applied.append(f"min_interval={min_int}s") - + # Penalty box settings if "penalty_enabled" in data: penalty_cfg["enabled"] = bool(data["penalty_enabled"]) applied.append(f"penalty={'enabled' if penalty_cfg['enabled'] else 'disabled'}") - + if "violation_threshold" in data: thresh = max(1, int(data["violation_threshold"])) penalty_cfg["violation_threshold"] = thresh applied.append(f"violation_threshold={thresh}") - + if "violation_decay_seconds" in data: decay = max(60, int(data["violation_decay_seconds"])) penalty_cfg["violation_decay_seconds"] = decay applied.append(f"violation_decay={decay}s") - + if "base_penalty_seconds" in data: base = max(60, int(data["base_penalty_seconds"])) penalty_cfg["base_penalty_seconds"] = base applied.append(f"base_penalty={base}s") - + if "penalty_multiplier" in data: mult = max(1.0, float(data["penalty_multiplier"])) penalty_cfg["penalty_multiplier"] = mult applied.append(f"penalty_multiplier={mult}") - + if "max_penalty_seconds" in data: max_pen = max(60, int(data["max_penalty_seconds"])) penalty_cfg["max_penalty_seconds"] = max_pen applied.append(f"max_penalty={max_pen}s") - + # Adaptive settings if "adaptive_enabled" in data: adaptive_cfg["enabled"] = bool(data["adaptive_enabled"]) applied.append(f"adaptive={'enabled' if adaptive_cfg['enabled'] else 'disabled'}") - + if "ewma_alpha" in data: alpha = max(0.01, min(1.0, float(data["ewma_alpha"]))) adaptive_cfg["ewma_alpha"] = alpha applied.append(f"ewma_alpha={alpha}") - + if "hysteresis_seconds" in data: hyst = max(0, int(data["hysteresis_seconds"])) adaptive_cfg["hysteresis_seconds"] = hyst applied.append(f"hysteresis={hyst}s") - + # Adaptive thresholds if "thresholds" not in adaptive_cfg: adaptive_cfg["thresholds"] = {} - + if "quiet_max" in data: adaptive_cfg["thresholds"]["quiet_max"] = float(data["quiet_max"]) applied.append(f"quiet_max={data['quiet_max']}") - + if "normal_max" in data: adaptive_cfg["thresholds"]["normal_max"] = float(data["normal_max"]) applied.append(f"normal_max={data['normal_max']}") - + if "busy_max" in data: adaptive_cfg["thresholds"]["busy_max"] = float(data["busy_max"]) applied.append(f"busy_max={data['busy_max']}") - + if not applied: return self._error("No valid settings provided") - + # Save to config file and live update daemon result = self.config_manager.update_and_save( - updates={}, - live_update=True, - live_update_sections=['repeater'] + updates={}, live_update=True, live_update_sections=["repeater"] ) - + logger.info(f"Advert rate limit config updated: {', '.join(applied)}") - - return self._success({ - "applied": applied, - "persisted": result.get("saved", False), - "live_update": result.get("live_updated", False), - "restart_required": False, - "message": "Advert rate limit settings applied immediately." - }) - + + return self._success( + { + "applied": applied, + "persisted": result.get("saved", False), + "live_update": result.get("live_updated", False), + "restart_required": False, + "message": "Advert rate limit settings applied immediately.", + } + ) + except cherrypy.HTTPError: raise except Exception as e: @@ -1295,27 +1309,31 @@ def mqtt_status(self): try: storage = self._get_storage() handler = getattr(storage, "mqtt_handler", None) - except Exception: - pass + except Exception as exc: + logger.debug(f"mqtt_status could not access mqtt_handler: {exc}") connected_brokers = [] if handler: for conn in getattr(handler, "connections", []): - connected_brokers.append({ - "enabled": conn.enabled, - "name": conn.broker.get("name", ""), - "host": conn.broker.get("host", ""), - "status": { - "connected": conn.is_connected(), - "reconnecting": conn.has_pending_reconnect(), - }, - "format": conn.format - }) + connected_brokers.append( + { + "enabled": conn.enabled, + "name": conn.broker.get("name", ""), + "host": conn.broker.get("host", ""), + "status": { + "connected": conn.is_connected(), + "reconnecting": conn.has_pending_reconnect(), + }, + "format": conn.format, + } + ) - return self._success({ - "handler_active": handler is not None, - "brokers": connected_brokers, - }) + return self._success( + { + "handler_active": handler is not None, + "brokers": connected_brokers, + } + ) except Exception as e: logger.error(f"Error getting MQTT status: {e}") return self._error(str(e)) @@ -1388,7 +1406,7 @@ def update_mqtt_config(self): "email": "user@example.com", "brokers": [ { - + }] } """ @@ -1433,28 +1451,32 @@ def update_mqtt_config(self): for field in ("name", "host", "port", "format"): if not b.get(field, ""): - return self._error(f"Broker at index {i} missing required field: {field}") - + return self._error( + f"Broker at index {i} missing required field: {field}" + ) + try: port = int(b.get("port", 443)) except (ValueError, TypeError): return self._error(f"Broker at index {i} has invalid port") - + new_broker = { - "name": str(b["name"]).strip(), - "enabled": b.get("enabled", False), + "name": str(b["name"]).strip(), + "enabled": b.get("enabled", False), "transport": str(b.get("transport", "websockets")).strip(), - "host": str(b["host"]).strip(), - "port": port, - "format": str(b["format"]).strip(), + "host": str(b["host"]).strip(), + "port": port, + "format": str(b["format"]).strip(), "disallowed_packet_types": list(b.get("disallowed_packet_types", [])), "retain_status": bool(b.get("retain_status", False)), "tls": { - "enabled": bool(b.get("tls", {}).get("enabled", True if port == 443 else False)), + "enabled": bool( + b.get("tls", {}).get("enabled", True if port == 443 else False) + ), "insecure": bool(b.get("tls", {}).get("insecure", False)), - } + }, } - + if b.get("use_jwt_auth", False): new_broker["use_jwt_auth"] = True new_broker["audience"] = str(b["audience"]).strip() @@ -1477,11 +1499,13 @@ def update_mqtt_config(self): if result.get("success"): logger.info(f"MQTT config updated: {list(mqtt_updates.keys())}") - return self._success({ - "persisted": result.get("saved", False), - "restart_required": True, - "message": "Observer settings saved. Restart the service for changes to take effect.", - }) + return self._success( + { + "persisted": result.get("saved", False), + "restart_required": True, + "message": "Observer settings saved. Restart the service for changes to take effect.", + } + ) else: return self._error(result.get("error", "Failed to update LetsMesh configuration")) @@ -1608,7 +1632,9 @@ def as_float(value, path: str): add_error("repeater.security", "Missing required section 'repeater.security'") security = {} - admin_password = (security.get("admin_password") if isinstance(security, dict) else "") or "" + admin_password = ( + security.get("admin_password") if isinstance(security, dict) else "" + ) or "" if not str(admin_password).strip(): add_error("repeater.security.admin_password", "Admin password is required") @@ -1651,13 +1677,29 @@ def as_float(value, path: str): add_error("radio.frequency", "Frequency must be 100-1000 MHz") bandwidth = as_int((radio or {}).get("bandwidth"), "radio.bandwidth") - valid_bw = [7800, 10400, 15600, 20800, 31250, 41700, 62500, 125000, 250000, 500000] + valid_bw = [ + 7800, + 10400, + 15600, + 20800, + 31250, + 41700, + 62500, + 125000, + 250000, + 500000, + ] if bandwidth is None: add_error("radio.bandwidth", "Bandwidth is required") elif bandwidth not in valid_bw: - add_error("radio.bandwidth", f"Bandwidth must be one of {[b / 1000 for b in valid_bw]} kHz") + add_error( + "radio.bandwidth", + f"Bandwidth must be one of {[b / 1000 for b in valid_bw]} kHz", + ) - spreading_factor = as_int((radio or {}).get("spreading_factor"), "radio.spreading_factor") + spreading_factor = as_int( + (radio or {}).get("spreading_factor"), "radio.spreading_factor" + ) if spreading_factor is None: add_error("radio.spreading_factor", "Spreading factor is required") elif spreading_factor < 5 or spreading_factor > 12: @@ -1675,11 +1717,15 @@ def as_float(value, path: str): elif tx_power < -9 or tx_power > 30: add_error("radio.tx_power", "TX power must be between -9 and +30 dBm") - preamble_length = as_int((radio or {}).get("preamble_length"), "radio.preamble_length") + preamble_length = as_int( + (radio or {}).get("preamble_length"), "radio.preamble_length" + ) if preamble_length is None: add_error("radio.preamble_length", "Preamble length is required") elif preamble_length <= 0: - add_error("radio.preamble_length", "Preamble length must be greater than zero") + add_error( + "radio.preamble_length", "Preamble length must be greater than zero" + ) if radio_type in ("sx1262", "sx1262_ch341"): sx1262_cfg = config_yaml.get("sx1262") @@ -1701,7 +1747,9 @@ def as_float(value, path: str): value = sx1262_cfg.get(key) if isinstance(sx1262_cfg, dict) else None parsed = as_int(value, f"sx1262.{key}") if parsed is None: - add_error(f"sx1262.{key}", f"Missing or invalid required setting '{key}'") + add_error( + f"sx1262.{key}", f"Missing or invalid required setting '{key}'" + ) en_pins = sx1262_cfg.get("en_pins") if isinstance(sx1262_cfg, dict) else None if en_pins is not None: @@ -1718,13 +1766,17 @@ def as_float(value, path: str): if radio_type == "sx1262_ch341": ch341_cfg = config_yaml.get("ch341") if not isinstance(ch341_cfg, dict): - add_error("ch341", "Missing required section 'ch341' for radio_type sx1262_ch341") + add_error( + "ch341", "Missing required section 'ch341' for radio_type sx1262_ch341" + ) ch341_cfg = {} for key in ("vid", "pid"): value = ch341_cfg.get(key) if isinstance(ch341_cfg, dict) else None parsed = as_int(value, f"ch341.{key}") if parsed is None: - add_error(f"ch341.{key}", f"Missing or invalid required setting '{key}'") + add_error( + f"ch341.{key}", f"Missing or invalid required setting '{key}'" + ) if radio_type == "kiss": kiss_cfg = config_yaml.get("kiss") @@ -1743,26 +1795,37 @@ def as_float(value, path: str): if radio_type == "pymc_usb": usb_cfg = config_yaml.get("pymc_usb") if not isinstance(usb_cfg, dict): - add_error("pymc_usb", "Missing required section 'pymc_usb' for radio_type pymc_usb") + add_error( + "pymc_usb", + "Missing required section 'pymc_usb' for radio_type pymc_usb", + ) usb_cfg = {} port = (usb_cfg.get("port") if isinstance(usb_cfg, dict) else "") or "" if not str(port).strip(): add_error("pymc_usb.port", "pymc_usb.port is required") baud = as_int((usb_cfg or {}).get("baudrate"), "pymc_usb.baudrate") if baud is not None and baud <= 0: - add_error("pymc_usb.baudrate", "pymc_usb.baudrate must be greater than zero") + add_error( + "pymc_usb.baudrate", "pymc_usb.baudrate must be greater than zero" + ) if radio_type == "pymc_tcp": tcp_cfg = config_yaml.get("pymc_tcp") if not isinstance(tcp_cfg, dict): - add_error("pymc_tcp", "Missing required section 'pymc_tcp' for radio_type pymc_tcp") + add_error( + "pymc_tcp", + "Missing required section 'pymc_tcp' for radio_type pymc_tcp", + ) tcp_cfg = {} host = (tcp_cfg.get("host") if isinstance(tcp_cfg, dict) else "") or "" host_str = str(host).strip() if not host_str: add_error("pymc_tcp.host", "pymc_tcp.host is required") elif host_str == "REPLACE_WITH_MODEM_HOST": - add_error("pymc_tcp.host", "Replace placeholder host with your modem hostname or IP") + add_error( + "pymc_tcp.host", + "Replace placeholder host with your modem hostname or IP", + ) port = as_int((tcp_cfg or {}).get("port"), "pymc_tcp.port") if port is None: add_error("pymc_tcp.port", "pymc_tcp.port is required") @@ -1784,7 +1847,9 @@ def as_float(value, path: str): "warning_count": len(warnings), }, "config_path": self._config_path, - "message": "Configuration is valid" if valid else "Configuration has validation errors", + "message": "Configuration is valid" + if valid + else "Configuration has validation errors", } ) @@ -1862,15 +1927,19 @@ def memory_debug(self, **kwargs): if not tracemalloc.is_tracing(): # Use 1 frame instead of 10 — much less overhead & faster snapshots tracemalloc.start(1) - self._tracemalloc_baseline = tracemalloc.take_snapshot().filter_traces(( - tracemalloc.Filter(False, tracemalloc.__file__), - tracemalloc.Filter(False, ""), - )) + self._tracemalloc_baseline = tracemalloc.take_snapshot().filter_traces( + ( + tracemalloc.Filter(False, tracemalloc.__file__), + tracemalloc.Filter(False, ""), + ) + ) logger.info("Memory tracing started") - return self._success({ - "tracing": True, - "message": "Tracing started — check again after some time to see growth", - }) + return self._success( + { + "tracing": True, + "message": "Tracing started — check again after some time to see growth", + } + ) if action == "stop": if tracemalloc.is_tracing(): @@ -1888,30 +1957,35 @@ def memory_debug(self, **kwargs): # Always include RSS regardless of tracing state try: import resource + rusage = resource.getrusage(resource.RUSAGE_SELF) result["rss_mb"] = round(rusage.ru_maxrss / 1024, 1) - except Exception: - pass + except Exception as exc: + logger.debug(f"Could not read process RSS usage: {exc}") if not tracing: return self._success(result) # Filter out tracemalloc's own allocations to keep snapshot small & fast - current = tracemalloc.take_snapshot().filter_traces(( - tracemalloc.Filter(False, tracemalloc.__file__), - tracemalloc.Filter(False, ""), - )) + current = tracemalloc.take_snapshot().filter_traces( + ( + tracemalloc.Filter(False, tracemalloc.__file__), + tracemalloc.Filter(False, ""), + ) + ) baseline = getattr(self, "_tracemalloc_baseline", None) # Top 20 allocations right now top_current = current.statistics("lineno")[:20] current_stats = [] for stat in top_current: - current_stats.append({ - "file": str(stat.traceback), - "size_kb": round(stat.size / 1024, 1), - "count": stat.count, - }) + current_stats.append( + { + "file": str(stat.traceback), + "size_kb": round(stat.size / 1024, 1), + "count": stat.count, + } + ) result["current_top_20"] = current_stats # Growth since baseline @@ -1921,12 +1995,14 @@ def memory_debug(self, **kwargs): growth.sort(key=lambda d: d.size_diff, reverse=True) growth_stats = [] for stat in growth[:20]: - growth_stats.append({ - "file": str(stat.traceback), - "size_diff_kb": round(stat.size_diff / 1024, 1), - "count_diff": stat.count_diff, - "current_size_kb": round(stat.size / 1024, 1), - }) + growth_stats.append( + { + "file": str(stat.traceback), + "size_diff_kb": round(stat.size_diff / 1024, 1), + "count_diff": stat.count_diff, + "current_size_kb": round(stat.size / 1024, 1), + } + ) result["growth_since_baseline"] = growth_stats traced_current, traced_peak = tracemalloc.get_traced_memory() @@ -2090,7 +2166,9 @@ def airtime_data(self, start_timestamp=None, end_timestamp=None, limit=50000): end_ts = float(end_timestamp) if end_timestamp is not None else None limit_int = min(int(limit), 50000) packets = self._get_storage().get_airtime_data( - start_timestamp=start_ts, end_timestamp=end_ts, limit=limit_int, + start_timestamp=start_ts, + end_timestamp=end_ts, + limit=limit_int, ) return self._success(packets, count=len(packets)) except Exception as e: @@ -2145,17 +2223,6 @@ def packet_by_hash(self, packet_hash=None): logger.error(f"Error getting packet by hash: {e}") return self._error(e) - @cherrypy.expose - @cherrypy.tools.json_out() - def packet_type_stats(self, hours=24): - try: - hours = int(hours) - stats = self._get_storage().get_packet_type_stats(hours=hours) - return self._success(stats) - except Exception as e: - logger.error(f"Error getting packet type stats: {e}") - return self._error(e) - @cherrypy.expose @cherrypy.tools.json_out() def rrd_data(self): @@ -2364,7 +2431,6 @@ def save_cad_settings(self): self.config["radio"]["cad"]["peak_threshold"] = peak self.config["radio"]["cad"]["min_threshold"] = min_val - config_path = getattr(self, "_config_path", "/etc/pymc_repeater/config.yaml") saved = self.config_manager.save_to_file() if not saved: return self._error("Failed to save configuration to file") @@ -2448,16 +2514,18 @@ def update_radio_config(self): if freq < 100_000_000 or freq > 1_000_000_000: return self._error("Frequency must be 100-1000 MHz") self.config["radio"]["frequency"] = freq - applied.append(f"freq={freq/1_000_000:.3f}MHz") + applied.append(f"freq={freq / 1_000_000:.3f}MHz") # Update bandwidth (in Hz) if "bandwidth" in data: bw = int(float(data["bandwidth"])) valid_bw = [7800, 10400, 15600, 20800, 31250, 41700, 62500, 125000, 250000, 500000] if bw not in valid_bw: - return self._error(f"Bandwidth must be one of {[b/1000 for b in valid_bw]} kHz") + return self._error( + f"Bandwidth must be one of {[b / 1000 for b in valid_bw]} kHz" + ) self.config["radio"]["bandwidth"] = bw - applied.append(f"bw={bw/1000}kHz") + applied.append(f"bw={bw / 1000}kHz") # Update spreading factor if "spreading_factor" in data: @@ -2554,7 +2622,9 @@ def update_radio_config(self): if "path_hash_mode" in data: phm = int(data["path_hash_mode"]) if phm not in (0, 1, 2): - return self._error("Path hash mode must be 0 (1-byte), 1 (2-byte), or 2 (3-byte)") + return self._error( + "Path hash mode must be 0 (1-byte), 1 (2-byte), or 2 (3-byte)" + ) self.config["mesh"]["path_hash_mode"] = phm applied.append(f"path_hash_mode={phm}") @@ -2672,10 +2742,7 @@ def crc_error_count(self, hours: int = 24): storage = self._get_storage() hours = int(hours) count = storage.get_crc_error_count(hours=hours) - return self._success({ - "crc_error_count": count, - "hours": hours - }) + return self._success({"crc_error_count": count, "hours": hours}) except Exception as e: logger.error(f"Error fetching CRC error count: {e}") return self._error(e) @@ -2689,11 +2756,7 @@ def crc_error_history(self, hours: int = 24, limit: int = None): hours = int(hours) limit = int(limit) if limit else None history = storage.get_crc_error_history(hours=hours, limit=limit) - return self._success({ - "history": history, - "hours": hours, - "count": len(history) - }) + return self._success({"history": history, "hours": hours, "count": len(history)}) except Exception as e: logger.error(f"Error fetching CRC error history: {e}") return self._error(e) @@ -2775,7 +2838,12 @@ def adverts_by_contact_type(self, contact_type=None, limit=None, offset=None, ho adverts, count=len(adverts), contact_type=contact_type, - filters={"contact_type": contact_type, "limit": limit_int, "offset": offset_int, "hours": hours_int}, + filters={ + "contact_type": contact_type, + "limit": limit_int, + "offset": offset_int, + "hours": hours_int, + }, ) except ValueError as e: @@ -2816,19 +2884,19 @@ def adverts_count_by_contact_type(self, contact_type=None, hours=None): def advert_rate_limit_stats(self): """Get advert rate limiting statistics and adaptive tier info.""" try: - if not self.daemon_instance or not hasattr(self.daemon_instance, 'advert_helper'): + if not self.daemon_instance or not hasattr(self.daemon_instance, "advert_helper"): return self._error("Advert helper not available") - + advert_helper = self.daemon_instance.advert_helper if not advert_helper: return self._error("Advert helper not initialized") - - if not hasattr(advert_helper, 'get_rate_limit_stats'): + + if not hasattr(advert_helper, "get_rate_limit_stats"): return self._error("Rate limit stats not supported by this advert helper version") - + stats = advert_helper.get_rate_limit_stats() return self._success(stats) - + except Exception as e: logger.error(f"Error getting advert rate limit stats: {e}") return self._error(e) @@ -3105,9 +3173,7 @@ def ping_neighbor(self): trace_helper = self.daemon_instance.trace_helper # Generate unique tag for this ping - import random - - trace_tag = random.randint(0, 0xFFFFFFFF) + trace_tag = secrets.randbits(32) # Create trace packet from pymc_core.protocol import PacketBuilder @@ -3127,7 +3193,9 @@ async def send_and_wait(): # Send packet via router await router.inject_packet(packet) - logger.info(f"Ping sent to 0x{target_hash:0{hex_chars}x} with tag {trace_tag} (path_hash_mode={path_hash_mode})") + logger.info( + f"Ping sent to 0x{target_hash:0{hex_chars}x} with tag {trace_tag} (path_hash_mode={path_hash_mode})" + ) try: await asyncio.wait_for(event.wait(), timeout=timeout) @@ -3277,11 +3345,7 @@ def identities(self): settings = comp_config.get("settings", {}) matching = next( - ( - r - for r in registered_identities - if r["name"] == f"companion:{name}" - ), + (r for r in registered_identities if r["name"] == f"companion:{name}"), None, ) @@ -3295,9 +3359,7 @@ def identities(self): { "name": name, "type": "companion", - "identity_key": ( - ik_hex[:16] + "..." if len(ik_hex) > 16 else ik_hex - ), + "identity_key": (ik_hex[:16] + "..." if len(ik_hex) > 16 else ik_hex), "identity_key_length": len(ik_hex), "settings": settings, "hash": matching["hash"] if matching else None, @@ -3663,9 +3725,7 @@ def update_identity(self): key_bytes = bytes.fromhex(new_key) if len(key_bytes) in (32, 64): identity["identity_key"] = new_key - logger.info( - f"Updated identity_key for companion '{resolved_name}'" - ) + logger.info(f"Updated identity_key for companion '{resolved_name}'") except ValueError: pass @@ -3836,7 +3896,9 @@ def update_identity(self): @cherrypy.expose @cherrypy.tools.json_out() - def delete_identity(self, name=None, type=None, lookup_identity_key=None, public_key_prefix=None): + def delete_identity( + self, name=None, type=None, lookup_identity_key=None, public_key_prefix=None + ): """ DELETE /api/delete_identity?name=&type= - Delete an identity Companions may also be deleted with lookup_identity_key or public_key_prefix when name is empty. @@ -3857,9 +3919,7 @@ def delete_identity(self, name=None, type=None, lookup_identity_key=None, public identity_type = (type or "room_server").lower() if identity_type not in ["room_server", "companion"]: - return self._error( - f"Invalid type: {type}. Use 'room_server' or 'companion'." - ) + return self._error(f"Invalid type: {type}. Use 'room_server' or 'companion'.") identities_config = self.config.get("identities", {}) @@ -3911,9 +3971,7 @@ def delete_identity(self, name=None, type=None, lookup_identity_key=None, public # Find and remove the identity initial_count = len(room_servers) - room_servers = [ - r for r in room_servers if str(r.get("name") or "").strip() != name_s - ] + room_servers = [r for r in room_servers if str(r.get("name") or "").strip() != name_s] if len(room_servers) == initial_count: return self._error(f"Identity '{name_s}' not found") @@ -4140,7 +4198,9 @@ def acl_info(self): { "name": "repeater", "type": "repeater", - "hash": self._fmt_hash(self.daemon_instance.local_identity.get_public_key()), + "hash": self._fmt_hash( + self.daemon_instance.local_identity.get_public_key() + ), "max_clients": repeater_acl.max_clients, "authenticated_clients": repeater_acl.get_num_clients(), "has_admin_password": bool(repeater_acl.admin_password), @@ -4185,7 +4245,11 @@ def acl_info(self): writer = getattr(fs, "_client_writer", None) active_by_hash[h] = writer is not None if writer is not None: - peername = writer.get_extra_info("peername") if hasattr(writer, "get_extra_info") else None + peername = ( + writer.get_extra_info("peername") + if hasattr(writer, "get_extra_info") + else None + ) client_ip_by_hash[h] = str(peername[0]) if peername else None except (ValueError, TypeError): pass @@ -5344,9 +5408,7 @@ def config_import(self): # Preserve identity keys that are redacted for id_section in ("room_servers", "companions"): entries = value.get(id_section, []) or [] - cur_entries = ( - self.config.get("identities", {}).get(id_section, []) or [] - ) + cur_entries = self.config.get("identities", {}).get(id_section, []) or [] cur_by_name = {e.get("name"): e for e in cur_entries} for entry in entries: if entry.get("identity_key") == "*** REDACTED ***": @@ -5381,7 +5443,7 @@ def config_import(self): return self._error("No valid configuration sections found in import") # Persist and live-reload - result = self.config_manager.update_and_save( + self.config_manager.update_and_save( updates={}, # Already applied above live_update=True, live_update_sections=updated_sections, @@ -5432,7 +5494,9 @@ def identity_export(self): elif isinstance(identity_key, str): key_hex = identity_key else: - return self._error(f"Identity key has unexpected type: {type(identity_key).__name__}") + return self._error( + f"Identity key has unexpected type: {type(identity_key).__name__}" + ) result = { "identity_key_hex": key_hex, @@ -5446,8 +5510,8 @@ def identity_export(self): pub = li.get_public_key() result["public_key_hex"] = bytes(pub).hex() result["node_address"] = f"0x{pub[0]:02x}" - except Exception: - pass # Not critical + except Exception as exc: + logger.debug(f"Could not derive local identity public key info: {exc}") return {"success": True, "data": result} @@ -5545,9 +5609,7 @@ def db_stats(self): # Add RRD file size if it exists rrd_path = storage.sqlite_handler.storage_dir / "metrics.rrd" - stats["rrd_size_bytes"] = ( - rrd_path.stat().st_size if rrd_path.exists() else 0 - ) + stats["rrd_size_bytes"] = rrd_path.stat().st_size if rrd_path.exists() else 0 return {"success": True, "data": stats} except Exception as e: @@ -5578,10 +5640,16 @@ def db_purge(self): return self._error("Missing 'tables' parameter") ALL_PURGEABLE = [ - "packets", "adverts", "noise_floor", "crc_errors", - "room_messages", "room_client_sync", - "companion_contacts", "companion_channels", - "companion_messages", "companion_prefs", + "packets", + "adverts", + "noise_floor", + "crc_errors", + "room_messages", + "room_client_sync", + "companion_contacts", + "companion_channels", + "companion_messages", + "companion_prefs", ] if tables_param == "all": diff --git a/repeater/web/auth/api_tokens.py b/repeater/web/auth/api_tokens.py index 5105e70..e236f37 100644 --- a/repeater/web/auth/api_tokens.py +++ b/repeater/web/auth/api_tokens.py @@ -22,22 +22,22 @@ def hash_token(self, token: str) -> str: def create_token(self, name: str) -> tuple[int, str]: plaintext_token = self.generate_api_token() token_hash = self.hash_token(plaintext_token) - + token_id = self.db.create_api_token(name, token_hash) - + logger.info(f"Created API token '{name}' with ID {token_id}") return token_id, plaintext_token - + def verify_token(self, token: str) -> Optional[Dict]: token_hash = self.hash_token(token) return self.db.verify_api_token(token_hash) - + def revoke_token(self, token_id: int) -> bool: deleted = self.db.revoke_api_token(token_id) - + if deleted: logger.info(f"Revoked API token ID {token_id}") - + return deleted def list_tokens(self) -> List[Dict]: diff --git a/repeater/web/auth/cherrypy_tool.py b/repeater/web/auth/cherrypy_tool.py index f107dc5..124eb57 100644 --- a/repeater/web/auth/cherrypy_tool.py +++ b/repeater/web/auth/cherrypy_tool.py @@ -8,7 +8,7 @@ def check_auth(): """ CherryPy tool to check authentication before processing request. - + Checks for either JWT in Authorization header, API token in X-API-Key header, or JWT token in query parameter (for EventSource/SSE connections). Sets cherrypy.request.user on success. @@ -17,26 +17,26 @@ def check_auth(): # Skip auth check for OPTIONS requests (CORS preflight) if cherrypy.request.method == "OPTIONS": return - + # Skip auth check for /auth/login endpoint if cherrypy.request.path_info == "/auth/login": return - + # Get auth handlers from config jwt_handler = cherrypy.config.get("jwt_handler") token_manager = cherrypy.config.get("token_manager") - + if not jwt_handler or not token_manager: logger.error("Auth handlers not initialized in cherrypy.config") cherrypy.response.status = 500 return {"success": False, "error": "Authentication system not configured"} - + # Check for JWT token in Authorization header first auth_header = cherrypy.request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[7:] # Remove "Bearer " prefix payload = jwt_handler.verify_jwt(token) - + if payload: cherrypy.request.user = { "username": payload.get("sub"), @@ -50,7 +50,7 @@ def check_auth(): query_token = cherrypy.request.params.get("token") if query_token: payload = jwt_handler.verify_jwt(query_token) - + if payload: cherrypy.request.user = { "username": payload.get("sub"), @@ -60,12 +60,12 @@ def check_auth(): # Remove token from params to avoid exposing it in logs del cherrypy.request.params["token"] return - + # Check for API token in X-API-Key header api_key = cherrypy.request.headers.get("X-API-Key", "") if api_key: token_info = token_manager.verify_token(api_key) - + if token_info: cherrypy.request.user = { "token_id": token_info["id"], @@ -79,6 +79,7 @@ def check_auth(): raise cherrypy.HTTPError(401, "Unauthorized - Valid JWT or API token required") -# Register the tool -cherrypy.tools.require_auth = cherrypy.Tool("before_handler", check_auth) -logger.info("CherryPy require_auth tool registered") +def register_require_auth_tool(): + if not hasattr(cherrypy.tools, "require_auth"): + cherrypy.tools.require_auth = cherrypy.Tool("before_handler", check_auth) + logger.info("CherryPy require_auth tool registered") diff --git a/repeater/web/auth/jwt_handler.py b/repeater/web/auth/jwt_handler.py index bc9d257..7ac487f 100644 --- a/repeater/web/auth/jwt_handler.py +++ b/repeater/web/auth/jwt_handler.py @@ -11,7 +11,7 @@ class JWTHandler: def __init__(self, secret: str, expiry_minutes: int = 15): self.secret = secret self.expiry_minutes = expiry_minutes - + def create_jwt(self, username: str, client_id: str) -> str: now = int(time.time()) diff --git a/repeater/web/auth_endpoints.py b/repeater/web/auth_endpoints.py index 2ceb572..6580b8a 100644 --- a/repeater/web/auth_endpoints.py +++ b/repeater/web/auth_endpoints.py @@ -1,6 +1,7 @@ """ Authentication endpoints for login and token management """ + import cherrypy import logging from .auth.middleware import require_auth @@ -10,7 +11,7 @@ class AuthAPIEndpoints: """Nested endpoint for /api/auth/* RESTful routes""" - + def __init__(self): # Create tokens nested endpoint for /api/auth/tokens self.tokens = TokensAPIEndpoint() @@ -18,446 +19,428 @@ def __init__(self): class TokensAPIEndpoint: """RESTful token management endpoints for /api/auth/tokens""" - + @cherrypy.expose @cherrypy.tools.json_out() @require_auth def index(self): # Handle CORS preflight - if cherrypy.request.method == 'OPTIONS': + if cherrypy.request.method == "OPTIONS": return {} - + # Get token manager from cherrypy config - token_manager = cherrypy.config.get('token_manager') + token_manager = cherrypy.config.get("token_manager") if not token_manager: cherrypy.response.status = 500 - return {'success': False, 'error': 'Token manager not available'} - - if cherrypy.request.method == 'GET': + return {"success": False, "error": "Token manager not available"} + + if cherrypy.request.method == "GET": try: tokens = token_manager.list_tokens() - return { - 'success': True, - 'tokens': tokens - } + return {"success": True, "tokens": tokens} except Exception as e: logger.error(f"Token list error: {e}") cherrypy.response.status = 500 - return { - 'success': False, - 'error': 'Failed to list tokens' - } - - elif cherrypy.request.method == 'POST': + return {"success": False, "error": "Failed to list tokens"} + + elif cherrypy.request.method == "POST": try: import json - body = cherrypy.request.body.read().decode('utf-8') + + body = cherrypy.request.body.read().decode("utf-8") data = json.loads(body) if body else {} - name = data.get('name', '').strip() - + name = data.get("name", "").strip() + if not name: cherrypy.response.status = 400 - return { - 'success': False, - 'error': 'Token name is required' - } - + return {"success": False, "error": "Token name is required"} + # Create the token token_id, plaintext_token = token_manager.create_token(name) - - logger.info(f"Generated API token '{name}' (ID: {token_id}) by user {cherrypy.request.user['username']}") - + + logger.info( + f"Generated API token '{name}' (ID: {token_id}) by user {cherrypy.request.user['username']}" + ) + return { - 'success': True, - 'token': plaintext_token, - 'token_id': token_id, - 'name': name, - 'warning': 'Save this token securely - it will not be shown again' + "success": True, + "token": plaintext_token, + "token_id": token_id, + "name": name, + "warning": "Save this token securely - it will not be shown again", } - + except Exception as e: logger.error(f"Token generation error: {e}") cherrypy.response.status = 500 - return { - 'success': False, - 'error': 'Failed to generate token' - } + return {"success": False, "error": "Failed to generate token"} else: raise cherrypy.HTTPError(405, "Method not allowed") - + @cherrypy.expose @cherrypy.tools.json_out() @require_auth def default(self, token_id=None): # Handle CORS preflight - if cherrypy.request.method == 'OPTIONS': + if cherrypy.request.method == "OPTIONS": return {} - + # Get token manager from cherrypy config - token_manager = cherrypy.config.get('token_manager') + token_manager = cherrypy.config.get("token_manager") if not token_manager: cherrypy.response.status = 500 - return {'success': False, 'error': 'Token manager not available'} - - if cherrypy.request.method == 'DELETE': + return {"success": False, "error": "Token manager not available"} + + if cherrypy.request.method == "DELETE": try: if not token_id: cherrypy.response.status = 400 - return { - 'success': False, - 'error': 'Token ID is required' - } - + return {"success": False, "error": "Token ID is required"} + # Convert to int try: token_id_int = int(token_id) except ValueError: cherrypy.response.status = 400 - return { - 'success': False, - 'error': 'Invalid token ID' - } - + return {"success": False, "error": "Invalid token ID"} + # Revoke the token success = token_manager.revoke_token(token_id_int) - + if success: - logger.info(f"Revoked API token ID {token_id_int} by user {cherrypy.request.user['username']}") - return { - 'success': True, - 'message': 'Token revoked successfully' - } + logger.info( + f"Revoked API token ID {token_id_int} by user {cherrypy.request.user['username']}" + ) + return {"success": True, "message": "Token revoked successfully"} else: cherrypy.response.status = 404 - return { - 'success': False, - 'error': 'Token not found' - } - + return {"success": False, "error": "Token not found"} + except Exception as e: logger.error(f"Token revocation error: {e}") cherrypy.response.status = 500 - return { - 'success': False, - 'error': 'Failed to revoke token' - } + return {"success": False, "error": "Failed to revoke token"} else: raise cherrypy.HTTPError(405, "Method not allowed") class AuthEndpoints: - def __init__(self, config, jwt_handler, token_manager, config_manager=None): self.config = config self.jwt_handler = jwt_handler self.token_manager = token_manager self.config_manager = config_manager - + @cherrypy.expose def login(self, **kwargs): - cherrypy.response.headers['Content-Type'] = 'application/json' - + cherrypy.response.headers["Content-Type"] = "application/json" + # Handle CORS preflight - if cherrypy.request.method == 'OPTIONS': - cherrypy.response.headers['Access-Control-Allow-Methods'] = 'POST, OPTIONS' - cherrypy.response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-API-Key' - return b'' - - if cherrypy.request.method != 'POST': + if cherrypy.request.method == "OPTIONS": + cherrypy.response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS" + cherrypy.response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, X-API-Key" + ) + return b"" + + if cherrypy.request.method != "POST": raise cherrypy.HTTPError(405, "Method not allowed") - + try: # Parse JSON body manually since we can't use json_in decorator with OPTIONS import json - body = cherrypy.request.body.read().decode('utf-8') + + body = cherrypy.request.body.read().decode("utf-8") data = json.loads(body) if body else {} - - username = data.get('username', '').strip() - password = data.get('password', '') - client_id = data.get('client_id', '').strip() - + + username = data.get("username", "").strip() + password = data.get("password", "") + client_id = data.get("client_id", "").strip() + if not username or not password or not client_id: - return json.dumps({ - 'success': False, - 'error': 'Missing required fields: username, password, client_id' - }).encode('utf-8') - + return json.dumps( + { + "success": False, + "error": "Missing required fields: username, password, client_id", + } + ).encode("utf-8") + # Validate credentials against config # Check if username is 'admin' and password matches config - repeater_config = self.config.get('repeater', {}) - security_config = repeater_config.get('security', {}) - config_password = security_config.get('admin_password', '') - + repeater_config = self.config.get("repeater", {}) + security_config = repeater_config.get("security", {}) + config_password = security_config.get("admin_password", "") + # Don't allow login with empty or unconfigured password if not config_password: - logger.warning(f"Login attempt rejected - password not configured") - return json.dumps({ - 'success': False, - 'error': 'System not configured. Please complete setup wizard.' - }).encode('utf-8') - - if username == 'admin' and password == config_password: + logger.warning("Login attempt rejected - password not configured") + return json.dumps( + { + "success": False, + "error": "System not configured. Please complete setup wizard.", + } + ).encode("utf-8") + + if username == "admin" and password == config_password: # Create JWT token token = self.jwt_handler.create_jwt(username, client_id) - - logger.info(f"Successful login for user '{username}' from client '{client_id[:8]}...'") - - return json.dumps({ - 'success': True, - 'token': token, - 'expires_in': self.jwt_handler.expiry_minutes * 60, - 'username': username - }).encode('utf-8') + + logger.info( + f"Successful login for user '{username}' from client '{client_id[:8]}...'" + ) + + return json.dumps( + { + "success": True, + "token": token, + "expires_in": self.jwt_handler.expiry_minutes * 60, + "username": username, + } + ).encode("utf-8") else: logger.warning(f"Failed login attempt for user '{username}'") - + # Don't reveal which part was wrong - return json.dumps({ - 'success': False, - 'error': 'Invalid username or password' - }).encode('utf-8') - + return json.dumps( + {"success": False, "error": "Invalid username or password"} + ).encode("utf-8") + except Exception as e: logger.error(f"Login error: {e}") - return json.dumps({ - 'success': False, - 'error': 'Internal server error' - }).encode('utf-8') - + return json.dumps({"success": False, "error": "Internal server error"}).encode("utf-8") + @cherrypy.expose @cherrypy.tools.json_out() @require_auth def verify(self): - if cherrypy.request.method != 'GET': + if cherrypy.request.method != "GET": raise cherrypy.HTTPError(405, "Method not allowed") - - return { - 'success': True, - 'authenticated': True, - 'user': cherrypy.request.user - } - + + return {"success": True, "authenticated": True, "user": cherrypy.request.user} + @cherrypy.expose def refresh(self, **kwargs): - cherrypy.response.headers['Content-Type'] = 'application/json' - + cherrypy.response.headers["Content-Type"] = "application/json" + # Handle CORS preflight - if cherrypy.request.method == 'OPTIONS': - cherrypy.response.headers['Access-Control-Allow-Methods'] = 'POST, OPTIONS' - cherrypy.response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-API-Key' - return b'' - - if cherrypy.request.method != 'POST': + if cherrypy.request.method == "OPTIONS": + cherrypy.response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS" + cherrypy.response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, X-API-Key" + ) + return b"" + + if cherrypy.request.method != "POST": raise cherrypy.HTTPError(405, "Method not allowed") - + try: import json - + # Manual authentication check (can't use @require_auth since we need to handle OPTIONS) - auth_header = cherrypy.request.headers.get('Authorization', '') - api_key = cherrypy.request.headers.get('X-API-Key', '') - - jwt_handler = cherrypy.config.get('jwt_handler') - token_manager = cherrypy.config.get('token_manager') - + auth_header = cherrypy.request.headers.get("Authorization", "") + api_key = cherrypy.request.headers.get("X-API-Key", "") + + jwt_handler = cherrypy.config.get("jwt_handler") + token_manager = cherrypy.config.get("token_manager") + user_info = None - + # Check JWT first - if auth_header.startswith('Bearer '): + if auth_header.startswith("Bearer "): token = auth_header[7:] payload = jwt_handler.verify_jwt(token) if payload: user_info = { - 'username': payload['sub'], - 'client_id': payload.get('client_id'), - 'auth_method': 'jwt' + "username": payload["sub"], + "client_id": payload.get("client_id"), + "auth_method": "jwt", } - + # Check API token if not user_info and api_key: token_data = token_manager.verify_token(api_key) if token_data: user_info = { - 'username': 'admin', - 'token_id': token_data['id'], - 'auth_method': 'api_token' + "username": "admin", + "token_id": token_data["id"], + "auth_method": "api_token", } - + if not user_info: - return json.dumps({ - 'success': False, - 'error': 'Unauthorized - Valid JWT or API token required' - }).encode('utf-8') - + return json.dumps( + {"success": False, "error": "Unauthorized - Valid JWT or API token required"} + ).encode("utf-8") + # Parse request body - body = cherrypy.request.body.read().decode('utf-8') + body = cherrypy.request.body.read().decode("utf-8") data = json.loads(body) if body else {} - - client_id = data.get('client_id', user_info.get('client_id', '')).strip() - + + client_id = data.get("client_id", user_info.get("client_id", "")).strip() + if not client_id: - return json.dumps({ - 'success': False, - 'error': 'Client ID is required' - }).encode('utf-8') - + return json.dumps({"success": False, "error": "Client ID is required"}).encode( + "utf-8" + ) + # Create new JWT token (refreshes expiry time) - new_token = self.jwt_handler.create_jwt(user_info['username'], client_id) - - logger.info(f"Token refreshed for user '{user_info['username']}' from client '{client_id[:8]}...'") - - return json.dumps({ - 'success': True, - 'token': new_token, - 'expires_in': self.jwt_handler.expiry_minutes * 60, - 'username': user_info['username'] - }).encode('utf-8') - + new_token = self.jwt_handler.create_jwt(user_info["username"], client_id) + + logger.info( + f"Token refreshed for user '{user_info['username']}' from client '{client_id[:8]}...'" + ) + + return json.dumps( + { + "success": True, + "token": new_token, + "expires_in": self.jwt_handler.expiry_minutes * 60, + "username": user_info["username"], + } + ).encode("utf-8") + except Exception as e: logger.error(f"Token refresh error: {e}") - return json.dumps({ - 'success': False, - 'error': 'Failed to refresh token' - }).encode('utf-8') - + return json.dumps({"success": False, "error": "Failed to refresh token"}).encode( + "utf-8" + ) + @cherrypy.expose def change_password(self): import json - - cherrypy.response.headers['Content-Type'] = 'application/json' - + + cherrypy.response.headers["Content-Type"] = "application/json" + # Handle CORS preflight - if cherrypy.request.method == 'OPTIONS': - cherrypy.response.headers['Access-Control-Allow-Methods'] = 'POST, OPTIONS' - cherrypy.response.headers['Access-Control-Allow-Headers'] = 'Content-Type, Authorization, X-API-Key' - return b'' - - if cherrypy.request.method != 'POST': + if cherrypy.request.method == "OPTIONS": + cherrypy.response.headers["Access-Control-Allow-Methods"] = "POST, OPTIONS" + cherrypy.response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, X-API-Key" + ) + return b"" + + if cherrypy.request.method != "POST": raise cherrypy.HTTPError(405, "Method not allowed") - + # Require authentication for POST # Get auth handlers from global cherrypy config - jwt_handler = cherrypy.config.get('jwt_handler') - token_manager = cherrypy.config.get('token_manager') - + jwt_handler = cherrypy.config.get("jwt_handler") + token_manager = cherrypy.config.get("token_manager") + if not jwt_handler or not token_manager: logger.error("Auth handlers not configured") raise cherrypy.HTTPError(500, "Authentication not configured") - + # Try JWT authentication first - auth_header = cherrypy.request.headers.get('Authorization', '') + auth_header = cherrypy.request.headers.get("Authorization", "") user = None - - if auth_header.startswith('Bearer '): + + if auth_header.startswith("Bearer "): token = auth_header[7:] # Remove 'Bearer ' prefix payload = jwt_handler.verify_jwt(token) - + if payload: user = { - 'username': payload['sub'], - 'client_id': payload['client_id'], - 'auth_type': 'jwt' + "username": payload["sub"], + "client_id": payload["client_id"], + "auth_type": "jwt", } - + # Try API token authentication if JWT failed if not user: - api_key = cherrypy.request.headers.get('X-API-Key', '') + api_key = cherrypy.request.headers.get("X-API-Key", "") if api_key: token_info = token_manager.verify_token(api_key) - + if token_info: user = { - 'username': 'api_token', - 'token_name': token_info['name'], - 'token_id': token_info['id'], - 'auth_type': 'api_token' + "username": "api_token", + "token_name": token_info["name"], + "token_id": token_info["id"], + "auth_type": "api_token", } - + if not user: cherrypy.response.status = 401 - return json.dumps({ - 'success': False, - 'error': 'Unauthorized - Valid JWT or API token required' - }).encode('utf-8') - + return json.dumps( + {"success": False, "error": "Unauthorized - Valid JWT or API token required"} + ).encode("utf-8") + try: # Parse JSON body manually - body = cherrypy.request.body.read().decode('utf-8') + body = cherrypy.request.body.read().decode("utf-8") data = json.loads(body) if body else {} - - current_password = data.get('current_password', '') - new_password = data.get('new_password', '') - + + current_password = data.get("current_password", "") + new_password = data.get("new_password", "") + if not current_password or not new_password: cherrypy.response.status = 400 - return json.dumps({ - 'success': False, - 'error': 'Both current_password and new_password are required' - }).encode('utf-8') - + return json.dumps( + { + "success": False, + "error": "Both current_password and new_password are required", + } + ).encode("utf-8") + # Validate new password strength if len(new_password) < 8: cherrypy.response.status = 400 - return json.dumps({ - 'success': False, - 'error': 'New password must be at least 8 characters long' - }).encode('utf-8') - + return json.dumps( + {"success": False, "error": "New password must be at least 8 characters long"} + ).encode("utf-8") + # Verify current password - repeater_config = self.config.get('repeater', {}) - security_config = repeater_config.get('security', {}) - config_password = security_config.get('admin_password', '') - + repeater_config = self.config.get("repeater", {}) + security_config = repeater_config.get("security", {}) + config_password = security_config.get("admin_password", "") + if not config_password: cherrypy.response.status = 500 - return json.dumps({ - 'success': False, - 'error': 'System configuration error' - }).encode('utf-8') - + return json.dumps({"success": False, "error": "System configuration error"}).encode( + "utf-8" + ) + if current_password != config_password: cherrypy.response.status = 401 - return json.dumps({ - 'success': False, - 'error': 'Current password is incorrect' - }).encode('utf-8') - + return json.dumps( + {"success": False, "error": "Current password is incorrect"} + ).encode("utf-8") + # Update password in config - if 'repeater' not in self.config: - self.config['repeater'] = {} - if 'security' not in self.config['repeater']: - self.config['repeater']['security'] = {} - - self.config['repeater']['security']['admin_password'] = new_password - + if "repeater" not in self.config: + self.config["repeater"] = {} + if "security" not in self.config["repeater"]: + self.config["repeater"]["security"] = {} + + self.config["repeater"]["security"]["admin_password"] = new_password + # Save to config file using ConfigManager if self.config_manager: if self.config_manager.save_to_file(): logger.info(f"Admin password changed successfully by user {user['username']}") - return json.dumps({ - 'success': True, - 'message': 'Password changed successfully. Please log in again with your new password.' - }).encode('utf-8') + return json.dumps( + { + "success": True, + "message": "Password changed successfully. Please log in again with your new password.", + } + ).encode("utf-8") else: cherrypy.response.status = 500 - return json.dumps({ - 'success': False, - 'error': 'Failed to save password to config file' - }).encode('utf-8') + return json.dumps( + {"success": False, "error": "Failed to save password to config file"} + ).encode("utf-8") else: cherrypy.response.status = 500 - return json.dumps({ - 'success': False, - 'error': 'Config manager not available' - }).encode('utf-8') - + return json.dumps( + {"success": False, "error": "Config manager not available"} + ).encode("utf-8") + except Exception as e: logger.error(f"Password change error: {e}") cherrypy.response.status = 500 - return json.dumps({ - 'success': False, - 'error': 'Failed to change password' - }).encode('utf-8') \ No newline at end of file + return json.dumps({"success": False, "error": "Failed to change password"}).encode( + "utf-8" + ) diff --git a/repeater/web/cad_calibration_engine.py b/repeater/web/cad_calibration_engine.py index f43dbe0..b290326 100644 --- a/repeater/web/cad_calibration_engine.py +++ b/repeater/web/cad_calibration_engine.py @@ -3,13 +3,12 @@ import random import threading import time -from typing import Any, Dict, Optional +from typing import Any, Dict logger = logging.getLogger("HTTPServer") class CADCalibrationEngine: - def __init__(self, daemon_instance=None, event_loop=None): self.daemon_instance = daemon_instance self.event_loop = event_loop @@ -49,8 +48,8 @@ async def test_cad_config( baseline_result = await radio.perform_cad(det_peak=35, det_min=25, timeout=0.3) if baseline_result: baseline_detections += 1 - except Exception: - pass + except Exception as exc: + logger.debug(f"CAD baseline sample failed: {exc}") await asyncio.sleep(0.1) # 100ms between baseline samples # Wait before actual test @@ -62,8 +61,8 @@ async def test_cad_config( result = await radio.perform_cad(det_peak=det_peak, det_min=det_min, timeout=0.3) if result: detections += 1 - except Exception: - pass + except Exception as exc: + logger.debug(f"CAD sample failed for det_peak={det_peak} det_min={det_min}: {exc}") # Variable delay to avoid sampling artifacts delay = 0.05 + (i % 3) * 0.05 # 50ms, 100ms, 150ms rotation diff --git a/repeater/web/companion_endpoints.py b/repeater/web/companion_endpoints.py index 7a5688d..0f433d3 100644 --- a/repeater/web/companion_endpoints.py +++ b/repeater/web/companion_endpoints.py @@ -617,9 +617,11 @@ def set_advert_name(self, **kwargs): companion_name = body.get("companion_name") if companion_name is None and getattr(self.daemon_instance, "identity_manager", None): pubkey = bridge.get_public_key() - for reg_name, identity, _ in self.daemon_instance.identity_manager.get_identities_by_type( - "companion" - ): + for ( + reg_name, + identity, + _, + ) in self.daemon_instance.identity_manager.get_identities_by_type("companion"): if identity.get_public_key() == pubkey: companion_name = reg_name break diff --git a/repeater/web/companion_ws_proxy.py b/repeater/web/companion_ws_proxy.py index d023fc9..29d8f89 100644 --- a/repeater/web/companion_ws_proxy.py +++ b/repeater/web/companion_ws_proxy.py @@ -25,7 +25,6 @@ def set_daemon(instance): class CompanionFrameWebSocket(WebSocket): - def opened(self): """Authenticate, resolve companion, open TCP socket, start reader.""" # JWT auth — same pattern as PacketWebSocket @@ -95,7 +94,9 @@ def opened(self): self._reader.start() user = payload.get("sub", "unknown") - logger.info(f"Companion WS opened: user={user}, companion={companion_name}, tcp={tcp_host}:{tcp_port}") + logger.info( + f"Companion WS opened: user={user}, companion={companion_name}, tcp={tcp_host}:{tcp_port}" + ) def received_message(self, message): """WS → TCP""" @@ -220,11 +221,11 @@ def _teardown(self): if tcp: try: tcp.close() - except Exception: - pass + except Exception as exc: + logger.debug(f"WS proxy TCP close failed for {name!r}: {exc}") self._tcp = None try: self.close() - except Exception: - pass + except Exception as exc: + logger.debug(f"WS proxy close failed for {name!r}: {exc}") diff --git a/repeater/web/http_server.py b/repeater/web/http_server.py index 88573af..ea2e9b6 100644 --- a/repeater/web/http_server.py +++ b/repeater/web/http_server.py @@ -1,7 +1,6 @@ import json import logging import os -import re import secrets from collections import deque from datetime import datetime @@ -10,15 +9,13 @@ import cherrypy import cherrypy_cors -from pymc_core.protocol.utils import PAYLOAD_TYPES, ROUTE_TYPES -from repeater import __version__ from repeater.config import resolve_storage_dir from repeater.data_acquisition import SQLiteHandler from .api_endpoints import APIEndpoints -from .auth import cherrypy_tool # Import to register the tool from .auth.api_tokens import APITokenManager +from .auth.cherrypy_tool import register_require_auth_tool from .auth.jwt_handler import JWTHandler from .auth_endpoints import AuthEndpoints @@ -26,10 +23,11 @@ try: from repeater.data_acquisition.websocket_handler import ( PacketWebSocket, - broadcast_packet, init_websocket, ) - from .companion_ws_proxy import CompanionFrameWebSocket, set_daemon as _set_companion_daemon + + from .companion_ws_proxy import CompanionFrameWebSocket + from .companion_ws_proxy import set_daemon as _set_companion_daemon WEBSOCKET_AVAILABLE = True except ImportError: @@ -42,7 +40,6 @@ # In-memory log buffer class LogBuffer(logging.Handler): - def __init__(self, max_lines=100): super().__init__() self.logs = deque(maxlen=max_lines) @@ -107,7 +104,6 @@ def openapi_json(self): class StatsApp: - def __init__( self, stats_getter: Optional[Callable] = None, @@ -165,7 +161,12 @@ def default(self, *args, **kwargs): raise cherrypy.NotFound() # Handle WebSocket routes - if args and len(args) >= 2 and args[0] == "ws" and args[1] in ("packets", "companion_frame"): + if ( + args + and len(args) >= 2 + and args[0] == "ws" + and args[1] in ("packets", "companion_frame") + ): # WebSocket tool will intercept this return "" @@ -174,7 +175,6 @@ def default(self, *args, **kwargs): class HTTPStatsServer: - def __init__( self, host: str = "0.0.0.0", @@ -288,6 +288,7 @@ def _json_error_handler(self, status, message, traceback, version): def start(self): try: + register_require_auth_tool() if self._cors_enabled: self._setup_server_cors() diff --git a/repeater/web/update_endpoints.py b/repeater/web/update_endpoints.py index 173592f..7ad5ee1 100644 --- a/repeater/web/update_endpoints.py +++ b/repeater/web/update_endpoints.py @@ -20,13 +20,16 @@ import os import re import ssl -import subprocess + +# Required for fixed internal maintenance commands. +import subprocess # nosec B404 import threading import time import urllib.error import urllib.request from datetime import datetime, timezone from typing import List, Optional +from urllib.parse import urlparse import cherrypy from repeater.service_utils import get_container_restart_message, is_buildroot, is_container @@ -44,6 +47,10 @@ # How long (seconds) before a cached check result expires CHECK_CACHE_TTL = 600 # 10 minutes +_RM_BIN = "/bin/rm" +_SED_BIN = "/usr/bin/sed" +_SYSTEMCTL_BIN = "/bin/systemctl" +_SUDO_BIN = "/usr/bin/sudo" _github_ssl_ctx: Optional[ssl.SSLContext] = None _disk_version_mismatch_logged: Optional[tuple] = None @@ -73,6 +80,7 @@ def _find_buildroot_upgrade_helper() -> Optional[str]: class _RateLimitError(Exception): """Raised when GitHub returns HTTP 403 due to rate limiting.""" + def __init__(self, msg: str, reset_at: Optional[datetime] = None): super().__init__(msg) self.reset_at = reset_at @@ -157,6 +165,7 @@ def _cache_and_return(value: str) -> str: else: try: from packaging.version import Version + disk_version = str(max(candidates, key=lambda v: Version(v))) except Exception: # packaging unavailable – sort lexicographically as best-effort @@ -166,13 +175,15 @@ def _cache_and_return(value: str) -> str: if disk_version is None: try: from importlib.metadata import version as _pkg_ver + disk_version = _pkg_ver(PACKAGE_NAME) - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] importlib.metadata fallback unavailable: {exc}") if disk_version is None: try: from repeater import __version__ + return _cache_and_return(__version__) except Exception: return _cache_and_return("unknown") @@ -183,6 +194,7 @@ def _cache_and_return(value: str) -> str: try: from repeater import __version__ as _running from packaging.version import Version + if Version(_running) > Version(disk_version): # status() polls can call this frequently; throttle mismatch logs. global _disk_version_mismatch_logged @@ -206,11 +218,12 @@ def _cache_and_return(value: str) -> str: # Strip PEP 440 local identifier (+gXXXXXX) – it only encodes # the git hash and causes spurious mismatches with GitHub versions. - return _cache_and_return(re.sub(r'\+[a-zA-Z0-9.]+$', '', _running)) - except Exception: - pass + return _cache_and_return(re.sub(r"\+[a-zA-Z0-9.]+$", "", _running)) + except Exception as exc: + logger.debug(f"[Update] Running-version sanity check skipped: {exc}") + + return _cache_and_return(re.sub(r"\+[a-zA-Z0-9.]+$", "", disk_version)) - return _cache_and_return(re.sub(r'\+[a-zA-Z0-9.]+$', '', disk_version)) # Channels file – persisted so the choice survives daemon restarts _CHANNELS_FILE = "/var/lib/pymc_repeater/.update_channel" @@ -259,9 +272,10 @@ def _detect_channel_from_dist_info() -> Optional[str]: # Use the highest-version dist-info so a stale old one doesn't win try: from packaging.version import Version + candidates.sort(key=lambda t: Version(t[0]), reverse=True) - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Could not version-sort direct_url candidates: {exc}") _, best_url_path = candidates[0] try: @@ -271,10 +285,10 @@ def _detect_channel_from_dist_info() -> Optional[str]: # ``requested_revision`` is only present when the user explicitly named # a branch/tag; absent means HEAD of the default branch. revision = vcs_info.get("requested_revision") - if revision and re.match(r'^[a-zA-Z0-9_./\-]+$', revision): + if revision and re.match(r"^[a-zA-Z0-9_./\-]+$", revision): return revision - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Failed to inspect direct_url metadata {best_url_path}: {exc}") return None @@ -294,7 +308,7 @@ def __init__(self): self.channel: str = self._load_channel() self.last_checked: Optional[datetime] = None # progress / install state - self.state: str = "idle" # idle | checking | installing | complete | error + self.state: str = "idle" # idle | checking | installing | complete | error self.error_message: Optional[str] = None self.progress_lines: List[str] = [] self._install_thread: Optional[threading.Thread] = None @@ -357,7 +371,9 @@ def snapshot(self) -> dict: "last_checked": self.last_checked.isoformat() if self.last_checked else None, "state": self.state, "error": self.error_message, - "rate_limit_until": self.rate_limit_until.isoformat() if self.rate_limit_until else None, + "rate_limit_until": self.rate_limit_until.isoformat() + if self.rate_limit_until + else None, } def set_channel(self, channel: str) -> None: @@ -439,6 +455,7 @@ def append_line(self, line: str) -> None: # Internal helpers # --------------------------------------------------------------------------- + def _fetch_url(url: str, timeout: int = 10) -> str: """Perform a simple GET and return text body, or raise on failure. @@ -451,10 +468,16 @@ def _fetch_url(url: str, timeout: int = 10) -> str: token = os.environ.get("GITHUB_TOKEN") or os.environ.get("GH_TOKEN") if token: headers["Authorization"] = f"Bearer {token}" + parsed = urlparse(url) + if parsed.scheme != "https" or parsed.netloc not in { + "api.github.com", + "raw.githubusercontent.com", + }: + raise RuntimeError(f"Refusing to fetch untrusted update URL: {url}") req = urllib.request.Request(url, headers=headers) try: ctx = _get_github_ssl_context() if url.startswith("https") else None - with urllib.request.urlopen(req, timeout=timeout, context=ctx) as resp: + with urllib.request.urlopen(req, timeout=timeout, context=ctx) as resp: # nosec B310 return resp.read().decode("utf-8", errors="replace") except urllib.error.HTTPError as exc: if exc.code == 403: @@ -464,8 +487,10 @@ def _fetch_url(url: str, timeout: int = 10) -> str: reset_ts = exc.headers.get("X-RateLimit-Reset") if reset_ts: reset_at = datetime.fromtimestamp(int(reset_ts), timezone.utc) - except Exception: - pass + except Exception as inner_exc: + logger.debug( + f"[Update] Failed to parse GitHub rate-limit reset header: {inner_exc}" + ) reset_str = reset_at.strftime("%H:%M UTC") if reset_at else "a short while" raise _RateLimitError( f"GitHub API rate limit exceeded — resets at {reset_str}. " @@ -482,7 +507,7 @@ def _get_latest_tag() -> str: tags = json.loads(body) for tag in tags: name = tag.get("name", "").lstrip("v") - if re.match(r'^\d+\.\d+', name): + if re.match(r"^\d+\.\d+", name): return name raise RuntimeError("No semver tags found in repository") @@ -496,10 +521,10 @@ def _branch_is_dynamic(channel: str) -> bool: if re.search(r'^version\s*=\s*["\'][0-9]', toml_text, re.MULTILINE): return False # Dynamic looks like: dynamic = ["version"] - if re.search(r'^dynamic\s*=', toml_text, re.MULTILINE): + if re.search(r"^dynamic\s*=", toml_text, re.MULTILINE): return True - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Could not determine channel versioning mode for {channel!r}: {exc}") return True # assume dynamic if we can't tell @@ -519,7 +544,7 @@ def _next_dev_version(base_tag: str, ahead_by: int) -> str: def _parse_dev_number(version_str: str) -> Optional[int]: """Extract the dev commit count from a setuptools_scm version like 1.0.6.dev118.""" - m = re.search(r'\.dev(\d+)', version_str) + m = re.search(r"\.dev(\d+)", version_str) return int(m.group(1)) if m else None @@ -567,6 +592,7 @@ def _cleanup_stale_dist_info(allow_sudo: bool = True) -> None: try: from packaging.version import Version + keep = max(found, key=lambda p: Version(found[p])) except Exception: return # can't determine winner safely — leave everything alone @@ -589,11 +615,15 @@ def _cleanup_stale_dist_info(allow_sudo: bool = True) -> None: # dist-info is root-owned (pip ran via sudo); use sudo to remove try: subprocess.run( - ["sudo", "--non-interactive", "rm", "-rf", path], - check=True, capture_output=True, timeout=10, - ) + [_SUDO_BIN, "--non-interactive", _RM_BIN, "-rf", path], + check=True, + capture_output=True, + timeout=10, + ) # nosec B603 logger.info(f"[Update] Removed stale dist-info (sudo): {path} (version {ver})") - _state.append_line(f"[pyMC updater] Removed stale dist-info: {os.path.basename(path)}") + _state.append_line( + f"[pyMC updater] Removed stale dist-info: {os.path.basename(path)}" + ) removed_any = True except Exception as exc2: logger.warning(f"[Update] Could not remove stale dist-info {path}: {exc2}") @@ -618,16 +648,15 @@ def _startup_dist_info_cleanup() -> None: def _has_update(installed: str, latest: str) -> bool: - """ - - """ + """ """ if installed == latest: return False try: from packaging.version import Version + return Version(latest) > Version(installed) - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] PEP 440 comparison failed for {installed!r} vs {latest!r}: {exc}") # Fallback: dev-number comparison only when base version is identical installed_dev = _parse_dev_number(installed) latest_dev = _parse_dev_number(latest) @@ -647,7 +676,8 @@ def _fetch_latest_version(channel: str) -> str: data = json.loads(body) ahead_by = int(data.get("ahead_by", 0)) return _next_dev_version(base_tag, ahead_by) - except Exception: + except Exception as exc: + logger.debug(f"[Update] Dynamic version compare failed for {channel!r}: {exc}") return base_tag # fallback: show the tag # Static version channel — read the pinned version from pyproject.toml on @@ -658,8 +688,8 @@ def _fetch_latest_version(channel: str) -> str: m = re.search(r'^version\s*=\s*["\']([0-9][^"\']*)["\']', toml_text, re.MULTILINE) if m: return m.group(1) - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Static version lookup failed for {channel!r}: {exc}") return base_tag # last-resort fallback @@ -673,7 +703,7 @@ def _fetch_changelog(channel: str, installed: str, max_commits: int = 50) -> Lis compare_url = f"{GITHUB_API_BASE}/compare/{base_tag}...{channel}?per_page=100" else: # For static channels compare from the installed tag if we know it - from_ref = installed if re.match(r'^\d+\.\d+', installed) else base_tag + from_ref = installed if re.match(r"^\d+\.\d+", installed) else base_tag compare_url = f"{GITHUB_API_BASE}/compare/{from_ref}...{channel}?per_page=100" body = _fetch_url(compare_url, timeout=12) @@ -701,15 +731,17 @@ def _fetch_changelog(channel: str, installed: str, max_commits: int = 50) -> Lis ) date = commit_data.get("author", {}).get("date", "") sha = c.get("sha", "") - result.append({ - "sha": sha, - "short_sha": sha[:7], - "title": title, - "body": body_text, - "author": author, - "date": date, - "url": c.get("html_url", ""), - }) + result.append( + { + "sha": sha, + "short_sha": sha[:7], + "title": title, + "body": body_text, + "author": author, + "date": date, + "url": c.get("html_url", ""), + } + ) return result except Exception as exc: logger.warning(f"[Update] Changelog fetch failed: {exc}") @@ -760,24 +792,24 @@ def _migrate_service_unit() -> None: logger.info("[Update] Buildroot image detected, skipping systemd unit migration.") return - import subprocess as _sp _SVC_UNIT = "/etc/systemd/system/pymc-repeater.service" _VENV_PYTHON = "/opt/pymc_repeater/venv/bin/python" try: - _sp.run(["sed", "-i", "/^Environment=.*PYTHONPATH/d", _SVC_UNIT], check=False) - _sp.run( - ["sed", "-i", - "s|WorkingDirectory=/opt/pymc_repeater|WorkingDirectory=/var/lib/pymc_repeater|", - _SVC_UNIT], + subprocess.run([_SED_BIN, "-i", "/^Environment=.*PYTHONPATH/d", _SVC_UNIT], check=False) # nosec B603 + subprocess.run( + [ + _SED_BIN, + "-i", + "s|WorkingDirectory=/opt/pymc_repeater|WorkingDirectory=/var/lib/pymc_repeater|", + _SVC_UNIT, + ], check=False, - ) - _sp.run( - ["sed", "-i", - f"s|ExecStart=/usr/bin/python3|ExecStart={_VENV_PYTHON}|", - _SVC_UNIT], + ) # nosec B603 + subprocess.run( + [_SED_BIN, "-i", f"s|ExecStart=/usr/bin/python3|ExecStart={_VENV_PYTHON}|", _SVC_UNIT], check=False, - ) - _sp.run(["systemctl", "daemon-reload"], check=False) + ) # nosec B603 + subprocess.run([_SYSTEMCTL_BIN, "daemon-reload"], check=False) # nosec B603 logger.info("[Update] Service unit migration applied (root path).") except Exception as exc: logger.warning(f"[Update] Service unit migration failed: {exc}") @@ -797,12 +829,13 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: text=True, bufsize=1, env=env, - ) - for line in proc.stdout: - line = line.rstrip() - if line: - _state.append_line(line) - logger.debug(f"[pip] {line}") + ) # nosec B603 + if proc.stdout is not None: + for line in proc.stdout: + line = line.rstrip() + if line: + _state.append_line(line) + logger.debug(f"[pip] {line}") proc.wait() return proc.returncode == 0 except Exception as exc: @@ -811,6 +844,7 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: return False import os as _os + env = _os.environ.copy() env["SETUPTOOLS_SCM_PRETEND_VERSION"] = _state.latest_version or "1.0.0" @@ -822,15 +856,20 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: _UPGRADE_WRAPPER = "/usr/local/bin/pymc-do-upgrade" _BUILDROOT_UPGRADE_HELPER = _find_buildroot_upgrade_helper() - is_root = (_os.geteuid() == 0) + is_root = _os.geteuid() == 0 if is_root and is_buildroot(): env["PYMC_REPEATER_REF"] = channel env["PYMC_CORE_REF"] = channel if not _BUILDROOT_UPGRADE_HELPER: - _state.finish_install(False, "Buildroot upgrade helper not found in repo checkout or image bootstrap paths") + _state.finish_install( + False, + "Buildroot upgrade helper not found in repo checkout or image bootstrap paths", + ) return - _state.append_line(f"[pyMC updater] Buildroot image detected – using {_BUILDROOT_UPGRADE_HELPER}") + _state.append_line( + f"[pyMC updater] Buildroot image detected – using {_BUILDROOT_UPGRADE_HELPER}" + ) cmd = ["/bin/sh", _BUILDROOT_UPGRADE_HELPER, "upgrade"] elif is_root: _migrate_service_unit() @@ -838,27 +877,27 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: # Ensure venv exists (migration from system-pip era) if not os.path.isfile(_VENV_PYTHON): _state.append_line("[pyMC updater] Creating venv (first-time migration)…") - _run(["python3", "-m", "venv", "--system-site-packages", _VENV_DIR], env=env) + _run(["/usr/bin/python3", "-m", "venv", "--system-site-packages", _VENV_DIR], env=env) _run([_VENV_PIP, "install", "--upgrade", "pip", "setuptools", "wheel"], env=env) # Clean up system-level packages to avoid shadowing - _run(["python3", "-m", "pip", "uninstall", "-y", "pymc_repeater"], env=env) - _run(["python3", "-m", "pip", "uninstall", "-y", "pymc_core"], env=env) + _run(["/usr/bin/python3", "-m", "pip", "uninstall", "-y", "pymc_repeater"], env=env) + _run(["/usr/bin/python3", "-m", "pip", "uninstall", "-y", "pymc_core"], env=env) # Remove stale source tree that could shadow the venv package stale_src = "/opt/pymc_repeater/repeater" if os.path.isdir(stale_src): _state.append_line("[pyMC updater] Removing stale source tree…") import shutil + shutil.rmtree(stale_src, ignore_errors=True) - install_spec = ( - f"pymc_repeater[hardware] @ git+https://github.com/{GITHUB_OWNER}/{GITHUB_REPO}.git@{channel}" - ) - _state.append_line(f"[pyMC updater] Running as root – venv pip install") + install_spec = f"pymc_repeater[hardware] @ git+https://github.com/{GITHUB_OWNER}/{GITHUB_REPO}.git@{channel}" + _state.append_line("[pyMC updater] Running as root – venv pip install") _state.append_line(f"[pyMC updater] Target: {install_spec}") cmd = [ - _VENV_PIP, "install", + _VENV_PIP, + "install", "--upgrade", "--no-cache-dir", install_spec, @@ -866,7 +905,7 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: elif _os.path.isfile(_UPGRADE_WRAPPER): _state.append_line(f"[pyMC updater] Using sudo wrapper: {_UPGRADE_WRAPPER}") # The wrapper handles venv creation/migration internally - cmd = ["sudo", _UPGRADE_WRAPPER, channel, _state.latest_version or ""] + cmd = [_SUDO_BIN, _UPGRADE_WRAPPER, channel, _state.latest_version or ""] else: msg = ( f"Upgrade wrapper not found at {_UPGRADE_WRAPPER}. " @@ -885,6 +924,7 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: restart_msg = "Restart not attempted" try: from repeater.service_utils import restart_service + restart_ok, restart_msg = restart_service() logger.info(f"[Update] Post-upgrade restart: {restart_msg}") except Exception as exc: @@ -897,9 +937,13 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: f"Upgraded to latest on channel '{channel}' – {get_container_restart_message()}", ) else: - _state.finish_install(True, f"Upgraded to latest on channel '{channel}' – service restarted") + _state.finish_install( + True, f"Upgraded to latest on channel '{channel}' – service restarted" + ) else: - _state.finish_install(False, f"Upgrade succeeded but service restart failed: {restart_msg}") + _state.finish_install( + False, f"Upgrade succeeded but service restart failed: {restart_msg}" + ) else: _state.finish_install(False, "pip install failed – see progress log for details") @@ -911,13 +955,15 @@ def _run(cmd: List[str], env: Optional[dict] = None) -> bool: # CherryPy Endpoint class # --------------------------------------------------------------------------- -class UpdateAPIEndpoints: +class UpdateAPIEndpoints: def _set_cors_headers(self, config: dict) -> None: if config.get("web", {}).get("cors_enabled", False): cherrypy.response.headers["Access-Control-Allow-Origin"] = "*" cherrypy.response.headers["Access-Control-Allow-Methods"] = "GET, POST, OPTIONS" - cherrypy.response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" + cherrypy.response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization" + ) def _require_post(self): if cherrypy.request.method != "POST": @@ -963,8 +1009,8 @@ def check(self, **kwargs): body = {} try: body = cherrypy.request.json or {} - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Ignoring non-JSON update check payload: {exc}") force = bool(body.get("force", False)) # Honour the cache to avoid hammering GitHub (unless forced) @@ -985,11 +1031,13 @@ def check(self, **kwargs): remaining = (_state.rate_limit_until - datetime.now(timezone.utc)).total_seconds() if remaining > 0: reset_str = _state.rate_limit_until.strftime("%H:%M UTC") - return self._ok({ - "message": f"GitHub rate limit active — resets at {reset_str}", - "state": snap["state"], - **snap, - }) + return self._ok( + { + "message": f"GitHub rate limit active — resets at {reset_str}", + "state": snap["state"], + **snap, + } + ) if not force and snap["last_checked"] is not None: age = (datetime.now(timezone.utc) - _state.last_checked).total_seconds() @@ -1024,8 +1072,8 @@ def install(self, **kwargs): body = {} try: body = cherrypy.request.json or {} - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Ignoring non-JSON update install payload: {exc}") snap = _state.snapshot() @@ -1038,7 +1086,7 @@ def install(self, **kwargs): if snap["latest_version"] is not None: return self._err( f"Already up to date ({snap['current_version']}). " - "Pass {\"force\": true} to reinstall anyway.", + 'Pass {"force": true} to reinstall anyway.', 409, ) @@ -1048,14 +1096,14 @@ def install(self, **kwargs): return self._err("Could not start install thread – check state", 409) t.start() - logger.warning( - f"[Update] Install triggered via API – channel={_state.channel}" + logger.warning(f"[Update] Install triggered via API – channel={_state.channel}") + return self._ok( + { + "message": f"Update started on channel '{_state.channel}'. " + "Watch /api/update/progress for live output.", + "state": "installing", + } ) - return self._ok({ - "message": f"Update started on channel '{_state.channel}'. " - "Watch /api/update/progress for live output.", - "state": "installing", - }) # ------------------------------------------------------------------ # # GET /api/update/progress (SSE) # @@ -1095,11 +1143,13 @@ def generate(): # Terminate stream when install completes or errors if current_state in ("complete", "error") and last_idx >= len(current_lines): - done_payload = json.dumps({ - "type": "done", - "state": current_state, - "error": snap.get("error"), - }) + done_payload = json.dumps( + { + "type": "done", + "state": current_state, + "error": snap.get("error"), + } + ) yield f"data: {done_payload}\n\n" return @@ -1133,10 +1183,12 @@ def channels(self, **kwargs): return "" branch_list = _fetch_branches() - return self._ok({ - "channels": branch_list, - "current_channel": _state.channel, - }) + return self._ok( + { + "channels": branch_list, + "current_channel": _state.channel, + } + ) # ------------------------------------------------------------------ # # POST /api/update/set_channel # @@ -1157,8 +1209,8 @@ def set_channel(self, **kwargs): body = {} try: body = cherrypy.request.json or {} - except Exception: - pass + except Exception as exc: + logger.debug(f"[Update] Ignoring non-JSON update channel payload: {exc}") channel = str(body.get("channel", "")).strip() if not channel: @@ -1169,10 +1221,12 @@ def set_channel(self, **kwargs): _state.set_channel(channel) logger.info(f"[Update] Channel changed to '{channel}' via API") - return self._ok({ - "channel": channel, - "message": f"Channel switched to '{channel}'. Run /api/update/check to verify.", - }) + return self._ok( + { + "channel": channel, + "message": f"Channel switched to '{channel}'. Run /api/update/check to verify.", + } + ) # ------------------------------------------------------------------ # # GET /api/update/changelog # @@ -1195,9 +1249,11 @@ def changelog(self, **kwargs): latest = snap["latest_version"] or "" commits = _fetch_changelog(channel, installed, max_commits) - return self._ok({ - "channel": channel, - "installed": installed, - "latest": latest, - "commits": commits, - }) + return self._ok( + { + "channel": channel, + "installed": installed, + "latest": latest, + "commits": commits, + } + ) diff --git a/tests/test_airtime.py b/tests/test_airtime.py index f765e79..cf61537 100644 --- a/tests/test_airtime.py +++ b/tests/test_airtime.py @@ -18,7 +18,7 @@ def _semtech_airtime_ms(payload_len: int, sf: int, bw_hz: int, cr: int, preamble crc = 1 h = 0 # explicit header de = 1 if (sf >= 11 and bw_hz <= 125000) else 0 - t_sym = (2 ** sf) / (bw_hz / 1000) + t_sym = (2**sf) / (bw_hz / 1000) t_preamble = (preamble + 4.25) * t_sym numerator = max(8 * payload_len - 4 * sf + 28 + 16 * crc - 20 * h, 0) denominator = 4 * (sf - 2 * de) diff --git a/tests/test_api_endpoints_core_coverage.py b/tests/test_api_endpoints_core_coverage.py index feba70a..b55285d 100644 --- a/tests/test_api_endpoints_core_coverage.py +++ b/tests/test_api_endpoints_core_coverage.py @@ -21,9 +21,7 @@ def _make_api(config=None): def _attach_storage(api, storage): - api.daemon_instance = SimpleNamespace( - repeater_handler=SimpleNamespace(storage=storage) - ) + api.daemon_instance = SimpleNamespace(repeater_handler=SimpleNamespace(storage=storage)) @pytest.fixture @@ -333,7 +331,7 @@ def test_config_export_redacts_secrets_and_identity_keys(cherrypy_ctx): "companions": [{"name": "c1", "identity_key": bytes.fromhex("0102")}], "room_servers": [{"name": "r1", "identity_key": bytes.fromhex("0304")}], }, - "misc": {"blob": b"\x0A\x0B"}, + "misc": {"blob": b"\x0a\x0b"}, } ) @@ -689,7 +687,9 @@ def test_db_vacuum_options_success_and_error(cherrypy_ctx): assert result["success"] is True assert result["data"] == {"size_before": 1000, "size_after": 700, "freed_bytes": 300} - sqlite_path.stat = MagicMock(side_effect=[SimpleNamespace(st_size=700), SimpleNamespace(st_size=700)]) + sqlite_path.stat = MagicMock( + side_effect=[SimpleNamespace(st_size=700), SimpleNamespace(st_size=700)] + ) sqlite_handler.vacuum.side_effect = RuntimeError("vacuum failed") err = api.db_vacuum() assert err["success"] is False @@ -840,96 +840,99 @@ def test_validate_config_disabled_radio_warns_but_valid(cherrypy_ctx, tmp_path): def test_update_web_config_options_no_updates_success_failure(cherrypy_ctx): - request, _ = cherrypy_ctx - api = _make_api({"web": {"cors_enabled": True}}) - - request.method = "OPTIONS" - assert api.update_web_config() == "" - - request.method = "POST" - request.json = {} - no_updates = api.update_web_config() - assert no_updates["success"] is False - assert "No configuration updates" in no_updates["error"] - - request.json = {"web": {"cors_enabled": True}} - api.config_manager.update_and_save.return_value = {"success": True, "saved": True} - ok = api.update_web_config() - assert ok["success"] is True - assert ok["data"]["persisted"] is True - api.config_manager.update_and_save.assert_called_with( - updates={"web": {"cors_enabled": True}}, - live_update=False, - ) + request, _ = cherrypy_ctx + api = _make_api({"web": {"cors_enabled": True}}) - api.config_manager.update_and_save.return_value = {"success": False, "error": "bad"} - fail = api.update_web_config() - assert fail["success"] is False - assert fail["error"] == "bad" + request.method = "OPTIONS" + assert api.update_web_config() == "" + + request.method = "POST" + request.json = {} + no_updates = api.update_web_config() + assert no_updates["success"] is False + assert "No configuration updates" in no_updates["error"] + + request.json = {"web": {"cors_enabled": True}} + api.config_manager.update_and_save.return_value = {"success": True, "saved": True} + ok = api.update_web_config() + assert ok["success"] is True + assert ok["data"]["persisted"] is True + api.config_manager.update_and_save.assert_called_with( + updates={"web": {"cors_enabled": True}}, + live_update=False, + ) + + api.config_manager.update_and_save.return_value = {"success": False, "error": "bad"} + fail = api.update_web_config() + assert fail["success"] is False + assert fail["error"] == "bad" def test_update_web_config_requires_post_and_handles_exception(cherrypy_ctx): - request, _ = cherrypy_ctx - api = _make_api({"web": {"cors_enabled": True}}) + request, _ = cherrypy_ctx + api = _make_api({"web": {"cors_enabled": True}}) - request.method = "GET" - with pytest.raises(cherrypy.HTTPError) as exc: - api.update_web_config() - assert exc.value.status == 405 + request.method = "GET" + with pytest.raises(cherrypy.HTTPError) as exc: + api.update_web_config() + assert exc.value.status == 405 - request.method = "POST" - request.json = {"web": {"site_name": "mesh"}} - api.config_manager.update_and_save.side_effect = RuntimeError("write failed") - err = api.update_web_config() - assert err["success"] is False - assert "write failed" in err["error"] + request.method = "POST" + request.json = {"web": {"site_name": "mesh"}} + api.config_manager.update_and_save.side_effect = RuntimeError("write failed") + err = api.update_web_config() + assert err["success"] is False + assert "write failed" in err["error"] def test_validate_config_top_level_must_be_mapping(cherrypy_ctx, tmp_path): - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text("- list\n- not\n- mapping\n", encoding="utf-8") + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text("- list\n- not\n- mapping\n", encoding="utf-8") - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - assert any(e["message"].startswith("Top-level YAML value must be a mapping") for e in result["data"]["errors"]) + assert result["success"] is True + assert result["data"]["valid"] is False + assert any( + e["message"].startswith("Top-level YAML value must be a mapping") + for e in result["data"]["errors"] + ) def test_validate_config_invalid_radio_type_and_missing_sections(cherrypy_ctx, tmp_path): - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text( - """ + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text( + """ repeater: node_name: "" radio_type: weird_radio """.strip(), - encoding="utf-8", - ) + encoding="utf-8", + ) - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - paths = {e["path"] for e in result["data"]["errors"]} - assert "repeater.node_name" in paths - assert "repeater.security" in paths - assert "radio_type" in paths + assert result["success"] is True + assert result["data"]["valid"] is False + paths = {e["path"] for e in result["data"]["errors"]} + assert "repeater.node_name" in paths + assert "repeater.security" in paths + assert "radio_type" in paths def test_validate_config_pymc_tcp_placeholder_and_bad_port(cherrypy_ctx, tmp_path): - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text( - """ + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text( + """ repeater: node_name: mesh-node-03 security: @@ -946,25 +949,25 @@ def test_validate_config_pymc_tcp_placeholder_and_bad_port(cherrypy_ctx, tmp_pat host: REPLACE_WITH_MODEM_HOST port: 70000 """.strip(), - encoding="utf-8", - ) + encoding="utf-8", + ) - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - paths = {e["path"] for e in result["data"]["errors"]} - assert "pymc_tcp.host" in paths - assert "pymc_tcp.port" in paths + assert result["success"] is True + assert result["data"]["valid"] is False + paths = {e["path"] for e in result["data"]["errors"]} + assert "pymc_tcp.host" in paths + assert "pymc_tcp.port" in paths def test_validate_config_sx1262_ch341_missing_sections(cherrypy_ctx, tmp_path): - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text( - """ + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text( + """ repeater: node_name: mesh-node-04 security: @@ -978,26 +981,26 @@ def test_validate_config_sx1262_ch341_missing_sections(cherrypy_ctx, tmp_path): tx_power: 22 preamble_length: 16 """.strip(), - encoding="utf-8", - ) + encoding="utf-8", + ) - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - paths = {e["path"] for e in result["data"]["errors"]} - assert "sx1262" in paths - assert "ch341" in paths + assert result["success"] is True + assert result["data"]["valid"] is False + paths = {e["path"] for e in result["data"]["errors"]} + assert "sx1262" in paths + assert "ch341" in paths def test_validate_config_rejects_bool_numeric_fields(cherrypy_ctx, tmp_path): - """Booleans silently cast to int in Python, so this guards explicit type checks.""" - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text( - """ + """Booleans silently cast to int in Python, so this guards explicit type checks.""" + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text( + """ repeater: node_name: mesh-node-bool security: @@ -1014,25 +1017,25 @@ def test_validate_config_rejects_bool_numeric_fields(cherrypy_ctx, tmp_path): port: /dev/ttyUSB0 baud_rate: true """.strip(), - encoding="utf-8", - ) + encoding="utf-8", + ) - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - errors = {e["path"]: e["message"] for e in result["data"]["errors"]} - assert "radio.bandwidth" in errors - assert "kiss.baud_rate" in errors + assert result["success"] is True + assert result["data"]["valid"] is False + errors = {e["path"]: e["message"] for e in result["data"]["errors"]} + assert "radio.bandwidth" in errors + assert "kiss.baud_rate" in errors def test_validate_config_radio_numeric_ranges_and_modes(cherrypy_ctx, tmp_path): - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text( - """ + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text( + """ repeater: node_name: mesh-node-ranges security: @@ -1055,29 +1058,29 @@ def test_validate_config_radio_numeric_ranges_and_modes(cherrypy_ctx, tmp_path): txen_pin: 18 rxen_pin: 17 """.strip(), - encoding="utf-8", - ) + encoding="utf-8", + ) - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - paths = {e["path"] for e in result["data"]["errors"]} - assert "radio.frequency" in paths - assert "radio.bandwidth" in paths - assert "radio.spreading_factor" in paths - assert "radio.coding_rate" in paths - assert "radio.tx_power" in paths - assert "radio.preamble_length" in paths + assert result["success"] is True + assert result["data"]["valid"] is False + paths = {e["path"] for e in result["data"]["errors"]} + assert "radio.frequency" in paths + assert "radio.bandwidth" in paths + assert "radio.spreading_factor" in paths + assert "radio.coding_rate" in paths + assert "radio.tx_power" in paths + assert "radio.preamble_length" in paths def test_validate_config_en_pins_type_and_entry_validation(cherrypy_ctx, tmp_path): - request, _ = cherrypy_ctx - request.method = "GET" - api = _make_api() - api._config_path = str(tmp_path / "config.yaml") - (tmp_path / "config.yaml").write_text( - """ + request, _ = cherrypy_ctx + request.method = "GET" + api = _make_api() + api._config_path = str(tmp_path / "config.yaml") + (tmp_path / "config.yaml").write_text( + """ repeater: node_name: mesh-node-enpins security: @@ -1101,67 +1104,67 @@ def test_validate_config_en_pins_type_and_entry_validation(cherrypy_ctx, tmp_pat rxen_pin: 17 en_pins: [21, bad] """.strip(), - encoding="utf-8", - ) + encoding="utf-8", + ) - result = api.validate_config() + result = api.validate_config() - assert result["success"] is True - assert result["data"]["valid"] is False - paths = {e["path"] for e in result["data"]["errors"]} - assert "sx1262.en_pins[1]" in paths + assert result["success"] is True + assert result["data"]["valid"] is False + paths = {e["path"] for e in result["data"]["errors"]} + assert "sx1262.en_pins[1]" in paths def test_config_import_web_only_no_restart_required(cherrypy_ctx): - request, _ = cherrypy_ctx - request.method = "POST" - api = _make_api({"web": {"site_name": "old"}}) - api.config_manager.update_and_save.return_value = {"ok": True} - api.config_manager.save_to_file.return_value = True - request.json = {"config": {"web": {"site_name": "new", "cors_enabled": True}}} + request, _ = cherrypy_ctx + request.method = "POST" + api = _make_api({"web": {"site_name": "old"}}) + api.config_manager.update_and_save.return_value = {"ok": True} + api.config_manager.save_to_file.return_value = True + request.json = {"config": {"web": {"site_name": "new", "cors_enabled": True}}} - result = api.config_import() + result = api.config_import() - assert result["success"] is True - assert result["restart_required"] is False - assert result["sections_updated"] == ["web"] - assert api.config["web"]["site_name"] == "new" - assert api.config["web"]["cors_enabled"] is True + assert result["success"] is True + assert result["restart_required"] is False + assert result["sections_updated"] == ["web"] + assert api.config["web"]["site_name"] == "new" + assert api.config["web"]["cors_enabled"] is True def test_config_import_identity_redaction_preserves_by_name_for_room_servers(cherrypy_ctx): - request, _ = cherrypy_ctx - request.method = "POST" - api = _make_api( - { - "identities": { - "room_servers": [ - {"name": "main-room", "identity_key": bytes.fromhex("ABCD")}, - ] - } - } - ) - api.config_manager.update_and_save.return_value = {"ok": True} - api.config_manager.save_to_file.return_value = True - request.json = { - "config": { - "identities": { - "room_servers": [ - {"name": "main-room", "identity_key": "*** REDACTED ***"}, - {"name": "new-room", "identity_key": "*** REDACTED ***"}, - ] - } - } + request, _ = cherrypy_ctx + request.method = "POST" + api = _make_api( + { + "identities": { + "room_servers": [ + {"name": "main-room", "identity_key": bytes.fromhex("ABCD")}, + ] + } } + ) + api.config_manager.update_and_save.return_value = {"ok": True} + api.config_manager.save_to_file.return_value = True + request.json = { + "config": { + "identities": { + "room_servers": [ + {"name": "main-room", "identity_key": "*** REDACTED ***"}, + {"name": "new-room", "identity_key": "*** REDACTED ***"}, + ] + } + } + } - result = api.config_import() + result = api.config_import() - assert result["success"] is True - rooms = api.config["identities"]["room_servers"] - by_name = {r["name"]: r["identity_key"] for r in rooms} - assert by_name["main-room"] == bytes.fromhex("ABCD") - # Unknown existing room keeps empty value when imported as redacted. - assert by_name["new-room"] == "" + assert result["success"] is True + rooms = api.config["identities"]["room_servers"] + by_name = {r["name"]: r["identity_key"] for r in rooms} + assert by_name["main-room"] == bytes.fromhex("ABCD") + # Unknown existing room keeps empty value when imported as redacted. + assert by_name["new-room"] == "" def test_stats_includes_versions_and_buildroot_image_info(cherrypy_ctx): @@ -1169,7 +1172,10 @@ def test_stats_includes_versions_and_buildroot_image_info(cherrypy_ctx): api = _make_api({"radio_type": "sx1262", "web": {"site_name": "Field"}}) api.stats_getter = lambda: {"uptime": 10} - with patch("repeater.web.api_endpoints.get_buildroot_image_info", return_value={"image_name": "pyMC", "image_version": "1.2.3"}): + with patch( + "repeater.web.api_endpoints.get_buildroot_image_info", + return_value={"image_name": "pyMC", "image_version": "1.2.3"}, + ): out = api.stats() assert out["uptime"] == 10 @@ -1184,7 +1190,9 @@ def test_gps_snapshot_when_service_present_and_default_when_absent(cherrypy_ctx) del cherrypy_ctx api = _make_api({"gps": {"enabled": True}}) - api.daemon_instance = SimpleNamespace(gps_service=SimpleNamespace(get_snapshot=lambda: {"running": True})) + api.daemon_instance = SimpleNamespace( + gps_service=SimpleNamespace(get_snapshot=lambda: {"running": True}) + ) out = api.gps() assert out == {"success": True, "data": {"running": True}} @@ -1226,9 +1234,16 @@ def test_check_pymc_console_and_mqtt_status_and_broker_presets(cherrypy_ctx): assert status2["success"] is True assert status2["data"]["brokers"][0]["name"] == "main" - with patch("repeater.presets.list_presets", return_value=["waev"]), patch( - "repeater.presets.get_preset", - return_value={"display_name": "Waev", "website": "https://waev.app", "brokers": [{"host": "h"}]}, + with ( + patch("repeater.presets.list_presets", return_value=["waev"]), + patch( + "repeater.presets.get_preset", + return_value={ + "display_name": "Waev", + "website": "https://waev.app", + "brokers": [{"host": "h"}], + }, + ), ): presets = api.broker_presets() assert presets["success"] is True @@ -1587,14 +1602,18 @@ def test_identity_endpoints_paths(cherrypy_ctx): if t == "room_server" else [("comp1", _FakeIdentityObj(0x51), {"settings": {"tcp_port": 5000}})] ), - get_identity_by_name=lambda n: (_FakeIdentityObj(0x42), {}, "room_server") if n == "main" else None, + get_identity_by_name=lambda n: ( + (_FakeIdentityObj(0x42), {}, "room_server") if n == "main" else None + ), named_identities={"comp1": 1, "main": 1}, ) api.daemon_instance = SimpleNamespace(identity_manager=id_mgr) api.config = { "identities": { "room_servers": [{"name": "main", "identity_key": "a" * 64, "settings": {"x": 1}}], - "companions": [{"name": "comp1", "identity_key": "b" * 64, "settings": {"tcp_port": 5000}}], + "companions": [ + {"name": "comp1", "identity_key": "b" * 64, "settings": {"tcp_port": 5000}} + ], } } api.config_manager.save_to_file.return_value = True @@ -1616,15 +1635,27 @@ def test_identity_endpoints_paths(cherrypy_ctx): assert api.create_identity()["success"] is False request.json = {"name": "x", "type": "invalid"} assert api.create_identity()["success"] is False - request.json = {"name": "x", "type": "room_server", "settings": {"admin_password": "p", "guest_password": "p"}} + request.json = { + "name": "x", + "type": "room_server", + "settings": {"admin_password": "p", "guest_password": "p"}, + } assert api.create_identity()["success"] is False request.json = {"name": "comp1", "type": "companion", "identity_key": "aa" * 32} assert api.create_identity()["success"] is False - request.json = {"name": "new-comp", "type": "companion", "identity_key": "cc" * 32, "settings": {"node_name": "N"}} + request.json = { + "name": "new-comp", + "type": "companion", + "identity_key": "cc" * 32, + "settings": {"node_name": "N"}, + } api.event_loop = object() api.daemon_instance = SimpleNamespace(add_companion_from_config=MagicMock()) - with patch("asyncio.run_coroutine_threadsafe", return_value=SimpleNamespace(result=lambda timeout: True)): + with patch( + "asyncio.run_coroutine_threadsafe", + return_value=SimpleNamespace(result=lambda timeout: True), + ): created = api.create_identity() assert created["success"] is True @@ -1674,11 +1705,16 @@ def test_acl_endpoints_paths(cherrypy_ctx): login_helper = SimpleNamespace(get_acl_dict=lambda: {0x42: acl, 0x51: _FakeACL([])}) id_mgr = SimpleNamespace( get_identities_by_type=lambda t: ( - [("room1", _FakeIdentityObj(0x42), {})] if t == "room_server" else [("comp1", _FakeIdentityObj(0x51), {})] + [("room1", _FakeIdentityObj(0x42), {})] + if t == "room_server" + else [("comp1", _FakeIdentityObj(0x51), {})] ) ) local = _FakeIdentityObj(0x42) - frame_server = SimpleNamespace(companion_hash="0x51", _client_writer=SimpleNamespace(get_extra_info=lambda k: ("10.0.0.2", 1234))) + frame_server = SimpleNamespace( + companion_hash="0x51", + _client_writer=SimpleNamespace(get_extra_info=lambda k: ("10.0.0.2", 1234)), + ) api.daemon_instance = SimpleNamespace( login_helper=login_helper, identity_manager=id_mgr, @@ -1727,14 +1763,37 @@ def test_room_endpoint_slice(cherrypy_ctx): request.method = "GET" db = SimpleNamespace( get_room_message_count=MagicMock(return_value=1), - get_room_messages=MagicMock(return_value=[{"id": 1, "author_pubkey": "aa" * 32, "post_timestamp": 1.0, "sender_timestamp": 1, "message_text": "m", "txt_type": 0}]), + get_room_messages=MagicMock( + return_value=[ + { + "id": 1, + "author_pubkey": "aa" * 32, + "post_timestamp": 1.0, + "sender_timestamp": 1, + "message_text": "m", + "txt_type": 0, + } + ] + ), get_messages_since=MagicMock(return_value=[]), delete_room_message=MagicMock(return_value=True), clear_room_messages=MagicMock(return_value=1), get_all_room_clients=MagicMock(return_value=[]), ) - room = SimpleNamespace(db=db, max_posts=10, _running=True, next_push_time=0, last_cleanup_time=0) - with patch.object(api, "_get_room_server_by_name_or_hash", return_value={"room_server": room, "name": "room", "hash": 0x42, "identity": None, "config": {}}): + room = SimpleNamespace( + db=db, max_posts=10, _running=True, next_push_time=0, last_cleanup_time=0 + ) + with patch.object( + api, + "_get_room_server_by_name_or_hash", + return_value={ + "room_server": room, + "name": "room", + "hash": 0x42, + "identity": None, + "config": {}, + }, + ): _attach_storage(api, SimpleNamespace(get_node_name_by_pubkey=lambda _pk: "Node")) msgs = api.room_messages(room_name="room") assert msgs["success"] is True diff --git a/tests/test_auth_components.py b/tests/test_auth_components.py index b36cd54..c3c3ad4 100644 --- a/tests/test_auth_components.py +++ b/tests/test_auth_components.py @@ -20,7 +20,9 @@ def test_jwt_handler_create_and_verify_and_invalid_cases(): assert payload["sub"] == "admin" assert payload["client_id"] == "client-1" - expired = jwt.encode({"sub": "admin", "client_id": "c", "iat": 1, "exp": 1}, secret, algorithm="HS256") + expired = jwt.encode( + {"sub": "admin", "client_id": "c", "iat": 1, "exp": 1}, secret, algorithm="HS256" + ) assert h.verify_jwt(expired) is None assert h.verify_jwt("not-a-token") is None diff --git a/tests/test_auth_endpoints.py b/tests/test_auth_endpoints.py index 9a49413..c2b5651 100644 --- a/tests/test_auth_endpoints.py +++ b/tests/test_auth_endpoints.py @@ -35,8 +35,14 @@ def _jwt_ok_payload(): def _jwt_handler(ok=True): if ok: - return SimpleNamespace(verify_jwt=lambda _token: _jwt_ok_payload(), create_jwt=lambda u, c: "jwt-new", expiry_minutes=15) - return SimpleNamespace(verify_jwt=lambda _token: None, create_jwt=lambda u, c: "jwt-new", expiry_minutes=15) + return SimpleNamespace( + verify_jwt=lambda _token: _jwt_ok_payload(), + create_jwt=lambda u, c: "jwt-new", + expiry_minutes=15, + ) + return SimpleNamespace( + verify_jwt=lambda _token: None, create_jwt=lambda u, c: "jwt-new", expiry_minutes=15 + ) def _token_mgr(): @@ -78,7 +84,9 @@ def test_tokens_index_get_post_and_error_paths(cp_ctx): # GET exception _req, _resp, cfg = cp_ctx(method="GET", headers={"Authorization": "Bearer ok"}) cfg["jwt_handler"] = _jwt_handler(ok=True) - cfg["token_manager"] = SimpleNamespace(list_tokens=lambda: (_ for _ in ()).throw(RuntimeError("db"))) + cfg["token_manager"] = SimpleNamespace( + list_tokens=lambda: (_ for _ in ()).throw(RuntimeError("db")) + ) out = endpoint.index() assert out["success"] is False assert cherrypy.response.status == 500 @@ -144,7 +152,11 @@ def test_tokens_default_delete_paths(cp_ctx): def test_login_paths(cp_ctx): - auth = AuthEndpoints(config={"repeater": {"security": {"admin_password": "pw"}}}, jwt_handler=_jwt_handler(ok=True), token_manager=_token_mgr()) + auth = AuthEndpoints( + config={"repeater": {"security": {"admin_password": "pw"}}}, + jwt_handler=_jwt_handler(ok=True), + token_manager=_token_mgr(), + ) cp_ctx(method="OPTIONS") assert auth.login() == b"" @@ -153,12 +165,18 @@ def test_login_paths(cp_ctx): out = json.loads(auth.login().decode()) assert out["success"] is False - cp_ctx(method="POST", body=json.dumps({"username": "admin", "password": "pw", "client_id": "abc"}).encode()) + cp_ctx( + method="POST", + body=json.dumps({"username": "admin", "password": "pw", "client_id": "abc"}).encode(), + ) out = json.loads(auth.login().decode()) assert out["success"] is True assert out["token"] == "jwt-new" - cp_ctx(method="POST", body=json.dumps({"username": "admin", "password": "bad", "client_id": "abc"}).encode()) + cp_ctx( + method="POST", + body=json.dumps({"username": "admin", "password": "bad", "client_id": "abc"}).encode(), + ) out = json.loads(auth.login().decode()) assert out["success"] is False @@ -201,7 +219,9 @@ def test_refresh_paths(cp_ctx): assert out["success"] is True # falls back to payload client_id # api token path - _req, _resp, cfg = cp_ctx(method="POST", headers={"X-API-Key": "k"}, body=json.dumps({"client_id": "z"}).encode()) + _req, _resp, cfg = cp_ctx( + method="POST", headers={"X-API-Key": "k"}, body=json.dumps({"client_id": "z"}).encode() + ) cfg["jwt_handler"] = _jwt_handler(ok=False) cfg["token_manager"] = _token_mgr() out = json.loads(auth.refresh().decode()) @@ -269,7 +289,9 @@ def test_change_password_paths(cp_ctx): _req, _resp, cfg = cp_ctx( method="POST", headers={"Authorization": "Bearer ok"}, - body=json.dumps({"current_password": "old-password", "new_password": "new-password"}).encode(), + body=json.dumps( + {"current_password": "old-password", "new_password": "new-password"} + ).encode(), ) cfg["jwt_handler"] = _jwt_handler(ok=True) cfg["token_manager"] = _token_mgr() @@ -286,7 +308,9 @@ def test_change_password_paths(cp_ctx): _req, _resp, cfg = cp_ctx( method="POST", headers={"Authorization": "Bearer ok"}, - body=json.dumps({"current_password": "old-password", "new_password": "new-password"}).encode(), + body=json.dumps( + {"current_password": "old-password", "new_password": "new-password"} + ).encode(), ) cfg["jwt_handler"] = _jwt_handler(ok=True) cfg["token_manager"] = _token_mgr() diff --git a/tests/test_companion_bridge_frame_utils.py b/tests/test_companion_bridge_frame_utils.py index cc79738..fcd1c09 100644 --- a/tests/test_companion_bridge_frame_utils.py +++ b/tests/test_companion_bridge_frame_utils.py @@ -134,9 +134,12 @@ async def test_frame_server_persistence_paths_and_stop(): get_channel=lambda idx: None, ) - with patch("repeater.companion.frame_server._BaseFrameServer.__init__", lambda self, **kwargs: None), patch( - "repeater.companion.frame_server._BaseFrameServer.stop", AsyncMock() - ) as base_stop: + with ( + patch( + "repeater.companion.frame_server._BaseFrameServer.__init__", lambda self, **kwargs: None + ), + patch("repeater.companion.frame_server._BaseFrameServer.stop", AsyncMock()) as base_stop, + ): srv = CompanionFrameServer(bridge=bridge, companion_hash="h", sqlite_handler=sqlite) srv.bridge = bridge srv.companion_hash = "h" @@ -171,7 +174,9 @@ async def test_frame_server_persistence_paths_and_stop(): sqlite.companion_upsert_contact.assert_called_once() bridge.get_contacts = lambda: [contact] - bridge.get_channel = lambda idx: (SimpleNamespace(name="c1", secret="s") if idx == 1 else None) + bridge.get_channel = lambda idx: ( + SimpleNamespace(name="c1", secret="s") if idx == 1 else None + ) await srv.stop() sqlite.companion_save_contacts.assert_called_once() @@ -185,7 +190,9 @@ async def test_frame_server_persistence_paths_and_stop(): async def test_frame_server_no_more_messages_response_when_empty(): bridge = SimpleNamespace(sync_next_message=lambda: None) - with patch("repeater.companion.frame_server._BaseFrameServer.__init__", lambda self, **kwargs: None): + with patch( + "repeater.companion.frame_server._BaseFrameServer.__init__", lambda self, **kwargs: None + ): srv = CompanionFrameServer(bridge=bridge, companion_hash="h", sqlite_handler=None) srv.bridge = bridge srv._write_frame = MagicMock() diff --git a/tests/test_companion_ws_proxy.py b/tests/test_companion_ws_proxy.py index f2f9d18..3bfd6db 100644 --- a/tests/test_companion_ws_proxy.py +++ b/tests/test_companion_ws_proxy.py @@ -110,7 +110,11 @@ def test_resolve_tcp_endpoint_paths(monkeypatch): # daemon with empty bridges daemon = SimpleNamespace( - identity_manager=SimpleNamespace(get_identities_by_type=lambda _t: [("c1", SimpleNamespace(get_public_key=lambda: b"\x01"), {})]), + identity_manager=SimpleNamespace( + get_identities_by_type=lambda _t: [ + ("c1", SimpleNamespace(get_public_key=lambda: b"\x01"), {}) + ] + ), companion_bridges={}, config={"identities": {"companions": []}}, ) @@ -119,9 +123,19 @@ def test_resolve_tcp_endpoint_paths(monkeypatch): # found in identity+bridge and in config, bind 0.0.0.0 => loopback daemon = SimpleNamespace( - identity_manager=SimpleNamespace(get_identities_by_type=lambda _t: [("c1", SimpleNamespace(get_public_key=lambda: b"\x01"), {})]), + identity_manager=SimpleNamespace( + get_identities_by_type=lambda _t: [ + ("c1", SimpleNamespace(get_public_key=lambda: b"\x01"), {}) + ] + ), companion_bridges={1: object()}, - config={"identities": {"companions": [{"name": "c1", "settings": {"tcp_port": 6000, "bind_address": "0.0.0.0"}}]}}, + config={ + "identities": { + "companions": [ + {"name": "c1", "settings": {"tcp_port": 6000, "bind_address": "0.0.0.0"}} + ] + } + }, ) proxy.set_daemon(daemon) assert ws._resolve_tcp_endpoint("c1") == ("127.0.0.1", 6000) @@ -166,7 +180,9 @@ def test_tcp_to_ws_and_teardown(): ws2._companion_name = "c2" tcp_ref = MagicMock() ws2._tcp = tcp_ref - ws2._teardown = proxy.CompanionFrameWebSocket._teardown.__get__(ws2, proxy.CompanionFrameWebSocket) + ws2._teardown = proxy.CompanionFrameWebSocket._teardown.__get__( + ws2, proxy.CompanionFrameWebSocket + ) ws2._teardown() tcp_ref.close.assert_called_once() ws2.close.assert_called_once() diff --git a/tests/test_config_manager.py b/tests/test_config_manager.py index 55edff9..ba66ff5 100644 --- a/tests/test_config_manager.py +++ b/tests/test_config_manager.py @@ -127,4 +127,4 @@ def test_live_update_daemon_applies_kiss_radio_config(): ) ] assert radio.radio_config == config["radio"] - assert daemon.repeater_handler.radio_config == config["radio"] \ No newline at end of file + assert daemon.repeater_handler.radio_config == config["radio"] diff --git a/tests/test_engine.py b/tests/test_engine.py index c5dedea..cf312ee 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -5,6 +5,7 @@ mark_seen, validate_packet, packet scoring, TX delay, cache management, airtime duty-cycle, TX mode (forward/monitor/no_tx), and config reloading. """ + import asyncio import base64 import time @@ -101,13 +102,14 @@ def handler(): patch("repeater.engine.RepeaterHandler._start_background_tasks"), ): from repeater.engine import RepeaterHandler + h = RepeaterHandler(config, dispatcher, LOCAL_HASH) return h -def _make_flood_packet(payload: bytes = b"\x01\x02\x03\x04", - path: bytes = b"", - payload_type: int = 0x01) -> Packet: +def _make_flood_packet( + payload: bytes = b"\x01\x02\x03\x04", path: bytes = b"", payload_type: int = 0x01 +) -> Packet: """Build a FLOOD-routed packet.""" pkt = Packet() # header: route=FLOOD(0x01), payload_type shifted, version=0 @@ -119,9 +121,9 @@ def _make_flood_packet(payload: bytes = b"\x01\x02\x03\x04", return pkt -def _make_direct_packet(payload: bytes = b"\x01\x02\x03\x04", - path: bytes = None, - payload_type: int = 0x01) -> Packet: +def _make_direct_packet( + payload: bytes = b"\x01\x02\x03\x04", path: bytes = None, payload_type: int = 0x01 +) -> Packet: """Build a DIRECT-routed packet with path[0] == LOCAL_HASH by default.""" if path is None: path = bytes([LOCAL_HASH, 0xCC, 0xDD]) @@ -134,10 +136,12 @@ def _make_direct_packet(payload: bytes = b"\x01\x02\x03\x04", return pkt -def _make_transport_flood_packet(payload: bytes = b"\x01\x02\x03\x04", - path: bytes = b"", - payload_type: int = 0x01, - transport_codes=(0x1234, 0x5678)) -> Packet: +def _make_transport_flood_packet( + payload: bytes = b"\x01\x02\x03\x04", + path: bytes = b"", + payload_type: int = 0x01, + transport_codes=(0x1234, 0x5678), +) -> Packet: """Build a TRANSPORT_FLOOD-routed packet.""" pkt = Packet() pkt.header = ROUTE_TYPE_TRANSPORT_FLOOD | (payload_type << PH_TYPE_SHIFT) @@ -149,10 +153,12 @@ def _make_transport_flood_packet(payload: bytes = b"\x01\x02\x03\x04", return pkt -def _make_transport_direct_packet(payload: bytes = b"\x01\x02\x03\x04", - path: bytes = None, - payload_type: int = 0x01, - transport_codes=(0x1234, 0x5678)) -> Packet: +def _make_transport_direct_packet( + payload: bytes = b"\x01\x02\x03\x04", + path: bytes = None, + payload_type: int = 0x01, + transport_codes=(0x1234, 0x5678), +) -> Packet: """Build a TRANSPORT_DIRECT-routed packet with path[0] == LOCAL_HASH.""" if path is None: path = bytes([LOCAL_HASH, 0xCC]) @@ -170,6 +176,7 @@ def _make_transport_direct_packet(payload: bytes = b"\x01\x02\x03\x04", # 1. flood_forward # =================================================================== + class TestFloodForward: """flood_forward: validation, duplicate suppression, path append.""" @@ -237,7 +244,7 @@ def test_unscoped_flood_deny_plain_flood(self, handler): def test_hash_computed_before_path_append(self, handler): """mark_seen must use the pre-append hash so duplicate detection works when another node sends the same packet with or without our hash.""" - pkt1 = _make_flood_packet(payload=b"\xAA\xBB") + pkt1 = _make_flood_packet(payload=b"\xaa\xbb") hash_before = pkt1.calculate_packet_hash().hex().upper() handler.flood_forward(pkt1) @@ -271,6 +278,7 @@ def test_path_none_is_handled(self, handler): # 2. direct_forward # =================================================================== + class TestDirectForward: """direct_forward: next-hop check, path consumption, duplicate suppression.""" @@ -343,6 +351,7 @@ def test_path_len_updated_after_consume(self, handler): # 3. process_packet — route dispatch # =================================================================== + class TestProcessPacket: """process_packet routes to flood_forward or direct_forward.""" @@ -364,7 +373,7 @@ def test_direct_route_dispatched(self, handler): def test_transport_flood_dispatched(self, handler): pkt = _make_transport_flood_packet() - with patch.object(handler, '_check_transport_codes', return_value=(True, "")): + with patch.object(handler, "_check_transport_codes", return_value=(True, "")): result = handler.process_packet(pkt, snr=5.0) assert result is not None fwd_pkt, _ = result @@ -410,6 +419,7 @@ def test_direct_forward_failure_returns_none(self, handler): # 4. is_duplicate / mark_seen / cache management # =================================================================== + class TestDuplicateDetection: """Duplicate tracking, TTL clean-up, and cache eviction.""" @@ -464,6 +474,7 @@ def test_mark_seen_stores_hex_upper_key(self, handler): # 5. validate_packet # =================================================================== + class TestValidatePacket: """validate_packet: empty payload, oversized path.""" @@ -500,51 +511,62 @@ def test_none_packet(self, handler): # 6. calculate_packet_score — static method # =================================================================== + class TestPacketScore: """Score: SNR thresholds, collision penalty, clamping.""" def test_below_threshold_returns_zero(self): from repeater.engine import RepeaterHandler + # SF8 threshold is -10.0 score = RepeaterHandler.calculate_packet_score(snr=-15.0, packet_len=50, spreading_factor=8) assert score == 0.0 def test_at_threshold_returns_zero(self): from repeater.engine import RepeaterHandler + score = RepeaterHandler.calculate_packet_score(snr=-10.0, packet_len=50, spreading_factor=8) assert score == 0.0 def test_above_threshold_positive(self): from repeater.engine import RepeaterHandler + score = RepeaterHandler.calculate_packet_score(snr=0.0, packet_len=50, spreading_factor=8) assert score > 0.0 def test_high_snr_high_score(self): from repeater.engine import RepeaterHandler + score = RepeaterHandler.calculate_packet_score(snr=10.0, packet_len=10, spreading_factor=8) assert score > 0.5 def test_long_packet_collision_penalty(self): from repeater.engine import RepeaterHandler + short = RepeaterHandler.calculate_packet_score(snr=5.0, packet_len=10, spreading_factor=8) long_ = RepeaterHandler.calculate_packet_score(snr=5.0, packet_len=250, spreading_factor=8) assert short > long_ def test_score_clamped_to_0_1(self): from repeater.engine import RepeaterHandler + score = RepeaterHandler.calculate_packet_score(snr=50.0, packet_len=1, spreading_factor=8) assert 0.0 <= score <= 1.0 def test_sf_below_7_returns_zero(self): from repeater.engine import RepeaterHandler + score = RepeaterHandler.calculate_packet_score(snr=10.0, packet_len=50, spreading_factor=6) assert score == 0.0 def test_each_sf_has_different_threshold(self): from repeater.engine import RepeaterHandler + scores = {} for sf in (7, 8, 9, 10, 11, 12): - scores[sf] = RepeaterHandler.calculate_packet_score(snr=-5.0, packet_len=50, spreading_factor=sf) + scores[sf] = RepeaterHandler.calculate_packet_score( + snr=-5.0, packet_len=50, spreading_factor=sf + ) # Higher SF → lower threshold → better reception at same SNR # At SNR=-5, SF7 (threshold -7.5) should be worse than SF12 (threshold -20) assert scores[12] > scores[7] @@ -554,6 +576,7 @@ def test_each_sf_has_different_threshold(self): # 7. _calculate_tx_delay # =================================================================== + class TestTxDelay: """TX delay: flood random, direct fixed, score adjustment, cap.""" @@ -616,11 +639,12 @@ def test_transport_direct_uses_direct_delay(self, handler): # 8. Hash stability through forwarding operations # =================================================================== + class TestHashStabilityThroughForwarding: """Verify hash is computed on original packet (before path mutation).""" def test_flood_hash_unchanged_after_forward(self, handler): - pkt = _make_flood_packet(payload=b"\xDE\xAD") + pkt = _make_flood_packet(payload=b"\xde\xad") hash_before = pkt.calculate_packet_hash().hex().upper() handler.flood_forward(pkt) @@ -628,8 +652,7 @@ def test_flood_hash_unchanged_after_forward(self, handler): assert hash_before in handler.seen_packets def test_direct_hash_unchanged_after_forward(self, handler): - pkt = _make_direct_packet(payload=b"\xBE\xEF", - path=bytes([LOCAL_HASH, 0xCC])) + pkt = _make_direct_packet(payload=b"\xbe\xef", path=bytes([LOCAL_HASH, 0xCC])) hash_before = pkt.calculate_packet_hash().hex().upper() handler.direct_forward(pkt) @@ -638,17 +661,15 @@ def test_direct_hash_unchanged_after_forward(self, handler): def test_flood_second_identical_detected_as_duplicate(self, handler): """Two identical packets with the same payload (but path not yet modified) should be correctly detected as duplicates.""" - p1 = _make_flood_packet(payload=b"\xCA\xFE") - p2 = _make_flood_packet(payload=b"\xCA\xFE") + p1 = _make_flood_packet(payload=b"\xca\xfe") + p2 = _make_flood_packet(payload=b"\xca\xfe") handler.flood_forward(p1) result = handler.flood_forward(p2) assert result is None def test_direct_second_identical_detected_as_duplicate(self, handler): - p1 = _make_direct_packet(payload=b"\xCA\xFE", - path=bytes([LOCAL_HASH, 0x11])) - p2 = _make_direct_packet(payload=b"\xCA\xFE", - path=bytes([LOCAL_HASH, 0x11])) + p1 = _make_direct_packet(payload=b"\xca\xfe", path=bytes([LOCAL_HASH, 0x11])) + p2 = _make_direct_packet(payload=b"\xca\xfe", path=bytes([LOCAL_HASH, 0x11])) handler.direct_forward(p1) result = handler.direct_forward(p2) assert result is None @@ -658,6 +679,7 @@ def test_direct_second_identical_detected_as_duplicate(self, handler): # 9. unscoped flood policy # =================================================================== + class TestUnscopedFloodPolicy: """unscoped_flood_allow=False blocks plain flood, transport checked.""" @@ -680,7 +702,7 @@ def test_transport_flood_unaffected_by_unscoped_policy(self, handler): # unscoped traffic is denied — the two settings are fully independent. handler.config["mesh"]["unscoped_flood_allow"] = False pkt = _make_transport_flood_packet() - with patch.object(handler, '_check_transport_codes', return_value=(True, "")): + with patch.object(handler, "_check_transport_codes", return_value=(True, "")): result = handler.flood_forward(pkt) assert result is not None # transport flood passes; unscoped=False did not block it @@ -735,6 +757,7 @@ def test_loop_detect_strict_drops_at_one(self, handler): # 10. Airtime / duty-cycle integration # =================================================================== + class TestAirtimeIntegration: """Airtime calculation and duty-cycle enforcement.""" @@ -771,6 +794,7 @@ def test_airtime_increases_with_packet_size(self, handler): # 11. Config reload # =================================================================== + class TestConfigReload: """reload_runtime_config updates in-memory state.""" @@ -799,6 +823,7 @@ def test_cache_ttl_reloaded(self, handler): # 12. _get_drop_reason # =================================================================== + class TestGetDropReason: """_get_drop_reason: determine why a packet was not forwarded.""" @@ -842,12 +867,13 @@ def test_direct_no_path_reason(self, handler): # 13. Transport route forwarding # =================================================================== + class TestTransportForwarding: """TRANSPORT_FLOOD and TRANSPORT_DIRECT: packet routing through process_packet.""" def test_transport_flood_appends_path(self, handler): pkt = _make_transport_flood_packet(path=b"\x11") - with patch.object(handler, '_check_transport_codes', return_value=(True, "")): + with patch.object(handler, "_check_transport_codes", return_value=(True, "")): result = handler.process_packet(pkt, snr=5.0) assert result is not None fwd_pkt, _ = result @@ -863,7 +889,7 @@ def test_transport_direct_consumes_path(self, handler): def test_transport_codes_preserved_after_flood(self, handler): pkt = _make_transport_flood_packet(transport_codes=(0xAAAA, 0xBBBB)) - with patch.object(handler, '_check_transport_codes', return_value=(True, "")): + with patch.object(handler, "_check_transport_codes", return_value=(True, "")): result = handler.process_packet(pkt, snr=5.0) assert result is not None fwd_pkt, _ = result @@ -881,6 +907,7 @@ def test_transport_codes_preserved_after_direct(self, handler): # 14. Statistics tracking # =================================================================== + class TestStatistics: """RX/TX/dropped counters and recent_packets list.""" @@ -908,6 +935,7 @@ def test_get_stats_local_hash_format(self, handler): # 15. Edge cases and regression tests # =================================================================== + class TestEdgeCases: """Miscellaneous edge cases and regressions.""" @@ -924,12 +952,12 @@ def test_path_as_list_converted_to_bytearray(self, handler): def test_flood_forward_idempotent_on_second_call(self, handler): """Calling flood_forward again with the SAME packet object should detect as duplicate (the first call already mark_seen'd it).""" - pkt = _make_flood_packet(payload=b"\xFF" * 10) + pkt = _make_flood_packet(payload=b"\xff" * 10) r1 = handler.flood_forward(pkt) assert r1 is not None # Now pkt has local_hash appended, but hash was computed pre-append. # A new packet with same original payload should be duplicate. - pkt2 = _make_flood_packet(payload=b"\xFF" * 10) + pkt2 = _make_flood_packet(payload=b"\xff" * 10) r2 = handler.flood_forward(pkt2) assert r2 is None @@ -988,6 +1016,7 @@ def test_monitor_mode_skips_processing(self, handler): # 15b. TX mode: forward, monitor, no_tx # =================================================================== + @pytest.mark.asyncio class TestTxMode: """forward = repeat on; monitor = no repeat, local TX allowed; no_tx = all TX off.""" @@ -1048,6 +1077,7 @@ async def test_forward_mode_allows_local_tx(self, handler): # 16. Airtime calculation correctness # =================================================================== + class TestAirtimeCalculation: """Semtech LoRa airtime formula validation.""" @@ -1055,8 +1085,9 @@ def test_known_airtime_sf7_125khz(self, handler): """SF7, 125kHz, CR4/5, 10-byte payload — well-known reference value.""" mgr = handler.airtime_mgr # Override to known settings - at = mgr.calculate_airtime(10, spreading_factor=7, bandwidth_hz=125000, - coding_rate=5, preamble_len=8) + at = mgr.calculate_airtime( + 10, spreading_factor=7, bandwidth_hz=125000, coding_rate=5, preamble_len=8 + ) # Semtech calculator: ~36ms for these params assert 30.0 < at < 50.0 @@ -1080,190 +1111,240 @@ def test_zero_payload_still_has_preamble(self, handler): # ---- 20 GOOD packets: all should be forwarded by process_packet ---- GOOD_PACKETS = [ # (id, description, builder) - ("good_flood_minimal", - "Flood, 1-byte payload, empty path", - lambda: _make_flood_packet(payload=b"\x01")), - - ("good_flood_typical", - "Flood, 10-byte payload, 2-hop path", - lambda: _make_flood_packet(payload=bytes(range(10)), path=b"\x11\x22")), - - ("good_flood_max_payload_type", - "Flood, payload_type=15 (max 4-bit)", - lambda: _make_flood_packet(payload=b"\xAA\xBB", payload_type=15)), - - ("good_flood_payload_type_0", - "Flood, payload_type=0 (plain text)", - lambda: _make_flood_packet(payload=b"\x01\x02\x03", payload_type=0)), - - ("good_flood_long_payload", - "Flood, 200-byte payload", - lambda: _make_flood_packet(payload=bytes(range(200)))), - - ("good_flood_single_byte_path", - "Flood, path has 1 prior hop", - lambda: _make_flood_packet(payload=b"\xDE\xAD", path=b"\x42")), - - ("good_flood_binary_payload", - "Flood, all-zero payload", - lambda: _make_flood_packet(payload=b"\x00" * 16)), - - ("good_flood_high_entropy", - "Flood, high-entropy random-looking payload", - lambda: _make_flood_packet(payload=bytes(i ^ 0xA5 for i in range(64)))), - - ("good_flood_advert_type", - "Flood, payload_type=4 (ADVERT)", - lambda: _make_flood_packet(payload=b"\xAB\x01\x02\x03", payload_type=4)), - - ("good_direct_minimal", - "Direct, 1-byte payload, single hop to us (forward with empty path)", - lambda: _make_direct_packet(payload=b"\x01", path=bytes([LOCAL_HASH]))), - - ("good_direct_multihop", - "Direct, 3-hop remaining path (us + 2 more)", - lambda: _make_direct_packet(payload=b"\xCA\xFE", path=bytes([LOCAL_HASH, 0x11, 0x22]))), - - ("good_direct_long_payload", - "Direct, 150-byte payload", - lambda: _make_direct_packet(payload=bytes(range(150)), path=bytes([LOCAL_HASH, 0xBB]))), - - ("good_direct_type_2", - "Direct, payload_type=2 (ACK)", - lambda: _make_direct_packet(payload=b"\x01\x02", path=bytes([LOCAL_HASH]), - payload_type=2)), - - ("good_direct_long_remaining_path", - "Direct, 10 hops remaining after us", - lambda: _make_direct_packet(payload=b"\xFF\xEE", - path=bytes([LOCAL_HASH] + list(range(10))))), - - ("good_transport_direct_basic", - "Transport direct, basic hop to us", - lambda: _make_transport_direct_packet(payload=b"\x01\x02")), - ("good_transport_direct_long_path", - "Transport direct, 5 remaining hops", - lambda: _make_transport_direct_packet( - payload=b"\xDE\xAD\xBE\xEF", - path=bytes([LOCAL_HASH, 0x11, 0x22, 0x33, 0x44]))), + ( + "good_flood_minimal", + "Flood, 1-byte payload, empty path", + lambda: _make_flood_packet(payload=b"\x01"), + ), + ( + "good_flood_typical", + "Flood, 10-byte payload, 2-hop path", + lambda: _make_flood_packet(payload=bytes(range(10)), path=b"\x11\x22"), + ), + ( + "good_flood_max_payload_type", + "Flood, payload_type=15 (max 4-bit)", + lambda: _make_flood_packet(payload=b"\xaa\xbb", payload_type=15), + ), + ( + "good_flood_payload_type_0", + "Flood, payload_type=0 (plain text)", + lambda: _make_flood_packet(payload=b"\x01\x02\x03", payload_type=0), + ), + ( + "good_flood_long_payload", + "Flood, 200-byte payload", + lambda: _make_flood_packet(payload=bytes(range(200))), + ), + ( + "good_flood_single_byte_path", + "Flood, path has 1 prior hop", + lambda: _make_flood_packet(payload=b"\xde\xad", path=b"\x42"), + ), + ( + "good_flood_binary_payload", + "Flood, all-zero payload", + lambda: _make_flood_packet(payload=b"\x00" * 16), + ), + ( + "good_flood_high_entropy", + "Flood, high-entropy random-looking payload", + lambda: _make_flood_packet(payload=bytes(i ^ 0xA5 for i in range(64))), + ), + ( + "good_flood_advert_type", + "Flood, payload_type=4 (ADVERT)", + lambda: _make_flood_packet(payload=b"\xab\x01\x02\x03", payload_type=4), + ), + ( + "good_direct_minimal", + "Direct, 1-byte payload, single hop to us (forward with empty path)", + lambda: _make_direct_packet(payload=b"\x01", path=bytes([LOCAL_HASH])), + ), + ( + "good_direct_multihop", + "Direct, 3-hop remaining path (us + 2 more)", + lambda: _make_direct_packet(payload=b"\xca\xfe", path=bytes([LOCAL_HASH, 0x11, 0x22])), + ), + ( + "good_direct_long_payload", + "Direct, 150-byte payload", + lambda: _make_direct_packet(payload=bytes(range(150)), path=bytes([LOCAL_HASH, 0xBB])), + ), + ( + "good_direct_type_2", + "Direct, payload_type=2 (ACK)", + lambda: _make_direct_packet(payload=b"\x01\x02", path=bytes([LOCAL_HASH]), payload_type=2), + ), + ( + "good_direct_long_remaining_path", + "Direct, 10 hops remaining after us", + lambda: _make_direct_packet( + payload=b"\xff\xee", path=bytes([LOCAL_HASH] + list(range(10))) + ), + ), + ( + "good_transport_direct_basic", + "Transport direct, basic hop to us", + lambda: _make_transport_direct_packet(payload=b"\x01\x02"), + ), + ( + "good_transport_direct_long_path", + "Transport direct, 5 remaining hops", + lambda: _make_transport_direct_packet( + payload=b"\xde\xad\xbe\xef", path=bytes([LOCAL_HASH, 0x11, 0x22, 0x33, 0x44]) + ), + ), ] # ---- 20 BAD packets: all should be dropped / return None ---- BAD_PACKETS = [ # (id, description, builder) - ("bad_empty_payload", - "Empty bytearray payload", - lambda: _make_flood_packet(payload=b""), - "Empty payload"), - - ("bad_none_payload", - "payload = None", - lambda: (lambda p: (setattr(p, "payload", None), p)[-1])(_make_flood_packet()), - "Empty payload"), - - ("bad_path_at_max", - "Path exactly MAX_PATH_SIZE — no room to append", - lambda: _make_flood_packet(payload=b"\x01", path=bytes(range(MAX_PATH_SIZE))), - "Path length"), - - ("bad_flood_path_near_max", - "Flood, path = MAX_PATH_SIZE - 1 (63 hops; path_len encodes 0-63, cannot append)", - lambda: _make_flood_packet(payload=b"\xFF", path=bytes(range(MAX_PATH_SIZE - 1))), - "cannot append"), - - ("bad_path_over_max", - "Path exceeds MAX_PATH_SIZE", - lambda: _make_flood_packet(payload=b"\x01", path=bytes(range(MAX_PATH_SIZE + 5))), - "Path length"), - - ("bad_do_not_retransmit", - "Marked do-not-retransmit", - lambda: (lambda p: (p.mark_do_not_retransmit(), p)[-1])(_make_flood_packet()), - "do not retransmit"), - - ("bad_direct_wrong_hop", - "Direct packet, path[0] != LOCAL_HASH", - lambda: _make_direct_packet(path=bytes([0xFF, 0xCC])), - "not for us"), - - ("bad_direct_empty_path", - "Direct packet with empty path", - lambda: _make_direct_packet(path=b""), - "no path"), - - ("bad_direct_none_path", - "Direct packet with path = None", - lambda: (lambda p: (setattr(p, "path", None), setattr(p, "path_len", 0), p)[-1])( - _make_direct_packet()), - "no path"), - - ("bad_flood_policy_off", - "Plain flood when unscoped_flood_allow=False (needs config override)", - lambda: _make_flood_packet(payload=b"\x01\x02"), - "unscoped flood"), - - ("bad_transport_flood_no_keys", - "Transport flood with no configured transport keys — always denied", - lambda: _make_transport_flood_packet(payload=b"\x01\x02"), - "transport"), - - ("bad_direct_empty_payload", - "Direct with empty payload (now caught by validate_packet)", - lambda: (lambda p: (setattr(p, "payload", bytearray()), setattr(p, "payload_len", 0), p)[-1])( - _make_direct_packet(path=bytes([LOCAL_HASH]))), - "Empty payload"), - - ("bad_flood_zero_len_payload", - "Flood with payload_len forced to 0", - lambda: (lambda p: (setattr(p, "payload_len", 0), setattr(p, "payload", bytearray()), p)[-1])( - _make_flood_packet(payload=b"\x01")), - "Empty payload"), - - ("bad_direct_only_wrong_hops", - "Direct path of all 0xFF bytes (none match LOCAL_HASH)", - lambda: _make_direct_packet(path=bytes([0xFF, 0xFE, 0xFD])), - "not for us"), - - ("bad_transport_direct_wrong_hop", - "Transport direct with wrong first hop", - lambda: _make_transport_direct_packet(path=bytes([0x01, 0x02])), - "not for us"), - - ("bad_transport_direct_empty_path", - "Transport direct with empty path", - lambda: _make_transport_direct_packet(path=b""), - "no path"), - - ("bad_transport_direct_none_path", - "Transport direct with path = None", - lambda: (lambda p: (setattr(p, "path", None), setattr(p, "path_len", 0), p)[-1])( - _make_transport_direct_packet()), - "no path"), - - ("bad_flood_payload_255_zeros", - "Flood with payload = bytearray(0) (empty)", - lambda: (lambda p: (setattr(p, "payload", bytearray()), setattr(p, "payload_len", 0), p)[-1])( - _make_flood_packet()), - "Empty payload"), - - ("bad_direct_none_payload", - "Direct with None payload (now caught by validate_packet)", - lambda: (lambda p: (setattr(p, "payload", None), p)[-1])( - _make_direct_packet(path=bytes([LOCAL_HASH]))), - "Empty payload"), - - ("bad_flood_do_not_retransmit_custom", - "Flood, do-not-retransmit with custom drop reason", - lambda: (lambda p: (p.mark_do_not_retransmit(), setattr(p, "drop_reason", "Advert consumed"), p)[-1])( - _make_flood_packet(payload=b"\xAB")), - "Advert consumed"), - - ("bad_direct_do_not_retransmit", - "Direct, marked do-not-retransmit (now caught by direct_forward)", - lambda: (lambda p: (p.mark_do_not_retransmit(), p)[-1])( - _make_direct_packet(payload=b"\x99", path=bytes([LOCAL_HASH, 0x11]))), - "do not retransmit"), + ( + "bad_empty_payload", + "Empty bytearray payload", + lambda: _make_flood_packet(payload=b""), + "Empty payload", + ), + ( + "bad_none_payload", + "payload = None", + lambda: (lambda p: (setattr(p, "payload", None), p)[-1])(_make_flood_packet()), + "Empty payload", + ), + ( + "bad_path_at_max", + "Path exactly MAX_PATH_SIZE — no room to append", + lambda: _make_flood_packet(payload=b"\x01", path=bytes(range(MAX_PATH_SIZE))), + "Path length", + ), + ( + "bad_flood_path_near_max", + "Flood, path = MAX_PATH_SIZE - 1 (63 hops; path_len encodes 0-63, cannot append)", + lambda: _make_flood_packet(payload=b"\xff", path=bytes(range(MAX_PATH_SIZE - 1))), + "cannot append", + ), + ( + "bad_path_over_max", + "Path exceeds MAX_PATH_SIZE", + lambda: _make_flood_packet(payload=b"\x01", path=bytes(range(MAX_PATH_SIZE + 5))), + "Path length", + ), + ( + "bad_do_not_retransmit", + "Marked do-not-retransmit", + lambda: (lambda p: (p.mark_do_not_retransmit(), p)[-1])(_make_flood_packet()), + "do not retransmit", + ), + ( + "bad_direct_wrong_hop", + "Direct packet, path[0] != LOCAL_HASH", + lambda: _make_direct_packet(path=bytes([0xFF, 0xCC])), + "not for us", + ), + ( + "bad_direct_empty_path", + "Direct packet with empty path", + lambda: _make_direct_packet(path=b""), + "no path", + ), + ( + "bad_direct_none_path", + "Direct packet with path = None", + lambda: (lambda p: (setattr(p, "path", None), setattr(p, "path_len", 0), p)[-1])( + _make_direct_packet() + ), + "no path", + ), + ( + "bad_flood_policy_off", + "Plain flood when unscoped_flood_allow=False (needs config override)", + lambda: _make_flood_packet(payload=b"\x01\x02"), + "unscoped flood", + ), + ( + "bad_transport_flood_no_keys", + "Transport flood with no configured transport keys — always denied", + lambda: _make_transport_flood_packet(payload=b"\x01\x02"), + "transport", + ), + ( + "bad_direct_empty_payload", + "Direct with empty payload (now caught by validate_packet)", + lambda: ( + lambda p: (setattr(p, "payload", bytearray()), setattr(p, "payload_len", 0), p)[-1] + )(_make_direct_packet(path=bytes([LOCAL_HASH]))), + "Empty payload", + ), + ( + "bad_flood_zero_len_payload", + "Flood with payload_len forced to 0", + lambda: ( + lambda p: (setattr(p, "payload_len", 0), setattr(p, "payload", bytearray()), p)[-1] + )(_make_flood_packet(payload=b"\x01")), + "Empty payload", + ), + ( + "bad_direct_only_wrong_hops", + "Direct path of all 0xFF bytes (none match LOCAL_HASH)", + lambda: _make_direct_packet(path=bytes([0xFF, 0xFE, 0xFD])), + "not for us", + ), + ( + "bad_transport_direct_wrong_hop", + "Transport direct with wrong first hop", + lambda: _make_transport_direct_packet(path=bytes([0x01, 0x02])), + "not for us", + ), + ( + "bad_transport_direct_empty_path", + "Transport direct with empty path", + lambda: _make_transport_direct_packet(path=b""), + "no path", + ), + ( + "bad_transport_direct_none_path", + "Transport direct with path = None", + lambda: (lambda p: (setattr(p, "path", None), setattr(p, "path_len", 0), p)[-1])( + _make_transport_direct_packet() + ), + "no path", + ), + ( + "bad_flood_payload_255_zeros", + "Flood with payload = bytearray(0) (empty)", + lambda: ( + lambda p: (setattr(p, "payload", bytearray()), setattr(p, "payload_len", 0), p)[-1] + )(_make_flood_packet()), + "Empty payload", + ), + ( + "bad_direct_none_payload", + "Direct with None payload (now caught by validate_packet)", + lambda: (lambda p: (setattr(p, "payload", None), p)[-1])( + _make_direct_packet(path=bytes([LOCAL_HASH])) + ), + "Empty payload", + ), + ( + "bad_flood_do_not_retransmit_custom", + "Flood, do-not-retransmit with custom drop reason", + lambda: ( + lambda p: (p.mark_do_not_retransmit(), setattr(p, "drop_reason", "Advert consumed"), p)[ + -1 + ] + )(_make_flood_packet(payload=b"\xab")), + "Advert consumed", + ), + ( + "bad_direct_do_not_retransmit", + "Direct, marked do-not-retransmit (now caught by direct_forward)", + lambda: (lambda p: (p.mark_do_not_retransmit(), p)[-1])( + _make_direct_packet(payload=b"\x99", path=bytes([LOCAL_HASH, 0x11])) + ), + "do not retransmit", + ), ] @@ -1276,7 +1357,9 @@ class TestGoodPacketArray: """All 20 good packets should be forwarded successfully.""" @pytest.mark.parametrize( - "name, desc, builder", GOOD_PACKETS, ids=_good_ids, + "name, desc, builder", + GOOD_PACKETS, + ids=_good_ids, ) def test_process_packet_forwards(self, handler, name, desc, builder): pkt = builder() @@ -1286,14 +1369,18 @@ def test_process_packet_forwards(self, handler, name, desc, builder): assert delay >= 0.0 @pytest.mark.parametrize( - "name, desc, builder", GOOD_PACKETS, ids=_good_ids, + "name, desc, builder", + GOOD_PACKETS, + ids=_good_ids, ) def test_good_packet_not_duplicate_on_first_see(self, handler, name, desc, builder): pkt = builder() assert handler.is_duplicate(pkt) is False, f"[{name}] falsely flagged as duplicate" @pytest.mark.parametrize( - "name, desc, builder", GOOD_PACKETS, ids=_good_ids, + "name, desc, builder", + GOOD_PACKETS, + ids=_good_ids, ) def test_good_packet_path_modified(self, handler, name, desc, builder): pkt = builder() @@ -1317,7 +1404,8 @@ class TestBadPacketArray: @pytest.mark.parametrize( "name, desc, builder, expected_reason", - BAD_PACKETS, ids=_bad_ids, + BAD_PACKETS, + ids=_bad_ids, ) def test_process_packet_drops(self, handler, name, desc, builder, expected_reason): # Two entries need unscoped_flood_allow=False @@ -1330,7 +1418,8 @@ def test_process_packet_drops(self, handler, name, desc, builder, expected_reaso @pytest.mark.parametrize( "name, desc, builder, expected_reason", - BAD_PACKETS, ids=_bad_ids, + BAD_PACKETS, + ids=_bad_ids, ) def test_drop_reason_set(self, handler, name, desc, builder, expected_reason): if "policy_off" in name: @@ -1345,7 +1434,8 @@ def test_drop_reason_set(self, handler, name, desc, builder, expected_reason): @pytest.mark.parametrize( "name, desc, builder, expected_reason", - BAD_PACKETS, ids=_bad_ids, + BAD_PACKETS, + ids=_bad_ids, ) def test_bad_packet_not_marked_seen(self, handler, name, desc, builder, expected_reason): """Dropped packets must NOT pollute the seen cache.""" @@ -1405,9 +1495,7 @@ def _prepare_fast_tx(handler): async def test_injected_flood_forwards_and_appends_path(self, handler): self._prepare_fast_tx(handler) - pkt = _inject_from_wire( - _make_flood_packet(payload=b"\x10\x20\x30", path=b"\x11") - ) + pkt = _inject_from_wire(_make_flood_packet(payload=b"\x10\x20\x30", path=b"\x11")) with ( patch.object(handler, "_calculate_tx_delay", return_value=0.0), @@ -1426,7 +1514,7 @@ async def test_injected_direct_forwards_and_consumes_hop(self, handler): self._prepare_fast_tx(handler) pkt = _inject_from_wire( - _make_direct_packet(payload=b"\xAA\xBB", path=bytes([LOCAL_HASH, 0x44, 0x55])) + _make_direct_packet(payload=b"\xaa\xbb", path=bytes([LOCAL_HASH, 0x44, 0x55])) ) with ( @@ -1440,9 +1528,7 @@ async def test_injected_direct_forwards_and_consumes_hop(self, handler): assert bytes(sent_pkt.path) == b"\x44\x55" async def test_direct_for_other_node_is_dropped(self, handler): - pkt = _inject_from_wire( - _make_direct_packet(payload=b"\xAA\xBB", path=b"\xFE\x44") - ) + pkt = _inject_from_wire(_make_direct_packet(payload=b"\xaa\xbb", path=b"\xfe\x44")) with patch("repeater.engine.asyncio.sleep", new_callable=AsyncMock): await handler(pkt, {"snr": 2.0, "rssi": -90}, local_transmission=False) @@ -1474,9 +1560,7 @@ async def test_duplicate_wire_packet_not_retransmitted(self, handler): assert original["duplicates"][0]["drop_reason"] == "Duplicate" async def test_transport_flood_injection_honors_transport_key_decision(self, handler): - pkt = _inject_from_wire( - _make_transport_flood_packet(payload=b"\x01\x02\x03\x04", path=b"") - ) + pkt = _inject_from_wire(_make_transport_flood_packet(payload=b"\x01\x02\x03\x04", path=b"")) with ( patch.object(handler, "_check_transport_codes", return_value=(False, "denied")), @@ -1490,16 +1574,14 @@ async def test_transport_flood_injection_honors_transport_key_decision(self, han async def test_local_tx_then_rf_echo_is_duplicate(self, handler): self._prepare_fast_tx(handler) - local_pkt = _make_flood_packet(payload=b"\x0A\x0B\x0C", path=b"") + local_pkt = _make_flood_packet(payload=b"\x0a\x0b\x0c", path=b"") with ( patch.object(handler, "_calculate_tx_delay", return_value=0.0), patch("repeater.engine.asyncio.sleep", new_callable=AsyncMock), ): await handler(local_pkt, {"snr": 0.0, "rssi": -50}, local_transmission=True) - rf_echo = _inject_from_wire( - _make_flood_packet(payload=b"\x0A\x0B\x0C", path=b"") - ) + rf_echo = _inject_from_wire(_make_flood_packet(payload=b"\x0a\x0b\x0c", path=b"")) await handler(rf_echo, {"snr": 0.0, "rssi": -70}, local_transmission=False) assert handler.dispatcher.send_packet.call_count == 1 @@ -1555,7 +1637,9 @@ async def test_all_payload_types_direct_injection_forwards(self, handler, payloa assert bytes(sent_pkt.path) == b"\x44\x55" @pytest.mark.parametrize("payload_type", range(16)) - async def test_all_payload_types_transport_flood_injection_forwards(self, handler, payload_type): + async def test_all_payload_types_transport_flood_injection_forwards( + self, handler, payload_type + ): self._prepare_fast_tx(handler) pkt = _inject_from_wire( _make_transport_flood_packet( @@ -1579,7 +1663,9 @@ async def test_all_payload_types_transport_flood_injection_forwards(self, handle assert sent_pkt.transport_codes == [0x1111, 0x2222] @pytest.mark.parametrize("payload_type", range(16)) - async def test_all_payload_types_transport_direct_injection_forwards(self, handler, payload_type): + async def test_all_payload_types_transport_direct_injection_forwards( + self, handler, payload_type + ): self._prepare_fast_tx(handler) pkt = _inject_from_wire( _make_transport_direct_packet( @@ -1845,7 +1931,11 @@ async def test_background_timer_loop_runs_tasks_and_handles_cancel(self, handler with ( patch("repeater.engine.time.time", return_value=100000.0), - patch("repeater.engine.asyncio.sleep", new_callable=AsyncMock, side_effect=asyncio.CancelledError), + patch( + "repeater.engine.asyncio.sleep", + new_callable=AsyncMock, + side_effect=asyncio.CancelledError, + ), ): with pytest.raises(asyncio.CancelledError): await handler._background_timer_loop() @@ -1870,7 +1960,11 @@ async def test_background_timer_loop_continues_when_db_cleanup_fails(self, handl with ( patch("repeater.engine.time.time", return_value=100000.0), - patch("repeater.engine.asyncio.sleep", new_callable=AsyncMock, side_effect=asyncio.CancelledError), + patch( + "repeater.engine.asyncio.sleep", + new_callable=AsyncMock, + side_effect=asyncio.CancelledError, + ), ): with pytest.raises(asyncio.CancelledError): await handler._background_timer_loop() @@ -1893,7 +1987,9 @@ def _fake_create_task(coro): with ( patch("repeater.engine.time.time", return_value=100000.0), - patch("repeater.engine.asyncio.sleep", new_callable=AsyncMock, return_value=None) as sleep_mock, + patch( + "repeater.engine.asyncio.sleep", new_callable=AsyncMock, return_value=None + ) as sleep_mock, patch("repeater.engine.asyncio.create_task", side_effect=_fake_create_task), ): await handler._background_timer_loop() @@ -1912,7 +2008,9 @@ async def test_record_noise_floor_handles_none_and_exceptions(self, handler): await handler._record_noise_floor_async() @pytest.mark.asyncio - async def test_record_crc_errors_returns_without_storage_and_handles_storage_exception(self, handler): + async def test_record_crc_errors_returns_without_storage_and_handles_storage_exception( + self, handler + ): # No storage configured: should return early. handler.storage = None await handler._record_crc_errors_async() @@ -1926,7 +2024,9 @@ async def test_record_crc_errors_returns_without_storage_and_handles_storage_exc await handler._record_crc_errors_async() @pytest.mark.asyncio - async def test_send_periodic_advert_handles_missing_handler_and_handler_exception(self, handler): + async def test_send_periodic_advert_handles_missing_handler_and_handler_exception( + self, handler + ): handler.send_advert_func = None await handler._send_periodic_advert_async() @@ -1945,7 +2045,10 @@ def test_record_duplicate_appends_when_original_not_found(self, handler): handler.record_duplicate(pkt, rssi=-85, snr=1.0) assert handler.recent_packets[-1]["drop_reason"] == "Duplicate" - assert handler.recent_packets[-1]["packet_hash"] == pkt.calculate_packet_hash().hex().upper()[:16] + assert ( + handler.recent_packets[-1]["packet_hash"] + == pkt.calculate_packet_hash().hex().upper()[:16] + ) def test_record_duplicate_appends_when_recent_packets_empty(self, handler): handler.recent_packets.clear() @@ -1960,7 +2063,7 @@ def test_record_duplicate_appends_when_recent_packets_empty(self, handler): def test_record_duplicate_route_zero_maps_to_flood_counters(self, handler): pkt = _make_flood_packet(payload=b"\x75\x76") # Route nibble 0 is parsed as FLOOD in current protocol constants. - pkt.header = (0x00 << PH_TYPE_SHIFT) + pkt.header = 0x00 << PH_TYPE_SHIFT handler.record_duplicate(pkt, rssi=-90, snr=0.5) diff --git a/tests/test_flood_loop_dedup.py b/tests/test_flood_loop_dedup.py index dfc6b0d..6ffac3d 100644 --- a/tests/test_flood_loop_dedup.py +++ b/tests/test_flood_loop_dedup.py @@ -11,6 +11,7 @@ - mark_seen / is_duplicate cache behaviour - do_not_retransmit flag handling """ + from unittest.mock import MagicMock, patch @@ -106,12 +107,12 @@ class TestDuplicateSuppression: def test_same_packet_forwarded_twice_is_duplicate(self): """Forwarding the same packet a second time must be rejected as duplicate.""" h = _make_handler() - pkt1 = _make_flood_packet(payload=b"\xDE\xAD") + pkt1 = _make_flood_packet(payload=b"\xde\xad") result1 = h.flood_forward(pkt1) assert result1 is not None # Same content in a fresh Packet object - pkt2 = _make_flood_packet(payload=b"\xDE\xAD") + pkt2 = _make_flood_packet(payload=b"\xde\xad") result2 = h.flood_forward(pkt2) assert result2 is None assert pkt2.drop_reason == "Duplicate" @@ -128,7 +129,7 @@ def test_different_payload_not_duplicate(self): def test_mark_seen_makes_is_duplicate_true(self): """mark_seen records the hash; is_duplicate finds it.""" h = _make_handler() - pkt = _make_flood_packet(payload=b"\xAA\xBB") + pkt = _make_flood_packet(payload=b"\xaa\xbb") assert not h.is_duplicate(pkt) h.mark_seen(pkt) assert h.is_duplicate(pkt) @@ -149,10 +150,8 @@ def test_different_path_same_payload_same_hash(self): except for TRACE packets. Two flood packets with different paths but same payload have the same hash. """ - pkt_a = _make_flood_packet(path_bytes=b"\x11", hash_size=1, hash_count=1, - payload=b"\xFF") - pkt_b = _make_flood_packet(path_bytes=b"\x22", hash_size=1, hash_count=1, - payload=b"\xFF") + pkt_a = _make_flood_packet(path_bytes=b"\x11", hash_size=1, hash_count=1, payload=b"\xff") + pkt_b = _make_flood_packet(path_bytes=b"\x22", hash_size=1, hash_count=1, payload=b"\xff") assert pkt_a.calculate_packet_hash() == pkt_b.calculate_packet_hash() def test_seen_cache_eviction(self): @@ -183,62 +182,53 @@ class TestLoopDetection1Byte: def test_loop_detect_off_allows_own_hash(self): """With loop_detect=off, packet with our hash in path is forwarded.""" - h = _make_handler(loop_detect="off", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="off", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Path contains our 1-byte hash (0xAB) once - pkt = _make_flood_packet(b"\xAB", hash_size=1, hash_count=1) + pkt = _make_flood_packet(b"\xab", hash_size=1, hash_count=1) result = h.flood_forward(pkt) assert result is not None def test_loop_detect_strict_blocks_single_occurrence(self): """strict mode (threshold=1): one occurrence of our hash → loop.""" - h = _make_handler(loop_detect="strict", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_flood_packet(b"\xAB", hash_size=1, hash_count=1) + h = _make_handler(loop_detect="strict", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + pkt = _make_flood_packet(b"\xab", hash_size=1, hash_count=1) result = h.flood_forward(pkt) assert result is None assert "loop" in pkt.drop_reason.lower() def test_loop_detect_moderate_allows_one_occurrence(self): """moderate mode (threshold=2): one occurrence is fine.""" - h = _make_handler(loop_detect="moderate", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_flood_packet(b"\x11\xAB", hash_size=1, hash_count=2) + h = _make_handler(loop_detect="moderate", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + pkt = _make_flood_packet(b"\x11\xab", hash_size=1, hash_count=2) result = h.flood_forward(pkt) assert result is not None def test_loop_detect_moderate_blocks_two_occurrences(self): """moderate mode (threshold=2): two occurrences → loop.""" - h = _make_handler(loop_detect="moderate", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_flood_packet(b"\xAB\x11\xAB", hash_size=1, hash_count=3) + h = _make_handler(loop_detect="moderate", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + pkt = _make_flood_packet(b"\xab\x11\xab", hash_size=1, hash_count=3) result = h.flood_forward(pkt) assert result is None assert "loop" in pkt.drop_reason.lower() def test_loop_detect_minimal_allows_three_occurrences(self): """minimal mode (threshold=4): three occurrences still OK.""" - h = _make_handler(loop_detect="minimal", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_flood_packet(b"\xAB\x11\xAB\x22\xAB", hash_size=1, hash_count=5) + h = _make_handler(loop_detect="minimal", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + pkt = _make_flood_packet(b"\xab\x11\xab\x22\xab", hash_size=1, hash_count=5) result = h.flood_forward(pkt) assert result is not None def test_loop_detect_minimal_blocks_four_occurrences(self): """minimal mode (threshold=4): four occurrences → loop.""" - h = _make_handler(loop_detect="minimal", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_flood_packet( - b"\xAB\x11\xAB\x22\xAB\x33\xAB", hash_size=1, hash_count=7 - ) + h = _make_handler(loop_detect="minimal", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + pkt = _make_flood_packet(b"\xab\x11\xab\x22\xab\x33\xab", hash_size=1, hash_count=7) result = h.flood_forward(pkt) assert result is None assert "loop" in pkt.drop_reason.lower() def test_loop_detect_no_match_passes(self): """Strict mode still passes if our hash is not in the path.""" - h = _make_handler(loop_detect="strict", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="strict", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) pkt = _make_flood_packet(b"\x11\x22\x33", hash_size=1, hash_count=3) result = h.flood_forward(pkt) assert result is not None @@ -261,15 +251,13 @@ def test_forward_then_receive_again_is_duplicate(self): Receiving an identical packet is a duplicate. """ h = _make_handler(loop_detect="off") - pkt = _make_flood_packet(b"\x11", hash_size=1, hash_count=1, - payload=b"\xAA\xBB") + pkt = _make_flood_packet(b"\x11", hash_size=1, hash_count=1, payload=b"\xaa\xbb") result = h.flood_forward(pkt) assert result is not None # The original packet's payload hash was marked seen # Receiving same original packet again (before our hop was appended) - pkt2 = _make_flood_packet(b"\x11", hash_size=1, hash_count=1, - payload=b"\xAA\xBB") + pkt2 = _make_flood_packet(b"\x11", hash_size=1, hash_count=1, payload=b"\xaa\xbb") result2 = h.flood_forward(pkt2) assert result2 is None assert pkt2.drop_reason == "Duplicate" @@ -283,8 +271,7 @@ def test_strict_detects_own_hash_after_flood_chain(self): h = _make_handler(loop_detect="strict", local_hash_bytes=our_hash) # Original packet arrives, we forward (appending 0xAB) - pkt = _make_flood_packet(b"\x11", hash_size=1, hash_count=1, - payload=b"\xDD\xEE") + pkt = _make_flood_packet(b"\x11", hash_size=1, hash_count=1, payload=b"\xdd\xee") result = h.flood_forward(pkt) assert result is not None # Now path is [0x11, 0xAB], and this exact payload is in seen_packets @@ -293,8 +280,10 @@ def test_strict_detects_own_hash_after_flood_chain(self): # so it's a new payload in the packet hash sense (different path iteration) # but path contains our hash 0xAB looped_pkt = _make_flood_packet( - b"\x11\xAB\x22", hash_size=1, hash_count=3, - payload=b"\xDD\xEE\xFF" # different payload → not a duplicate + b"\x11\xab\x22", + hash_size=1, + hash_count=3, + payload=b"\xdd\xee\xff", # different payload → not a duplicate ) result2 = h.flood_forward(looped_pkt) assert result2 is None @@ -328,25 +317,22 @@ def test_2_byte_mode_strict_partial_byte_match_does_not_loop(self): In 2-byte mode, a partial byte overlap (0xABxx) is not a loop unless the full 2-byte local hash (0xABCD) matches a hop. """ - h = _make_handler(loop_detect="strict", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="strict", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Path hop is AB11; local 2-byte hash is ABCD. - pkt = _make_flood_packet(b"\xAB\x11", hash_size=2, hash_count=1) + pkt = _make_flood_packet(b"\xab\x11", hash_size=2, hash_count=1) result = h.flood_forward(pkt) assert result is not None def test_2_byte_mode_off_ignores_byte_match(self): """With loop_detect=off, even byte-level 0xAB matches are ignored.""" - h = _make_handler(loop_detect="off", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_flood_packet(b"\xAB\x11", hash_size=2, hash_count=1) + h = _make_handler(loop_detect="off", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + pkt = _make_flood_packet(b"\xab\x11", hash_size=2, hash_count=1) result = h.flood_forward(pkt) assert result is not None def test_2_byte_no_local_hash_byte_passes_strict(self): """If local_hash byte doesn't appear anywhere in the 2-byte path, strict passes.""" - h = _make_handler(loop_detect="strict", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="strict", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Path: [0x11, 0x22] — no 0xAB byte pkt = _make_flood_packet(b"\x11\x22", hash_size=2, hash_count=1) result = h.flood_forward(pkt) @@ -354,10 +340,9 @@ def test_2_byte_no_local_hash_byte_passes_strict(self): def test_3_byte_mode_partial_byte_match_does_not_loop(self): """In 3-byte mode, partial byte overlap is not enough to trigger strict.""" - h = _make_handler(loop_detect="strict", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="strict", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Hop 11AB33 does not equal local 3-byte hash ABCDEF. - pkt = _make_flood_packet(b"\x11\xAB\x33", hash_size=3, hash_count=1) + pkt = _make_flood_packet(b"\x11\xab\x33", hash_size=3, hash_count=1) result = h.flood_forward(pkt) assert result is not None @@ -366,10 +351,9 @@ def test_moderate_multi_byte_requires_full_hash_occurrences(self): moderate threshold=2 counts full 2-byte hash matches only. Two hops with ABxx but not ABCD must not loop. """ - h = _make_handler(loop_detect="moderate", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="moderate", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Two 2-byte hops: AB11 and AB22 (neither equals ABCD) - pkt = _make_flood_packet(b"\xAB\x11\xAB\x22", hash_size=2, hash_count=2) + pkt = _make_flood_packet(b"\xab\x11\xab\x22", hash_size=2, hash_count=2) result = h.flood_forward(pkt) assert result is not None @@ -378,14 +362,13 @@ def test_2_byte_flood_forward_appends_correctly(self): After flood_forward in 2-byte mode, verify the path contains only the expected bytes (no extra, no corruption). """ - h = _make_handler(loop_detect="off", - local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) + h = _make_handler(loop_detect="off", local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) pkt = _make_flood_packet(b"\x11\x22", hash_size=2, hash_count=1) result = h.flood_forward(pkt) assert result is not None assert result.get_path_hash_count() == 2 hashes = result.get_path_hashes() - assert hashes == [b"\x11\x22", b"\xAB\xCD"] + assert hashes == [b"\x11\x22", b"\xab\xcd"] # =================================================================== @@ -494,8 +477,10 @@ def test_hop_count_63_rejected(self): pkt = _make_flood_packet(bytes(63), hash_size=1, hash_count=63) result = h.flood_forward(pkt) assert result is None - assert "maximum" in (pkt.drop_reason or "").lower() or \ - "exceed" in (pkt.drop_reason or "").lower() + assert ( + "maximum" in (pkt.drop_reason or "").lower() + or "exceed" in (pkt.drop_reason or "").lower() + ) # =================================================================== @@ -510,10 +495,8 @@ class TestSerializationAfterForward: """ def test_forwarded_1_byte_round_trips(self): - h = _make_handler(loop_detect="moderate", - local_hash_bytes=bytes([0x42, 0x00, 0x00])) - pkt = _make_flood_packet(b"\x11\x22", hash_size=1, hash_count=2, - payload=b"\xAA\xBB") + h = _make_handler(loop_detect="moderate", local_hash_bytes=bytes([0x42, 0x00, 0x00])) + pkt = _make_flood_packet(b"\x11\x22", hash_size=1, hash_count=2, payload=b"\xaa\xbb") result = h.flood_forward(pkt) assert result is not None wire = result.write_to() @@ -521,13 +504,11 @@ def test_forwarded_1_byte_round_trips(self): pkt2.read_from(wire) assert pkt2.get_path_hash_count() == 3 assert pkt2.get_path_hashes() == [b"\x11", b"\x22", b"\x42"] - assert pkt2.get_payload() == b"\xAA\xBB" + assert pkt2.get_payload() == b"\xaa\xbb" def test_forwarded_2_byte_round_trips(self): - h = _make_handler(loop_detect="off", - local_hash_bytes=bytes([0xAA, 0xBB, 0xCC])) - pkt = _make_flood_packet(b"\x11\x22", hash_size=2, hash_count=1, - payload=b"\xDE\xAD") + h = _make_handler(loop_detect="off", local_hash_bytes=bytes([0xAA, 0xBB, 0xCC])) + pkt = _make_flood_packet(b"\x11\x22", hash_size=2, hash_count=1, payload=b"\xde\xad") result = h.flood_forward(pkt) assert result is not None wire = result.write_to() @@ -535,13 +516,11 @@ def test_forwarded_2_byte_round_trips(self): pkt2.read_from(wire) assert pkt2.get_path_hash_size() == 2 assert pkt2.get_path_hash_count() == 2 - assert pkt2.get_path_hashes() == [b"\x11\x22", b"\xAA\xBB"] + assert pkt2.get_path_hashes() == [b"\x11\x22", b"\xaa\xbb"] def test_forwarded_3_byte_round_trips(self): - h = _make_handler(loop_detect="off", - local_hash_bytes=bytes([0xAA, 0xBB, 0xCC])) - pkt = _make_flood_packet(b"\x11\x22\x33", hash_size=3, hash_count=1, - payload=b"\xBE\xEF") + h = _make_handler(loop_detect="off", local_hash_bytes=bytes([0xAA, 0xBB, 0xCC])) + pkt = _make_flood_packet(b"\x11\x22\x33", hash_size=3, hash_count=1, payload=b"\xbe\xef") result = h.flood_forward(pkt) assert result is not None wire = result.write_to() @@ -549,7 +528,7 @@ def test_forwarded_3_byte_round_trips(self): pkt2.read_from(wire) assert pkt2.get_path_hash_size() == 3 assert pkt2.get_path_hash_count() == 2 - assert pkt2.get_path_hashes() == [b"\x11\x22\x33", b"\xAA\xBB\xCC"] + assert pkt2.get_path_hashes() == [b"\x11\x22\x33", b"\xaa\xbb\xcc"] # =================================================================== @@ -572,7 +551,7 @@ def test_three_repeater_chain_no_loop(self): ] handlers = [_make_handler(loop_detect="strict", local_hash_bytes=h) for h in hashes] - pkt = _make_flood_packet(payload=b"\xFE\xED") + pkt = _make_flood_packet(payload=b"\xfe\xed") for i, h in enumerate(handlers): result = h.flood_forward(pkt) assert result is not None, f"repeater {i} unexpectedly dropped packet" @@ -582,7 +561,7 @@ def test_three_repeater_chain_no_loop(self): path_bytes=bytes(result.path), hash_size=1, hash_count=result.get_path_hash_count(), - payload=b"\xFE\xED" + bytes([i + 1]), + payload=b"\xfe\xed" + bytes([i + 1]), ) assert pkt.get_path_hash_count() == 3 @@ -607,7 +586,8 @@ def test_circular_topology_strict_blocks_loop(self): # B forwards (new payload to avoid dedup) pkt_b = _make_flood_packet( - bytes(pkt.path), hash_size=1, + bytes(pkt.path), + hash_size=1, hash_count=pkt.get_path_hash_count(), payload=b"\x01\x02\x03\x04", ) @@ -616,7 +596,8 @@ def test_circular_topology_strict_blocks_loop(self): # C forwards pkt_c = _make_flood_packet( - bytes(pkt_b.path), hash_size=1, + bytes(pkt_b.path), + hash_size=1, hash_count=pkt_b.get_path_hash_count(), payload=b"\x01\x02\x03\x04\x05", ) @@ -625,7 +606,8 @@ def test_circular_topology_strict_blocks_loop(self): # Back to A — 0x11 is already in path → strict blocks it pkt_a2 = _make_flood_packet( - bytes(pkt_c.path), hash_size=1, + bytes(pkt_c.path), + hash_size=1, hash_count=pkt_c.get_path_hash_count(), payload=b"\x01\x02\x03\x04\x05\x06", ) @@ -640,11 +622,11 @@ def test_circular_topology_off_allows_loop_but_dedup_catches(self): """ h = _make_handler(loop_detect="off", local_hash_bytes=bytes([0x11, 0x00, 0x00])) - pkt = _make_flood_packet(payload=b"\xAA\xBB") + pkt = _make_flood_packet(payload=b"\xaa\xbb") assert h.flood_forward(pkt) is not None # Same payload comes back - pkt2 = _make_flood_packet(payload=b"\xAA\xBB") + pkt2 = _make_flood_packet(payload=b"\xaa\xbb") result = h.flood_forward(pkt2) assert result is None assert pkt2.drop_reason == "Duplicate" @@ -664,7 +646,8 @@ def test_two_byte_chain_loop_detected(self): # B forwards pkt_b = _make_flood_packet( - bytes(pkt.path), hash_size=2, + bytes(pkt.path), + hash_size=2, hash_count=pkt.get_path_hash_count(), payload=b"\x01\x02\x03", ) @@ -673,7 +656,8 @@ def test_two_byte_chain_loop_detected(self): # Back to A — byte 0xAA is in path → strict detects it pkt_a2 = _make_flood_packet( - bytes(pkt_b.path), hash_size=2, + bytes(pkt_b.path), + hash_size=2, hash_count=pkt_b.get_path_hash_count(), payload=b"\x01\x02\x03\x04", ) diff --git a/tests/test_glass_handler.py b/tests/test_glass_handler.py index 2624e5e..d99e56c 100644 --- a/tests/test_glass_handler.py +++ b/tests/test_glass_handler.py @@ -5,7 +5,9 @@ from pathlib import Path import pytest -_MODULE_PATH = Path(__file__).resolve().parents[1] / "repeater" / "data_acquisition" / "glass_handler.py" +_MODULE_PATH = ( + Path(__file__).resolve().parents[1] / "repeater" / "data_acquisition" / "glass_handler.py" +) _SPEC = importlib.util.spec_from_file_location("repeater_glass_handler", _MODULE_PATH) _MODULE = importlib.util.module_from_spec(_SPEC) assert _SPEC and _SPEC.loader diff --git a/tests/test_gps_service.py b/tests/test_gps_service.py index 12049b5..8995003 100644 --- a/tests/test_gps_service.py +++ b/tests/test_gps_service.py @@ -3,7 +3,9 @@ from datetime import datetime, timezone from pathlib import Path -_MODULE_PATH = Path(__file__).resolve().parents[1] / "repeater" / "data_acquisition" / "gps_service.py" +_MODULE_PATH = ( + Path(__file__).resolve().parents[1] / "repeater" / "data_acquisition" / "gps_service.py" +) _SPEC = importlib.util.spec_from_file_location("repeater_gps_service", _MODULE_PATH) _MODULE = importlib.util.module_from_spec(_SPEC) assert _SPEC and _SPEC.loader @@ -28,12 +30,8 @@ def test_nmea_parser_combines_rmc_gga_gsa_gsv_attributes(): assert parser.ingest_sentence( _sentence("GPGGA,123519,4807.038,N,01131.000,E,1,08,0.9,545.4,M,46.9,M,,") ) - assert parser.ingest_sentence( - _sentence("GPGSA,A,3,04,05,09,12,24,25,29,,,,,,1.8,1.0,1.5") - ) - assert parser.ingest_sentence( - _sentence("GPGSV,1,1,03,04,77,045,42,05,13,180,35,09,07,095,29") - ) + assert parser.ingest_sentence(_sentence("GPGSA,A,3,04,05,09,12,24,25,29,,,,,,1.8,1.0,1.5")) + assert parser.ingest_sentence(_sentence("GPGSV,1,1,03,04,77,045,42,05,13,180,35,09,07,095,29")) snapshot = parser.snapshot() @@ -378,6 +376,7 @@ def test_gps_service_reflects_runtime_manual_location_updates(): assert snapshot["gps_position"]["longitude"] == -71.1076 assert snapshot["position_meta"]["source"] == "manual_config" + def test_repeater_location_uses_config_when_gps_opt_in_disabled(): service = GPSService( { diff --git a/tests/test_handler_helpers_acl_advert.py b/tests/test_handler_helpers_acl_advert.py index 7f5af2f..fecc80f 100644 --- a/tests/test_handler_helpers_acl_advert.py +++ b/tests/test_handler_helpers_acl_advert.py @@ -17,7 +17,7 @@ def get_public_key(self): class _FakePacket: - def __init__(self, *, header=0x00, path=None, pkt_hash=b"\xAA" * 16): + def __init__(self, *, header=0x00, path=None, pkt_hash=b"\xaa" * 16): self.header = header self.path = path if path is not None else bytearray() self._pkt_hash = pkt_hash @@ -286,7 +286,9 @@ def test_advert_reload_config_and_cleanup_old_state_bounds_memory(): now = time.time() helper._recent_advert_hashes["old"] = now - 1 helper._penalty_until["pk"] = now - 1 - helper._bucket_state["oldpk"] = {"last_seen": now - (helper._bucket_state_retention_seconds + 1)} + helper._bucket_state["oldpk"] = { + "last_seen": now - (helper._bucket_state_retention_seconds + 1) + } helper._violation_state["oldpk"] = {"count": 3, "last_violation": now - 9999} helper._cleanup_old_state(now) diff --git a/tests/test_handler_helpers_mesh_cli.py b/tests/test_handler_helpers_mesh_cli.py index 2e82d9d..e9d454a 100644 --- a/tests/test_handler_helpers_mesh_cli.py +++ b/tests/test_handler_helpers_mesh_cli.py @@ -77,7 +77,9 @@ def test_cmd_advert_branches_and_success_schedule(): fake_loop = SimpleNamespace(is_running=lambda: True) cli._event_loop = fake_loop - with patch("asyncio.run_coroutine_threadsafe", side_effect=lambda coro, _loop: coro.close()) as run_ts: + with patch( + "asyncio.run_coroutine_threadsafe", side_effect=lambda coro, _loop: coro.close() + ) as run_ts: out = cli._cmd_advert() assert out == "OK - Advert sent" @@ -124,7 +126,9 @@ def test_cmd_get_public_key_and_neighbor_branches(): cli.storage_handler = storage assert cli._cmd_neighbors() == "No neighbors discovered yet" - storage.get_neighbors = lambda: {"aa": {"is_repeater": False, "zero_hop": False, "last_seen": 1}} + storage.get_neighbors = lambda: { + "aa": {"is_repeater": False, "zero_hop": False, "last_seen": 1} + } assert "No repeaters or zero hop" in cli._cmd_neighbors() storage.get_neighbors = lambda: { @@ -137,7 +141,9 @@ def test_cmd_get_public_key_and_neighbor_branches(): assert "abcdef12:20:4" in out assert "11223344:10:1" in out - cli.storage_handler = SimpleNamespace(get_neighbors=MagicMock(side_effect=RuntimeError("db fail"))) + cli.storage_handler = SimpleNamespace( + get_neighbors=MagicMock(side_effect=RuntimeError("db fail")) + ) assert cli._cmd_neighbors().startswith("Error:") @@ -179,7 +185,7 @@ def test_cmd_set_updates_and_validation_errors(): def test_misc_commands_and_routes(): cli = MeshCLI("/tmp/cfg.yaml", _base_config(), _cfg_mgr(), enable_regions=True) - assert cli._cmd_region("region") .startswith("Error:") + assert cli._cmd_region("region").startswith("Error:") assert cli._cmd_region("region load us").startswith("Error:") assert cli._cmd_region("region save").startswith("Error:") assert cli._cmd_region("region remove x").startswith("Error:") diff --git a/tests/test_handler_helpers_path_protocol_text.py b/tests/test_handler_helpers_path_protocol_text.py index 6e5df86..474ff83 100644 --- a/tests/test_handler_helpers_path_protocol_text.py +++ b/tests/test_handler_helpers_path_protocol_text.py @@ -53,9 +53,11 @@ async def test_path_helper_updates_client_out_path_on_valid_decrypt(): helper = PathHelper(acl_dict={0x11: acl}) # Payload: dest(0x11), src(0x22), mac+data... - packet = _PathPacket(payload=b"\x11\x22\xAA\xBB\xCC") + packet = _PathPacket(payload=b"\x11\x22\xaa\xbb\xcc") - with patch("pymc_core.protocol.crypto.CryptoUtils.mac_then_decrypt", return_value=b"\x02\x99\x88\x01"): + with patch( + "pymc_core.protocol.crypto.CryptoUtils.mac_then_decrypt", return_value=b"\x02\x99\x88\x01" + ): handled = await helper.process_path_packet(packet) assert handled is False @@ -71,14 +73,17 @@ async def test_path_helper_returns_false_for_non_matching_or_invalid_inputs(): helper = PathHelper(acl_dict={0x11: acl}) assert await helper.process_path_packet(_PathPacket(payload=b"\x11")) is False - assert await helper.process_path_packet(_PathPacket(payload=b"\x33\x22\xAA\xBB")) is False + assert await helper.process_path_packet(_PathPacket(payload=b"\x33\x22\xaa\xbb")) is False no_secret_client = _FakeClient(pubkey=bytes([0x22]) + b"x" * 31, shared_secret=b"") helper_no_secret = PathHelper(acl_dict={0x11: _FakeACL([no_secret_client])}) - assert await helper_no_secret.process_path_packet(_PathPacket(payload=b"\x11\x22\xAA\xBB")) is False + assert ( + await helper_no_secret.process_path_packet(_PathPacket(payload=b"\x11\x22\xaa\xbb")) + is False + ) with patch("pymc_core.protocol.crypto.CryptoUtils.mac_then_decrypt", return_value=None): - assert await helper.process_path_packet(_PathPacket(payload=b"\x11\x22\xAA\xBB")) is False + assert await helper.process_path_packet(_PathPacket(payload=b"\x11\x22\xaa\xbb")) is False @pytest.mark.asyncio @@ -173,9 +178,24 @@ def test_protocol_request_access_list_admin_and_reserved_rules(): def test_protocol_request_get_neighbours_sort_and_pagination(): neighbors = { - "AA" * 16: {"is_repeater": True, "zero_hop": True, "last_seen": time.time() - 1, "snr": 5.0}, - "BB" * 16: {"is_repeater": True, "zero_hop": True, "last_seen": time.time() - 10, "snr": 1.0}, - "CC" * 16: {"is_repeater": False, "zero_hop": True, "last_seen": time.time() - 1, "snr": 9.0}, + "AA" * 16: { + "is_repeater": True, + "zero_hop": True, + "last_seen": time.time() - 1, + "snr": 5.0, + }, + "BB" * 16: { + "is_repeater": True, + "zero_hop": True, + "last_seen": time.time() - 10, + "snr": 1.0, + }, + "CC" * 16: { + "is_repeater": False, + "zero_hop": True, + "last_seen": time.time() - 1, + "snr": 9.0, + }, } storage = SimpleNamespace(get_neighbors=lambda: neighbors) helper = ProtocolRequestHelper( @@ -209,10 +229,16 @@ def test_protocol_request_owner_info_fallback_version(): def test_text_helper_cli_prefix_and_admin_permission_checks(): - acl = _FakeACL([ - _FakeClient(pubkey=bytes([0x21]) + b"x" * 31, shared_secret=b"k" * 32, permissions=0x02), - _FakeClient(pubkey=bytes([0x22]) + b"x" * 31, shared_secret=b"k" * 32, permissions=0x01), - ]) + acl = _FakeACL( + [ + _FakeClient( + pubkey=bytes([0x21]) + b"x" * 31, shared_secret=b"k" * 32, permissions=0x02 + ), + _FakeClient( + pubkey=bytes([0x22]) + b"x" * 31, shared_secret=b"k" * 32, permissions=0x01 + ), + ] + ) helper = TextHelper(identity_manager=MagicMock(), acl_dict={0x41: acl}) assert helper._is_cli_command("get status") is True @@ -301,11 +327,16 @@ def test_text_helper_register_identity_room_server_without_event_loop_is_safe(): with ( patch("repeater.handler_helpers.text.TextMessageHandler", return_value=MagicMock()), patch("repeater.handler_helpers.text.RoomServer") as room_server_cls, - patch("repeater.handler_helpers.text.asyncio.get_running_loop", side_effect=RuntimeError("no loop")), + patch( + "repeater.handler_helpers.text.asyncio.get_running_loop", + side_effect=RuntimeError("no loop"), + ), ): room_server_obj = MagicMock() room_server_cls.return_value = room_server_obj - helper.register_identity("room-a", identity, identity_type="room_server", radio_config={"max_posts": 2}) + helper.register_identity( + "room-a", identity, identity_type="room_server", radio_config={"max_posts": 2} + ) assert 0x34 in helper.room_servers @@ -313,7 +344,9 @@ def test_text_helper_register_identity_room_server_without_event_loop_is_safe(): @pytest.mark.asyncio async def test_text_helper_send_cli_reply_uses_direct_path_from_client(): helper = TextHelper(identity_manager=MagicMock(), packet_injector=AsyncMock()) - sender = _FakeClient(pubkey=bytes([0x99]) + b"x" * 31, shared_secret=b"s" * 32, permissions=0x02) + sender = _FakeClient( + pubkey=bytes([0x99]) + b"x" * 31, shared_secret=b"s" * 32, permissions=0x02 + ) sender.out_path = bytearray([0xAA, 0xBB]) sender.out_path_len = 2 helper.acl_dict = {0x10: _FakeACL([sender])} @@ -332,6 +365,6 @@ async def test_text_helper_send_cli_reply_uses_direct_path_from_client(): handler_info={"identity": _FakeId(bytes([0x10]) + b"i" * 31)}, ) - assert bytes(reply_packet.path) == b"\xAA\xBB" + assert bytes(reply_packet.path) == b"\xaa\xbb" assert reply_packet.path_len == 2 helper._send_packet.assert_awaited_once_with(reply_packet, wait_for_ack=False) diff --git a/tests/test_handler_helpers_room_server.py b/tests/test_handler_helpers_room_server.py index 1332f0c..cd1d9c3 100644 --- a/tests/test_handler_helpers_room_server.py +++ b/tests/test_handler_helpers_room_server.py @@ -122,7 +122,7 @@ async def test_room_server_push_post_to_client_success_direct_route_sets_path_an rs.global_limiter = SimpleNamespace(acquire=AsyncMock(), release=MagicMock()) rs._handle_ack_received = AsyncMock() - client = _FakeClient(pubkey=b"E" * 32, out_path=b"\xAA\xBB", out_path_len=2) + client = _FakeClient(pubkey=b"E" * 32, out_path=b"\xaa\xbb", out_path_len=2) post = { "author_pubkey": (b"F" * 32).hex(), "message_text": "payload", @@ -131,17 +131,28 @@ async def test_room_server_push_post_to_client_success_direct_route_sets_path_an packet = SimpleNamespace(path=bytearray(), path_len=0) with ( - patch("repeater.handler_helpers.room_server.PacketBuilder._pack_timestamp_data", return_value=b"pk"), - patch("repeater.handler_helpers.room_server.CryptoUtils.sha256", return_value=b"\x01\x02\x03\x04abcd"), - patch("repeater.handler_helpers.room_server.PacketBuilder.create_datagram", return_value=packet), + patch( + "repeater.handler_helpers.room_server.PacketBuilder._pack_timestamp_data", + return_value=b"pk", + ), + patch( + "repeater.handler_helpers.room_server.CryptoUtils.sha256", + return_value=b"\x01\x02\x03\x04abcd", + ), + patch( + "repeater.handler_helpers.room_server.PacketBuilder.create_datagram", + return_value=packet, + ), ): ok = await rs.push_post_to_client(client, post) assert ok is True - assert bytes(packet.path) == b"\xAA\xBB" + assert bytes(packet.path) == b"\xaa\xbb" assert packet.path_len == 2 injector.assert_awaited_once_with(packet, wait_for_ack=True) - rs._handle_ack_received.assert_awaited_once_with(client.id.get_public_key(), post["post_timestamp"]) + rs._handle_ack_received.assert_awaited_once_with( + client.id.get_public_key(), post["post_timestamp"] + ) rs.global_limiter.release.assert_called_once() @@ -169,9 +180,18 @@ async def test_room_server_push_post_to_client_backoff_skip_and_timeout_path(): # Out of backoff and send fails -> timeout handler called. db.get_client_sync.return_value = {"push_failures": 1, "updated_at": time.time() - 9999} with ( - patch("repeater.handler_helpers.room_server.PacketBuilder._pack_timestamp_data", return_value=b"pk"), - patch("repeater.handler_helpers.room_server.CryptoUtils.sha256", return_value=b"\x01\x02\x03\x04abcd"), - patch("repeater.handler_helpers.room_server.PacketBuilder.create_datagram", return_value=SimpleNamespace(path=bytearray(), path_len=0)), + patch( + "repeater.handler_helpers.room_server.PacketBuilder._pack_timestamp_data", + return_value=b"pk", + ), + patch( + "repeater.handler_helpers.room_server.CryptoUtils.sha256", + return_value=b"\x01\x02\x03\x04abcd", + ), + patch( + "repeater.handler_helpers.room_server.PacketBuilder.create_datagram", + return_value=SimpleNamespace(path=bytearray(), path_len=0), + ), ): fail_ok = await rs.push_post_to_client(client, post) diff --git a/tests/test_handler_helpers_trace_discovery_login.py b/tests/test_handler_helpers_trace_discovery_login.py index e711adb..593976e 100644 --- a/tests/test_handler_helpers_trace_discovery_login.py +++ b/tests/test_handler_helpers_trace_discovery_login.py @@ -12,7 +12,9 @@ class DummyPacket: - def __init__(self, *, route=ROUTE_TYPE_DIRECT, path=b"", payload=b"\x01\x02", snr=2.5, rssi=-70): + def __init__( + self, *, route=ROUTE_TYPE_DIRECT, path=b"", payload=b"\x01\x02", snr=2.5, rssi=-70 + ): self.header = route self.path = bytearray(path) self.path_len = len(self.path) @@ -86,7 +88,7 @@ async def test_trace_helper_process_sets_pending_ping_and_forwards(): tag = 77 evt = helper.register_ping(tag, 0x42) - packet = DummyPacket(path=b"\x01", payload=b"\xAA\xBB\xCC") + packet = DummyPacket(path=b"\x01", payload=b"\xaa\xbb\xcc") helper._forward_trace_packet = AsyncMock() helper._extract_path_info = MagicMock(return_value=([], [])) helper._should_forward_trace = MagicMock(return_value=True) @@ -112,7 +114,9 @@ async def test_trace_helper_process_sets_pending_ping_and_forwards(): @pytest.mark.asyncio async def test_trace_helper_ignores_zero_rssi_pending_ping_response(): - helper = TraceHelper(local_hash=0x42, local_identity=FakeIdentity(0x42), repeater_handler=MagicMock()) + helper = TraceHelper( + local_hash=0x42, local_identity=FakeIdentity(0x42), repeater_handler=MagicMock() + ) tag = 9 evt = helper.register_ping(tag, 0x42) @@ -158,7 +162,9 @@ async def test_trace_helper_forward_trace_packet_updates_recent_record_and_injec def test_trace_helper_cleanup_stale_pings(): - helper = TraceHelper(local_hash=0x42, local_identity=FakeIdentity(0x42), repeater_handler=MagicMock()) + helper = TraceHelper( + local_hash=0x42, local_identity=FakeIdentity(0x42), repeater_handler=MagicMock() + ) helper.pending_pings = { 1: {"sent_at": time.time() - 100, "event": asyncio.Event(), "result": None, "target": 1}, 2: {"sent_at": time.time(), "event": asyncio.Event(), "result": None, "target": 2}, @@ -171,13 +177,19 @@ def test_trace_helper_cleanup_stale_pings(): def test_discovery_request_filter_match_and_mismatch(): - helper = DiscoveryHelper(local_identity=FakeIdentity(0x42), packet_injector=AsyncMock(), node_type=2) + helper = DiscoveryHelper( + local_identity=FakeIdentity(0x42), packet_injector=AsyncMock(), node_type=2 + ) helper._send_discovery_response = MagicMock() - helper._on_discovery_request({"tag": 1, "filter": 0x00, "prefix_only": False, "snr": 1.2, "rssi": -80}) + helper._on_discovery_request( + {"tag": 1, "filter": 0x00, "prefix_only": False, "snr": 1.2, "rssi": -80} + ) helper._send_discovery_response.assert_not_called() - helper._on_discovery_request({"tag": 2, "filter": 0x04, "prefix_only": True, "snr": 2.3, "rssi": -70}) + helper._on_discovery_request( + {"tag": 2, "filter": 0x04, "prefix_only": True, "snr": 2.3, "rssi": -70} + ) helper._send_discovery_response.assert_called_once_with(2, 2, 2.3, True) @@ -185,7 +197,9 @@ def test_discovery_request_without_identity_does_not_send(): helper = DiscoveryHelper(local_identity=None, packet_injector=AsyncMock(), node_type=2) helper._send_discovery_response = MagicMock() - helper._on_discovery_request({"tag": 7, "filter": 0x04, "prefix_only": False, "snr": 0.0, "rssi": -90}) + helper._on_discovery_request( + {"tag": 7, "filter": 0x04, "prefix_only": False, "snr": 0.0, "rssi": -90} + ) helper._send_discovery_response.assert_not_called() @@ -205,7 +219,10 @@ async def test_discovery_send_packet_async_success_failure_and_exception(): def test_discovery_send_response_without_injector_is_safe(): helper = DiscoveryHelper(local_identity=FakeIdentity(0x42), packet_injector=None) - with patch("pymc_core.protocol.packet_builder.PacketBuilder.create_discovery_response", return_value=object()): + with patch( + "pymc_core.protocol.packet_builder.PacketBuilder.create_discovery_response", + return_value=object(), + ): helper._send_discovery_response(tag=5, node_type=2, inbound_snr=1.0, prefix_only=False) @@ -237,13 +254,19 @@ def test_login_register_identity_repeater_creates_acl_and_handler(): with ( patch("repeater.handler_helpers.acl.ACL", return_value=acl_obj) as acl_cls, - patch("repeater.handler_helpers.login.LoginServerHandler", return_value=handler_obj) as handler_cls, + patch( + "repeater.handler_helpers.login.LoginServerHandler", return_value=handler_obj + ) as handler_cls, ): helper.register_identity( name="repeater-main", identity=identity, identity_type="repeater", - config={"repeater": {"security": {"max_clients": 3, "admin_password": "a", "guest_password": "g"}}}, + config={ + "repeater": { + "security": {"max_clients": 3, "admin_password": "a", "guest_password": "g"} + } + }, ) acl_cls.assert_called_once() diff --git a/tests/test_http_server_unit.py b/tests/test_http_server_unit.py index 6f4d788..fb909a0 100644 --- a/tests/test_http_server_unit.py +++ b/tests/test_http_server_unit.py @@ -31,7 +31,9 @@ def test_doc_endpoint_routes_and_openapi_json_paths(monkeypatch): assert doc.index() == "docs-html" assert doc.docs() == "docs-html" - monkeypatch.setattr(cherrypy, "response", SimpleNamespace(headers={}, status=200), raising=False) + monkeypatch.setattr( + cherrypy, "response", SimpleNamespace(headers={}, status=200), raising=False + ) # success path monkeypatch.setattr("builtins.open", lambda *args, **kwargs: io.StringIO("openapi: 3.0.0\n")) @@ -90,7 +92,6 @@ def test_stats_app_index_error_paths(monkeypatch, tmp_path): with pytest.raises(cherrypy.HTTPError): app.index() - # Force generic open() exception branch def _explode(*_args, **_kwargs): raise RuntimeError("boom") @@ -107,15 +108,21 @@ def _fake_init_auth(self): self.token_manager = object() monkeypatch.setattr(hs.HTTPStatsServer, "_init_auth_handlers", _fake_init_auth) - monkeypatch.setattr(hs, "StatsApp", lambda *args, **kwargs: SimpleNamespace(api=SimpleNamespace(config_manager=object()))) + monkeypatch.setattr( + hs, + "StatsApp", + lambda *args, **kwargs: SimpleNamespace(api=SimpleNamespace(config_manager=object())), + ) monkeypatch.setattr(hs, "AuthEndpoints", lambda *args, **kwargs: object()) monkeypatch.setattr(hs, "DocEndpoint", lambda *_args, **_kwargs: object()) - server = hs.HTTPStatsServer(config={"web": {"cors_enabled": False}}, config_path=str(Path(tmp_path) / "cfg.yml")) + server = hs.HTTPStatsServer( + config={"web": {"cors_enabled": False}}, config_path=str(Path(tmp_path) / "cfg.yml") + ) monkeypatch.setattr(cherrypy, "response", SimpleNamespace(headers={}), raising=False) out = server._json_error_handler(401, "no", "", "") - assert "\"success\": false" in out + assert '"success": false' in out install_called = {"v": False} monkeypatch.setattr(hs.cherrypy_cors, "install", lambda: install_called.__setitem__("v", True)) @@ -123,6 +130,11 @@ def _fake_init_auth(self): assert install_called["v"] is True exited = {"v": False} - monkeypatch.setattr(cherrypy, "engine", SimpleNamespace(exit=lambda: exited.__setitem__("v", True)), raising=False) + monkeypatch.setattr( + cherrypy, + "engine", + SimpleNamespace(exit=lambda: exited.__setitem__("v", True)), + raising=False, + ) server.stop() assert exited["v"] is True diff --git a/tests/test_identity_manager_and_repeater_cli.py b/tests/test_identity_manager_and_repeater_cli.py index 20df1b5..a519dc0 100644 --- a/tests/test_identity_manager_and_repeater_cli.py +++ b/tests/test_identity_manager_and_repeater_cli.py @@ -6,7 +6,7 @@ class _FakeIdentity: - def __init__(self, pubkey: bytes, addr: bytes = b"\xAA\xBB"): + def __init__(self, pubkey: bytes, addr: bytes = b"\xaa\xbb"): self._pubkey = pubkey self._addr = addr @@ -194,7 +194,7 @@ def test_cli_set_commands_apply_and_validate_ranges(): assert cli._cmd_set("radio 900000000 250000 9 6").startswith("OK") assert cfg["radio"]["frequency"] == 900000000.0 - assert cli._cmd_set("freq 868000000") .startswith("OK") + assert cli._cmd_set("freq 868000000").startswith("OK") assert cli._cmd_set("tx 17") == "OK" assert cli._cmd_set("guest.password gpw") == "OK" assert cli._cmd_set("allow.read.only off") == "OK" @@ -251,7 +251,7 @@ def test_cli_setperm_region_neighbor_tempradio_log_paths(): assert cli._cmd_neighbor_remove("neighbor.remove ") == "ERR: Missing pubkey" assert cli._cmd_neighbor_remove("neighbor.remove 001122").startswith("Error:") - assert cli._cmd_tempradio("tempradio 1 2 3") .startswith("Error:") + assert cli._cmd_tempradio("tempradio 1 2 3").startswith("Error:") assert cli._cmd_tempradio("tempradio 299 125 7 5 10") == "Error: invalid frequency" assert cli._cmd_tempradio("tempradio 915 6 7 5 10") == "Error: invalid bandwidth" assert cli._cmd_tempradio("tempradio 915 125 4 5 10") == "Error: invalid spreading factor" diff --git a/tests/test_keygen_local_cli.py b/tests/test_keygen_local_cli.py index 639c74e..9cd6131 100644 --- a/tests/test_keygen_local_cli.py +++ b/tests/test_keygen_local_cli.py @@ -28,11 +28,13 @@ def test_generate_meshcore_keypair_clamps_scalar_and_shapes_output(): def _fake_scalarmult(scalar_bytes): captured["scalar"] = scalar_bytes - return b"\xAA" * 32 + return b"\xaa" * 32 with ( patch("repeater.keygen.secrets.token_bytes", return_value=seed), - patch("repeater.keygen.crypto_scalarmult_ed25519_base_noclamp", side_effect=_fake_scalarmult), + patch( + "repeater.keygen.crypto_scalarmult_ed25519_base_noclamp", side_effect=_fake_scalarmult + ), ): pub, priv = keygen.generate_meshcore_keypair() @@ -42,7 +44,7 @@ def _fake_scalarmult(scalar_bytes): expected[31] &= 63 expected[31] |= 64 - assert pub == b"\xAA" * 32 + assert pub == b"\xaa" * 32 assert len(pub) == 32 assert len(priv) == 64 assert captured["scalar"] == bytes(expected) diff --git a/tests/test_main_py_coverage.py b/tests/test_main_py_coverage.py index e7e832a..ec7330f 100644 --- a/tests/test_main_py_coverage.py +++ b/tests/test_main_py_coverage.py @@ -166,7 +166,7 @@ async def test_deliver_control_data_filters_non_discovery_and_pushes_valid(): fs_ok.push_control_data.assert_not_awaited() payload = bytes([0x90, 0x00, 0x11, 0x22, 0x33, 0x44]) - await daemon.deliver_control_data(1.0, -70, 2, b"\xAA\xBB", payload) + await daemon.deliver_control_data(1.0, -70, 2, b"\xaa\xbb", payload) fs_ok.push_control_data.assert_awaited_once() @@ -182,7 +182,7 @@ async def test_trace_complete_for_companions_requires_valid_lengths(): fs.push_trace_data_async.assert_not_awaited() parsed = { - "trace_path_bytes": b"\xAA\xBB\xCC\xDD", + "trace_path_bytes": b"\xaa\xbb\xcc\xdd", "flags": 0, "tag": 1, "auth_code": 2, @@ -218,7 +218,9 @@ async def test_send_advert_branches_and_success_path(): # Missing dispatcher/local identity assert await daemon.send_advert() is False - daemon.dispatcher = SimpleNamespace(send_packet=AsyncMock(), packet_filter=SimpleNamespace(track_packet=MagicMock())) + daemon.dispatcher = SimpleNamespace( + send_packet=AsyncMock(), packet_filter=SimpleNamespace(track_packet=MagicMock()) + ) daemon.local_identity = _FakeIdentity(b"\x21" + b"x" * 31) daemon.config["repeater"]["mode"] = "no_tx" assert await daemon.send_advert() is False @@ -229,7 +231,7 @@ async def test_send_advert_branches_and_success_path(): get_repeater_location=lambda: {"latitude": 9.1, "longitude": 8.2, "source": "gps"} ) - packet = SimpleNamespace(calculate_packet_hash=lambda: b"\xAB" * 16) + packet = SimpleNamespace(calculate_packet_hash=lambda: b"\xab" * 16) with patch("pymc_core.protocol.PacketBuilder.create_advert", return_value=packet): ok = await daemon.send_advert() @@ -254,10 +256,14 @@ def test_update_repeater_location_from_gps_branches(): assert daemon.config["repeater"]["latitude"] == 3.5 assert daemon.config["repeater"]["longitude"] == 4.5 - daemon.config_manager = SimpleNamespace(update_and_save=MagicMock(return_value={"success": False, "error": "nope"})) + daemon.config_manager = SimpleNamespace( + update_and_save=MagicMock(return_value={"success": False, "error": "nope"}) + ) assert daemon._update_repeater_location_from_gps({"latitude": 5.5, "longitude": 6.5}) is False - daemon.config_manager = SimpleNamespace(update_and_save=MagicMock(return_value={"success": True})) + daemon.config_manager = SimpleNamespace( + update_and_save=MagicMock(return_value={"success": True}) + ) assert daemon._update_repeater_location_from_gps({"latitude": 6.5, "longitude": 7.5}) is True diff --git a/tests/test_main_py_more.py b/tests/test_main_py_more.py index b807300..6f46c13 100644 --- a/tests/test_main_py_more.py +++ b/tests/test_main_py_more.py @@ -16,7 +16,7 @@ def get_public_key(self): return bytes([self._seed[0]]) + (b"P" * 31) def get_address_bytes(self): - return b"\xAB\xCD" + return b"\xab\xcd" def _base_config(): @@ -60,7 +60,9 @@ async def test_run_starts_http_and_handles_dispatcher_cancelled_gracefully(): async def _init_stub(): daemon.local_identity = SimpleNamespace(get_public_key=lambda: b"\x22" * 32) - daemon.dispatcher = SimpleNamespace(run_forever=AsyncMock(side_effect=asyncio.CancelledError())) + daemon.dispatcher = SimpleNamespace( + run_forever=AsyncMock(side_effect=asyncio.CancelledError()) + ) daemon.initialize = _init_stub diff --git a/tests/test_mqtt_publish_integration.py b/tests/test_mqtt_publish_integration.py index 9099922..ef3ee7e 100644 --- a/tests/test_mqtt_publish_integration.py +++ b/tests/test_mqtt_publish_integration.py @@ -75,9 +75,7 @@ def _attach_capturing_client(conn) -> list: captured: list = [] def _fake_publish(topic, payload, retain=False, qos=0): - captured.append( - {"topic": topic, "payload": payload, "retain": retain, "qos": qos} - ) + captured.append({"topic": topic, "payload": payload, "retain": retain, "qos": qos}) return None conn._running = True @@ -150,9 +148,7 @@ def test_mqtt_published_packet_carries_semtech_duration_end_to_end(): payload_dict = json.loads(publish["payload"]) assert payload_dict["duration"] == expected_duration assert payload_dict["duration"] != "0", "duration must not be hard-coded zero" - assert 0 < int(payload_dict["duration"]) < 10_000, ( - "duration should be a sane time-on-air in ms" - ) + assert 0 < int(payload_dict["duration"]) < 10_000, "duration should be a sane time-on-air in ms" # Sanity: other key fields flowed through correctly. assert payload_dict["origin"] == "test-node" diff --git a/tests/test_packet_duration.py b/tests/test_packet_duration.py index c28481c..3c646c8 100644 --- a/tests/test_packet_duration.py +++ b/tests/test_packet_duration.py @@ -19,7 +19,7 @@ def _semtech_airtime_ms(payload_len: int, sf: int, bw_hz: int, cr: int, preamble crc = 1 h = 0 # explicit header de = 1 if (sf >= 11 and bw_hz <= 125000) else 0 - t_sym = (2 ** sf) / (bw_hz / 1000) + t_sym = (2**sf) / (bw_hz / 1000) t_preamble = (preamble + 4.25) * t_sym numerator = max(8 * payload_len - 4 * sf + 28 + 16 * crc - 20 * h, 0) denominator = 4 * (sf - 2 * de) diff --git a/tests/test_packet_router.py b/tests/test_packet_router.py index 1e62ca5..384b21b 100644 --- a/tests/test_packet_router.py +++ b/tests/test_packet_router.py @@ -38,11 +38,11 @@ _is_direct_final_hop, ) - # --------------------------------------------------------------------------- # Minimal daemon stub # --------------------------------------------------------------------------- + def _make_daemon(): """Minimal daemon that satisfies PacketRouter without touching hardware.""" daemon = MagicMock() @@ -84,8 +84,8 @@ def _make_bridge(): # Tests # --------------------------------------------------------------------------- -class TestInFlightCap(unittest.IsolatedAsyncioTestCase): +class TestInFlightCap(unittest.IsolatedAsyncioTestCase): # ── 1. Cap enforcement ────────────────────────────────────────────────── async def test_cap_drops_packets_when_full(self): @@ -100,7 +100,7 @@ async def test_cap_drops_packets_when_full(self): barrier = asyncio.Event() async def slow_route(pkt): - await barrier.wait() # blocks until we release + await barrier.wait() # blocks until we release routed = [] @@ -115,7 +115,7 @@ async def counting_route(pkt): # Fill the cap for _ in range(3): await router.enqueue(_make_packet()) - await asyncio.sleep(0.05) # let queue drain into tasks + await asyncio.sleep(0.05) # let queue drain into tasks self.assertEqual(router._in_flight, 3) # These should be dropped @@ -123,12 +123,10 @@ async def counting_route(pkt): await router.enqueue(_make_packet()) await asyncio.sleep(0.05) - self.assertEqual(router._in_flight, 3, - "In-flight count exceeded cap") - self.assertEqual(router._cap_drop_count, 5, - "Expected 5 cap-drops, got different count") + self.assertEqual(router._in_flight, 3, "In-flight count exceeded cap") + self.assertEqual(router._cap_drop_count, 5, "Expected 5 cap-drops, got different count") - barrier.set() # release blocked tasks + barrier.set() # release blocked tasks await router.stop() # ── 2. Drop counter ───────────────────────────────────────────────────── @@ -218,7 +216,7 @@ async def test_stop_waits_for_in_flight_tasks(self): async def slow_route(pkt): started.set() - await asyncio.sleep(0.2) # finishes well within 5 s timeout + await asyncio.sleep(0.2) # finishes well within 5 s timeout completed.append(pkt) router._route_packet = slow_route @@ -233,8 +231,7 @@ async def slow_route(pkt): await router.stop() # Task should have completed, not been cancelled - self.assertEqual(len(completed), 1, - "In-flight task was cancelled instead of drained") + self.assertEqual(len(completed), 1, "In-flight task was cancelled instead of drained") async def test_stop_cancels_tasks_that_exceed_timeout(self): """ @@ -250,16 +247,13 @@ async def test_stop_cancels_tasks_that_exceed_timeout(self): async def hanging_route(pkt): started.set() try: - await asyncio.sleep(999) # will not finish within 5 s + await asyncio.sleep(999) # will not finish within 5 s except asyncio.CancelledError: cancelled.append(pkt) raise router._route_packet = hanging_route - # Patch the timeout to 0.1 s so the test runs fast - original_stop = router.stop - async def fast_stop(): router.running = False if router.router_task: @@ -283,8 +277,7 @@ async def fast_stop(): await router.stop() - self.assertEqual(len(cancelled), 1, - "Hanging task was not cancelled on shutdown") + self.assertEqual(len(cancelled), 1, "Hanging task was not cancelled on shutdown") # ── 4. Route-tasks set stays in sync with counter ─────────────────────── @@ -296,7 +289,7 @@ async def test_route_tasks_set_cleaned_up_on_completion(self): router = PacketRouter(_make_daemon()) async def fast_route(pkt): - await asyncio.sleep(0) # yield, then done + await asyncio.sleep(0) # yield, then done router._route_packet = fast_route @@ -308,10 +301,10 @@ async def fast_route(pkt): # Give tasks time to complete await asyncio.sleep(0.1) - self.assertEqual(len(router._route_tasks), 0, - "_route_tasks not cleaned up after task completion") - self.assertEqual(router._in_flight, 0, - "_in_flight counter not back to 0 after completion") + self.assertEqual( + len(router._route_tasks), 0, "_route_tasks not cleaned up after task completion" + ) + self.assertEqual(router._in_flight, 0, "_in_flight counter not back to 0 after completion") await router.stop() @@ -339,8 +332,9 @@ async def blocking_route(pkt): await asyncio.sleep(0.05) self.assertEqual( - router._in_flight, len(router._route_tasks), - f"Counter ({router._in_flight}) != set size ({len(router._route_tasks)})" + router._in_flight, + len(router._route_tasks), + f"Counter ({router._in_flight}) != set size ({len(router._route_tasks)})", ) barrier.set() diff --git a/tests/test_path_hash_protocol.py b/tests/test_path_hash_protocol.py index 2f93778..12fc685 100644 --- a/tests/test_path_hash_protocol.py +++ b/tests/test_path_hash_protocol.py @@ -11,11 +11,12 @@ - PacketBuilder.create_trace payload structure + TraceHandler parsing - Max-hop boundary enforcement per hash size """ + import struct from unittest.mock import MagicMock, patch import pytest - +from pymc_core.node.handlers.trace import TraceHandler from pymc_core.protocol import Packet, PacketBuilder, PathUtils from pymc_core.protocol.constants import ( MAX_PATH_SIZE, @@ -25,8 +26,6 @@ ROUTE_TYPE_DIRECT, ROUTE_TYPE_FLOOD, ) -from pymc_core.node.handlers.trace import TraceHandler - # --------------------------------------------------------------------------- # Helpers @@ -35,8 +34,9 @@ LOCAL_HASH_BYTES = bytes([0xAB, 0xCD, 0xEF]) -def _make_flood_packet(path_bytes: bytes, hash_size: int, hash_count: int, - payload: bytes = b"\x01\x02\x03\x04") -> Packet: +def _make_flood_packet( + path_bytes: bytes, hash_size: int, hash_count: int, payload: bytes = b"\x01\x02\x03\x04" +) -> Packet: """Create a real flood Packet with the given multi-byte path encoding.""" pkt = Packet() pkt.header = ROUTE_TYPE_FLOOD @@ -47,8 +47,9 @@ def _make_flood_packet(path_bytes: bytes, hash_size: int, hash_count: int, return pkt -def _make_direct_packet(path_bytes: bytes, hash_size: int, hash_count: int, - payload: bytes = b"\x01\x02\x03\x04") -> Packet: +def _make_direct_packet( + path_bytes: bytes, hash_size: int, hash_count: int, payload: bytes = b"\x01\x02\x03\x04" +) -> Packet: """Create a real direct-routed Packet.""" pkt = Packet() pkt.header = ROUTE_TYPE_DIRECT @@ -87,8 +88,12 @@ def _make_handler(path_hash_mode=0, local_hash_bytes=None): } dispatcher = MagicMock() dispatcher.radio = MagicMock( - spreading_factor=8, bandwidth=125000, coding_rate=8, - preamble_length=17, frequency=915000000, tx_power=14, + spreading_factor=8, + bandwidth=125000, + coding_rate=8, + preamble_length=17, + frequency=915000000, + tx_power=14, ) dispatcher.local_identity = MagicMock() with ( @@ -96,6 +101,7 @@ def _make_handler(path_hash_mode=0, local_hash_bytes=None): patch("repeater.engine.RepeaterHandler._start_background_tasks"), ): from repeater.engine import RepeaterHandler + h = RepeaterHandler(config, dispatcher, lhb[0], local_hash_bytes=lhb) return h @@ -114,11 +120,20 @@ def test_encode_decode_hash_size(self, hash_size): assert PathUtils.get_path_hash_size(encoded) == hash_size assert PathUtils.get_path_hash_count(encoded) == 0 - @pytest.mark.parametrize("hash_size,count", [ - (1, 1), (1, 10), (1, 63), - (2, 1), (2, 15), (2, 32), - (3, 1), (3, 10), (3, 21), - ]) + @pytest.mark.parametrize( + "hash_size,count", + [ + (1, 1), + (1, 10), + (1, 63), + (2, 1), + (2, 15), + (2, 32), + (3, 1), + (3, 10), + (3, 21), + ], + ) def test_encode_decode_round_trip(self, hash_size, count): encoded = PathUtils.encode_path_len(hash_size, count) assert PathUtils.get_path_hash_size(encoded) == hash_size @@ -186,16 +201,16 @@ class TestPacketMultiBytePath: """Verify Packet write_to/read_from preserves multi-byte path encoding.""" def test_1_byte_path_round_trip(self): - pkt = _make_flood_packet(b"\xAA\xBB\xCC", hash_size=1, hash_count=3) + pkt = _make_flood_packet(b"\xaa\xbb\xcc", hash_size=1, hash_count=3) wire = pkt.write_to() pkt2 = Packet() pkt2.read_from(wire) assert pkt2.get_path_hash_size() == 1 assert pkt2.get_path_hash_count() == 3 - assert bytes(pkt2.path) == b"\xAA\xBB\xCC" + assert bytes(pkt2.path) == b"\xaa\xbb\xcc" def test_2_byte_path_round_trip(self): - path = b"\xAA\xBB\xCC\xDD" # 2 hops of 2 bytes + path = b"\xaa\xbb\xcc\xdd" # 2 hops of 2 bytes pkt = _make_flood_packet(path, hash_size=2, hash_count=2) wire = pkt.write_to() pkt2 = Packet() @@ -205,7 +220,7 @@ def test_2_byte_path_round_trip(self): assert bytes(pkt2.path) == path def test_3_byte_path_round_trip(self): - path = b"\xAA\xBB\xCC\xDD\xEE\xFF" # 2 hops of 3 bytes + path = b"\xaa\xbb\xcc\xdd\xee\xff" # 2 hops of 3 bytes pkt = _make_flood_packet(path, hash_size=3, hash_count=2) wire = pkt.write_to() pkt2 = Packet() @@ -225,7 +240,7 @@ def test_empty_path_2_byte_mode(self): def test_payload_preserved_after_multibyte_path(self): """Payload bytes after a multi-byte path are correctly sliced.""" - payload = b"\xDE\xAD\xBE\xEF" + payload = b"\xde\xad\xbe\xef" path = b"\x11\x22\x33\x44\x55\x66" pkt = _make_flood_packet(path, hash_size=3, hash_count=2, payload=payload) wire = pkt.write_to() @@ -248,28 +263,26 @@ class TestPacketGetPathHashes: """Verify Packet.get_path_hashes splits path into per-hop byte entries.""" def test_1_byte_hashes(self): - pkt = _make_flood_packet(b"\xAA\xBB\xCC", hash_size=1, hash_count=3) + pkt = _make_flood_packet(b"\xaa\xbb\xcc", hash_size=1, hash_count=3) hashes = pkt.get_path_hashes() - assert hashes == [b"\xAA", b"\xBB", b"\xCC"] + assert hashes == [b"\xaa", b"\xbb", b"\xcc"] def test_2_byte_hashes(self): - pkt = _make_flood_packet(b"\xAA\xBB\xCC\xDD", hash_size=2, hash_count=2) + pkt = _make_flood_packet(b"\xaa\xbb\xcc\xdd", hash_size=2, hash_count=2) hashes = pkt.get_path_hashes() - assert hashes == [b"\xAA\xBB", b"\xCC\xDD"] + assert hashes == [b"\xaa\xbb", b"\xcc\xdd"] def test_3_byte_hashes(self): - pkt = _make_flood_packet( - b"\xAA\xBB\xCC\xDD\xEE\xFF", hash_size=3, hash_count=2 - ) + pkt = _make_flood_packet(b"\xaa\xbb\xcc\xdd\xee\xff", hash_size=3, hash_count=2) hashes = pkt.get_path_hashes() - assert hashes == [b"\xAA\xBB\xCC", b"\xDD\xEE\xFF"] + assert hashes == [b"\xaa\xbb\xcc", b"\xdd\xee\xff"] def test_empty_path(self): pkt = _make_flood_packet(b"", hash_size=2, hash_count=0) assert pkt.get_path_hashes() == [] def test_hashes_hex_output(self): - pkt = _make_flood_packet(b"\x0A\x0B\x0C\x0D", hash_size=2, hash_count=2) + pkt = _make_flood_packet(b"\x0a\x0b\x0c\x0d", hash_size=2, hash_count=2) hex_hashes = pkt.get_path_hashes_hex() assert hex_hashes == ["0A0B", "0C0D"] @@ -289,7 +302,7 @@ def test_apply_mode_sets_hash_size(self, mode, expected_hash_size): def test_apply_mode_skips_nonzero_hop_count(self): """Mode should not be re-applied if path already has hops.""" - pkt = _make_flood_packet(b"\xAA\xBB", hash_size=2, hash_count=1) + pkt = _make_flood_packet(b"\xaa\xbb", hash_size=2, hash_count=1) original_path_len = pkt.path_len pkt.apply_path_hash_mode(0) # try to override to 1-byte assert pkt.path_len == original_path_len # unchanged @@ -314,7 +327,7 @@ class TestPacketSetPath: def test_set_path_with_encoded_len(self): pkt = Packet() pkt.header = ROUTE_TYPE_FLOOD - path = b"\xAA\xBB\xCC\xDD" + path = b"\xaa\xbb\xcc\xdd" encoded = PathUtils.encode_path_len(2, 2) pkt.set_path(path, path_len_encoded=encoded) assert pkt.get_path_hash_size() == 2 @@ -325,7 +338,7 @@ def test_set_path_without_encoded_defaults_1_byte(self): """Without explicit path_len_encoded, defaults to 1-byte hash_size.""" pkt = Packet() pkt.header = ROUTE_TYPE_FLOOD - pkt.set_path(b"\xAA\xBB\xCC") + pkt.set_path(b"\xaa\xbb\xcc") assert pkt.get_path_hash_size() == 1 assert pkt.get_path_hash_count() == 3 @@ -347,7 +360,7 @@ def test_1_byte_mode_appends_single_byte(self): assert result.get_path_hash_size() == 1 hashes = result.get_path_hashes() assert hashes[0] == b"\x11" - assert hashes[1] == b"\xAB" # first byte of local_hash_bytes + assert hashes[1] == b"\xab" # first byte of local_hash_bytes def test_2_byte_mode_appends_two_bytes(self): h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) @@ -358,7 +371,7 @@ def test_2_byte_mode_appends_two_bytes(self): assert result.get_path_hash_size() == 2 hashes = result.get_path_hashes() assert hashes[0] == b"\x11\x22" - assert hashes[1] == b"\xAB\xCD" + assert hashes[1] == b"\xab\xcd" def test_3_byte_mode_appends_three_bytes(self): h = _make_handler(path_hash_mode=2, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) @@ -369,7 +382,7 @@ def test_3_byte_mode_appends_three_bytes(self): assert result.get_path_hash_size() == 3 hashes = result.get_path_hashes() assert hashes[0] == b"\x11\x22\x33" - assert hashes[1] == b"\xAB\xCD\xEF" + assert hashes[1] == b"\xab\xcd\xef" def test_empty_path_gets_local_hash_appended(self): h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) @@ -378,7 +391,7 @@ def test_empty_path_gets_local_hash_appended(self): assert result is not None assert result.get_path_hash_count() == 1 hashes = result.get_path_hashes() - assert hashes[0] == b"\xAB\xCD" + assert hashes[0] == b"\xab\xcd" def test_path_len_re_encoded_after_forward(self): """After appending, path_len byte should encode (hash_size, count+1).""" @@ -400,7 +413,7 @@ def test_forwarded_packet_serializes_correctly(self): pkt2.read_from(wire) assert pkt2.get_path_hash_size() == 2 assert pkt2.get_path_hash_count() == 2 - assert pkt2.get_path_hashes() == [b"\x11\x22", b"\xAB\xCD"] + assert pkt2.get_path_hashes() == [b"\x11\x22", b"\xab\xcd"] def test_flood_rejects_at_max_hops_2_byte(self): """At 32 hops (2-byte mode), flood_forward should drop the packet.""" @@ -449,7 +462,7 @@ class TestDirectForwardMultiByte: def test_1_byte_match_strips_first_hop(self): h = _make_handler(path_hash_mode=0, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Path: [0xAB, 0x11] — first hop matches local_hash_bytes[0] - pkt = _make_direct_packet(b"\xAB\x11", hash_size=1, hash_count=2) + pkt = _make_direct_packet(b"\xab\x11", hash_size=1, hash_count=2) result = h.direct_forward(pkt) assert result is not None assert result.get_path_hash_count() == 1 @@ -459,7 +472,7 @@ def test_1_byte_match_strips_first_hop(self): def test_2_byte_match_strips_first_hop(self): h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Path: [0xAB,0xCD, 0x11,0x22] — first 2-byte hop matches local_hash_bytes[:2] - pkt = _make_direct_packet(b"\xAB\xCD\x11\x22", hash_size=2, hash_count=2) + pkt = _make_direct_packet(b"\xab\xcd\x11\x22", hash_size=2, hash_count=2) result = h.direct_forward(pkt) assert result is not None assert result.get_path_hash_count() == 1 @@ -468,9 +481,7 @@ def test_2_byte_match_strips_first_hop(self): def test_3_byte_match_strips_first_hop(self): h = _make_handler(path_hash_mode=2, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_direct_packet( - b"\xAB\xCD\xEF\x11\x22\x33", hash_size=3, hash_count=2 - ) + pkt = _make_direct_packet(b"\xab\xcd\xef\x11\x22\x33", hash_size=3, hash_count=2) result = h.direct_forward(pkt) assert result is not None assert result.get_path_hash_count() == 1 @@ -480,14 +491,14 @@ def test_3_byte_match_strips_first_hop(self): def test_2_byte_mismatch_rejects(self): h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) # Path: [0xFF,0xEE, ...] — first 2-byte hop doesn't match - pkt = _make_direct_packet(b"\xFF\xEE\x11\x22", hash_size=2, hash_count=2) + pkt = _make_direct_packet(b"\xff\xee\x11\x22", hash_size=2, hash_count=2) result = h.direct_forward(pkt) assert result is None assert "not for us" in (pkt.drop_reason or "") def test_path_len_re_encoded_after_strip(self): h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_direct_packet(b"\xAB\xCD\x11\x22\x33\x44", hash_size=2, hash_count=3) + pkt = _make_direct_packet(b"\xab\xcd\x11\x22\x33\x44", hash_size=2, hash_count=3) result = h.direct_forward(pkt) assert result is not None expected_path_len = PathUtils.encode_path_len(2, 2) @@ -496,7 +507,7 @@ def test_path_len_re_encoded_after_strip(self): def test_last_hop_strips_to_empty(self): """When only one hop remains and it matches, path becomes empty.""" h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_direct_packet(b"\xAB\xCD", hash_size=2, hash_count=1) + pkt = _make_direct_packet(b"\xab\xcd", hash_size=2, hash_count=1) result = h.direct_forward(pkt) assert result is not None assert result.get_path_hash_count() == 0 @@ -506,7 +517,7 @@ def test_forwarded_direct_serializes_correctly(self): """After stripping, the packet should serialize/deserialize cleanly.""" h = _make_handler(path_hash_mode=2, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) pkt = _make_direct_packet( - b"\xAB\xCD\xEF\x11\x22\x33\x44\x55\x66", hash_size=3, hash_count=3 + b"\xab\xcd\xef\x11\x22\x33\x44\x55\x66", hash_size=3, hash_count=3 ) result = h.direct_forward(pkt) assert result is not None @@ -527,7 +538,7 @@ def test_no_path_rejects(self): def test_path_too_short_for_hash_size(self): """If path has fewer bytes than hash_size, reject.""" h = _make_handler(path_hash_mode=1, local_hash_bytes=bytes([0xAB, 0xCD, 0xEF])) - pkt = _make_direct_packet(b"\xAB", hash_size=2, hash_count=1) + pkt = _make_direct_packet(b"\xab", hash_size=2, hash_count=1) # path has 1 byte but hash_size is 2 result = h.direct_forward(pkt) assert result is None @@ -546,7 +557,6 @@ def test_flood_chain_builds_path_then_direct_consumes(self): Simulate: node_A floods → repeater_1 forwards → repeater_2 forwards Then the return direct packet strips hops in reverse order. """ - node_a_hash = bytes([0x11, 0x22, 0x33]) rep1_hash = bytes([0xAA, 0xBB, 0xCC]) rep2_hash = bytes([0xDD, 0xEE, 0xFF]) @@ -576,8 +586,7 @@ def test_flood_chain_builds_path_then_direct_consumes(self): # The path should be [rep1, rep2] — direct packet addressed to rep1 first # (Direct packets strip from the front) direct_pkt = _make_direct_packet( - bytes(pkt_rx.path), hash_size=2, hash_count=2, - payload=b"\xFE\xED" + bytes(pkt_rx.path), hash_size=2, hash_count=2, payload=b"\xfe\xed" ) # repeater_1 strips its hop @@ -616,9 +625,7 @@ def test_create_trace_basic(self): def test_create_trace_with_path_bytes(self): """Trace path goes into payload, not routing path.""" path_bytes = [0xAA, 0xBB, 0xCC, 0xDD] - pkt = PacketBuilder.create_trace( - tag=1, auth_code=2, flags=0, path=path_bytes - ) + pkt = PacketBuilder.create_trace(tag=1, auth_code=2, flags=0, path=path_bytes) payload = pkt.get_payload() assert len(payload) == 9 + 4 # Routing path stays empty @@ -699,9 +706,7 @@ def test_parse_from_real_packet(self): """Create a trace with PacketBuilder, serialize, deserialize, then parse.""" th = self._make_trace_handler() trace_path = [0x11, 0x22, 0x33, 0x44, 0x55, 0x66] - pkt = PacketBuilder.create_trace( - tag=100, auth_code=200, flags=0, path=trace_path - ) + pkt = PacketBuilder.create_trace(tag=100, auth_code=200, flags=0, path=trace_path) wire = pkt.write_to() pkt2 = Packet() pkt2.read_from(wire) @@ -725,10 +730,10 @@ class TestTraceHelperMultibyte: """TraceHelper._should_forward_trace with 2-byte TRACE payload hashes.""" def test_should_forward_when_next_hop_matches_pubkey_prefix(self): - from repeater.handler_helpers.trace import TraceHelper - from pymc_core.protocol import LocalIdentity + from repeater.handler_helpers.trace import TraceHelper + identity = LocalIdentity() pub = bytes(identity.get_public_key()) rh = MagicMock() @@ -747,10 +752,10 @@ def test_should_forward_when_next_hop_matches_pubkey_prefix(self): assert th._should_forward_trace(pkt, trace_bytes, flags, hash_width) def test_should_not_forward_when_next_hop_mismatch(self): - from repeater.handler_helpers.trace import TraceHelper - from pymc_core.protocol import LocalIdentity + from repeater.handler_helpers.trace import TraceHelper + identity = LocalIdentity() pub = bytes(identity.get_public_key()) rh = MagicMock() @@ -782,22 +787,18 @@ def test_2_byte_mode_wire_format(self): ROUTE_TYPE_FLOOD (no transport codes): [header(1)] [path_len(1)] [path(N)] [payload(M)] """ - pkt = _make_flood_packet( - b"\xAA\xBB\xCC\xDD", hash_size=2, hash_count=2, - payload=b"\xFE" - ) + pkt = _make_flood_packet(b"\xaa\xbb\xcc\xdd", hash_size=2, hash_count=2, payload=b"\xfe") wire = pkt.write_to() assert wire[0] == ROUTE_TYPE_FLOOD # header path_len = wire[1] assert PathUtils.get_path_hash_size(path_len) == 2 assert PathUtils.get_path_hash_count(path_len) == 2 - assert wire[2:6] == b"\xAA\xBB\xCC\xDD" # path bytes - assert wire[6:] == b"\xFE" # payload + assert wire[2:6] == b"\xaa\xbb\xcc\xdd" # path bytes + assert wire[6:] == b"\xfe" # payload def test_3_byte_mode_wire_format(self): pkt = _make_flood_packet( - b"\x11\x22\x33\x44\x55\x66", hash_size=3, hash_count=2, - payload=b"\xAA" + b"\x11\x22\x33\x44\x55\x66", hash_size=3, hash_count=2, payload=b"\xaa" ) wire = pkt.write_to() assert wire[0] == ROUTE_TYPE_FLOOD @@ -805,11 +806,11 @@ def test_3_byte_mode_wire_format(self): assert PathUtils.get_path_hash_size(path_len) == 3 assert PathUtils.get_path_hash_count(path_len) == 2 assert wire[2:8] == b"\x11\x22\x33\x44\x55\x66" - assert wire[8:] == b"\xAA" + assert wire[8:] == b"\xaa" def test_1_byte_mode_backward_compat_wire(self): """1-byte mode: path_len byte on wire == hop count (legacy format).""" - pkt = _make_flood_packet(b"\xAA\xBB", hash_size=1, hash_count=2) + pkt = _make_flood_packet(b"\xaa\xbb", hash_size=1, hash_count=2) wire = pkt.write_to() assert wire[1] == 2 # path_len == hop_count for 1-byte mode @@ -817,10 +818,10 @@ def test_read_from_2_byte_wire(self): """Manually construct wire bytes and verify read_from parses correctly.""" # header=ROUTE_TYPE_FLOOD, path_len=encode(2, 2), path=4 bytes, payload=2 bytes path_len = PathUtils.encode_path_len(2, 2) - wire = bytes([ROUTE_TYPE_FLOOD, path_len]) + b"\xAA\xBB\xCC\xDD" + b"\xFE\xED" + wire = bytes([ROUTE_TYPE_FLOOD, path_len]) + b"\xaa\xbb\xcc\xdd" + b"\xfe\xed" pkt = Packet() pkt.read_from(wire) assert pkt.get_path_hash_size() == 2 assert pkt.get_path_hash_count() == 2 - assert pkt.get_path_hashes() == [b"\xAA\xBB", b"\xCC\xDD"] - assert pkt.get_payload() == b"\xFE\xED" + assert pkt.get_path_hashes() == [b"\xaa\xbb", b"\xcc\xdd"] + assert pkt.get_payload() == b"\xfe\xed" diff --git a/tests/test_radio_config.py b/tests/test_radio_config.py index a1495e4..dab8d30 100644 --- a/tests/test_radio_config.py +++ b/tests/test_radio_config.py @@ -187,4 +187,4 @@ def test_get_radio_for_board_pymc_usb_requires_port(monkeypatch): } with pytest.raises(ValueError, match="Missing 'port'"): - get_radio_for_board(board_config) \ No newline at end of file + get_radio_for_board(board_config) diff --git a/tests/test_sensors.py b/tests/test_sensors.py index fc04b23..3936fd0 100644 --- a/tests/test_sensors.py +++ b/tests/test_sensors.py @@ -64,7 +64,9 @@ def _load_hardware_stats_sensor_module(monkeypatch): setattr(fake_hardware_stats, "HardwareStatsCollector", MagicMock) monkeypatch.setitem(sys.modules, "repeater.data_acquisition", fake_data_acquisition) - monkeypatch.setitem(sys.modules, "repeater.data_acquisition.hardware_stats", fake_hardware_stats) + monkeypatch.setitem( + sys.modules, "repeater.data_acquisition.hardware_stats", fake_hardware_stats + ) module_name = "repeater.sensors._hardware_stats_test" module_path = Path(__file__).resolve().parents[1] / "repeater" / "sensors" / "hardware_stats.py" @@ -289,7 +291,20 @@ def read_i2c_block_data(self, addr, register, length): values = { waveshare_ups_e_module._REG_STATUS: [waveshare_ups_e_module._FLAG_CHARGING], waveshare_ups_e_module._REG_VBUS: [0xA0, 0x0F, 0x2C, 0x01, 0x58, 0x1B], - waveshare_ups_e_module._REG_BATT: [0x80, 0x3E, 0xFA, 0x00, 0x4E, 0x00, 0x98, 0x08, 0x2D, 0x00, 0x5A, 0x00], + waveshare_ups_e_module._REG_BATT: [ + 0x80, + 0x3E, + 0xFA, + 0x00, + 0x4E, + 0x00, + 0x98, + 0x08, + 0x2D, + 0x00, + 0x5A, + 0x00, + ], waveshare_ups_e_module._REG_CELLS: [0x80, 0x0C, 0x6C, 0x0C, 0x1C, 0x0C, 0x76, 0x0C], } return values[register] diff --git a/tests/test_service_utils.py b/tests/test_service_utils.py index 379957f..d0ed3d9 100644 --- a/tests/test_service_utils.py +++ b/tests/test_service_utils.py @@ -63,8 +63,12 @@ def test_is_container_detection_paths(monkeypatch): (b"abc", "1:name=systemd:/", False, False), ], ) -def test_is_container_proc_and_host_paths(monkeypatch, environ_bytes, cgroup_text, host_path, expected): - monkeypatch.setattr(su.os.path, "exists", lambda p: p == "/run/host/container-manager" and host_path) +def test_is_container_proc_and_host_paths( + monkeypatch, environ_bytes, cgroup_text, host_path, expected +): + monkeypatch.setattr( + su.os.path, "exists", lambda p: p == "/run/host/container-manager" and host_path + ) monkeypatch.delenv("container", raising=False) def _open(path, mode="r", encoding=None): diff --git a/tests/test_sqlite_handler_easy.py b/tests/test_sqlite_handler_easy.py index 9d2510c..210b35c 100644 --- a/tests/test_sqlite_handler_easy.py +++ b/tests/test_sqlite_handler_easy.py @@ -176,9 +176,7 @@ def test_verify_api_token_last_used_throttle(tmp_path, monkeypatch): now = {"v": 1000.0} - monkeypatch.setattr( - "repeater.data_acquisition.sqlite_handler.time.time", lambda: now["v"] - ) + monkeypatch.setattr("repeater.data_acquisition.sqlite_handler.time.time", lambda: now["v"]) token_id = h.create_api_token("svc-throttle", "hash-throttle") assert token_id > 0 diff --git a/tests/test_storage_collector_ws_stats_throttle.py b/tests/test_storage_collector_ws_stats_throttle.py index d6e4814..fb3f7f0 100644 --- a/tests/test_storage_collector_ws_stats_throttle.py +++ b/tests/test_storage_collector_ws_stats_throttle.py @@ -3,6 +3,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch +from repeater.data_acquisition.storage_collector import StorageCollector + sys.modules.setdefault("psutil", types.ModuleType("psutil")) nacl_module = types.ModuleType("nacl") @@ -19,8 +21,6 @@ class _SigningKeyStub: sys.modules.setdefault("nacl", nacl_module) sys.modules.setdefault("nacl.signing", nacl_signing_module) -from repeater.data_acquisition.storage_collector import StorageCollector - def _make_collector() -> StorageCollector: with ( diff --git a/tests/test_tx_lock.py b/tests/test_tx_lock.py index 0ca9193..2484a44 100644 --- a/tests/test_tx_lock.py +++ b/tests/test_tx_lock.py @@ -24,6 +24,7 @@ # Minimal handler factory # --------------------------------------------------------------------------- + def _make_handler(): """ Return a RepeaterHandler instance with all external I/O mocked. @@ -49,11 +50,9 @@ def _make_handler(): h = RepeaterHandler.__new__(RepeaterHandler) h.config = { - "repeater": {"mode": "forward", "cache_ttl": 3600, - "send_advert_interval_hours": 0}, + "repeater": {"mode": "forward", "cache_ttl": 3600, "send_advert_interval_hours": 0}, "delays": {"tx_delay_factor": 1.0, "direct_tx_delay_factor": 0.5}, - "duty_cycle": {"enforcement_enabled": True, - "max_airtime_per_minute": 3600}, + "duty_cycle": {"enforcement_enabled": True, "max_airtime_per_minute": 3600}, "storage": {}, "mesh": {}, } @@ -80,8 +79,8 @@ def _make_packet(size: int = 50) -> MagicMock: # Tests # --------------------------------------------------------------------------- -class TestTxLockSerialisation(unittest.IsolatedAsyncioTestCase): +class TestTxLockSerialisation(unittest.IsolatedAsyncioTestCase): # ── Test 1: no interleaving ───────────────────────────────────────────── async def test_concurrent_sends_do_not_interleave(self): @@ -100,7 +99,7 @@ async def send_with_overlap_check(*args, **kwargs): if in_flight[0]: overlap_detected[0] = True in_flight[0] = True - await asyncio.sleep(0.05) # simulate ~50ms radio TX + await asyncio.sleep(0.05) # simulate ~50ms radio TX in_flight[0] = False h.dispatcher.send_packet.side_effect = send_with_overlap_check @@ -115,8 +114,9 @@ async def send_with_overlap_check(*args, **kwargs): "send_packet was entered while another call was already in-flight " "— _tx_lock is not serialising correctly", ) - self.assertEqual(h.dispatcher.send_packet.call_count, 2, - "Expected exactly 2 send_packet calls") + self.assertEqual( + h.dispatcher.send_packet.call_count, 2, "Expected exactly 2 send_packet calls" + ) # ── Test 2: TOCTOU duty-cycle fix ────────────────────────────────────── @@ -150,7 +150,8 @@ def can_tx(ms): await asyncio.gather(t1, t2, return_exceptions=True) self.assertEqual( - h.dispatcher.send_packet.call_count, 1, + h.dispatcher.send_packet.call_count, + 1, "Both packets were sent — duty-cycle TOCTOU race was NOT fixed", ) @@ -193,10 +194,8 @@ async def tracked_send(*args, **kwargs): ) await asyncio.gather(t_local, t_other, return_exceptions=True) - self.assertIn(id(pkt_other), send_times, - "pkt_other was never sent") - self.assertIn(id(pkt_local), send_times, - "pkt_local retry was never sent") + self.assertIn(id(pkt_other), send_times, "pkt_other was never sent") + self.assertIn(id(pkt_local), send_times, "pkt_local retry was never sent") # pkt_other fires at ~0.1s; pkt_local retry fires at ~1.0s. # If the lock were held during backoff, pkt_other would block until ~1.0s @@ -218,8 +217,7 @@ async def test_non_local_failure_propagates(self): h.dispatcher.send_packet.side_effect = RuntimeError("radio error") - task = await h.schedule_retransmit(pkt, delay=0.0, airtime_ms=0, - local_transmission=False) + task = await h.schedule_retransmit(pkt, delay=0.0, airtime_ms=0, local_transmission=False) with self.assertRaises(RuntimeError): await task @@ -259,8 +257,9 @@ async def failing_then_gone(*args, **kwargs): ) await task # should complete without error (gate returns silently) - self.assertEqual(send_calls[0], 1, - "send_packet called on retry despite duty-cycle rejection") + self.assertEqual( + send_calls[0], 1, "send_packet called on retry despite duty-cycle rejection" + ) if __name__ == "__main__": diff --git a/tests/test_update_endpoints_unit.py b/tests/test_update_endpoints_unit.py index 168e8e8..3426934 100644 --- a/tests/test_update_endpoints_unit.py +++ b/tests/test_update_endpoints_unit.py @@ -54,7 +54,7 @@ def read(self): return b"ok" monkeypatch.setattr(ue.urllib.request, "urlopen", lambda *args, **kwargs: _Resp()) - assert ue._fetch_url("https://example.com") == "ok" + assert ue._fetch_url("https://api.github.com/test") == "ok" reset = int((datetime.now(timezone.utc) + timedelta(minutes=5)).timestamp()) hdrs = {"X-RateLimit-Reset": str(reset)} @@ -242,7 +242,9 @@ def test_channels_set_channel_and_changelog(cherrypy_ctx, isolated_state, monkey assert ok["channel"] == "dev" request.method = "GET" - monkeypatch.setattr(ue, "_fetch_changelog", lambda channel, installed, max_commits: [{"title": "t"}]) + monkeypatch.setattr( + ue, "_fetch_changelog", lambda channel, installed, max_commits: [{"title": "t"}] + ) c = api.changelog(channel="dev", max="5") assert c["success"] is True assert c["commits"][0]["title"] == "t" @@ -376,9 +378,7 @@ def __init__(self, cmd): self.cmd = cmd self.stdout = [] self.returncode = ( - 1 - if any(isinstance(x, str) and "git+https://github.com" in x for x in cmd) - else 0 + 1 if any(isinstance(x, str) and "git+https://github.com" in x for x in cmd) else 0 ) def wait(self): @@ -402,7 +402,9 @@ def test_do_install_wrapper_success_then_restart_failure(isolated_state, monkeyp monkeypatch.setattr(ue, "_cleanup_stale_dist_info", lambda *args, **kwargs: None) monkeypatch.setattr(ue.time, "sleep", lambda _s: None) monkeypatch.setattr(ue.os.path, "isfile", lambda p: p == "/usr/local/bin/pymc-do-upgrade") - monkeypatch.setattr("repeater.service_utils.restart_service", lambda: (False, "systemctl failed")) + monkeypatch.setattr( + "repeater.service_utils.restart_service", lambda: (False, "systemctl failed") + ) class _Proc: def __init__(self, cmd): From 0c334839472940dcba32a10d9d9c0b5f495e412b Mon Sep 17 00:00:00 2001 From: Lloyd Date: Wed, 27 May 2026 20:16:23 +0100 Subject: [PATCH 5/8] refactor: clean up import statements and whitespace in local_cli, base, and update_endpoints modules --- repeater/local_cli.py | 8 ++++---- repeater/sensors/base.py | 1 - repeater/web/update_endpoints.py | 4 +++- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/repeater/local_cli.py b/repeater/local_cli.py index 9ee79d7..59225a6 100644 --- a/repeater/local_cli.py +++ b/repeater/local_cli.py @@ -8,7 +8,6 @@ from typing import Optional from urllib.parse import urlparse - CONFIG_PATHS = [ "/etc/pymc_repeater/config.yaml", "config.yaml", @@ -25,9 +24,10 @@ def _validate_http_url(url: str) -> None: def _load_config(config_path=None): """Load repeater config.yaml, trying common paths.""" - import yaml from pathlib import Path + import yaml + paths = [config_path] if config_path else CONFIG_PATHS for p in paths: path = Path(p) @@ -41,9 +41,9 @@ def run_client_cli(host: str = "127.0.0.1", port: int = 8000, password: Optional """ Standalone CLI client that connects to a running repeater's HTTP API. """ - import urllib.request - import urllib.error import json + import urllib.error + import urllib.request base_url = f"http://{host}:{port}" diff --git a/repeater/sensors/base.py b/repeater/sensors/base.py index b14748d..faa9010 100644 --- a/repeater/sensors/base.py +++ b/repeater/sensors/base.py @@ -11,7 +11,6 @@ from datetime import datetime, timezone from typing import Any, Dict, Iterable, Optional, Tuple - _PIP_PACKAGE_RE = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]*$") diff --git a/repeater/web/update_endpoints.py b/repeater/web/update_endpoints.py index 7ad5ee1..9810e97 100644 --- a/repeater/web/update_endpoints.py +++ b/repeater/web/update_endpoints.py @@ -32,6 +32,7 @@ from urllib.parse import urlparse import cherrypy + from repeater.service_utils import get_container_restart_message, is_buildroot, is_container logger = logging.getLogger("HTTPServer") @@ -192,9 +193,10 @@ def _cache_and_return(value: str) -> str: # If the running process is already on a higher version than anything found # on disk, the dist-info dirs are stale leftovers and __version__ is truth. try: - from repeater import __version__ as _running from packaging.version import Version + from repeater import __version__ as _running + if Version(_running) > Version(disk_version): # status() polls can call this frequently; throttle mismatch logs. global _disk_version_mismatch_logged From 5f25d3bd26e3e3b91197200ca29634cd0766f95f Mon Sep 17 00:00:00 2001 From: Rightup Date: Wed, 27 May 2026 21:18:11 +0100 Subject: [PATCH 6/8] refactor: update bandit arguments and change pytest entry to use a script --- .pre-commit-config.yaml | 5 ++--- scripts/precommit-pytest.sh | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) create mode 100755 scripts/precommit-pytest.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a1bd84..1297d36 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,8 +37,7 @@ repos: rev: 1.7.9 hooks: - id: bandit - # B104: intentional LAN listeners, B105: setup-required placeholder credentials. - args: ["-q", "-l", "-i", "-s", "B104,B105"] + args: ["-q", "-l", "-i"] files: ^.*\.py$ exclude: ^tests/ @@ -47,7 +46,7 @@ repos: hooks: - id: pytest name: pytest - entry: python -m pytest -q + entry: ./scripts/precommit-pytest.sh language: system pass_filenames: false always_run: true diff --git a/scripts/precommit-pytest.sh b/scripts/precommit-pytest.sh new file mode 100755 index 0000000..888dfe9 --- /dev/null +++ b/scripts/precommit-pytest.sh @@ -0,0 +1,18 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Prefer the currently activated venv, then repo-local .venv, then python3/python. +if [ -n "${VIRTUAL_ENV:-}" ] && [ -x "${VIRTUAL_ENV}/bin/python" ]; then + PYTHON_BIN="${VIRTUAL_ENV}/bin/python" +elif [ -x ".venv/bin/python" ]; then + PYTHON_BIN=".venv/bin/python" +elif command -v python3 >/dev/null 2>&1; then + PYTHON_BIN="python3" +elif command -v python >/dev/null 2>&1; then + PYTHON_BIN="python" +else + echo "No Python interpreter found for pytest hook." >&2 + exit 1 +fi + +exec "${PYTHON_BIN}" -m pytest -q \ No newline at end of file From a5355f188dc26e8a772860e40ef4c9ad893283e2 Mon Sep 17 00:00:00 2001 From: Rightup Date: Wed, 27 May 2026 21:23:32 +0100 Subject: [PATCH 7/8] feat: add PR checks workflow for pre-commit validation --- .github/workflows/pr-checks.yml | 37 +++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 .github/workflows/pr-checks.yml diff --git a/.github/workflows/pr-checks.yml b/.github/workflows/pr-checks.yml new file mode 100644 index 0000000..5bb3cb0 --- /dev/null +++ b/.github/workflows/pr-checks.yml @@ -0,0 +1,37 @@ +name: PR Checks + +on: + pull_request: + workflow_dispatch: + +concurrency: + group: pr-checks-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + pre-commit: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: | + pyproject.toml + .pre-commit-config.yaml + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install pre-commit + python -m pip install -e .[dev] + + - name: Run pre-commit + run: pre-commit run --all-files \ No newline at end of file From 60ca184dbd1d1f36b13de3427cd597c9c5ccc249 Mon Sep 17 00:00:00 2001 From: Rightup Date: Wed, 27 May 2026 22:07:34 +0100 Subject: [PATCH 8/8] refactor: enhance security comments and error handling across multiple modules --- repeater/companion/frame_server.py | 2 +- repeater/data_acquisition/glass_handler.py | 2 +- repeater/data_acquisition/sqlite_handler.py | 9 ++++++--- repeater/handler_helpers/acl.py | 7 +++++-- repeater/main.py | 6 +++--- repeater/service_utils.py | 7 ++++--- repeater/web/api_endpoints.py | 2 +- repeater/web/auth/cherrypy_tool.py | 3 +-- repeater/web/companion_ws_proxy.py | 4 ++-- repeater/web/http_server.py | 2 +- tests/test_auth_components.py | 9 +++++---- 11 files changed, 30 insertions(+), 23 deletions(-) diff --git a/repeater/companion/frame_server.py b/repeater/companion/frame_server.py index 530d520..0d55eb1 100644 --- a/repeater/companion/frame_server.py +++ b/repeater/companion/frame_server.py @@ -32,7 +32,7 @@ def __init__( bridge, companion_hash: str, port: int = 5000, - bind_address: str = "0.0.0.0", + bind_address: str = "0.0.0.0", # nosec B104 - intentional default for LAN reachability client_idle_timeout_sec: Optional[int] = 8 * 60 * 60, # 8 hours sqlite_handler=None, local_hash: Optional[int] = None, diff --git a/repeater/data_acquisition/glass_handler.py b/repeater/data_acquisition/glass_handler.py index 4a4cffa..e80ebd6 100644 --- a/repeater/data_acquisition/glass_handler.py +++ b/repeater/data_acquisition/glass_handler.py @@ -45,7 +45,7 @@ def __init__(self, config: dict, daemon_instance=None, config_manager=None): self.base_url = "http://localhost:8080" self.request_timeout_seconds = 10 self.verify_tls = True - self.api_token = "" + self.api_token = "" # nosec - runtime config value, not a hardcoded credential self.inform_interval_seconds = 30 self.cert_store_dir = "/etc/pymc_repeater/glass" self._cert_expires_at: Optional[str] = None diff --git a/repeater/data_acquisition/sqlite_handler.py b/repeater/data_acquisition/sqlite_handler.py index a41e0da..fc15edc 100644 --- a/repeater/data_acquisition/sqlite_handler.py +++ b/repeater/data_acquisition/sqlite_handler.py @@ -1715,11 +1715,14 @@ def generate_transport_key(self, name: str, key_length_bytes: int = 16) -> str: except Exception as e: logger.error(f"Failed to generate transport key using get_auto_key_for: {e}") - # Fallback to a transport-compatible random 16-byte key if derivation fails. + # Fallback to a transport-compatible random key if derivation fails. try: - random_bytes = secrets.token_bytes(16) + fallback_length = max(1, int(key_length_bytes)) + random_bytes = secrets.token_bytes(fallback_length) key = base64.b64encode(random_bytes).decode("utf-8") - logger.warning(f"Using fallback random key generation for '{name}'") + logger.warning( + f"Using fallback random key generation for '{name}' with {fallback_length} bytes" + ) return key except Exception as fallback_e: logger.error(f"Fallback key generation also failed: {fallback_e}") diff --git a/repeater/handler_helpers/acl.py b/repeater/handler_helpers/acl.py index 0d5d659..e371422 100644 --- a/repeater/handler_helpers/acl.py +++ b/repeater/handler_helpers/acl.py @@ -43,8 +43,8 @@ def __init__( allow_read_only: bool = True, ): self.max_clients = max_clients - self.admin_password = admin_password - self.guest_password = guest_password + self.admin_password = admin_password or "" + self.guest_password = guest_password or "" self.allow_read_only = allow_read_only self.clients: Dict[bytes, ClientInfo] = {} @@ -93,6 +93,9 @@ def authenticate_client( f"guest: {'SET' if guest_pwd else 'NONE'}" ) + admin_pwd = admin_pwd or "" + guest_pwd = guest_pwd or "" + if target_identity_name: logger.debug( f"Authenticating for identity '{target_identity_name}' (room_server={is_room_server})" diff --git a/repeater/main.py b/repeater/main.py index 157a69f..8646a1b 100644 --- a/repeater/main.py +++ b/repeater/main.py @@ -539,7 +539,7 @@ async def _load_companion_identities(self) -> None: node_name = settings.get("node_name", name) tcp_port = settings.get("tcp_port", 5000) - bind_address = settings.get("bind_address", "0.0.0.0") + bind_address = settings.get("bind_address", "0.0.0.0") # nosec B104 tcp_timeout_raw = settings.get("tcp_timeout", 8 * 60 * 60) # 8 hours client_idle_timeout_sec = None if tcp_timeout_raw == 0 else int(tcp_timeout_raw) @@ -721,7 +721,7 @@ async def add_companion_from_config(self, comp_config: dict) -> None: node_name = settings.get("node_name", name) tcp_port = settings.get("tcp_port", 5000) - bind_address = settings.get("bind_address", "0.0.0.0") + bind_address = settings.get("bind_address", "0.0.0.0") # nosec B104 tcp_timeout_raw = settings.get("tcp_timeout", 120) client_idle_timeout_sec = None if tcp_timeout_raw == 0 else int(tcp_timeout_raw) @@ -1291,7 +1291,7 @@ async def run(self): # Start HTTP stats server http_port = self.config.get("http", {}).get("port", 8000) - http_host = self.config.get("http", {}).get("host", "0.0.0.0") + http_host = self.config.get("http", {}).get("host", "0.0.0.0") # nosec B104 node_name = self.config.get("repeater", {}).get("node_name", "Repeater") diff --git a/repeater/service_utils.py b/repeater/service_utils.py index e6c14ef..2419701 100644 --- a/repeater/service_utils.py +++ b/repeater/service_utils.py @@ -5,6 +5,7 @@ import logging import os +import shutil import subprocess # nosec B404 import threading import time @@ -14,9 +15,9 @@ INIT_SCRIPT = "/etc/init.d/S80pymc-repeater" BUILDROOT_METADATA_PATH = "/etc/pymc-image-build-id" _CONTAINER_RESTART_DELAY_SECONDS = 1.0 -_SH_BIN = "/bin/sh" -_SYSTEMCTL_BIN = "/bin/systemctl" -_SUDO_BIN = "/usr/bin/sudo" +_SH_BIN = shutil.which("sh") or "sh" +_SYSTEMCTL_BIN = shutil.which("systemctl") or "systemctl" +_SUDO_BIN = shutil.which("sudo") or "sudo" def is_buildroot() -> bool: diff --git a/repeater/web/api_endpoints.py b/repeater/web/api_endpoints.py index a88ab7c..09442a0 100644 --- a/repeater/web/api_endpoints.py +++ b/repeater/web/api_endpoints.py @@ -3522,7 +3522,7 @@ def create_identity(self): comp_settings = { "node_name": settings.get("node_name") or name, "tcp_port": settings.get("tcp_port", 5000), - "bind_address": settings.get("bind_address", "0.0.0.0"), + "bind_address": settings.get("bind_address", "0.0.0.0"), # nosec B104 } if "tcp_timeout" in settings: comp_settings["tcp_timeout"] = settings["tcp_timeout"] diff --git a/repeater/web/auth/cherrypy_tool.py b/repeater/web/auth/cherrypy_tool.py index 124eb57..1e18acc 100644 --- a/repeater/web/auth/cherrypy_tool.py +++ b/repeater/web/auth/cherrypy_tool.py @@ -28,8 +28,7 @@ def check_auth(): if not jwt_handler or not token_manager: logger.error("Auth handlers not initialized in cherrypy.config") - cherrypy.response.status = 500 - return {"success": False, "error": "Authentication system not configured"} + raise cherrypy.HTTPError(500, "Authentication system not configured") # Check for JWT token in Authorization header first auth_header = cherrypy.request.headers.get("Authorization", "") diff --git a/repeater/web/companion_ws_proxy.py b/repeater/web/companion_ws_proxy.py index 29d8f89..e12218e 100644 --- a/repeater/web/companion_ws_proxy.py +++ b/repeater/web/companion_ws_proxy.py @@ -172,9 +172,9 @@ def _resolve_tcp_endpoint(self, companion_name): if entry.get("name") == companion_name: settings = entry.get("settings") or {} port = settings.get("tcp_port", 5000) - bind = settings.get("bind_address", "0.0.0.0") + bind = settings.get("bind_address", "0.0.0.0") # nosec B104 # 0.0.0.0 = all interfaces — connect via loopback - host = "127.0.0.1" if bind == "0.0.0.0" else bind + host = "127.0.0.1" if bind == "0.0.0.0" else bind # nosec B104 logger.debug(f"_resolve_tcp_endpoint: '{companion_name}' → {host}:{port}") return (host, port) diff --git a/repeater/web/http_server.py b/repeater/web/http_server.py index ea2e9b6..0b3ab57 100644 --- a/repeater/web/http_server.py +++ b/repeater/web/http_server.py @@ -177,7 +177,7 @@ def default(self, *args, **kwargs): class HTTPStatsServer: def __init__( self, - host: str = "0.0.0.0", + host: str = "0.0.0.0", # nosec B104 - intentional default for service exposure port: int = 8000, stats_getter: Optional[Callable] = None, node_name: str = "Repeater", diff --git a/tests/test_auth_components.py b/tests/test_auth_components.py index c3c3ad4..2762637 100644 --- a/tests/test_auth_components.py +++ b/tests/test_auth_components.py @@ -73,11 +73,12 @@ def test_check_auth_skips_options_and_login(monkeypatch): assert check_auth() is None -def test_check_auth_missing_handlers_returns_500_json(monkeypatch): +def test_check_auth_missing_handlers_raises_http_500(monkeypatch): _set_cp(monkeypatch, cfg={}) - out = check_auth() - assert out["success"] is False - assert cherrypy.response.status == 500 + with pytest.raises(cherrypy.HTTPError) as exc_info: + check_auth() + + assert exc_info.value.status == 500 def test_check_auth_accepts_bearer_token(monkeypatch):